user.go 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. package ovpm
  2. import (
  3. "fmt"
  4. "time"
  5. "github.com/Sirupsen/logrus"
  6. "github.com/asaskevich/govalidator"
  7. "github.com/jinzhu/gorm"
  8. )
  9. // User represents the interface that is being used within the public api.
  10. type User interface {
  11. GetUsername() string
  12. GetServerSerialNumber() string
  13. GetCert() string
  14. }
  15. // DBUser is database model for VPN users.
  16. type DBUser struct {
  17. gorm.Model
  18. ServerID uint
  19. Server DBServer
  20. Username string `gorm:"unique_index"`
  21. Cert string
  22. ServerSerialNumber string
  23. Password string
  24. Key string
  25. }
  26. // DBRevoked is a database model for revoked VPN users.
  27. type DBRevoked struct {
  28. gorm.Model
  29. SerialNumber string
  30. }
  31. func (u *DBUser) setPassword(newPassword string) error {
  32. // TODO(cad): Use a proper password hashing algorithm here.
  33. u.Password = newPassword
  34. return nil
  35. }
  36. // GetUser finds and returns the user with the given username from database.
  37. func GetUser(username string) (*DBUser, error) {
  38. user := DBUser{}
  39. db.Where(&DBUser{Username: username}).First(&user)
  40. if db.NewRecord(&user) {
  41. // user is not found
  42. return nil, fmt.Errorf("user not found: %s", username)
  43. }
  44. return &user, nil
  45. }
  46. // GetAllUsers returns all recorded users in the database.
  47. func GetAllUsers() ([]*DBUser, error) {
  48. var users []*DBUser
  49. db.Find(&users)
  50. return users, nil
  51. }
  52. // CreateNewUser creates a new user with the given username and password in the database.
  53. // It also generates the necessary client keys and signs certificates with the current
  54. // server's CA.
  55. func CreateNewUser(username, password string) (*DBUser, error) {
  56. if !CheckBootstrapped() {
  57. return nil, fmt.Errorf("you first need to create server")
  58. }
  59. // Validate user input.
  60. if govalidator.IsNull(username) {
  61. return nil, fmt.Errorf("validation error: %s can not be null", username)
  62. }
  63. if !govalidator.IsAlphanumeric(username) {
  64. return nil, fmt.Errorf("validation error: `%s` can only contain letters and numbers", username)
  65. }
  66. ca, err := GetSystemCA()
  67. if err != nil {
  68. return nil, err
  69. }
  70. clientCert, err := NewClientCertHolder(username, ca)
  71. if err != nil {
  72. return nil, fmt.Errorf("can not create client cert %s: %v", username, err)
  73. }
  74. server, err := GetServerInstance()
  75. if err != nil {
  76. return nil, fmt.Errorf("can not get server: %v", err)
  77. }
  78. user := DBUser{
  79. Username: username,
  80. Password: password,
  81. Cert: clientCert.Cert,
  82. Key: clientCert.Key,
  83. ServerSerialNumber: server.SerialNumber,
  84. }
  85. db.Create(&user)
  86. if db.NewRecord(&user) {
  87. // user is still not created
  88. return nil, fmt.Errorf("can not create user in database: %s", user.Username)
  89. }
  90. logrus.Infof("user created: %s", username)
  91. // Emit server config
  92. err = Emit()
  93. if err != nil {
  94. return nil, err
  95. }
  96. return &user, nil
  97. }
  98. // Delete deletes a user by the given username from the database.
  99. func (u *DBUser) Delete() error {
  100. if db.NewRecord(&u) {
  101. // user is not found
  102. return fmt.Errorf("user is not initialized: %s", u.Username)
  103. }
  104. crt, err := readCertFromPEM(u.Cert)
  105. if err != nil {
  106. return fmt.Errorf("can not get user's certificate: %v", err)
  107. }
  108. db.Create(&DBRevoked{
  109. SerialNumber: crt.SerialNumber.Text(16),
  110. })
  111. db.Unscoped().Delete(&u)
  112. err = Emit()
  113. if err != nil {
  114. return err
  115. }
  116. u = nil // delete the existing user struct
  117. return nil
  118. }
  119. // ResetPassword resets the users password into the provided password.
  120. func (u *DBUser) ResetPassword(newPassword string) error {
  121. err := u.setPassword(newPassword)
  122. if err != nil {
  123. // user password can not be updated
  124. return fmt.Errorf("user password can not be updated %s: %v", u.Username, err)
  125. }
  126. db.Save(u)
  127. return nil
  128. }
  129. // Sign creates a key and a ceritificate signed by the current server's CA.
  130. //
  131. // This is often used to sign users when the current CA is changed while there are
  132. // still existing users in the database.
  133. func (u *DBUser) Sign() error {
  134. if !CheckBootstrapped() {
  135. return fmt.Errorf("you first need to create server")
  136. }
  137. ca, err := GetSystemCA()
  138. if err != nil {
  139. return err
  140. }
  141. clientCert, err := NewClientCertHolder(u.Username, ca)
  142. if err != nil {
  143. return fmt.Errorf("can not create client cert %s: %v", u.Username, err)
  144. }
  145. server, err := GetServerInstance()
  146. if err != nil {
  147. return err
  148. }
  149. u.Cert = clientCert.Cert
  150. u.Key = clientCert.Key
  151. u.ServerSerialNumber = server.SerialNumber
  152. db.Save(&u)
  153. return nil
  154. }
  155. // GetUsername returns user's username.
  156. func (u *DBUser) GetUsername() string {
  157. return u.Username
  158. }
  159. // GetCert returns user's public certificate.
  160. func (u *DBUser) GetCert() string {
  161. return u.Cert
  162. }
  163. // GetServerSerialNumber returns user's server serial number.
  164. func (u *DBUser) GetServerSerialNumber() string {
  165. return u.ServerSerialNumber
  166. }
  167. // GetCreatedAt returns user's creation time.
  168. func (u *DBUser) GetCreatedAt() string {
  169. return u.CreatedAt.Format(time.UnixDate)
  170. }