Browse Source

refactor(vpn): move all funcs under Server struct

All funcs related to vpn server within the global scope are now relocated
as methods of Server struct.

TheServer() func always returns a singleton instance to the server.

This work increases testablity of the vpn management functions without
polluting the global scope.
It also improves structure and readablity.
Mustafa Arici 7 years ago
parent
commit
9ce7db14b1
9 changed files with 401 additions and 347 deletions
  1. 6 9
      api/rpc.go
  2. 6 6
      bindata/bindata.go
  3. 2 2
      cmd/ovpmd/main.go
  4. 19 24
      net.go
  5. 48 13
      net_test.go
  6. 26 49
      user.go
  7. 15 13
      user_test.go
  8. 158 165
      vpn.go
  9. 121 66
      vpn_test.go

+ 6 - 9
api/rpc.go

@@ -299,7 +299,7 @@ func (s *UserService) GenConfig(ctx context.Context, req *pb.UserGenConfigReques
 	}
 
 	if perms.Contains(ovpm.GenConfigAnyUserPerm) {
-		configBlob, err := ovpm.DumpsClientConfig(user.GetUsername())
+		configBlob, err := ovpm.TheServer().DumpsClientConfig(user.GetUsername())
 		if err != nil {
 			return nil, err
 		}
@@ -310,7 +310,7 @@ func (s *UserService) GenConfig(ctx context.Context, req *pb.UserGenConfigReques
 		if user.GetUsername() != username {
 			return nil, grpc.Errorf(codes.PermissionDenied, "Caller can only genconfig for their user.")
 		}
-		configBlob, err := ovpm.DumpsClientConfig(user.GetUsername())
+		configBlob, err := ovpm.TheServer().DumpsClientConfig(user.GetUsername())
 		if err != nil {
 			return nil, err
 		}
@@ -324,10 +324,7 @@ type VPNService struct{}
 
 func (s *VPNService) Status(ctx context.Context, req *pb.VPNStatusRequest) (*pb.VPNStatusResponse, error) {
 	logrus.Debugf("rpc call: vpn status")
-	server, err := ovpm.GetServerInstance()
-	if err != nil {
-		return nil, err
-	}
+	server := ovpm.TheServer()
 
 	perms, err := permset.FromContext(ctx)
 	if err != nil {
@@ -375,7 +372,7 @@ func (s *VPNService) Init(ctx context.Context, req *pb.VPNInitRequest) (*pb.VPNI
 		return nil, grpc.Errorf(codes.PermissionDenied, "ovpm.InitVPNPerm is required for this operation.")
 	}
 
-	if err := ovpm.Init(req.Hostname, req.Port, proto, req.IpBlock, req.Dns); err != nil {
+	if err := ovpm.TheServer().Init(req.Hostname, req.Port, proto, req.IpBlock, req.Dns); err != nil {
 		logrus.Errorf("server can not be created: %v", err)
 	}
 	return &pb.VPNInitResponse{}, nil
@@ -392,7 +389,7 @@ func (s *VPNService) Update(ctx context.Context, req *pb.VPNUpdateRequest) (*pb.
 		return nil, grpc.Errorf(codes.PermissionDenied, "ovpm.UpdateVPNPerm is required for this operation.")
 	}
 
-	if err := ovpm.Update(req.IpBlock, req.Dns); err != nil {
+	if err := ovpm.TheServer().Update(req.IpBlock, req.Dns); err != nil {
 		logrus.Errorf("server can not be updated: %v", err)
 	}
 	return &pb.VPNUpdateResponse{}, nil
@@ -409,7 +406,7 @@ func (s *VPNService) Restart(ctx context.Context, req *pb.VPNRestartRequest) (*p
 		return nil, grpc.Errorf(codes.PermissionDenied, "ovpm.UpdateVPNPerm is required for this operation.")
 	}
 
-	ovpm.RestartVPNProc()
+	ovpm.TheServer().RestartVPNProc()
 	return &pb.VPNRestartResponse{}, nil
 }
 

+ 6 - 6
bindata/bindata.go

@@ -93,7 +93,7 @@ func templateAuthSwaggerJson() (*asset, error) {
 		return nil, err
 	}
 
-	info := bindataFileInfo{name: "template/auth.swagger.json", size: 2854, mode: os.FileMode(420), modTime: time.Unix(1522505931, 0)}
+	info := bindataFileInfo{name: "template/auth.swagger.json", size: 2854, mode: os.FileMode(420), modTime: time.Unix(1522534025, 0)}
 	a := &asset{bytes: bytes, info: info}
 	return a, nil
 }
@@ -113,7 +113,7 @@ func templateBundleJs() (*asset, error) {
 		return nil, err
 	}
 
-	info := bindataFileInfo{name: "template/bundle.js", size: 559458, mode: os.FileMode(420), modTime: time.Unix(1522505947, 0)}
+	info := bindataFileInfo{name: "template/bundle.js", size: 559458, mode: os.FileMode(420), modTime: time.Unix(1522534037, 0)}
 	a := &asset{bytes: bytes, info: info}
 	return a, nil
 }
@@ -193,7 +193,7 @@ func templateIndexHtml() (*asset, error) {
 		return nil, err
 	}
 
-	info := bindataFileInfo{name: "template/index.html", size: 577, mode: os.FileMode(420), modTime: time.Unix(1522505947, 0)}
+	info := bindataFileInfo{name: "template/index.html", size: 577, mode: os.FileMode(420), modTime: time.Unix(1522534037, 0)}
 	a := &asset{bytes: bytes, info: info}
 	return a, nil
 }
@@ -233,7 +233,7 @@ func templateNetworkSwaggerJson() (*asset, error) {
 		return nil, err
 	}
 
-	info := bindataFileInfo{name: "template/network.swagger.json", size: 6669, mode: os.FileMode(420), modTime: time.Unix(1522505931, 0)}
+	info := bindataFileInfo{name: "template/network.swagger.json", size: 6669, mode: os.FileMode(420), modTime: time.Unix(1522534025, 0)}
 	a := &asset{bytes: bytes, info: info}
 	return a, nil
 }
@@ -273,7 +273,7 @@ func templateUserSwaggerJson() (*asset, error) {
 		return nil, err
 	}
 
-	info := bindataFileInfo{name: "template/user.swagger.json", size: 6907, mode: os.FileMode(420), modTime: time.Unix(1522505931, 0)}
+	info := bindataFileInfo{name: "template/user.swagger.json", size: 6907, mode: os.FileMode(420), modTime: time.Unix(1522534025, 0)}
 	a := &asset{bytes: bytes, info: info}
 	return a, nil
 }
@@ -293,7 +293,7 @@ func templateVpnSwaggerJson() (*asset, error) {
 		return nil, err
 	}
 
-	info := bindataFileInfo{name: "template/vpn.swagger.json", size: 3732, mode: os.FileMode(420), modTime: time.Unix(1522505931, 0)}
+	info := bindataFileInfo{name: "template/vpn.swagger.json", size: 3732, mode: os.FileMode(420), modTime: time.Unix(1522534025, 0)}
 	a := &asset{bytes: bytes, info: info}
 	return a, nil
 }

+ 2 - 2
cmd/ovpmd/main.go

@@ -134,14 +134,14 @@ func (s *server) start() {
 	logrus.Infof("OVPM %s is running gRPC:%s, REST:%s ...", ovpm.Version, s.grpcPort, s.restPort)
 	go s.grpcServer.Serve(s.lis)
 	go http.ListenAndServe(":"+s.restPort, s.restServer)
-	ovpm.StartVPNProc()
+	ovpm.TheServer().StartVPNProc()
 }
 
 func (s *server) stop() {
 	logrus.Info("OVPM is shutting down ...")
 	s.grpcServer.Stop()
 	s.restCancel()
-	ovpm.StopVPNProc()
+	ovpm.TheServer().StopVPNProc()
 
 }
 

+ 19 - 24
net.go

