Browse Source

refactor(cmd/ovpm): add structure && impr. testablity

- Prevent running `os.Exit(1)` when running in tests.
- Move main cli app under it's own factory function `NewApp`.
Mustafa Arici 8 years ago
parent
commit
d5763d7be5
4 changed files with 60 additions and 47 deletions
  1. 14 1
      cmd/ovpm/main.go
  2. 14 14
      cmd/ovpm/net.go
  3. 19 19
      cmd/ovpm/user.go
  4. 13 13
      cmd/ovpm/vpn.go

+ 14 - 1
cmd/ovpm/main.go

@@ -1,6 +1,7 @@
 package main
 
 import (
+	"flag"
 	"os"
 
 	"github.com/Sirupsen/logrus"
@@ -10,7 +11,7 @@ import (
 
 var action string
 
-func main() {
+func NewApp() *cli.App {
 	app := cli.NewApp()
 	app.Name = "ovpm"
 	app.Usage = "OpenVPN Manager"
@@ -72,6 +73,10 @@ func main() {
 			},
 		},
 	}
+	return app
+}
+func main() {
+	app := NewApp()
 	app.Run(os.Args)
 }
 
@@ -83,3 +88,11 @@ func stringInSlice(a string, list []string) bool {
 	}
 	return false
 }
+
+func exit(status int) {
+	if flag.Lookup("test.v") == nil {
+		os.Exit(status)
+	} else {
+
+	}
+}

+ 14 - 14
cmd/ovpm/net.go

