Преглед изворни кода

refactor: major refactor of all service

- Improve API by improving the wording of the methods.
- Utilize composition to provide abstraction.
- Group together related methods under the same struct.
Mustafa Arici пре 8 година
родитељ
комит
55acfb6b25
9 измењених фајлова са 336 додато и 230 уклоњено
  1. 9 10
      api/rpc.go
  2. 21 11
      db.go
  3. 5 5
      db_test.go
  4. 47 34
      net.go
  5. 31 24
      net_test.go
  6. 65 55
      user.go
  7. 27 35
      user_test.go
  8. 103 28
      vpn.go
  9. 28 28
      vpn_test.go

+ 9 - 10
api/rpc.go

@@ -2,7 +2,6 @@ package api
 
 import (
 	"os"
-	"time"
 
 	"google.golang.org/grpc"
 
@@ -178,16 +177,16 @@ func (s *VPNService) Status(ctx context.Context, req *pb.VPNStatusRequest) (*pb.
 	}
 
 	response := pb.VPNStatusResponse{
-		Name:         server.Name,
-		SerialNumber: server.SerialNumber,
-		Hostname:     server.Hostname,
-		Port:         server.Port,
+		Name:         server.GetServerName(),
+		SerialNumber: server.GetSerialNumber(),
+		Hostname:     server.GetHostname(),
+		Port:         server.GetPort(),
 		Proto:        server.GetProto(),
-		Cert:         server.Cert,
-		CACert:       server.CACert,
-		Net:          server.Net,
-		Mask:         server.Mask,
-		CreatedAt:    server.CreatedAt.Format(time.UnixDate),
+		Cert:         server.GetCert(),
+		CACert:       server.GetCACert(),
+		Net:          server.GetNet(),
+		Mask:         server.GetMask(),
+		CreatedAt:    server.GetCreatedAt(),
 	}
 	return &response, nil
 }

+ 21 - 11
db.go

@@ -8,30 +8,40 @@ import (
 	_ "github.com/jinzhu/gorm/dialects/sqlite"
 )
 
-var db *gorm.DB
+var db *DB
 
-// SetupDB prepares database for use.
+// DB represents a persistent storage.
+type DB struct {
+	*gorm.DB
+}
+
+// CreateDB prepares and returns new storage.
 //
 // It should be run at the start of the program.