@@ -103,7 +103,7 @@ type Network struct {
 
 // GetNetwork returns a network specified by its name.
 func GetNetwork(name string) (*Network, error) {
-	if !IsInitialized() {
+	if svr := TheServer(); !svr.IsInitialized() {
 		return nil, fmt.Errorf("you first need to create server")
 	}
 	// Validate user input.
@@ -134,9 +134,10 @@ func GetAllNetworks() []*Network {
 
 // CreateNewNetwork creates a new network definition in the system.
 func CreateNewNetwork(name, cidr string, nettype NetworkType, via string) (*Network, error) {
-	if !IsInitialized() {
+	if svr := TheServer(); !svr.IsInitialized() {
 		return nil, fmt.Errorf("you first need to create server")
 	}
+
 	// Validate user input.
 	if govalidator.IsNull(name) {
 		return nil, fmt.Errorf("validation error: %s can not be null", name)
@@ -185,7 +186,7 @@ func CreateNewNetwork(name, cidr string, nettype NetworkType, via string) (*Netw
 	if db.NewRecord(&network) {
 		return nil, fmt.Errorf("can not create network in the db")
 	}
-	EmitWithRestart()
+	TheServer().EmitWithRestart()
 	logrus.Infof("network defined: %s (%s)", network.Name, network.CIDR)
 	return &Network{dbNetworkModel: network}, nil
 
@@ -193,19 +194,20 @@ func CreateNewNetwork(name, cidr string, nettype NetworkType, via string) (*Netw
 
 // Delete deletes a network definition in the system.
 func (n *Network) Delete() error {
-	if !IsInitialized() {
+	svr := TheServer()
+	if !svr.IsInitialized() {
 		return fmt.Errorf("you first need to create server")
 	}
 
 	db.Unscoped().Delete(n.dbNetworkModel)
-	EmitWithRestart()
+	svr.EmitWithRestart()
 	logrus.Infof("network deleted: %s", n.Name)
 	return nil
 }
 
 // Associate allows the given user access to this network.
 func (n *Network) Associate(username string) error {
-	if !IsInitialized() {
+	if svr := TheServer(); !svr.IsInitialized() {
 		return fmt.Errorf("you first need to create server")
 	}
 	user, err := GetUser(username)
@@ -231,14 +233,15 @@ func (n *Network) Associate(username string) error {
 	if userAssoc.Error != nil {
 		return fmt.Errorf("association failed: %v", userAssoc.Error)
 	}
-	EmitWithRestart()
+	TheServer().EmitWithRestart()
 	logrus.Infof("user '%s' is associated with the network '%s'", user.GetUsername(), n.Name)
 	return nil
 }
 
 // Dissociate breaks up the given users association to the said network.
 func (n *Network) Dissociate(username string) error {
-	if !IsInitialized() {
+	svr := TheServer()
+	if !svr.IsInitialized() {
 		return fmt.Errorf("you first need to create server")
 	}
 
@@ -265,7 +268,7 @@ func (n *Network) Dissociate(username string) error {
 	if userAssoc.Error != nil {
 		return fmt.Errorf("disassociation failed: %v", userAssoc.Error)
 	}
-	EmitWithRestart()
+	svr.EmitWithRestart()
 	logrus.Infof("user '%s' is dissociated with the network '%s'", user.GetUsername(), n.Name)
 	return nil
 }
@@ -410,14 +413,10 @@ func hasRoutableIP(network string, ifi *net.Interface) (net.IP, bool) {
 
 // vpnInterface returns the interface which belongs to the VPN server.
 func vpnInterface() *net.Interface {
-	server, err := GetServerInstance()
-	if err != nil {
-		logrus.Errorf("can't get server instance: %v", err)
-		return nil
-	}
+	svr := TheServer()
 
-	mask := net.IPMask(net.ParseIP(server.Mask))
-	prefix := net.ParseIP(server.Net)
+	mask := net.IPMask(net.ParseIP(svr.Mask))
+	prefix := net.ParseIP(svr.Net)
 	netw := prefix.Mask(mask).To4()
 	netw[3] = byte(1) // Server is always gets xxx.xxx.xxx.1
 	ipnet := net.IPNet{IP: netw, Mask: mask}
@@ -491,20 +490,16 @@ func enableNat() error {
 	}
 
 	// Enable ip forwarding.
-	emitToFile("/proc/sys/net/ipv4/ip_forward", "1", 0)
+	TheServer().emitToFile("/proc/sys/net/ipv4/ip_forward", "1", 0)
 	ipt, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
 	if err != nil {
 		return fmt.Errorf("can not create new iptables object: %v", err)
 	}
 
-	server, err := GetServerInstance()
-	if err != nil {
-		logrus.Errorf("can't get server instance: %v", err)
-		return nil
-	}
+	svr := TheServer()
 
-	mask := net.IPMask(net.ParseIP(server.Mask))
-	prefix := net.ParseIP(server.Net)
+	mask := net.IPMask(net.ParseIP(svr.Mask))
+	prefix := net.ParseIP(svr.Net)
 	netw := prefix.Mask(mask).To4()
 	netw[3] = byte(1) // Server is always gets xxx.xxx.xxx.1
 	ipnet := net.IPNet{IP: netw, Mask: mask}

+ 48 - 13
net_test.go

@@ -12,7 +12,7 @@ func TestVPNCreateNewNetwork(t *testing.T) {
 	setupTestCase()
 	CreateDB("sqlite3", ":memory:")
 	defer db.Cease()
-	Init("localhost", "", UDPProto, "", "")
+	TheServer().Init("localhost", "", UDPProto, "", "")
 
 	// Prepare:
 	// Test:
@@ -80,7 +80,7 @@ func TestVPNDeleteNetwork(t *testing.T) {
 	setupTestCase()
 	CreateDB("sqlite3", ":memory:")
 	defer db.Cease()
-	Init("localhost", "", UDPProto, "", "")
+	TheServer().Init("localhost", "", UDPProto, "", "")
 
 	// Prepare:
 	// Test:
@@ -118,7 +118,7 @@ func TestVPNGetNetwork(t *testing.T) {
 	setupTestCase()
 	CreateDB("sqlite3", ":memory:")
 	defer db.Cease()
-	Init("localhost", "", UDPProto, "", "")
+	TheServer().Init("localhost", "", UDPProto, "", "")
 
 	// Prepare:
 	// Test:
@@ -153,7 +153,7 @@ func TestVPNGetAllNetworks(t *testing.T) {
 	setupTestCase()
 	CreateDB("sqlite3", ":memory:")
 	defer db.Cease()
-	Init("localhost", "", UDPProto, "", "")
+	TheServer().Init("localhost", "", UDPProto, "", "")
 
 	// Prepare:
 	// Test:
@@ -199,7 +199,7 @@ func TestNetAssociate(t *testing.T) {
 	setupTestCase()
 	CreateDB("sqlite3", ":memory:")
 	defer db.Cease()
-	Init("localhost", "", UDPProto, "", "")
+	TheServer().Init("localhost", "", UDPProto, "", "")
 
 	// Prepare:
 	// Test:
@@ -244,8 +244,8 @@ func TestNetDissociate(t *testing.T) {
 	setupTestCase()
 	CreateDB("sqlite3", ":memory:")
 	defer db.Cease()
-	err := Init("localhost", "", UDPProto, "", "")
-	if err != nil {
+
+	if err := TheServer().Init("localhost", "", UDPProto, "", ""); err != nil {
 		t.Fatal(err)
 	}
 
@@ -297,7 +297,7 @@ func TestNetGetAssociatedUsers(t *testing.T) {
 	setupTestCase()
 	CreateDB("sqlite3", ":memory:")
 	defer db.Cease()
-	Init("localhost", "", UDPProto, "", "")
+	TheServer().Init("localhost", "", UDPProto, "", "")
 
 	// Prepare:
 	// Test:
@@ -318,12 +318,14 @@ func TestNetGetAssociatedUsers(t *testing.T) {
 	}
 }
 
-func init() {
-	// Init
-	Testing = true
-}
-
 func TestNetworkTypeFromString(t *testing.T) {
+	// Initialize:
+	setupTestCase()
+	CreateDB("sqlite3", ":memory:")
+	defer db.Cease()
+	TheServer().Init("localhost", "", UDPProto, "", "")
+
+	// Test
 	type args struct {
 		typ string
 	}
@@ -346,6 +348,13 @@ func TestNetworkTypeFromString(t *testing.T) {
 }
 
 func TestGetAllNetworkTypes(t *testing.T) {
+	// Initialize:
+	setupTestCase()
+	CreateDB("sqlite3", ":memory:")
+	defer db.Cease()
+	TheServer().Init("localhost", "", UDPProto, "", "")
+
+	// Test
 	tests := []struct {
 		name string
 		want []NetworkType
@@ -362,6 +371,13 @@ func TestGetAllNetworkTypes(t *testing.T) {
 }
 
 func TestIsNetworkType(t *testing.T) {
+	// Initialize:
+	setupTestCase()
+	CreateDB("sqlite3", ":memory:")
+	defer db.Cease()
+	TheServer().Init("localhost", "", UDPProto, "", "")
+
+	// Test
 	type args struct {
 		s string
 	}
@@ -384,6 +400,13 @@ func TestIsNetworkType(t *testing.T) {
 }
 
 func TestIncrementIP(t *testing.T) {
+	// Initialize:
+	setupTestCase()
+	CreateDB("sqlite3", ":memory:")
+	defer db.Cease()
+	TheServer().Init("localhost", "", UDPProto, "", "")
+
+	// Test
 	type args struct {
 		ip   string
 		mask string
@@ -411,6 +434,13 @@ func TestIncrementIP(t *testing.T) {
 }
 
 func Test_routableIP(t *testing.T) {
+	// Initialize:
+	setupTestCase()
+	CreateDB("sqlite3", ":memory:")
+	defer db.Cease()
+	TheServer().Init("localhost", "", UDPProto, "", "")
+
+	// Test
 	type args struct {
 		network string
 		ip      net.IP
@@ -433,3 +463,8 @@ func Test_routableIP(t *testing.T) {
 		})
 	}
 }
+
+func init() {
+	// Init
+	Testing = true
+}

+ 26 - 49
user.go

@@ -131,7 +131,8 @@ func GetAllUsers() ([]*User, error) {
 // It also generates the necessary client keys and signs certificates with the current
 // server's CA.
 func CreateNewUser(username, password string, nogw bool, hostid uint32, admin bool) (*User, error) {
-	if !IsInitialized() {
+	svr := TheServer()
+	if !svr.IsInitialized() {
 		return nil, fmt.Errorf("you first need to create server")
 	}
 	// Validate user input.
@@ -145,7 +146,7 @@ func CreateNewUser(username, password string, nogw bool, hostid uint32, admin bo
 		return nil, fmt.Errorf("forbidden: username root is reserved and can not be used")
 	}
 
-	ca, err := GetSystemCA()
+	ca, err := svr.GetSystemCA()
 	if err != nil {
 		return nil, err
 	}
@@ -154,10 +155,6 @@ func CreateNewUser(username, password string, nogw bool, hostid uint32, admin bo
 	if err != nil {
 		return nil, fmt.Errorf("can not create client cert %s: %v", username, err)
 	}
-	server, err := GetServerInstance()
-	if err != nil {
-		return nil, fmt.Errorf("can not get server: %v", err)
-	}
 
 	if hostid != 0 {
 		ip := HostID2IP(hostid)
@@ -165,7 +162,7 @@ func CreateNewUser(username, password string, nogw bool, hostid uint32, admin bo
 			return nil, fmt.Errorf("host id doesn't represent an ip %d", hostid)
 		}
 
-		network := net.IPNet{IP: net.ParseIP(server.Net).To4(), Mask: net.IPMask(net.ParseIP(server.Mask).To4())}
+		network := net.IPNet{IP: net.ParseIP(svr.Net).To4(), Mask: net.IPMask(net.ParseIP(svr.Mask).To4())}
 		if !network.Contains(ip) {
 			return nil, fmt.Errorf("ip %s, is out of vpn network %s", ip, network.String())
 		}
@@ -176,8 +173,8 @@ func CreateNewUser(username, password string, nogw bool, hostid uint32, admin bo
 
 		// Check if requested ip is allocated to the VPN server itself.
 		serverNet := net.IPNet{
-			IP:   net.ParseIP(server.Net).To4(),
-			Mask: net.IPMask(net.ParseIP(server.Mask).To4()),
+			IP:   net.ParseIP(svr.Net).To4(),
+			Mask: net.IPMask(net.ParseIP(svr.Mask).To4()),
 		}
 
 		ip, ipnet, err := net.ParseCIDR(serverNet.String())
@@ -192,7 +189,7 @@ func CreateNewUser(username, password string, nogw bool, hostid uint32, admin bo
 		Username:           username,
 		Cert:               clientCert.Cert,
 		Key:                clientCert.Key,
-		ServerSerialNumber: server.SerialNumber,
+		ServerSerialNumber: svr.SerialNumber,
 		NoGW:               nogw,
 		HostID:             hostid,
 		Admin:              admin,
@@ -207,8 +204,7 @@ func CreateNewUser(username, password string, nogw bool, hostid uint32, admin bo
 	logrus.Infof("user created: %s", username)
 
 	// EmitWithRestart server config
-	err = EmitWithRestart()
-	if err != nil {
+	if err = svr.EmitWithRestart(); err != nil {
 		return nil, err
 	}
 	return &User{dbUserModel: user}, nil
@@ -218,7 +214,8 @@ func CreateNewUser(username, password string, nogw bool, hostid uint32, admin bo
 //
 // How this method works is similiar to PUT semantics of REST. It sets the user record fields to the provided function arguments.
 func (u *User) Update(password string, nogw bool, hostid uint32, admin bool) error {
-	if !IsInitialized() {
+	svr := TheServer()
+	if !svr.IsInitialized() {
 		return fmt.Errorf("you first need to create server")
 	}
 
@@ -232,17 +229,12 @@ func (u *User) Update(password string, nogw bool, hostid uint32, admin bool) err
 	u.Admin = admin
 
 	if hostid != 0 {
-		server, err := GetServerInstance()
-		if err != nil {
-			return fmt.Errorf("can not get server: %v", err)
-		}
-
 		ip := HostID2IP(hostid)
 		if ip == nil {
 			return fmt.Errorf("host id doesn't represent an ip %d", hostid)
 		}
 
-		network := net.IPNet{IP: net.ParseIP(server.Net).To4(), Mask: net.IPMask(net.ParseIP(server.Mask).To4())}
+		network := net.IPNet{IP: net.ParseIP(svr.Net).To4(), Mask: net.IPMask(net.ParseIP(svr.Mask).To4())}
 		if !network.Contains(ip) {
 			return fmt.Errorf("ip %s, is out of vpn network %s", ip, network.String())
 		}
@@ -253,11 +245,7 @@ func (u *User) Update(password string, nogw bool, hostid uint32, admin bool) err
 	}
 	db.Save(u.dbUserModel)
 
-	err := EmitWithRestart()
-	if err != nil {
-		return err
-	}
-	return nil
+	return svr.EmitWithRestart()
 }
 
 // Delete deletes a user by the given username from the database.
@@ -275,8 +263,8 @@ func (u *User) Delete() error {
 	})
 	db.Unscoped().Delete(u.dbUserModel)
 	logrus.Infof("user deleted: %s", u.GetUsername())
-	err = EmitWithRestart()
-	if err != nil {
+
+	if err = TheServer().EmitWithRestart(); err != nil {
 		return err
 	}
 	u = nil // delete the existing user struct
@@ -291,8 +279,7 @@ func (u *User) ResetPassword(password string) error {
 		return fmt.Errorf("user password can not be updated %s: %v", u.Username, err)
 	}
 	db.Save(u.dbUserModel)
-	err = EmitWithRestart()
-	if err != nil {
+	if err = TheServer().EmitWithRestart(); err != nil {
 		return err
 	}
 
@@ -307,10 +294,11 @@ func (u *User) ResetPassword(password string) error {
 //
 // Also it can be used when a user cert is expired or user's private key stolen, missing etc.
 func (u *User) Renew() error {
-	if !IsInitialized() {
+	svr := TheServer()
+	if !svr.IsInitialized() {
 		return fmt.Errorf("you first need to create server")
 	}
-	ca, err := GetSystemCA()
+	ca, err := svr.GetSystemCA()
 	if err != nil {
 		return err
 	}
@@ -320,18 +308,12 @@ func (u *User) Renew() error {
 		return fmt.Errorf("can not create client cert %s: %v", u.Username, err)
 	}
 
-	server, err := GetServerInstance()
-	if err != nil {
-		return err
-	}
-
 	u.Cert = clientCert.Cert
 	u.Key = clientCert.Key
-	u.ServerSerialNumber = server.SerialNumber
+	u.ServerSerialNumber = svr.SerialNumber
 
 	db.Save(u.dbUserModel)
-	err = EmitWithRestart()
-	if err != nil {
+	if err = svr.EmitWithRestart(); err != nil {
 		return err
 	}
 
@@ -363,12 +345,9 @@ func (u *User) GetCreatedAt() string {
 func (u *User) getIP() net.IP {
 	users := getNonStaticHostUsers()
 	staticHostIDs := getStaticHostIDs()
-	server, err := GetServerInstance()
-	if err != nil {
-		logrus.Panicf("can not get server instance: %v", err)
-	}
-	mask := net.IPMask(net.ParseIP(server.Mask).To4())
-	network := net.ParseIP(server.Net).To4().Mask(mask)
+	svr := TheServer()
+	mask := net.IPMask(net.ParseIP(svr.Mask).To4())
+	network := net.ParseIP(svr.Net).To4().Mask(mask)
 
 	// If the user has static ip address, return it immediately.
 	if u.HostID != 0 {
@@ -399,11 +378,9 @@ func (u *User) getIP() net.IP {
 
 // GetIPNet returns user's vpn ip network. (e.g. 192.168.0.1/24)
 func (u *User) GetIPNet() string {
-	server, err := GetServerInstance()
-	if err != nil {
-		logrus.Panicf("can not get user ipnet: %v", err)
-	}
-	mask := net.IPMask(net.ParseIP(server.Mask).To4())
+	svr := TheServer()
+
+	mask := net.IPMask(net.ParseIP(svr.Mask).To4())
 
 	ipn := net.IPNet{
 		IP:   u.getIP(),

+ 15 - 13
user_test.go

@@ -13,8 +13,8 @@ func TestCreateNewUser(t *testing.T) {
 	// Initialize:
 	db := ovpm.CreateDB("sqlite3", ":memory:")
 	defer db.Cease()
-	ovpm.Init("localhost", "", ovpm.UDPProto, "", "")
-	server, _ := ovpm.GetServerInstance()
+	svr := ovpm.TheServer()
+	svr.Init("localhost", "", ovpm.UDPProto, "", "")
 
 	// Prepare:
 	username := "test.User"
@@ -44,7 +44,7 @@ func TestCreateNewUser(t *testing.T) {
 	}
 
 	// Is user's server serial number correct?
-	if !server.CheckSerial(user.ServerSerialNumber) {
+	if !svr.CheckSerial(user.ServerSerialNumber) {
 		t.Fatalf("user's ServerSerialNumber is expected to be 'CORRECT' but it is 'INCORRECT' instead %+v", user)
 	}
 
@@ -103,7 +103,7 @@ func TestUserUpdate(t *testing.T) {
 	// Initialize:
 	db := ovpm.CreateDB("sqlite3", ":memory:")
 	defer db.Cease()
-	ovpm.Init("localhost", "", ovpm.UDPProto, "", "")
+	ovpm.TheServer().Init("localhost", "", ovpm.UDPProto, "", "")
 
 	// Prepare:
 	username := "testUser"
@@ -143,7 +143,7 @@ func TestUserPasswordCorrect(t *testing.T) {
 	// Initialize:
 	db := ovpm.CreateDB("sqlite3", ":memory:")
 	defer db.Cease()
-	ovpm.Init("localhost", "", ovpm.UDPProto, "", "")
+	ovpm.TheServer().Init("localhost", "", ovpm.UDPProto, "", "")
 
 	// Prepare:
 	initialPassword := "g00dp@ssW0rd9"
@@ -160,7 +160,7 @@ func TestUserPasswordReset(t *testing.T) {
 	// Initialize:
 	db := ovpm.CreateDB("sqlite3", ":memory:")
 	defer db.Cease()
-	ovpm.Init("localhost", "", ovpm.UDPProto, "", "")
+	ovpm.TheServer().Init("localhost", "", ovpm.UDPProto, "", "")
 
 	// Prepare:
 	initialPassword := "g00dp@ssW0rd9"
@@ -187,7 +187,7 @@ func TestUserDelete(t *testing.T) {
 	// Initialize:
 	db := ovpm.CreateDB("sqlite3", ":memory:")
 	defer db.Cease()
-	ovpm.Init("localhost", "", ovpm.UDPProto, "", "")
+	ovpm.TheServer().Init("localhost", "", ovpm.UDPProto, "", "")
 
 	// Prepare:
 	username := "testUser"
@@ -225,7 +225,7 @@ func TestUserGet(t *testing.T) {
 	// Initialize:
 	db := ovpm.CreateDB("sqlite3", ":memory:")
 	defer db.Cease()
-	ovpm.Init("localhost", "", ovpm.UDPProto, "", "")
+	ovpm.TheServer().Init("localhost", "", ovpm.UDPProto, "", "")
 
 	// Prepare:
 	username := "testUser"
@@ -249,7 +249,7 @@ func TestUserGetAll(t *testing.T) {
 	// Initialize:
 	db := ovpm.CreateDB("sqlite3", ":memory:")
 	defer db.Cease()
-	ovpm.Init("localhost", "", ovpm.UDPProto, "", "")
+	ovpm.TheServer().Init("localhost", "", ovpm.UDPProto, "", "")
 	count := 5
 
 	// Prepare:
@@ -287,14 +287,15 @@ func TestUserRenew(t *testing.T) {
 	// Initialize:
 	db := ovpm.CreateDB("sqlite3", ":memory:")
 	defer db.Cease()
-	ovpm.Init("localhost", "", ovpm.UDPProto, "", "")
+	svr := ovpm.TheServer()
+	svr.Init("localhost", "", ovpm.UDPProto, "", "")
 
 	// Prepare:
 	user, _ := ovpm.CreateNewUser("user", "1234", false, 0, true)
 
 	// Test:
 	// Re initialize the server.
-	ovpm.Init("example.com", "3333", ovpm.UDPProto, "", "") // This causes implicit Renew() on every user in the system.
+	svr.Init("example.com", "3333", ovpm.UDPProto, "", "") // This causes implicit Renew() on every user in the system.
 
 	// Fetch user back.
 	fetchedUser, _ := ovpm.GetUser(user.GetUsername())
@@ -309,7 +310,8 @@ func TestUserIPAllocator(t *testing.T) {
 	// Initialize:
 	db := ovpm.CreateDB("sqlite3", ":memory:")
 	defer db.Cease()
-	ovpm.Init("localhost", "", ovpm.UDPProto, "", "")
+	svr := ovpm.TheServer()
+	svr.Init("localhost", "", ovpm.UDPProto, "", "")
 
 	// Prepare:
 
@@ -336,7 +338,7 @@ func TestUserIPAllocator(t *testing.T) {
 		}
 		if user != nil {
 			if user.GetIPNet() != tt.expectedIP {
-				t.Fatalf("user %s ip %s is expected to be %s", user.GetUsername(), user.GetIPNet(), tt.expectedIP)
+				t.Fatalf("user %s ip %s(%d) is expected to be %s", user.GetUsername(), user.GetIPNet(), user.GetHostID(), tt.expectedIP)
 			}
 		}
 	}

+ 158 - 165
vpn.go

@@ -3,12 +3,14 @@ package ovpm
 import (
 	"bytes"
 	"fmt"
+	"io"
 	"math/big"
 	"net"
 	"os"
 	"os/exec"
 	"path/filepath"
 	"strings"
+	"sync"
 	"text/template"
 
 	"time"
@@ -48,98 +50,125 @@ type dbServerModel struct {
 	DNS      string // DNS servers to push to the clients.
 }
 
+var serverInstance *Server
+var once sync.Once
+
 // Server represents VPN server.
 type Server struct {
+	sync.Mutex
+
 	dbServerModel
 	webPort string
+
+	emitToFileFunc func(path, content string, mode uint) error
+	openFunc       func(path string) (io.Reader, error)
+}
+
+// TheServer returns a pointer to the server instance.
+func TheServer() *Server {
+	once.Do(func() {
+		// Initialize the server instance.
+		serverInstance = &Server{
+			emitToFileFunc: emitToFile,
+			openFunc: func(path string) (io.Reader, error) {
+				return os.Open(path)
+			},
+		}
+	})
+	if db != nil {
+		serverInstance.Refresh()
+	} else {
+		logrus.Warn("database is not connected yet. skipping server instance refresh")
+	}
+	return serverInstance
 }
 
 // CheckSerial takes a serial number and checks it against the current server's serial number.
-func (s *Server) CheckSerial(serial string) bool {
-	return serial == s.SerialNumber
+func (svr *Server) CheckSerial(serial string) bool {
+	return serial == svr.SerialNumber
 }
 
 // GetSerialNumber returns server's serial number.
-func (s *Server) GetSerialNumber() string {
-	return s.SerialNumber
+func (svr *Server) GetSerialNumber() string {
+	return svr.SerialNumber
 }
 
 // GetServerName returns server's name.
-func (s *Server) GetServerName() string {
-	if s.Name != "" {
-		return s.Name
+func (svr *Server) GetServerName() string {
+	if svr.Name != "" {
+		return svr.Name
 	}
 	return "default"
 }
 
 // GetHostname returns vpn server's hostname.
-func (s *Server) GetHostname() string {
-	return s.Hostname
+func (svr *Server) GetHostname() string {
+	return svr.Hostname
 }
 
 // GetPort returns vpn server's port.
-func (s *Server) GetPort() string {
-	if s.Port != "" {
-		return s.Port
+func (svr *Server) GetPort() string {
+	if svr.Port != "" {
+		return svr.Port
 	}
 	return DefaultVPNPort
 
 }
 
 // GetProto returns vpn server's proto.
-func (s *Server) GetProto() string {
-	if s.Proto != "" {
-		return s.Proto
+func (svr *Server) GetProto() string {
+	if svr.Proto != "" {
+		return svr.Proto
 	}
 	return DefaultVPNProto
 }
 
 // GetCert returns vpn server's cert.
-func (s *Server) GetCert() string {
-	return s.Cert
+func (svr *Server) GetCert() string {
+	return svr.Cert
 }
 
 // GetKey returns vpn server's key.
-func (s *Server) GetKey() string {
-	return s.Key
+func (svr *Server) GetKey() string {
+	return svr.Key
 }
 
 // GetCACert returns vpn server's cacert.
-func (s *Server) GetCACert() string {
-	return s.CACert
+func (svr *Server) GetCACert() string {
+	return svr.CACert
 }
 
 // GetCAKey returns vpn server's cakey.
-func (s *Server) GetCAKey() string {
-	return s.CAKey
+func (svr *Server) GetCAKey() string {
+	return svr.CAKey
 }
 
 // GetNet returns vpn server's net.
-func (s *Server) GetNet() string {
-	return s.Net
+func (svr *Server) GetNet() string {
+	return svr.Net
 }
 
 // GetMask returns vpn server's mask.
-func (s *Server) GetMask() string {
-	return s.Mask
+func (svr *Server) GetMask() string {
+	return svr.Mask
 }
 
 // GetCRL returns vpn server's crl.
-func (s *Server) GetCRL() string {
-	return s.CRL
+func (svr *Server) GetCRL() string {
+	return svr.CRL
 }
 
 // GetDNS returns vpn server's dns.
-func (s *Server) GetDNS() string {
-	if s.DNS != "" {
-		return s.DNS
+func (svr *Server) GetDNS() string {
+	if svr.DNS != "" {
+		return svr.DNS
 	}
 	return DefaultVPNDNS
 }
 
 // GetCreatedAt returns server's created at.
-func (s *Server) GetCreatedAt() string {
-	return s.CreatedAt.Format(time.UnixDate)
+func (svr *Server) GetCreatedAt() string {
+	return svr.CreatedAt.Format(time.UnixDate)
 }
 
 // Init regenerates keys and certs for a Root CA, gets initial settings for the VPN server
@@ -152,7 +181,7 @@ func (s *Server) GetCreatedAt() string {
 //
 // Please note that, Init is potentially destructive procedure, it will cause invalidation of
 // existing .ovpn profiles of the current users. So it should be used carefully.
-func Init(hostname string, port string, proto string, ipblock string, dns string) error {
+func (svr *Server) Init(hostname string, port string, proto string, ipblock string, dns string) error {
 	if port == "" {
 		port = DefaultVPNPort
 	}
@@ -204,8 +233,8 @@ func Init(hostname string, port string, proto string, ipblock string, dns string
 	}
 
 	serverName := "default"
-	if IsInitialized() {
-		if err := Deinit(); err != nil {
+	if svr := TheServer(); svr.IsInitialized() {
+		if err := svr.Deinit(); err != nil {
 			logrus.Errorf("server can not be deleted: %v", err)
 			return err
 		}
@@ -268,40 +297,35 @@ func Init(hostname string, port string, proto string, ipblock string, dns string
 		user.HostID = 0
 		db.Save(&user.dbUserModel)
 	}
-	EmitWithRestart()
+	TheServer().EmitWithRestart()
 	logrus.Infof("server initialized")
 	return nil
 }
 
 // Update updates VPN server attributes.
-func Update(ipblock string, dns string) error {
-	if !IsInitialized() {
+func (svr *Server) Update(ipblock string, dns string) error {
+	if !svr.IsInitialized() {
 		return fmt.Errorf("server is not initialized")
 	}
 
-	server, err := GetServerInstance()
-	if err != nil {
-		return err
-	}
-
 	var changed bool
 	if ipblock != "" && govalidator.IsCIDR(ipblock) {
 		var ipnet *net.IPNet
-		_, ipnet, err = net.ParseCIDR(ipblock)
+		_, ipnet, err := net.ParseCIDR(ipblock)
 		if err != nil {
 			return fmt.Errorf("can not parse CIDR %s: %v", ipblock, err)
 		}
-		server.dbServerModel.Net = ipnet.IP.To4().String()
-		server.dbServerModel.Mask = net.IP(ipnet.Mask).To4().String()
+		svr.dbServerModel.Net = ipnet.IP.To4().String()
+		svr.dbServerModel.Mask = net.IP(ipnet.Mask).To4().String()
 		changed = true
 	}
 
 	if dns != "" && govalidator.IsIPv4(dns) {
-		server.dbServerModel.DNS = dns
+		svr.dbServerModel.DNS = dns
 		changed = true
 	}
 	if changed {
-		db.Save(server.dbServerModel)
+		db.Save(svr.dbServerModel)
 		users, err := GetAllUsers()
 		if err != nil {
 			return err
@@ -314,37 +338,32 @@ func Update(ipblock string, dns string) error {
 			db.Save(user.dbUserModel)
 		}
 
-		EmitWithRestart()
+		svr.EmitWithRestart()
 		logrus.Infof("server updated")
 	}
 	return nil
 }
 
 // Deinit deletes the VPN server from the database and frees the allocated resources.
-func Deinit() error {
-	if !IsInitialized() {
+func (svr *Server) Deinit() error {
+	if !svr.IsInitialized() {
 		return fmt.Errorf("server not found")
 	}
 
 	db.Unscoped().Delete(&dbServerModel{})
 	db.Unscoped().Delete(&dbRevokedModel{})
-	EmitWithRestart()
+	svr.EmitWithRestart()
 	return nil
 }
 
 // DumpsClientConfig generates .ovpn file for the given vpn user and returns it as a string.
-func DumpsClientConfig(username string) (string, error) {
+func (svr *Server) DumpsClientConfig(username string) (string, error) {
 	var result bytes.Buffer
 	user, err := GetUser(username)
 	if err != nil {
 		return "", err
 	}
 
-	server, err := GetServerInstance()
-	if err != nil {
-		return "", err
-	}
-
 	params := struct {
 		Hostname string
 		Port     string
@@ -354,13 +373,13 @@ func DumpsClientConfig(username string) (string, error) {
 		NoGW     bool
 		Proto    string
 	}{
-		Hostname: server.GetHostname(),
-		Port:     server.GetPort(),
-		CA:       server.GetCACert(),
+		Hostname: svr.GetHostname(),
+		Port:     svr.GetPort(),
+		CA:       svr.GetCACert(),
 		Key:      user.getKey(),
 		Cert:     user.GetCert(),
 		NoGW:     user.IsNoGW(),
-		Proto:    server.GetProto(),
+		Proto:    svr.GetProto(),
 	}
 	data, err := bindata.Asset("template/client.ovpn.tmpl")
 	if err != nil {
@@ -381,18 +400,18 @@ func DumpsClientConfig(username string) (string, error) {
 }
 
 // DumpClientConfig generates .ovpn file for the given vpn user and dumps it to outPath.
-func DumpClientConfig(username, path string) error {
-	result, err := DumpsClientConfig(username)
+func (svr *Server) DumpClientConfig(username, path string) error {
+	result, err := svr.DumpsClientConfig(username)
 	if err != nil {
 		return err
 	}
 	// Wite rendered content into openvpn server conf.
-	return emitToFile(path, result, 0)
+	return svr.emitToFile(path, result, 0)
 
 }
 
 // GetSystemCA returns the system CA from the database if available.
-func GetSystemCA() (*pki.CA, error) {
+func (svr *Server) GetSystemCA() (*pki.CA, error) {
 	server := dbServerModel{}
 	db.First(&server)
 	if db.NewRecord(&server) {
@@ -411,8 +430,8 @@ func GetSystemCA() (*pki.CA, error) {
 var vpnProc supervisor.Supervisable
 
 // StartVPNProc starts the OpenVPN process.
-func StartVPNProc() {
-	if !IsInitialized() {
+func (svr *Server) StartVPNProc() {
+	if !svr.IsInitialized() {
 		logrus.Error("can not launch OpenVPN because system is not initialized")
 		return
 	}
@@ -423,27 +442,27 @@ func StartVPNProc() {
 		logrus.Error("OpenVPN is already started")
 		return
 	}
-	Emit()
+	svr.Emit()
 	vpnProc.Start()
 	ensureNatEnabled()
 }
 
 // RestartVPNProc restarts the OpenVPN process.
-func RestartVPNProc() {
-	if !IsInitialized() {
+func (svr *Server) RestartVPNProc() {
+	if !svr.IsInitialized() {
 		logrus.Error("can not launch OpenVPN because system is not initialized")
 		return
 	}
 	if vpnProc == nil {
 		panic(fmt.Sprintf("vpnProc is not initialized!"))
 	}
-	Emit()
+	svr.Emit()
 	vpnProc.Restart()
 	ensureNatEnabled()
 }
 
 // StopVPNProc stops the OpenVPN process.
-func StopVPNProc() {
+func (svr *Server) StopVPNProc() {
 	if vpnProc == nil {
 		panic(fmt.Sprintf("vpnProc is not initialized!"))
 	}
@@ -455,7 +474,7 @@ func StopVPNProc() {
 }
 
 // Emit generates all needed files for the OpenVPN server and dumps them to their corresponding paths defined in the config.
-func Emit() error {
+func (svr *Server) Emit() error {
 	// Check dependencies
 	if !checkOpenVPNExecutable() {
 		return fmt.Errorf("openvpn executable can not be found! you should install OpenVPN on this machine")
@@ -470,43 +489,43 @@ func Emit() error {
 		return fmt.Errorf("iptables executable can not be found")
 	}
 
-	if !IsInitialized() {
+	if !svr.IsInitialized() {
 		return fmt.Errorf("you should create a server first. e.g. $ ovpm vpn create-server")
 	}
 
-	if err := emitServerConf(); err != nil {
+	if err := svr.emitServerConf(); err != nil {
 		return fmt.Errorf("can not emit server conf: %s", err)
 	}
 
-	if err := emitServerCert(); err != nil {
+	if err := svr.emitServerCert(); err != nil {
 		return fmt.Errorf("can not emit server cert: %s", err)
 	}
 
-	if err := emitServerKey(); err != nil {
+	if err := svr.emitServerKey(); err != nil {
 		return fmt.Errorf("can not emit server key: %s", err)
 	}
 
-	if err := emitCACert(); err != nil {
+	if err := svr.emitCACert(); err != nil {
 		return fmt.Errorf("can not emit ca cert : %s", err)
 	}
 
-	if err := emitCAKey(); err != nil {
+	if err := svr.emitCAKey(); err != nil {
 		return fmt.Errorf("can not emit ca key: %s", err)
 	}
 
-	if err := emitDHParams(); err != nil {
+	if err := svr.emitDHParams(); err != nil {
 		return fmt.Errorf("can not emit dhparams: %s", err)
 	}
 
-	if err := emitCCD(); err != nil {
+	if err := svr.emitCCD(); err != nil {
 		return fmt.Errorf("can not emit ccd: %s", err)
 	}
 
-	if err := emitIptables(); err != nil {
+	if err := svr.emitIptables(); err != nil {
 		return fmt.Errorf("can not emit iptables: %s", err)
 	}
 
-	if err := emitCRL(); err != nil {
+	if err := svr.emitCRL(); err != nil {
 		return fmt.Errorf("can not emit crl: %s", err)
 	}
 
@@ -515,16 +534,15 @@ func Emit() error {
 }
 
 // EmitWithRestart restarts vpnProc after calling EmitWithRestart().
-func EmitWithRestart() error {
-	err := Emit()
-	if err != nil {
+func (svr *Server) EmitWithRestart() error {
+	if err := svr.Emit(); err != nil {
 		return err
 	}
-	if IsInitialized() {
+	if svr.IsInitialized() {
 		for {
 			if vpnProc.Status() == supervisor.RUNNING || vpnProc.Status() == supervisor.STOPPED {
 				logrus.Info("OpenVPN process is restarting")
-				RestartVPNProc()
+				svr.RestartVPNProc()
 				break
 			}
 			time.Sleep(1 * time.Second)
@@ -535,6 +553,12 @@ func EmitWithRestart() error {
 
 }
 
+// emitToFile is a proxy that calls svr.emitToFileFunc.
+func (svr *Server) emitToFile(path, content string, mode uint) error {
+	return svr.emitToFileFunc(path, content, mode)
+}
+
+// emitToFile is an implementation for svr.emitToFileFunc.
 func emitToFile(path, content string, mode uint) error {
 	// When testing don't emit files to the filesystem. Just pretend you did.
 	if Testing {
@@ -553,16 +577,7 @@ func emitToFile(path, content string, mode uint) error {
 	return nil
 }
 
-func emitServerConf() error {
-	dbServer, err := GetServerInstance()
-	if err != nil {
-		return fmt.Errorf("can not get server instance: %v", err)
-	}
-
-	serverInstance, err := GetServerInstance()
-	if err != nil {
-		return fmt.Errorf("can not retrieve server: %v", err)
-	}
+func (svr *Server) emitServerConf() error {
 	port := DefaultVPNPort
 	if serverInstance.Port != "" {
 		port = serverInstance.Port
@@ -601,8 +616,8 @@ func emitServerConf() error {
 		CCDPath:      _DefaultVPNCCDPath,
 		CRLPath:      _DefaultCRLPath,
 		DHParamsPath: _DefaultDHParamsPath,
-		Net:          dbServer.Net,
-		Mask:         dbServer.Mask,
+		Net:          svr.Net,
+		Mask:         svr.Mask,
 		Port:         port,
 		Proto:        proto,
 		DNS:          dns,
@@ -623,26 +638,32 @@ func emitServerConf() error {
 	}
 
 	// Wite rendered content into openvpn server conf.
-	return emitToFile(_DefaultVPNConfPath, result.String(), 0)
+	return svr.emitToFile(_DefaultVPNConfPath, result.String(), 0)
 }
 
-// GetServerInstance returns the default server from the database.
-func GetServerInstance() (*Server, error) {
-	var server dbServerModel
-	db.First(&server)
-	if db.NewRecord(server) {
-		return nil, fmt.Errorf("can not retrieve server from db")
+// Refresh synchronizes the server instance from db.
+func (svr *Server) Refresh() error {
+	//db = CreateDB("sqlite3", "")
+	var dbServer dbServerModel
+	fmt.Println(db)
+	q := db.First(&dbServer)
+	if err := q.Error; err != nil {
+		return fmt.Errorf("can't get server from db: %v", err)
 	}
-	return &Server{dbServerModel: server}, nil
+	if q.RecordNotFound() {
+		return fmt.Errorf("server is not initialized")
+	}
+	svr.dbServerModel = dbServer
+	return nil
 }
 
 // GetConnectedUsers will return a list of users who are currently connected
 // to the VPN service.
-func GetConnectedUsers() ([]User, error) {
+func (svr *Server) GetConnectedUsers() ([]User, error) {
 	var users []User
 
 	// Open the status log file.
-	f, err := os.Open(_DefaultStatusLogPath)
+	f, err := svr.openFunc(_DefaultStatusLogPath)
 	if err != nil {
 		panic(err)
 	}
@@ -673,36 +694,29 @@ func GetConnectedUsers() ([]User, error) {
 }
 
 // IsInitialized checks if there is a default VPN server configured in the database or not.
-func IsInitialized() bool {
-	var server dbServerModel
-	db.First(&server)
-	if db.NewRecord(server) {
+func (svr *Server) IsInitialized() bool {
+	var serverModel dbServerModel
+	q := db.First(&serverModel)
+	if err := q.Error; err != nil {
+		logrus.Errorf("can't retrieve server from db: %v", err)
+	}
+	if q.RecordNotFound() {
 		return false
 	}
 	return true
 }
 
-func emitServerKey() error {
-	server, err := GetServerInstance()
-	if err != nil {
-		return err
-	}
-
+func (svr *Server) emitServerKey() error {
 	// Write rendered content into key file.
-	return emitToFile(_DefaultKeyPath, server.Key, 0600)
+	return svr.emitToFile(_DefaultKeyPath, svr.Key, 0600)
 }
 
-func emitServerCert() error {
-	server, err := GetServerInstance()
-	if err != nil {
-		return err
-	}
-
+func (svr *Server) emitServerCert() error {
 	// Write rendered content into the cert file.
-	return emitToFile(_DefaultCertPath, server.Cert, 0)
+	return svr.emitToFile(_DefaultCertPath, svr.Cert, 0)
 }
 
-func emitCRL() error {
+func (svr *Server) emitCRL() error {
 	var revokedDBItems []*dbRevokedModel
 	db.Find(&revokedDBItems)
 	var revokedCertSerials []*big.Int
@@ -711,7 +725,7 @@ func emitCRL() error {
 		bi.SetString(item.SerialNumber, 16)
 		revokedCertSerials = append(revokedCertSerials, bi)
 	}
-	systemCA, err := GetSystemCA()
+	systemCA, err := svr.GetSystemCA()
 	if err != nil {
 		return fmt.Errorf("can not emit CRL: %v", err)
 	}
@@ -720,30 +734,20 @@ func emitCRL() error {
 		return fmt.Errorf("can not emit crl: %v", err)
 	}
 
-	return emitToFile(_DefaultCRLPath, crl, 0)
+	return svr.emitToFile(_DefaultCRLPath, crl, 0)
 }
 
-func emitCACert() error {
-	server, err := GetServerInstance()
-	if err != nil {
-		return err
-	}
-
+func (svr *Server) emitCACert() error {
 	// Write rendered content into the ca cert file.
-	return emitToFile(_DefaultCACertPath, server.CACert, 0)
+	return svr.emitToFile(_DefaultCACertPath, svr.CACert, 0)
 }
 
-func emitCAKey() error {
-	server, err := GetServerInstance()
-	if err != nil {
-		return err
-	}
-
+func (svr *Server) emitCAKey() error {
 	// Write rendered content into the ca key file.
-	return emitToFile(_DefaultCAKeyPath, server.CAKey, 0600)
+	return svr.emitToFile(_DefaultCAKeyPath, svr.CAKey, 0600)
 }
 
-func emitCCD() error {
+func (svr *Server) emitCCD() error {
 	users, err := GetAllUsers()
 	if err != nil {
 		return err
@@ -770,11 +774,6 @@ func emitCCD() error {
 			}
 		}
 	}
-	server, err := GetServerInstance()
-	if err != nil {
-		return fmt.Errorf("can not get server instance: %v", err)
-	}
-
 	// Render ccd templates for the users.
 	for _, user := range users {
 		var associatedRoutes [][3]string
@@ -814,7 +813,7 @@ func emitCCD() error {
 			Routes     [][3]string // [0] is IP, [1] is Netmask, [2] is Via
 			Servernets [][2]string // [0] is IP, [1] is Netmask
 			RedirectGW bool
-		}{IP: user.getIP().String(), NetMask: server.Mask, Routes: associatedRoutes, Servernets: serverNets, RedirectGW: !user.NoGW}
+		}{IP: user.getIP().String(), NetMask: svr.Mask, Routes: associatedRoutes, Servernets: serverNets, RedirectGW: !user.NoGW}
 
 		data, err := bindata.Asset("template/ccd.file.tmpl")
 		if err != nil {
@@ -829,16 +828,14 @@ func emitCCD() error {
 		if err != nil {
 			return fmt.Errorf("can not render ccd file %s: %s", user.Username, err)
 		}
-
-		err = emitToFile(filepath.Join(_DefaultVPNCCDPath, user.Username), result.String(), 0)
-		if err != nil {
+		if err = svr.emitToFile(filepath.Join(_DefaultVPNCCDPath, user.Username), result.String(), 0); err != nil {
 			return err
 		}
 	}
 	return nil
 }
 
-func emitDHParams() error {
+func (svr *Server) emitDHParams() error {
 	var result bytes.Buffer
 	data, err := bindata.Asset("template/dh4096.pem.tmpl")
 	if err != nil {
@@ -855,14 +852,10 @@ func emitDHParams() error {
 		return fmt.Errorf("can not render dh4096.pem file: %s", err)
 	}
 
-	err = emitToFile(_DefaultDHParamsPath, result.String(), 0)
-	if err != nil {
-		return err
-	}
-	return nil
+	return svr.emitToFile(_DefaultDHParamsPath, result.String(), 0)
 }
 
-func emitIptables() error {
+func (svr *Server) emitIptables() error {
 	if Testing {
 		return nil
 	}

+ 121 - 66
vpn_test.go

@@ -1,11 +1,14 @@
 package ovpm
 
 import (
+	"bytes"
+	"fmt"
+	"io"
+	"reflect"
 	"strings"
 	"testing"
 
 	"github.com/Sirupsen/logrus"
-	"github.com/bouk/monkey"
 	"github.com/cad/ovpm/supervisor"
 )
 
@@ -35,13 +38,13 @@ func TestVPNInit(t *testing.T) {
 	}
 
 	// Wrongfully initialize server.
-	err := Init("localhost", "asdf", UDPProto, "", "")
-	if err == nil {
+
+	if err := TheServer().Init("localhost", "asdf", UDPProto, "", ""); err == nil {
 		t.Fatalf("error is expected to be not nil but it's nil instead")
 	}
 
 	// Initialize the server.
-	Init("localhost", "", UDPProto, "", "")
+	TheServer().Init("localhost", "", UDPProto, "", "")
 
 	// Check database if the database has no server.
 	var server2 dbServerModel
@@ -61,7 +64,7 @@ func TestVPNDeinit(t *testing.T) {
 
 	// Prepare:
 	// Initialize the server.
-	Init("localhost", "", UDPProto, "", "")
+	TheServer().Init("localhost", "", UDPProto, "", "")
 	u, err := CreateNewUser("user", "p", false, 0, true)
 	if err != nil {
 		t.Fatal(err)
@@ -86,7 +89,7 @@ func TestVPNDeinit(t *testing.T) {
 	}
 
 	// Deinitialize.
-	Deinit()
+	TheServer().Deinit()
 
 	// Get server from db.
 	var server2 dbServerModel
@@ -112,7 +115,7 @@ func TestVPNUpdate(t *testing.T) {
 	CreateDB("sqlite3", ":memory:")
 	defer db.Cease()
 	// Prepare:
-	Init("localhost", "", UDPProto, "", "")
+	TheServer().Init("localhost", "", UDPProto, "", "")
 	// Test:
 
 	var updatetests = []struct {
@@ -127,20 +130,17 @@ func TestVPNUpdate(t *testing.T) {
 		{"9.9.9.0/24", "1.1.1.1", true, true},
 	}
 	for _, tt := range updatetests {
-		server, err := GetServerInstance()
-		if err != nil {
-			t.Fatal(err)
-		}
-
-		oldIP := server.Net
-		oldDNS := server.DNS
-		Update(tt.vpnnet, tt.dns)
-		server = nil
-		server, err = GetServerInstance()
-		if (server.Net != oldIP) != tt.vpnChanged {
+		svr := TheServer()
+
+		oldIP := svr.Net
+		oldDNS := svr.DNS
+		svr.Update(tt.vpnnet, tt.dns)
+		svr = nil
+		svr = TheServer()
+		if (svr.Net != oldIP) != tt.vpnChanged {
 			t.Fatalf("expected vpn change: %t but opposite happened", tt.vpnChanged)
 		}
-		if (server.DNS != oldDNS) != tt.dnsChanged {
+		if (svr.DNS != oldDNS) != tt.dnsChanged {
 			t.Fatalf("expected vpn change: %t but opposite happened", tt.dnsChanged)
 		}
 	}
@@ -157,20 +157,20 @@ func TestVPNIsInitialized(t *testing.T) {
 
 	// Test:
 	// Is initialized?
-	if IsInitialized() {
+	if TheServer().IsInitialized() {
 		t.Fatalf("IsInitialized() is expected to return false but it returned true")
 	}
 
 	// Initialize the server.
-	Init("localhost", "", UDPProto, "", "")
+	TheServer().Init("localhost", "", UDPProto, "", "")
 
 	// Isn't initialized?
-	if !IsInitialized() {
+	if !TheServer().IsInitialized() {
 		t.Fatalf("IsInitialized() is expected to return true but it returned false")
 	}
 }
 
-func TestVPNGetServerInstance(t *testing.T) {
+func TestVPNTheServer(t *testing.T) {
 	// Init:
 	setupTestCase()
 	CreateDB("sqlite3", ":memory:")
@@ -179,31 +179,21 @@ func TestVPNGetServerInstance(t *testing.T) {
 	// Prepare:
 
 	// Test:
-	server, err := GetServerInstance()
-
-	// Is it nil?
-	if err == nil {
-		t.Fatalf("GetServerInstance() is expected to give error since server is not initialized yet, but it gave no error instead")
-	}
+	svr := TheServer()
 
 	// Isn't server nil?
-	if server != nil {
-		t.Fatal("server is expected to be nil but it's not")
+	if svr.IsInitialized() {
+		t.Fatal("server is expected to be not initialized it is")
 	}
 
 	// Initialize server.
-	Init("localhost", "", UDPProto, "", "")
+	svr.Init("localhost", "", UDPProto, "", "")
 
-	server, err = GetServerInstance()
-
-	// Isn't it nil?
-	if err != nil {
-		t.Fatalf("GetServerInstance() is expected to give no error since server is initialized yet, but it gave error instead")
-	}
+	svr = TheServer()
 
 	// Is server nil?
-	if server == nil {
-		t.Fatal("server is expected to be not nil but it is")
+	if !svr.IsInitialized() {
+		t.Fatal("server is expected to be initialized but it's not")
 	}
 }
 
@@ -212,13 +202,14 @@ func TestVPNDumpsClientConfig(t *testing.T) {
 	setupTestCase()
 	CreateDB("sqlite3", ":memory:")
 	defer db.Cease()
-	Init("localhost", "", UDPProto, "", "")
+	svr := TheServer()
+	svr.Init("localhost", "", UDPProto, "", "")
 
 	// Prepare:
 	user, _ := CreateNewUser("user", "password", false, 0, true)
 
 	// Test:
-	clientConfigBlob, err := DumpsClientConfig(user.GetUsername())
+	clientConfigBlob, err := svr.DumpsClientConfig(user.GetUsername())
 	if err != nil {
 		t.Fatalf("expected to dump client config but we got error instead: %v", err)
 	}
@@ -234,7 +225,8 @@ func TestVPNDumpClientConfig(t *testing.T) {
 	setupTestCase()
 	CreateDB("sqlite3", ":memory:")
 	defer db.Cease()
-	Init("localhost", "", UDPProto, "", "")
+	svr := TheServer()
+	svr.Init("localhost", "", UDPProto, "", "")
 
 	// Prepare:
 	noGW := false
@@ -244,8 +236,7 @@ func TestVPNDumpClientConfig(t *testing.T) {
 	}
 
 	// Test:
-	err = DumpClientConfig(user.GetUsername(), "/tmp/user.ovpn")
-	if err != nil {
+	if err = svr.DumpClientConfig(user.GetUsername(), "/tmp/user.ovpn"); err != nil {
 		t.Fatalf("expected to dump client config but we got error instead: %v", err)
 	}
 
@@ -271,8 +262,7 @@ func TestVPNDumpClientConfig(t *testing.T) {
 		t.Fatalf("can not create user: %v", err)
 	}
 
-	err = DumpClientConfig(user.GetUsername(), "/tmp/user.ovpn")
-	if err != nil {
+	if err = TheServer().DumpClientConfig(user.GetUsername(), "/tmp/user.ovpn"); err != nil {
 		t.Fatalf("expected to dump client config but we got error instead: %v", err)
 	}
 
@@ -288,17 +278,18 @@ func TestVPNGetSystemCA(t *testing.T) {
 	defer db.Cease()
 
 	// Prepare:
+	svr := TheServer()
 
 	// Test:
-	ca, err := GetSystemCA()
+	ca, err := svr.GetSystemCA()
 	if err == nil {
 		t.Fatalf("GetSystemCA() is expected to give error but it didn't instead")
 	}
 
 	// Initialize system.
-	Init("localhost", "", UDPProto, "", "")
+	svr.Init("localhost", "", UDPProto, "", "")
 
-	ca, err = GetSystemCA()
+	ca, err = svr.GetSystemCA()
 	if err != nil {
 		t.Fatalf("GetSystemCA() is expected to get system ca, but it gave us an error instead: %v", err)
 	}
@@ -318,6 +309,7 @@ func TestVPNStartVPNProc(t *testing.T) {
 	setupTestCase()
 	CreateDB("sqlite3", ":memory:")
 	defer db.Cease()
+	svr := TheServer()
 
 	// Prepare:
 
@@ -328,7 +320,7 @@ func TestVPNStartVPNProc(t *testing.T) {
 	}
 
 	// Call start without server initialization.
-	StartVPNProc()
+	svr.StartVPNProc()
 
 	// Isn't it still stopped?
 	if vpnProc.Status() != supervisor.STOPPED {
@@ -336,10 +328,10 @@ func TestVPNStartVPNProc(t *testing.T) {
 	}
 
 	// Initialize OVPM server.
-	Init("localhost", "", UDPProto, "", "")
+	svr.Init("localhost", "", UDPProto, "", "")
 
 	// Call start again..
-	StartVPNProc()
+	svr.StartVPNProc()
 
 	// Isn't it RUNNING?
 	if vpnProc.Status() != supervisor.RUNNING {
@@ -352,7 +344,8 @@ func TestVPNStopVPNProc(t *testing.T) {
 	setupTestCase()
 	CreateDB("sqlite3", ":memory:")
 	defer db.Cease()
-	Init("localhost", "", UDPProto, "", "")
+	svr := TheServer()
+	svr.Init("localhost", "", UDPProto, "", "")
 
 	// Prepare:
 	vpnProc.Start()
@@ -364,7 +357,7 @@ func TestVPNStopVPNProc(t *testing.T) {
 	}
 
 	// Call stop.
-	StopVPNProc()
+	svr.StopVPNProc()
 
 	// Isn't it stopped?
 	if vpnProc.Status() != supervisor.STOPPED {
@@ -376,15 +369,15 @@ func TestVPNRestartVPNProc(t *testing.T) {
 	// Init:
 	CreateDB("sqlite3", ":memory:")
 	defer db.Cease()
-	Init("localhost", "", UDPProto, "", "")
+	svr := TheServer()
+	svr.Init("localhost", "", UDPProto, "", "")
 
 	// Prepare:
 
 	// Test:
 
 	// Call restart.
-
-	RestartVPNProc()
+	svr.RestartVPNProc()
 
 	// Isn't it running?
 	if vpnProc.Status() != supervisor.RUNNING {
@@ -392,7 +385,7 @@ func TestVPNRestartVPNProc(t *testing.T) {
 	}
 
 	// Call restart again.
-	RestartVPNProc()
+	svr.RestartVPNProc()
 
 	// Isn't it running?
 	if vpnProc.Status() != supervisor.RUNNING {
@@ -405,12 +398,13 @@ func TestVPNEmit(t *testing.T) {
 	setupTestCase()
 	CreateDB("sqlite3", ":memory:")
 	defer db.Cease()
-	Init("localhost", "", UDPProto, "", "")
+	svr := TheServer()
+	svr.Init("localhost", "", UDPProto, "", "")
 
 	// Prepare:
 
 	// Test:
-	Emit()
+	svr.Emit()
 
 	var emittests = []string{
 		_DefaultVPNConfPath,
@@ -433,6 +427,7 @@ func TestVPNEmit(t *testing.T) {
 
 func TestVPNemitToFile(t *testing.T) {
 	// Initialize:
+
 	// Prepare:
 	path := "/test/file"
 	content := "blah blah blah"
@@ -444,8 +439,8 @@ func TestVPNemitToFile(t *testing.T) {
 	}
 
 	// Emit the contents.
-	err := emitToFile(path, content, 0)
-	if err != nil {
+
+	if err := TheServer().emitToFile(path, content, 0); err != nil {
 		t.Fatalf("expected  to be able to emit to the filesystem but we got this error instead: %v", err)
 	}
 
@@ -475,15 +470,75 @@ func (f *fakeProcess) Status() supervisor.State {
 	return f.state
 }
 
+func TestGetConnectedUsers(t *testing.T) {
+	// Init:
+	setupTestCase()
+	CreateDB("sqlite3", ":memory:")
+	defer db.Cease()
+	svr := TheServer()
+	svr.Init("localhost", "", UDPProto, "", "")
+	// Change openr from os.Open to our mock file injecter.
+	svr.openFunc = func(path string) (io.Reader, error) {
+		const testLog = `
+OpenVPN CLIENT LIST
+Updated,Mon Mar 26 13:26:10 2018
+Common Name,Real Address,Bytes Received,Bytes Sent,Connected Since
+google.DNS,8.8.8.8:53246,527914279,3204562859,Sat Mar 17 16:26:38 2018
+google1.DNS,8.8.4.4:33974,42727443,291595456,Mon Mar 26 08:24:08 2018
+ROUTING TABLE
+Virtual Address,Common Name,Real Address,Last Ref
+10.20.30.6,google.DNS,8.8.8.8:33974,Mon Mar 26 13:26:04 2018
+10.20.30.5,google1.DNS,8.8.4.4:53246,Mon Mar 26 13:25:57 2018
+GLOBAL STATS
+Max bcast/mcast queue length,4
+END
+`
+		return bytes.NewBufferString(testLog), nil
+	}
+
+	// Create the corresponding users for test.
+	if _, err := CreateNewUser("usr1", "1234", true, 0, false); err != nil {
+		t.Fatalf("user creation failed: %v", err)
+	}
+	if _, err := CreateNewUser("usr2", "1234", true, 0, false); err != nil {
+		t.Fatalf("user creation failed: %v", err)
+	}
+
+	// Test:
+	tests := []struct {
+		name    string
+		want    []User
+		wantErr bool
+	}{
+		// TODO: Add test cases.
+	}
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			got, err := TheServer().GetConnectedUsers()
+			if (err != nil) != tt.wantErr {
+				t.Errorf("GetConnectedUsers() error = %v, wantErr %v", err, tt.wantErr)
+				return
+			}
+			if !reflect.DeepEqual(got, tt.want) {
+				t.Errorf("GetConnectedUsers() = %v, want %v", got, tt.want)
+			}
+		})
+	}
+}
+
 func init() {
 	// Init
 	Testing = true
 	fs = make(map[string]string)
+
+	CreateDB("sqlite3", ":memory:")
+	defer db.Cease()
+
 	// Monkeypatch emitToFile()
-	monkey.Patch(emitToFile, func(path, content string, mode uint) error {
+	fmt.Println(TheServer())
+	TheServer().emitToFileFunc = func(path, content string, mode uint) error {
 		fs[path] = content
 		return nil
-	})
-
+	}
 	vpnProc = &fakeProcess{state: supervisor.STOPPED}
 }