@@ -44,7 +44,7 @@ var netDefineCommand = cli.Command{
 
 		if name == "" || cidr == "" || typ == "" {
 			fmt.Println(cli.ShowSubcommandHelp(c))
-			os.Exit(1)
+			exit(1)
 		}
 
 		switch ovpm.NetworkTypeFromString(typ) {
@@ -53,7 +53,7 @@ var netDefineCommand = cli.Command{
 				fmt.Printf("validation error: `%s` must be a network in the IPv4 form", via)
 				fmt.Println()
 				fmt.Println(cli.ShowSubcommandHelp(c))
-				os.Exit(1)
+				exit(1)
 			}
 
 		case ovpm.SERVERNET:
@@ -61,7 +61,7 @@ var netDefineCommand = cli.Command{
 				fmt.Println("--via flag can only be used with --type ROUTE")
 				fmt.Println()
 				fmt.Println(cli.ShowSubcommandHelp(c))
-				os.Exit(1)
+				exit(1)
 			}
 		default: // Means UNDEFINEDNET
 			fmt.Printf("undefined network type %s", typ)
@@ -70,7 +70,7 @@ var netDefineCommand = cli.Command{
 			fmt.Println("    ", ovpm.GetAllNetworkTypes())
 			fmt.Println()
 			fmt.Println(cli.ShowSubcommandHelp(c))
-			os.Exit(1)
+			exit(1)
 		}
 
 		conn := getConn(c.GlobalString("daemon-port"))
@@ -80,7 +80,7 @@ var netDefineCommand = cli.Command{
 		response, err := netSvc.Create(context.Background(), &pb.NetworkCreateRequest{Name: name, Cidr: cidr, Type: typ, Via: via})
 		if err != nil {
 			logrus.Errorf("network can not be created '%s': %v", name, err)
-			os.Exit(1)
+			exit(1)
 			return err
 		}
 		logrus.Infof("network created: %s (%s)", response.Network.Name, response.Network.Cidr)
@@ -101,7 +101,7 @@ var netListCommand = cli.Command{
 		resp, err := netSvc.List(context.Background(), &pb.NetworkListRequest{})
 		if err != nil {
 			logrus.Errorf("networks can not be fetched: %v", err)
-			os.Exit(1)
+			exit(1)
 			return err
 		}
 
@@ -114,7 +114,7 @@ var netListCommand = cli.Command{
 			assocUsers, err := netSvc.GetAssociatedUsers(context.Background(), &pb.NetworkGetAssociatedUsersRequest{Name: network.Name})
 			if err != nil {
 				logrus.Errorf("assoc users can not be fetched: %v", err)
-				os.Exit(1)
+				exit(1)
 				return err
 			}
 
@@ -157,7 +157,7 @@ var netTypesCommand = cli.Command{
 		resp, err := netSvc.GetAllTypes(context.Background(), &pb.NetworkGetAllTypesRequest{})
 		if err != nil {
 			logrus.Errorf("networks can not be fetched: %v", err)
-			os.Exit(1)
+			exit(1)
 			return err
 		}
 		table := tablewriter.NewWriter(os.Stdout)
@@ -189,7 +189,7 @@ var netUndefineCommand = cli.Command{
 
 		if name == "" {
 			fmt.Println(cli.ShowSubcommandHelp(c))
-			os.Exit(1)
+			exit(1)
 		}
 
 		conn := getConn(c.GlobalString("daemon-port"))
@@ -199,7 +199,7 @@ var netUndefineCommand = cli.Command{
 		resp, err := netSvc.Delete(context.Background(), &pb.NetworkDeleteRequest{Name: name})
 		if err != nil {
 			logrus.Errorf("networks can not be deleted: %v", err)
-			os.Exit(1)
+			exit(1)
 			return err
 		}
 		logrus.Infof("network deleted: %s (%s)", resp.Network.Name, resp.Network.Cidr)
@@ -230,7 +230,7 @@ var netAssociateCommand = cli.Command{
 
 		if netName == "" || userName == "" {
 			fmt.Println(cli.ShowSubcommandHelp(c))
-			os.Exit(1)
+			exit(1)
 		}
 
 		conn := getConn(c.GlobalString("daemon-port"))
@@ -240,7 +240,7 @@ var netAssociateCommand = cli.Command{
 		_, err := netSvc.Associate(context.Background(), &pb.NetworkAssociateRequest{Name: netName, Username: userName})
 		if err != nil {
 			logrus.Errorf("networks can not be associated: %v", err)
-			os.Exit(1)
+			exit(1)
 			return err
 		}
 		logrus.Infof("network associated: user:%s <-> network:%s", userName, netName)
@@ -271,7 +271,7 @@ var netDissociateCommand = cli.Command{
 
 		if netName == "" || userName == "" {
 			fmt.Println(cli.ShowSubcommandHelp(c))
-			os.Exit(1)
+			exit(1)
 		}
 
 		conn := getConn(c.GlobalString("daemon-port"))
@@ -281,7 +281,7 @@ var netDissociateCommand = cli.Command{
 		_, err := netSvc.Dissociate(context.Background(), &pb.NetworkDissociateRequest{Name: netName, Username: userName})
 		if err != nil {
 			logrus.Errorf("networks can not be dissociated: %v", err)
-			os.Exit(1)
+			exit(1)
 			return err
 		}
 		logrus.Infof("network dissociated: user:%s <-> network:%s", userName, netName)

+ 19 - 19
cmd/ovpm/user.go

@@ -28,14 +28,14 @@ var userListCommand = cli.Command{
 		server, err := vpnSvc.Status(context.Background(), &pb.VPNStatusRequest{})
 		if err != nil {
 			logrus.Errorf("can not get server status: %v", err)
-			os.Exit(1)
+			exit(1)
 			return err
 		}
 
 		resp, err := userSvc.List(context.Background(), &pb.UserListRequest{})
 		if err != nil {
 			logrus.Errorf("users can not be fetched: %v", err)
-			os.Exit(1)
+			exit(1)
 			return err
 		}
 		table := tablewriter.NewWriter(os.Stdout)
@@ -95,13 +95,13 @@ var userCreateCommand = cli.Command{
 
 		if username == "" || password == "" {
 			fmt.Println(cli.ShowSubcommandHelp(c))
-			os.Exit(1)
+			exit(1)
 		}
 		if static != "" && !govalidator.IsIPv4(static) {
 			fmt.Println("--static flag takes a valid ipv4 address")
 			fmt.Println()
 			fmt.Println(cli.ShowSubcommandHelp(c))
-			os.Exit(1)
+			exit(1)
 		}
 		var hostid uint32
 		if static != "" {
@@ -110,7 +110,7 @@ var userCreateCommand = cli.Command{
 				fmt.Printf("can not parse %s as IPv4", static)
 				fmt.Println()
 				fmt.Println(cli.ShowSubcommandHelp(c))
-				os.Exit(1)
+				exit(1)
 			}
 
 			hostid = h
@@ -126,7 +126,7 @@ var userCreateCommand = cli.Command{
 		)
 		if err != nil {
 			logrus.Errorf("user can not be created '%s': %v", username, err)
-			os.Exit(1)
+			exit(1)
 			return err
 		}
 		logrus.Infof("user created: %s", response.Users[0].Username)
@@ -185,7 +185,7 @@ var userUpdateCommand = cli.Command{
 
 		if username == "" {
 			fmt.Println(cli.ShowSubcommandHelp(c))
-			os.Exit(1)
+			exit(1)
 		}
 
 		// Check whether if all flags are are empty.
@@ -193,7 +193,7 @@ var userUpdateCommand = cli.Command{
 			fmt.Println("nothing is updated!")
 			fmt.Println()
 			fmt.Println(cli.ShowSubcommandHelp(c))
-			os.Exit(1)
+			exit(1)
 		}
 
 		// Given that static is set, check whether it's IPv4.
@@ -201,7 +201,7 @@ var userUpdateCommand = cli.Command{
 			fmt.Println("--static flag takes a valid ipv4 address")
 			fmt.Println()
 			fmt.Println(cli.ShowSubcommandHelp(c))
-			os.Exit(1)
+			exit(1)
 		}
 		var staticPref pb.UserUpdateRequest_StaticPref
 		staticPref = pb.UserUpdateRequest_NOPREFSTATIC
@@ -216,7 +216,7 @@ var userUpdateCommand = cli.Command{
 					fmt.Printf("can't parse %s as IPv4", static)
 					fmt.Println()
 					fmt.Println(cli.ShowSubcommandHelp(c))
-					os.Exit(1)
+					exit(1)
 				}
 
 				hostid = h
@@ -232,7 +232,7 @@ var userUpdateCommand = cli.Command{
 			fmt.Println("--static flag and --no-static flag cannot be used together")
 			fmt.Println()
 			fmt.Println(cli.ShowSubcommandHelp(c))
-			os.Exit(1)
+			exit(1)
 		case static == "" && !noStatic:
 		default:
 			// means no pref
@@ -252,7 +252,7 @@ var userUpdateCommand = cli.Command{
 			fmt.Println("you can't use --gw together with --no-gw")
 			fmt.Println()
 			fmt.Println(cli.ShowSubcommandHelp(c))
-			os.Exit(1)
+			exit(1)
 		default:
 			gwPref = pb.UserUpdateRequest_NOPREF
 
@@ -272,7 +272,7 @@ var userUpdateCommand = cli.Command{
 			fmt.Println("you can't use --admin together with --no-admin")
 			fmt.Println()
 			fmt.Println(cli.ShowSubcommandHelp(c))
-			os.Exit(1)
+			exit(1)
 		}
 
 		//conn := getConn(c.String("port"))
@@ -291,7 +291,7 @@ var userUpdateCommand = cli.Command{
 
 		if err != nil {
 			logrus.Errorf("user can not be updated '%s': %v", username, err)
-			os.Exit(1)
+			exit(1)
 			return err
 		}
 		logrus.Infof("user updated: %s", response.Users[0].Username)
@@ -315,7 +315,7 @@ var userDeleteCommand = cli.Command{
 
 		if username == "" {
 			fmt.Println(cli.ShowSubcommandHelp(c))
-			os.Exit(1)
+			exit(1)
 		}
 
 		//conn := getConn(c.String("port"))
@@ -326,7 +326,7 @@ var userDeleteCommand = cli.Command{
 		_, err := userSvc.Delete(context.Background(), &pb.UserDeleteRequest{Username: username})
 		if err != nil {
 			logrus.Errorf("user can not be deleted '%s': %v", username, err)
-			os.Exit(1)
+			exit(1)
 			return err
 		}
 		logrus.Infof("user deleted: %s", username)
@@ -350,7 +350,7 @@ var userRenewCommand = cli.Command{
 
 		if username == "" {
 			fmt.Println(cli.ShowSubcommandHelp(c))
-			os.Exit(1)
+			exit(1)
 		}
 
 		//conn := getConn(c.String("port"))
@@ -362,7 +362,7 @@ var userRenewCommand = cli.Command{
 		_, err := userSvc.Renew(context.Background(), &pb.UserRenewRequest{Username: username})
 		if err != nil {
 			logrus.Errorf("can't renew user cert '%s': %v", username, err)
-			os.Exit(1)
+			exit(1)
 			return err
 		}
 		logrus.Infof("user cert renewed: '%s'", username)
@@ -391,7 +391,7 @@ var userGenconfigCommand = cli.Command{
 
 		if username == "" {
 			fmt.Println(cli.ShowSubcommandHelp(c))
-			os.Exit(1)
+			exit(1)
 		}
 		if output == "" {
 			output = username + ".ovpn"

+ 13 - 13
cmd/ovpm/vpn.go

@@ -24,7 +24,7 @@ var vpnStatusCommand = cli.Command{
 
 		res, err := vpnSvc.Status(context.Background(), &pb.VPNStatusRequest{})
 		if err != nil {
-			os.Exit(1)
+			exit(1)
 			return err
 		}
 
@@ -77,7 +77,7 @@ var vpnInitCommand = cli.Command{
 		if hostname == "" {
 			logrus.Errorf("'hostname' is required")
 			fmt.Println(cli.ShowSubcommandHelp(c))
-			os.Exit(1)
+			exit(1)
 
 		}
 
@@ -98,7 +98,7 @@ var vpnInitCommand = cli.Command{
 			fmt.Println("--net takes an ip network in the CIDR form. e.g. 10.9.0.0/24")
 			fmt.Println()
 			fmt.Println(cli.ShowSubcommandHelp(c))
-			os.Exit(1)
+			exit(1)
 		}
 
 		dns := c.String("dns")
@@ -106,7 +106,7 @@ var vpnInitCommand = cli.Command{
 			fmt.Println("--dns takes an IPv4 address. e.g. 8.8.8.8")
 			fmt.Println()
 			fmt.Println(cli.ShowSubcommandHelp(c))
-			os.Exit(1)
+			exit(1)
 		}
 
 		conn := getConn(c.GlobalString("daemon-port"))
@@ -122,7 +122,7 @@ var vpnInitCommand = cli.Command{
 			_, err := fmt.Scanln(&response)
 			if err != nil {
 				logrus.Fatal(err)
-				os.Exit(1)
+				exit(1)
 				return err
 			}
 			okayResponses := []string{"y", "Y", "yes", "Yes", "YES"}
@@ -130,7 +130,7 @@ var vpnInitCommand = cli.Command{
 			if stringInSlice(response, okayResponses) {
 				if _, err := vpnSvc.Init(context.Background(), &pb.VPNInitRequest{Hostname: hostname, Port: port, ProtoPref: proto, IpBlock: ipblock, Dns: dns}); err != nil {
 					logrus.Errorf("server can not be initialized: %v", err)
-					os.Exit(1)
+					exit(1)
 					return err
 				}
 				logrus.Info("ovpm server initialized")
@@ -166,7 +166,7 @@ var vpnUpdateCommand = cli.Command{
 			fmt.Println("--net takes an ip network in the CIDR form. e.g. 10.9.0.0/24")
 			fmt.Println()
 			fmt.Println(cli.ShowSubcommandHelp(c))
-			os.Exit(1)
+			exit(1)
 		}
 
 		if ipblock != "" {
@@ -179,7 +179,7 @@ var vpnUpdateCommand = cli.Command{
 				_, err := fmt.Scanln(&response)
 				if err != nil {
 					logrus.Fatal(err)
-					os.Exit(1)
+					exit(1)
 					return err
 				}
 				okayResponses := []string{"y", "Y", "yes", "Yes", "YES"}
@@ -198,13 +198,13 @@ var vpnUpdateCommand = cli.Command{
 			fmt.Println("--dns takes an IPv4 address. e.g. 8.8.8.8")
 			fmt.Println()
 			fmt.Println(cli.ShowSubcommandHelp(c))
-			os.Exit(1)
+			exit(1)
 		}
 
 		if !(ipblock != "" || dns != "") {
 			fmt.Println()
 			fmt.Println(cli.ShowSubcommandHelp(c))
-			os.Exit(1)
+			exit(1)
 		}
 
 		conn := getConn(c.GlobalString("daemon-port"))
@@ -213,7 +213,7 @@ var vpnUpdateCommand = cli.Command{
 
 		if _, err := vpnSvc.Update(context.Background(), &pb.VPNUpdateRequest{IpBlock: ipblock, Dns: dns}); err != nil {
 			logrus.Errorf("server can not be updated: %v", err)
-			os.Exit(1)
+			exit(1)
 			return err
 		}
 		logrus.Info("ovpm server updated")
@@ -224,7 +224,7 @@ var vpnUpdateCommand = cli.Command{
 var vpnRestartCommand = cli.Command{
 	Name:    "restart",
 	Usage:   "Restart VPN server.",
-	Aliases: []string{"s"},
+	Aliases: []string{"r"},
 	Action: func(c *cli.Context) error {
 		conn := getConn(c.GlobalString("daemon-port"))
 		defer conn.Close()
@@ -232,7 +232,7 @@ var vpnRestartCommand = cli.Command{
 
 		_, err := vpnSvc.Restart(context.Background(), &pb.VPNRestartRequest{})
 		if err != nil {
-			os.Exit(1)
+			exit(1)
 			return err
 		}