-func SetupDB(dialect string, args ...interface{}) {
+func CreateDB(dialect string, args ...interface{}) *DB {
 	if len(args) > 0 && args[0] == "" {
 		args[0] = _DefaultDBPath
 	}
 	var err error
-	db, err = gorm.Open(dialect, args...)
+
+	dbase, err := gorm.Open(dialect, args...)
 	if err != nil {
 		logrus.Fatalf("couldn't open sqlite database %v: %v", args, err)
 	}
 
-	db.AutoMigrate(&DBUser{})
-	db.AutoMigrate(&DBServer{})
-	db.AutoMigrate(&DBRevoked{})
-	db.AutoMigrate(&DBNetwork{})
+	dbase.AutoMigrate(&dbUserModel{})
+	dbase.AutoMigrate(&dbServerModel{})
+	dbase.AutoMigrate(&dbRevokedModel{})
+	dbase.AutoMigrate(&dbNetworkModel{})
+
+	dbPTR := &DB{DB: dbase}
+	db = dbPTR
+	return dbPTR
 }
 
-// CeaseDB closes the database.
+// Cease closes the database.
 //
 // It should be run at the exit of the program.
-func CeaseDB() {
-	db.Close()
+func (db *DB) Cease() {
+	db.DB.Close()
 }

+ 5 - 5
db_test.go

@@ -10,7 +10,7 @@ func TestDBSetup(t *testing.T) {
 	// Test:
 
 	// Create database.
-	SetupDB("sqlite3", ":memory:")
+	CreateDB("sqlite3", ":memory:")
 
 	// Is database created?
 	if db == nil {
@@ -23,15 +23,15 @@ func TestDBCease(t *testing.T) {
 	Testing = true
 
 	// Prepare:
-	SetupDB("sqlite3", ":memory:")
-	user := DBUser{Username: "testUser"}
+	CreateDB("sqlite3", ":memory:")
+	user := dbUserModel{Username: "testUser"}
 	db.Save(&user)
 
 	// Test:
 	// Close database.
-	CeaseDB()
+	db.Cease()
 
-	var users []DBUser
+	var users []dbUserModel
 	db.Find(&users)
 
 	// Is length zero?

+ 47 - 34
net.go

@@ -63,6 +63,7 @@ func (nt NetworkType) String() string {
 	return "UNDEFINEDNET"
 }
 
+// Description gives description about the network type.
 func (nt NetworkType) Description() string {
 	for _, v := range networkTypes {
 		if v.Type == nt {
@@ -72,21 +73,26 @@ func (nt NetworkType) Description() string {
 	return "UNDEFINEDNET"
 }
 
-// DBNetwork is database model for external networks on the VPN server.
-type DBNetwork struct {
+// dbNetworkModel is database model for external networks on the VPN server.
+type dbNetworkModel struct {
 	gorm.Model
 	ServerID uint
-	Server   DBServer
+	Server   dbServerModel
 
 	Name  string `gorm:"unique_index"`
 	CIDR  string
 	Type  NetworkType
 	Via   string
-	Users []*DBUser `gorm:"many2many:network_users;"`
+	Users []*dbUserModel `gorm:"many2many:network_users;"`
+}
+
+// Network represents a VPN related network.
+type Network struct {
+	dbNetworkModel
 }
 
 // GetNetwork returns a network specified by its name.
-func GetNetwork(name string) (*DBNetwork, error) {
+func GetNetwork(name string) (*Network, error) {
 	if !IsInitialized() {
 		return nil, fmt.Errorf("you first need to create server")
 	}
@@ -98,26 +104,29 @@ func GetNetwork(name string) (*DBNetwork, error) {
 		return nil, fmt.Errorf("validation error: `%s` can only contain letters and numbers", name)
 	}
 
-	var network DBNetwork
-	db.Preload("Users").Where(&DBNetwork{Name: name}).First(&network)
+	var network dbNetworkModel
+	db.Preload("Users").Where(&dbNetworkModel{Name: name}).First(&network)
 
 	if db.NewRecord(&network) {
 		return nil, fmt.Errorf("network not found %s", name)
 	}
 
-	return &network, nil
+	return &Network{dbNetworkModel: network}, nil
 }
 
 // GetAllNetworks returns all networks defined in the system.
-func GetAllNetworks() []*DBNetwork {
-	var networks []*DBNetwork
-	db.Preload("Users").Find(&networks)
-
+func GetAllNetworks() []*Network {
+	var networks []*Network
+	var dbNetworks []*dbNetworkModel
+	db.Preload("Users").Find(&dbNetworks)
+	for _, n := range dbNetworks {
+		networks = append(networks, &Network{dbNetworkModel: *n})
+	}
 	return networks
 }
 
 // CreateNewNetwork creates a new network definition in the system.
-func CreateNewNetwork(name, cidr string, nettype NetworkType, via string) (*DBNetwork, error) {
+func CreateNewNetwork(name, cidr string, nettype NetworkType, via string) (*Network, error) {
 	if !IsInitialized() {
 		return nil, fmt.Errorf("you first need to create server")
 	}
@@ -158,11 +167,11 @@ func CreateNewNetwork(name, cidr string, nettype NetworkType, via string) (*DBNe
 		via = ""
 	}
 
-	network := DBNetwork{
+	network := dbNetworkModel{
 		Name:  name,
 		CIDR:  ipnet.String(),
 		Type:  nettype,
-		Users: []*DBUser{},
+		Users: []*dbUserModel{},
 		Via:   via,
 	}
 	db.Save(&network)
@@ -172,24 +181,24 @@ func CreateNewNetwork(name, cidr string, nettype NetworkType, via string) (*DBNe
 	}
 	Emit()
 	logrus.Infof("network defined: %s (%s)", network.Name, network.CIDR)
-	return &network, nil
+	return &Network{dbNetworkModel: network}, nil
 
 }
 
 // Delete deletes a network definition in the system.
-func (n *DBNetwork) Delete() error {
+func (n *Network) Delete() error {
 	if !IsInitialized() {
 		return fmt.Errorf("you first need to create server")
 	}
 
-	db.Unscoped().Delete(n)
+	db.Unscoped().Delete(n.dbNetworkModel)
 	Emit()
 	logrus.Infof("network deleted: %s", n.Name)
 	return nil
 }
 
 // Associate allows the given user access to this network.
-func (n *DBNetwork) Associate(username string) error {
+func (n *Network) Associate(username string) error {
 	if !IsInitialized() {
 		return fmt.Errorf("you first need to create server")
 	}
@@ -198,8 +207,8 @@ func (n *DBNetwork) Associate(username string) error {
 		return fmt.Errorf("user can not be fetched: %v", err)
 	}
 
-	var users []DBUser
-	userAssoc := db.Model(&n).Association("Users")
+	var users []dbUserModel
+	userAssoc := db.Model(&n.dbNetworkModel).Association("Users")
 	userAssoc.Find(&users)
 	var found bool
 	for _, u := range users {
@@ -212,7 +221,7 @@ func (n *DBNetwork) Associate(username string) error {
 		return fmt.Errorf("user %s is already associated with the network %s", user.Username, n.Name)
 	}
 
-	userAssoc.Append(user)
+	userAssoc.Append(user.dbUserModel)
 	if userAssoc.Error != nil {
 		return fmt.Errorf("association failed: %v", userAssoc.Error)
 	}
@@ -222,7 +231,7 @@ func (n *DBNetwork) Associate(username string) error {
 }
 
 // Dissociate breaks up the given users association to the said network.
-func (n *DBNetwork) Dissociate(username string) error {
+func (n *Network) Dissociate(username string) error {
 	if !IsInitialized() {
 		return fmt.Errorf("you first need to create server")
 	}
@@ -232,8 +241,8 @@ func (n *DBNetwork) Dissociate(username string) error {
 		return fmt.Errorf("user can not be fetched: %v", err)
 	}
 
-	var users []DBUser
-	userAssoc := db.Model(&n).Association("Users")
+	var users []dbUserModel
+	userAssoc := db.Model(&n.dbNetworkModel).Association("Users")
 	userAssoc.Find(&users)
 	var found bool
 	for _, u := range users {
@@ -246,7 +255,7 @@ func (n *DBNetwork) Dissociate(username string) error {
 		return fmt.Errorf("user %s is already not associated with the network %s", user.Username, n.Name)
 	}
 
-	userAssoc.Delete(user)
+	userAssoc.Delete(user.dbUserModel)
 	if userAssoc.Error != nil {
 		return fmt.Errorf("disassociation failed: %v", userAssoc.Error)
 	}
@@ -256,32 +265,36 @@ func (n *DBNetwork) Dissociate(username string) error {
 }
 
 // GetName returns network's name.
-func (n *DBNetwork) GetName() string {
+func (n *Network) GetName() string {
 	return n.Name
 }
 
 // GetCIDR returns network's CIDR.
-func (n *DBNetwork) GetCIDR() string {
+func (n *Network) GetCIDR() string {
 	return n.CIDR
 }
 
 // GetCreatedAt returns network's name.
-func (n *DBNetwork) GetCreatedAt() string {
+func (n *Network) GetCreatedAt() string {
 	return n.CreatedAt.Format(time.UnixDate)
 }
 
 // GetType returns network's network type.
-func (n *DBNetwork) GetType() NetworkType {
+func (n *Network) GetType() NetworkType {
 	return NetworkType(n.Type)
 }
 
 // GetAssociatedUsers returns network's associated users.
-func (n *DBNetwork) GetAssociatedUsers() []*DBUser {
-	return n.Users
+func (n *Network) GetAssociatedUsers() []*User {
+	var users []*User
+	for _, u := range n.Users {
+		users = append(users, &User{dbUserModel: *u})
+	}
+	return users
 }
 
 // GetAssociatedUsernames returns network's associated user names.
-func (n *DBNetwork) GetAssociatedUsernames() []string {
+func (n *Network) GetAssociatedUsernames() []string {
 	var usernames []string
 
 	for _, user := range n.GetAssociatedUsers() {
@@ -291,7 +304,7 @@ func (n *DBNetwork) GetAssociatedUsernames() []string {
 }
 
 // GetVia returns network' via.
-func (n *DBNetwork) GetVia() string {
+func (n *Network) GetVia() string {
 	return n.Via
 }
 

+ 31 - 24
net_test.go

@@ -7,8 +7,8 @@ import (
 func TestVPNCreateNewNetwork(t *testing.T) {
 	// Initialize:
 	setupTestCase()
-	SetupDB("sqlite3", ":memory:")
-	defer CeaseDB()
+	CreateDB("sqlite3", ":memory:")
+	defer db.Cease()
 	Init("localhost", "", UDPProto, "")
 
 	// Prepare:
@@ -30,7 +30,7 @@ func TestVPNCreateNewNetwork(t *testing.T) {
 		t.Fatalf("network CIDR is expected to be '%s' but it's '%s' instead", cidrStr, n.CIDR)
 	}
 
-	var network DBNetwork
+	var network dbNetworkModel
 	db.First(&network)
 
 	if db.NewRecord(&network) {
@@ -54,8 +54,8 @@ func TestVPNCreateNewNetwork(t *testing.T) {
 func TestVPNDeleteNetwork(t *testing.T) {
 	// Initialize:
 	setupTestCase()
-	SetupDB("sqlite3", ":memory:")
-	defer CeaseDB()
+	CreateDB("sqlite3", ":memory:")
+	defer db.Cease()
 	Init("localhost", "", UDPProto, "")
 
 	// Prepare:
@@ -69,7 +69,7 @@ func TestVPNDeleteNetwork(t *testing.T) {
 		t.Fatalf("unexpected error when creating a new network: %v", err)
 	}
 
-	var network DBNetwork
+	var network dbNetworkModel
 	db.First(&network)
 
 	if db.NewRecord(&network) {
@@ -82,7 +82,7 @@ func TestVPNDeleteNetwork(t *testing.T) {
 	}
 
 	// Empty the existing network object.
-	network = DBNetwork{}
+	network = dbNetworkModel{}
 	db.First(&network)
 	if !db.NewRecord(&network) {
 		t.Fatalf("network is not deleted from the database. %+v", network)
@@ -92,8 +92,8 @@ func TestVPNDeleteNetwork(t *testing.T) {
 func TestVPNGetNetwork(t *testing.T) {
 	// Initialize:
 	setupTestCase()
-	SetupDB("sqlite3", ":memory:")
-	defer CeaseDB()
+	CreateDB("sqlite3", ":memory:")
+	defer db.Cease()
 	Init("localhost", "", UDPProto, "")
 
 	// Prepare:
@@ -107,7 +107,7 @@ func TestVPNGetNetwork(t *testing.T) {
 		t.Fatalf("unexpected error when creating a new network: %v", err)
 	}
 
-	var network DBNetwork
+	var network dbNetworkModel
 	db.First(&network)
 
 	if db.NewRecord(&network) {
@@ -119,7 +119,7 @@ func TestVPNGetNetwork(t *testing.T) {
 		t.Fatalf("unexpected error: %v", err)
 	}
 
-	if db.NewRecord(&n) {
+	if db.NewRecord(&n.dbNetworkModel) {
 		t.Fatalf("network is not correctly returned from db.")
 	}
 }
@@ -127,8 +127,8 @@ func TestVPNGetNetwork(t *testing.T) {
 func TestVPNGetAllNetworks(t *testing.T) {
 	// Initialize:
 	setupTestCase()
-	SetupDB("sqlite3", ":memory:")
-	defer CeaseDB()
+	CreateDB("sqlite3", ":memory:")
+	defer db.Cease()
 	Init("localhost", "", UDPProto, "")
 
 	// Prepare:
@@ -173,8 +173,8 @@ func TestVPNGetAllNetworks(t *testing.T) {
 func TestNetAssociate(t *testing.T) {
 	// Initialize:
 	setupTestCase()
-	SetupDB("sqlite3", ":memory:")
-	defer CeaseDB()
+	CreateDB("sqlite3", ":memory:")
+	defer db.Cease()
 	Init("localhost", "", UDPProto, "")
 
 	// Prepare:
@@ -188,20 +188,27 @@ func TestNetAssociate(t *testing.T) {
 		t.Fatal(err)
 	}
 
-	n, _ := CreateNewNetwork(netName, cidrStr, netType, "")
-	err = n.Associate(user.Username)
+	n, err := CreateNewNetwork(netName, cidrStr, netType, "")
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	err = n.Associate(user.dbUserModel.Username)
 	if err != nil {
 		t.Fatal(err)
 	}
 	n = nil
 
-	n, _ = GetNetwork(netName)
+	n, err = GetNetwork(netName)
+	if err != nil {
+		t.Fatal(err)
+	}
 
 	// Does number of associated users in the network object matches the number that we have created?
-	if count := len(n.Users); count != 1 {
+	if count := len(n.dbNetworkModel.Users); count != 1 {
 		t.Fatalf("network.Users count is expexted to be %d, but it's %d", 1, count)
 	}
-	err = n.Associate(user.Username)
+	err = n.Associate(user.dbUserModel.Username)
 	if err == nil {
 		t.Fatalf("expected to get error but got no error instead")
 	}
@@ -211,8 +218,8 @@ func TestNetAssociate(t *testing.T) {
 func TestNetDissociate(t *testing.T) {
 	// Initialize:
 	setupTestCase()
-	SetupDB("sqlite3", ":memory:")
-	defer CeaseDB()
+	CreateDB("sqlite3", ":memory:")
+	defer db.Cease()
 	err := Init("localhost", "", UDPProto, "")
 	if err != nil {
 		t.Fatal(err)
@@ -264,8 +271,8 @@ func TestNetDissociate(t *testing.T) {
 func TestNetGetAssociatedUsers(t *testing.T) {
 	// Initialize:
 	setupTestCase()
-	SetupDB("sqlite3", ":memory:")
-	defer CeaseDB()
+	CreateDB("sqlite3", ":memory:")
+	defer db.Cease()
 	Init("localhost", "", UDPProto, "")
 
 	// Prepare:

+ 65 - 55
user.go

@@ -13,22 +13,17 @@ import (
 	"github.com/jinzhu/gorm"
 )
 
-// User represents the interface that is being used within the public api.
-type User interface {
-	GetUsername() string
-	GetServerSerialNumber() string
-	GetCert() string
-	GetIPNet() string
-	IsNoGW() bool
-	GetHostID() uint32
-	IsAdmin() bool
+// dbRevokedModel is a database model for revoked VPN users.
+type dbRevokedModel struct {
+	gorm.Model
+	SerialNumber string
 }
 
-// DBUser is database model for VPN users.
-type DBUser struct {
+// dbUserModel is database model for VPN users.
+type dbUserModel struct {
 	gorm.Model
 	ServerID uint
-	Server   DBServer
+	Server   dbServerModel
 
 	Username           string `gorm:"unique_index"`
 	Cert               string // not user writable
@@ -40,13 +35,12 @@ type DBUser struct {
 	Admin              bool
 }
 
-// DBRevoked is a database model for revoked VPN users.
-type DBRevoked struct {
-	gorm.Model
-	SerialNumber string
+// User represents a vpn user.
+type User struct {
+	dbUserModel
 }
 
-func (u *DBUser) setPassword(password string) error {
+func (u *dbUserModel) setPassword(password string) error {
 	hashedPassword, err := passlib.Hash(password)
 	if err != nil {
 		return fmt.Errorf("can not set password: %v", err)
@@ -57,7 +51,7 @@ func (u *DBUser) setPassword(password string) error {
 }
 
 // CheckPassword returns whether the given password is correct for the user.
-func (u *DBUser) CheckPassword(password string) bool {
+func (u *User) CheckPassword(password string) bool {
 	_, err := passlib.Verify(password, u.Hash)
 	if err != nil {
 		logrus.Error(err)
@@ -67,21 +61,24 @@ func (u *DBUser) CheckPassword(password string) bool {
 }
 
 // GetUser finds and returns the user with the given username from database.
-func GetUser(username string) (*DBUser, error) {
-	user := DBUser{}
-	db.Where(&DBUser{Username: username}).First(&user)
+func GetUser(username string) (*User, error) {
+	user := dbUserModel{}
+	db.Where(&dbUserModel{Username: username}).First(&user)
 	if db.NewRecord(&user) {
 		// user is not found
 		return nil, fmt.Errorf("user not found: %s", username)
 	}
-	return &user, nil
+	return &User{dbUserModel: user}, nil
 }
 
 // GetAllUsers returns all recorded users in the database.
-func GetAllUsers() ([]*DBUser, error) {
-	var users []*DBUser
-	db.Find(&users)
-
+func GetAllUsers() ([]*User, error) {
+	var users []*User
+	var dbUsers []*dbUserModel
+	db.Find(&dbUsers)
+	for _, u := range dbUsers {
+		users = append(users, &User{dbUserModel: *u})
+	}
 	return users, nil
 
 }
@@ -91,7 +88,7 @@ func GetAllUsers() ([]*DBUser, error) {
 //
 // It also generates the necessary client keys and signs certificates with the current
 // server's CA.
-func CreateNewUser(username, password string, nogw bool, hostid uint32, admin bool) (*DBUser, error) {
+func CreateNewUser(username, password string, nogw bool, hostid uint32, admin bool) (*User, error) {
 	if !IsInitialized() {
 		return nil, fmt.Errorf("you first need to create server")
 	}
@@ -132,7 +129,7 @@ func CreateNewUser(username, password string, nogw bool, hostid uint32, admin bo
 			return nil, fmt.Errorf("ip %s is already allocated", ip)
 		}
 	}
-	user := DBUser{
+	user := dbUserModel{
 		Username:           username,
 		Cert:               clientCert.Cert,
 		Key:                clientCert.Key,
@@ -155,13 +152,13 @@ func CreateNewUser(username, password string, nogw bool, hostid uint32, admin bo
 	if err != nil {
 		return nil, err
 	}
-	return &user, nil
+	return &User{dbUserModel: user}, nil
 }
 
 // Update updates the user's attributes and writes them to the database.
 //
 // How this method works is similiar to PUT semantics of REST. It sets the user record fields to the provided function arguments.
-func (u *DBUser) Update(password string, nogw bool, hostid uint32, admin bool) error {
+func (u *User) Update(password string, nogw bool, hostid uint32, admin bool) error {
 	if !IsInitialized() {
 		return fmt.Errorf("you first need to create server")
 	}
@@ -195,7 +192,7 @@ func (u *DBUser) Update(password string, nogw bool, hostid uint32, admin bool) e
 			return fmt.Errorf("ip %s is already allocated", ip)
 		}
 	}
-	db.Save(u)
+	db.Save(u.dbUserModel)
 
 	err := Emit()
 	if err != nil {
@@ -205,8 +202,8 @@ func (u *DBUser) Update(password string, nogw bool, hostid uint32, admin bool) e
 }
 
 // Delete deletes a user by the given username from the database.
-func (u *DBUser) Delete() error {
-	if db.NewRecord(u) {
+func (u *User) Delete() error {
+	if db.NewRecord(u.dbUserModel) {
 		// user is not found
 		return fmt.Errorf("user is not initialized: %s", u.Username)
 	}
@@ -214,10 +211,10 @@ func (u *DBUser) Delete() error {
 	if err != nil {
 		return fmt.Errorf("can not get user's certificate: %v", err)
 	}
-	db.Create(&DBRevoked{
+	db.Create(&dbRevokedModel{
 		SerialNumber: crt.SerialNumber.Text(16),
 	})
-	db.Unscoped().Delete(u)
+	db.Unscoped().Delete(u.dbUserModel)
 	logrus.Infof("user deleted: %s", u.GetUsername())
 	err = Emit()
 	if err != nil {
@@ -228,13 +225,13 @@ func (u *DBUser) Delete() error {
 }
 
 // ResetPassword resets the users password into the provided password.
-func (u *DBUser) ResetPassword(password string) error {
-	err := u.setPassword(password)
+func (u *User) ResetPassword(password string) error {
+	err := u.dbUserModel.setPassword(password)
 	if err != nil {
 		// user password can not be updated
 		return fmt.Errorf("user password can not be updated %s: %v", u.Username, err)
 	}
-	db.Save(u)
+	db.Save(u.dbUserModel)
 	err = Emit()
 	if err != nil {
 		return err
@@ -250,7 +247,7 @@ func (u *DBUser) ResetPassword(password string) error {
 // still  existing users in the database.
 //
 // Also it can be used when a user cert is expired or user's private key stolen, missing etc.
-func (u *DBUser) Renew() error {
+func (u *User) Renew() error {
 	if !IsInitialized() {
 		return fmt.Errorf("you first need to create server")
 	}
@@ -273,7 +270,7 @@ func (u *DBUser) Renew() error {
 	u.Key = clientCert.Key
 	u.ServerSerialNumber = server.SerialNumber
 
-	db.Save(u)
+	db.Save(u.dbUserModel)
 	err = Emit()
 	if err != nil {
 		return err
@@ -284,27 +281,27 @@ func (u *DBUser) Renew() error {
 }
 
 // GetUsername returns user's username.
-func (u *DBUser) GetUsername() string {
+func (u *User) GetUsername() string {
 	return u.Username
 }
 
 // GetCert returns user's public certificate.
-func (u *DBUser) GetCert() string {
+func (u *User) GetCert() string {
 	return u.Cert
 }
 
 // GetServerSerialNumber returns user's server serial number.
-func (u *DBUser) GetServerSerialNumber() string {
+func (u *User) GetServerSerialNumber() string {
 	return u.ServerSerialNumber
 }
 
 // GetCreatedAt returns user's creation time.
-func (u *DBUser) GetCreatedAt() string {
+func (u *User) GetCreatedAt() string {
 	return u.CreatedAt.Format(time.UnixDate)
 }
 
 // getIP returns user's vpn ip addr.
-func (u *DBUser) getIP() net.IP {
+func (u *User) getIP() net.IP {
 	users := getNonStaticHostUsers()
 	staticHostIDs := getStaticHostIDs()
 	server, err := GetServerInstance()
@@ -342,7 +339,7 @@ func (u *DBUser) getIP() net.IP {
 }
 
 // GetIPNet returns user's vpn ip network. (e.g. 192.168.0.1/24)
-func (u *DBUser) GetIPNet() string {
+func (u *User) GetIPNet() string {
 	server, err := GetServerInstance()
 	if err != nil {
 		logrus.Panicf("can not get user ipnet: %v", err)
@@ -357,29 +354,42 @@ func (u *DBUser) GetIPNet() string {
 }
 
 // IsNoGW returns whether user is set to get the vpn server as their default gateway.
-func (u *DBUser) IsNoGW() bool {
+func (u *User) IsNoGW() bool {
 	return u.NoGW
 }
 
 // GetHostID returns user's Host ID.
-func (u *DBUser) GetHostID() uint32 {
+func (u *User) GetHostID() uint32 {
 	return u.HostID
 }
 
 // IsAdmin returns whether user is admin or not.
-func (u *DBUser) IsAdmin() bool {
+func (u *User) IsAdmin() bool {
 	return u.Admin
 }
 
-func getStaticHostUsers() []*DBUser {
-	var users []*DBUser
-	db.Unscoped().Not(DBUser{HostID: 0}).Find(&users)
+func (u *User) getKey() string {
+	return u.Key
+}
+
+func getStaticHostUsers() []*User {
+	var users []*User
+	var dbUsers []*dbUserModel
+	db.Unscoped().Not(dbUserModel{HostID: 0}).Find(&dbUsers)
+	for _, u := range dbUsers {
+		users = append(users, &User{dbUserModel: *u})
+	}
 	return users
 }
 
-func getNonStaticHostUsers() []*DBUser {
-	var users []*DBUser
-	db.Unscoped().Where(DBUser{HostID: 0}).Find(&users)
+func getNonStaticHostUsers() []*User {
+	var users []*User
+	var dbUsers []*dbUserModel
+	db.Unscoped().Where(dbUserModel{HostID: 0}).Find(&dbUsers)
+	for _, u := range dbUsers {
+		users = append(users, &User{dbUserModel: *u})
+	}
+
 	return users
 }
 

+ 27 - 35
user_test.go

@@ -11,8 +11,8 @@ import (
 
 func TestCreateNewUser(t *testing.T) {
 	// Initialize:
-	ovpm.SetupDB("sqlite3", ":memory:")
-	defer ovpm.CeaseDB()
+	db := ovpm.CreateDB("sqlite3", ":memory:")
+	defer db.Cease()
 	ovpm.Init("localhost", "", ovpm.UDPProto, "")
 	server, _ := ovpm.GetServerInstance()
 
@@ -32,11 +32,6 @@ func TestCreateNewUser(t *testing.T) {
 		t.Fatalf("user is expected to be 'NOT nil' but it is 'nil' %+v", user)
 	}
 
-	// Is user empty?
-	if *user == (ovpm.DBUser{}) {
-		t.Fatalf("user is expected to be 'NOT EMPTY' but it is 'EMPTY' %+v", user)
-	}
-
 	// Is user acutally exist in the system?
 	user2, err := ovpm.GetUser(username)
 	if err != nil {
@@ -87,8 +82,8 @@ func TestCreateNewUser(t *testing.T) {
 
 func TestUserUpdate(t *testing.T) {
 	// Initialize:
-	ovpm.SetupDB("sqlite3", ":memory:")
-	defer ovpm.CeaseDB()
+	db := ovpm.CreateDB("sqlite3", ":memory:")
+	defer db.Cease()
 	ovpm.Init("localhost", "", ovpm.UDPProto, "")
 
 	// Prepare:
@@ -125,8 +120,8 @@ func TestUserUpdate(t *testing.T) {
 
 func TestUserPasswordCorrect(t *testing.T) {
 	// Initialize:
-	ovpm.SetupDB("sqlite3", ":memory:")
-	defer ovpm.CeaseDB()
+	db := ovpm.CreateDB("sqlite3", ":memory:")
+	defer db.Cease()
 	ovpm.Init("localhost", "", ovpm.UDPProto, "")
 
 	// Prepare:
@@ -142,8 +137,8 @@ func TestUserPasswordCorrect(t *testing.T) {
 
 func TestUserPasswordReset(t *testing.T) {
 	// Initialize:
-	ovpm.SetupDB("sqlite3", ":memory:")
-	defer ovpm.CeaseDB()
+	db := ovpm.CreateDB("sqlite3", ":memory:")
+	defer db.Cease()
 	ovpm.Init("localhost", "", ovpm.UDPProto, "")
 
 	// Prepare:
@@ -169,8 +164,8 @@ func TestUserPasswordReset(t *testing.T) {
 
 func TestUserDelete(t *testing.T) {
 	// Initialize:
-	ovpm.SetupDB("sqlite3", ":memory:")
-	defer ovpm.CeaseDB()
+	db := ovpm.CreateDB("sqlite3", ":memory:")
+	defer db.Cease()
 	ovpm.Init("localhost", "", ovpm.UDPProto, "")
 
 	// Prepare:
@@ -207,8 +202,8 @@ func TestUserDelete(t *testing.T) {
 
 func TestUserGet(t *testing.T) {
 	// Initialize:
-	ovpm.SetupDB("sqlite3", ":memory:")
-	defer ovpm.CeaseDB()
+	db := ovpm.CreateDB("sqlite3", ":memory:")
+	defer db.Cease()
 	ovpm.Init("localhost", "", ovpm.UDPProto, "")
 
 	// Prepare:
@@ -231,13 +226,13 @@ func TestUserGet(t *testing.T) {
 
 func TestUserGetAll(t *testing.T) {
 	// Initialize:
-	ovpm.SetupDB("sqlite3", ":memory:")
-	defer ovpm.CeaseDB()
+	db := ovpm.CreateDB("sqlite3", ":memory:")
+	defer db.Cease()
 	ovpm.Init("localhost", "", ovpm.UDPProto, "")
 	count := 5
 
 	// Prepare:
-	var users []*ovpm.DBUser
+	var users []*ovpm.User
 	for i := 0; i < count; i++ {
 		username := fmt.Sprintf("user%d", i)
 		password := fmt.Sprintf("password%d", i)
@@ -269,8 +264,8 @@ func TestUserGetAll(t *testing.T) {
 
 func TestUserRenew(t *testing.T) {
 	// Initialize:
-	ovpm.SetupDB("sqlite3", ":memory:")
-	defer ovpm.CeaseDB()
+	db := ovpm.CreateDB("sqlite3", ":memory:")
+	defer db.Cease()
 	ovpm.Init("localhost", "", ovpm.UDPProto, "")
 
 	// Prepare:
@@ -291,8 +286,8 @@ func TestUserRenew(t *testing.T) {
 
 func TestUserIPAllocator(t *testing.T) {
 	// Initialize:
-	ovpm.SetupDB("sqlite3", ":memory:")
-	defer ovpm.CeaseDB()
+	db := ovpm.CreateDB("sqlite3", ":memory:")
+	defer db.Cease()
 	ovpm.Init("localhost", "", ovpm.UDPProto, "")
 
 	// Prepare:
@@ -326,21 +321,18 @@ func TestUserIPAllocator(t *testing.T) {
 }
 
 // areUsersEqual compares given users and returns true if they are the same.
-func areUsersEqual(user1, user2 *ovpm.DBUser) bool {
-	if user1.Cert != user2.Cert {
-		logrus.Info("Cert %v != %v", user1.Cert, user2.Cert)
+func areUsersEqual(user1, user2 *ovpm.User) bool {
+	if user1.GetCert() != user2.GetCert() {
+		logrus.Info("Cert %v != %v", user1.GetCert(), user2.GetCert())
 		return false
 	}
-	if user1.Username != user2.Username {
-		logrus.Infof("Username %v != %v", user1.Username, user2.Username)
+	if user1.GetUsername() != user2.GetUsername() {
+		logrus.Infof("Username %v != %v", user1.GetUsername(), user2.GetUsername())
 		return false
 	}
-	if user1.Hash != user2.Hash {
-		logrus.Infof("Password %v != %v", user1.Hash, user2.Hash)
-		return false
-	}
-	if user1.ServerSerialNumber != user2.ServerSerialNumber {
-		logrus.Infof("ServerSerialNumber %v != %v", user1.ServerSerialNumber, user2.ServerSerialNumber)
+
+	if user1.GetServerSerialNumber() != user2.GetServerSerialNumber() {
+		logrus.Infof("ServerSerialNumber %v != %v", user1.GetServerSerialNumber(), user2.GetServerSerialNumber())
 		return false
 	}
 	logrus.Infof("users are the same!")

+ 103 - 28
vpn.go

@@ -3,6 +3,9 @@
 //go:generate protoc -I pb/ -I$GOPATH/src/github.com/grpc-ecosystem/grpc-gateway/third_party/googleapis pb/user.proto pb/vpn.proto pb/network.proto --grpc-gateway_out=logtostderr=true:pb
 //go:generate protoc -I pb/ -I$GOPATH/src/github.com/grpc-ecosystem/grpc-gateway/third_party/googleapis pb/user.proto pb/vpn.proto pb/network.proto --swagger_out=logtostderr=true:pb
 
+// Package ovpm provides the implementation of core API.
+//
+// ovpm can create and destroy OpenVPN servers, manage vpn users, handle certificates etc...
 package ovpm
 
 import (
@@ -33,8 +36,8 @@ const (
 	UDPProto string = "udp"
 )
 
-// DBServer is database model for storing VPN server related stuff.
-type DBServer struct {
+// serverModel is database model for storing VPN server related stuff.
+type dbServerModel struct {
 	gorm.Model
 	Name         string `gorm:"unique_index"` // Server name.
 	SerialNumber string
@@ -51,11 +54,91 @@ type DBServer struct {
 	CRL      string // Certificate Revocation List
 }
 
+// Server represents VPN server.
+type Server struct {
+	dbServerModel
+}
+
 // CheckSerial takes a serial number and checks it against the current server's serial number.
-func (s *DBServer) CheckSerial(serial string) bool {
+func (s *Server) CheckSerial(serial string) bool {
 	return serial == s.SerialNumber
 }
 
+// GetSerialNumber returns server's serial number.
+func (s *Server) GetSerialNumber() string {
+	return s.SerialNumber
+}
+
+// GetServerName returns server's name.
+func (s *Server) GetServerName() string {
+	if s.Name != "" {
+		return s.Name
+	}
+	return "default"
+}
+
+// GetHostname returns vpn server's hostname.
+func (s *Server) GetHostname() string {
+	return s.Hostname
+}
+
+// GetPort returns vpn server's port.
+func (s *Server) GetPort() string {
+	if s.Port != "" {
+		return s.Port
+	}
+	return DefaultVPNPort
+
+}
+
+// GetProto returns vpn server's proto.
+func (s *Server) GetProto() string {
+	if s.Proto != "" {
+		return s.Proto
+	}
+	return DefaultVPNProto
+}
+
+// GetCert returns vpn server's cert.
+func (s *Server) GetCert() string {
+	return s.Cert
+}
+
+// GetKey returns vpn server's key.
+func (s *Server) GetKey() string {
+	return s.Key
+}
+
+// GetCACert returns vpn server's cacert.
+func (s *Server) GetCACert() string {
+	return s.CACert
+}
+
+// GetCAKey returns vpn server's cakey.
+func (s *Server) GetCAKey() string {
+	return s.CAKey
+}
+
+// GetNet returns vpn server's net.
+func (s *Server) GetNet() string {
+	return s.Net
+}
+
+// GetMask returns vpn server's mask.
+func (s *Server) GetMask() string {
+	return s.Mask
+}
+
+// GetCRL returns vpn server's crl.
+func (s *Server) GetCRL() string {
+	return s.CRL
+}
+
+// GetCreatedAt returns server's created at.
+func (s *Server) GetCreatedAt() string {
+	return s.CreatedAt.Format(time.UnixDate)
+}
+
 type _VPNServerConfig struct {
 	CertPath     string
 	KeyPath      string
@@ -150,7 +233,7 @@ func Init(hostname string, port string, proto string, ipblock string) error {
 	}
 
 	serialNumber := uuid.New().String()
-	serverInstance := DBServer{
+	serverInstance := dbServerModel{
 		Name: serverName,
 
 		SerialNumber: serialNumber,
@@ -189,14 +272,14 @@ func Init(hostname string, port string, proto string, ipblock string) error {
 	return nil
 }
 
-// Deinit deletes the server with the given serverName from the database and frees the allocated resources.
+// Deinit deletes the VPN server from the database and frees the allocated resources.
 func Deinit() error {
 	if !IsInitialized() {
 		return fmt.Errorf("server not found")
 	}
 
-	db.Unscoped().Delete(&DBServer{})
-	db.Unscoped().Delete(&DBRevoked{})
+	db.Unscoped().Delete(&dbServerModel{})
+	db.Unscoped().Delete(&dbRevokedModel{})
 	Emit()
 	return nil
 }
@@ -223,12 +306,12 @@ func DumpsClientConfig(username string) (string, error) {
 		NoGW     bool
 		Proto    string
 	}{
-		Hostname: server.Hostname,
-		Port:     server.Port,
-		CA:       server.CACert,
-		Key:      user.Key,
-		Cert:     user.Cert,
-		NoGW:     user.NoGW,
+		Hostname: server.GetHostname(),
+		Port:     server.GetPort(),
+		CA:       server.GetCACert(),
+		Key:      user.getKey(),
+		Cert:     user.GetCert(),
+		NoGW:     user.IsNoGW(),
 		Proto:    server.GetProto(),
 	}
 	data, err := bindata.Asset("template/client.ovpn.tmpl")
@@ -262,7 +345,7 @@ func DumpClientConfig(username, path string) error {
 
 // GetSystemCA returns the system CA from the database if available.
 func GetSystemCA() (*pki.CA, error) {
-	server := DBServer{}
+	server := dbServerModel{}
 	db.First(&server)
 	if db.NewRecord(&server) {
 		return nil, fmt.Errorf("server record does not exists in db")
@@ -467,26 +550,18 @@ func emitServerConf() error {
 }
 
 // GetServerInstance returns the default server from the database.
-func GetServerInstance() (*DBServer, error) {
-	var server DBServer
+func GetServerInstance() (*Server, error) {
+	var server dbServerModel
 	db.First(&server)
 	if db.NewRecord(server) {
 		return nil, fmt.Errorf("can not retrieve server from db")
 	}
-	return &server, nil
-}
-
-// GetProto returns the current VPN proto.
-func (s *DBServer) GetProto() string {
-	if s.Proto != "" {
-		return s.Proto
-	}
-	return UDPProto
+	return &Server{dbServerModel: server}, nil
 }
 
-// IsInitialized checks if there is a default server in the database or not.
+// IsInitialized checks if there is a default VPN server configured in the database or not.
 func IsInitialized() bool {
-	var server DBServer
+	var server dbServerModel
 	db.First(&server)
 	if db.NewRecord(server) {
 		return false
@@ -515,7 +590,7 @@ func emitServerCert() error {
 }
 
 func emitCRL() error {
-	var revokedDBItems []*DBRevoked
+	var revokedDBItems []*dbRevokedModel
 	db.Find(&revokedDBItems)
 	var revokedCertSerials []*big.Int
 	for _, item := range revokedDBItems {

+ 28 - 28
vpn_test.go

@@ -20,13 +20,13 @@ func setupTestCase() {
 func TestVPNInit(t *testing.T) {
 	// Init:
 	setupTestCase()
-	SetupDB("sqlite3", ":memory:")
-	defer CeaseDB()
+	CreateDB("sqlite3", ":memory:")
+	defer db.Cease()
 	// Prepare:
 	// Test:
 
 	// Check database if the database has no server.
-	var server DBServer
+	var server dbServerModel
 	db.First(&server)
 
 	// Isn't server empty struct?
@@ -44,7 +44,7 @@ func TestVPNInit(t *testing.T) {
 	Init("localhost", "", UDPProto, "")
 
 	// Check database if the database has no server.
-	var server2 DBServer
+	var server2 dbServerModel
 	db.First(&server2)
 
 	// Is server empty struct?
@@ -56,8 +56,8 @@ func TestVPNInit(t *testing.T) {
 func TestVPNDeinit(t *testing.T) {
 	// Init:
 	setupTestCase()
-	SetupDB("sqlite3", ":memory:")
-	defer CeaseDB()
+	CreateDB("sqlite3", ":memory:")
+	defer db.Cease()
 
 	// Prepare:
 	// Initialize the server.
@@ -69,7 +69,7 @@ func TestVPNDeinit(t *testing.T) {
 	u.Delete()
 
 	// Test:
-	var server DBServer
+	var server dbServerModel
 	db.First(&server)
 
 	// Isn't server empty struct?
@@ -78,7 +78,7 @@ func TestVPNDeinit(t *testing.T) {
 	}
 
 	// Test if Revoked table contains the removed user's entries.
-	var revoked DBRevoked
+	var revoked dbRevokedModel
 	db.First(&revoked)
 
 	if db.NewRecord(&revoked) {
@@ -89,7 +89,7 @@ func TestVPNDeinit(t *testing.T) {
 	Deinit()
 
 	// Get server from db.
-	var server2 DBServer
+	var server2 dbServerModel
 	db.First(&server2)
 
 	// Isn't server empty struct?
@@ -98,7 +98,7 @@ func TestVPNDeinit(t *testing.T) {
 	}
 
 	// Test if Revoked table contains the removed user's entries.
-	var revoked2 DBRevoked
+	var revoked2 dbRevokedModel
 	db.First(&revoked2)
 
 	// Is revoked empty?
@@ -110,8 +110,8 @@ func TestVPNDeinit(t *testing.T) {
 func TestVPNIsInitialized(t *testing.T) {
 	// Init:
 	setupTestCase()
-	SetupDB("sqlite3", ":memory:")
-	defer CeaseDB()
+	CreateDB("sqlite3", ":memory:")
+	defer db.Cease()
 
 	// Prepare:
 
@@ -133,8 +133,8 @@ func TestVPNIsInitialized(t *testing.T) {
 func TestVPNGetServerInstance(t *testing.T) {
 	// Init:
 	setupTestCase()
-	SetupDB("sqlite3", ":memory:")
-	defer CeaseDB()
+	CreateDB("sqlite3", ":memory:")
+	defer db.Cease()
 
 	// Prepare:
 
@@ -170,8 +170,8 @@ func TestVPNGetServerInstance(t *testing.T) {
 func TestVPNDumpsClientConfig(t *testing.T) {
 	// Init:
 	setupTestCase()
-	SetupDB("sqlite3", ":memory:")
-	defer CeaseDB()
+	CreateDB("sqlite3", ":memory:")
+	defer db.Cease()
 	Init("localhost", "", UDPProto, "")
 
 	// Prepare:
@@ -192,8 +192,8 @@ func TestVPNDumpsClientConfig(t *testing.T) {
 func TestVPNDumpClientConfig(t *testing.T) {
 	// Init:
 	setupTestCase()
-	SetupDB("sqlite3", ":memory:")
-	defer CeaseDB()
+	CreateDB("sqlite3", ":memory:")
+	defer db.Cease()
 	Init("localhost", "", UDPProto, "")
 
 	// Prepare:
@@ -250,8 +250,8 @@ func TestVPNDumpClientConfig(t *testing.T) {
 func TestVPNGetSystemCA(t *testing.T) {
 	// Init:
 	setupTestCase()
-	SetupDB("sqlite3", ":memory:")
-	defer CeaseDB()
+	CreateDB("sqlite3", ":memory:")
+	defer db.Cease()
 
 	// Prepare:
 
@@ -282,8 +282,8 @@ func TestVPNGetSystemCA(t *testing.T) {
 func TestVPNStartVPNProc(t *testing.T) {
 	// Init:
 	setupTestCase()
-	SetupDB("sqlite3", ":memory:")
-	defer CeaseDB()
+	CreateDB("sqlite3", ":memory:")
+	defer db.Cease()
 
 	// Prepare:
 
@@ -316,8 +316,8 @@ func TestVPNStartVPNProc(t *testing.T) {
 func TestVPNStopVPNProc(t *testing.T) {
 	// Init:
 	setupTestCase()
-	SetupDB("sqlite3", ":memory:")
-	defer CeaseDB()
+	CreateDB("sqlite3", ":memory:")
+	defer db.Cease()
 	Init("localhost", "", UDPProto, "")
 
 	// Prepare:
@@ -340,8 +340,8 @@ func TestVPNStopVPNProc(t *testing.T) {
 
 func TestVPNRestartVPNProc(t *testing.T) {
 	// Init:
-	SetupDB("sqlite3", ":memory:")
-	defer CeaseDB()
+	CreateDB("sqlite3", ":memory:")
+	defer db.Cease()
 	Init("localhost", "", UDPProto, "")
 
 	// Prepare:
@@ -369,8 +369,8 @@ func TestVPNRestartVPNProc(t *testing.T) {
 func TestVPNEmit(t *testing.T) {
 	// Init:
 	setupTestCase()
-	SetupDB("sqlite3", ":memory:")
-	defer CeaseDB()
+	CreateDB("sqlite3", ":memory:")
+	defer db.Cease()
 	Init("localhost", "", UDPProto, "")
 
 	// Prepare: