소스 검색

feat(vpn): add option to specify custom cidr to use during vpn init

- If custom cidr is specified as the ipblock, ovpm will initialize the
  database accordingly. If it's not specified (empty string) then it
  will fall back to default network defined in the const.go.
Mustafa Arici 8 년 전
부모
커밋
87dfa703d7
11개의 변경된 파일153개의 추가작업 그리고 73개의 파일을 삭제
  1. 2 1
      api/rpc.go
  2. 15 2
      cmd/ovpm/vpn.go
  3. 3 3
      const.go
  4. 7 7
      net_test.go
  5. 36 27
      pb/vpn.pb.go
  6. 1 0
      pb/vpn.proto
  7. 3 0
      pb/vpn.swagger.json
  8. 11 3
      user.go
  9. 10 10
      user_test.go
  10. 53 8
      vpn.go
  11. 12 12
      vpn_test.go

+ 2 - 1
api/rpc.go

@@ -187,7 +187,8 @@ func (s *VPNService) Init(ctx context.Context, req *pb.VPNInitRequest) (*pb.VPNI
 	case pb.VPNProto_NOPREF:
 		proto = ovpm.UDPProto
 	}
-	if err := ovpm.Init(req.Hostname, req.Port, proto); err != nil {
+
+	if err := ovpm.Init(req.Hostname, req.Port, proto, req.IPBlock); err != nil {
 		logrus.Errorf("server can not be created: %v", err)
 	}
 	return &pb.VPNInitResponse{}, nil

+ 15 - 2
cmd/ovpm/vpn.go

@@ -6,6 +6,7 @@ import (
 	"os"
 
 	"github.com/Sirupsen/logrus"
+	"github.com/asaskevich/govalidator"
 	"github.com/cad/ovpm"
 	"github.com/cad/ovpm/pb"
 	"github.com/olekukonko/tablewriter"
@@ -60,12 +61,16 @@ var vpnInitCommand = cli.Command{
 			Name:  "tcp, t",
 			Usage: "use TCP for vpn protocol, instead of UDP",
 		},
+		cli.StringFlag{
+			Name:  "net, n",
+			Usage: fmt.Sprintf("VPN network to give clients IP addresses from, in the CIDR form (default: %s)", ovpm.DefaultVPNNetwork),
+		},
 	},
 	Action: func(c *cli.Context) error {
 		action = "vpn:init"
 		hostname := c.String("hostname")
 		if hostname == "" {
-			logrus.Errorf("'hostname' is needed")
+			logrus.Errorf("'hostname' is required")
 			fmt.Println(cli.ShowSubcommandHelp(c))
 			os.Exit(1)
 
@@ -83,6 +88,14 @@ var vpnInitCommand = cli.Command{
 			proto = pb.VPNProto_TCP
 		}
 
+		ipblock := c.String("net")
+		if ipblock != "" && !govalidator.IsCIDR(ipblock) {
+			fmt.Println("--net takes an ip network in the CIDR form. e.g. 10.9.0.0/24")
+			fmt.Println()
+			fmt.Println(cli.ShowSubcommandHelp(c))
+			os.Exit(1)
+		}
+
 		conn := getConn(c.GlobalString("daemon-port"))
 		defer conn.Close()
 		vpnSvc := pb.NewVPNServiceClient(conn)
@@ -102,7 +115,7 @@ var vpnInitCommand = cli.Command{
 			okayResponses := []string{"y", "Y", "yes", "Yes", "YES"}
 			nokayResponses := []string{"n", "N", "no", "No", "NO"}
 			if stringInSlice(response, okayResponses) {
-				if _, err := vpnSvc.Init(context.Background(), &pb.VPNInitRequest{Hostname: hostname, Port: port, Protopref: proto}); err != nil {
+				if _, err := vpnSvc.Init(context.Background(), &pb.VPNInitRequest{Hostname: hostname, Port: port, Protopref: proto, IPBlock: ipblock}); err != nil {
 					logrus.Errorf("server can not be initialized: %v", err)
 					os.Exit(1)
 					return err

+ 3 - 3
const.go

@@ -10,6 +10,9 @@ const (
 	// DefaultVPNProto is the default OpenVPN protocol to use.
 	DefaultVPNProto = UDPProto
 
+	// DefaultVPNNetwork is the default OpenVPN network to use.
+	DefaultVPNNetwork = "10.9.0.0/24"
+
 	etcBasePath = "/etc/ovpm/"
 	varBasePath = "/var/db/ovpm/"
 
@@ -23,9 +26,6 @@ const (
 	_DefaultCAKeyPath    = varBasePath + "ca.key"
 	_DefaultDHParamsPath = varBasePath + "dh4096.pem"
 	_DefaultCRLPath      = varBasePath + "crl.pem"
-
-	_DefaultServerNetwork = "10.9.0.0"
-	_DefaultServerNetMask = "255.255.255.0"
 )
 
 // Testing is used to determine wether we are testing or running normally.

+ 7 - 7
net_test.go

@@ -9,7 +9,7 @@ func TestVPNCreateNewNetwork(t *testing.T) {
 	setupTestCase()
 	SetupDB("sqlite3", ":memory:")
 	defer CeaseDB()
-	Init("localhost", "", UDPProto)
+	Init("localhost", "", UDPProto, "")
 
 	// Prepare:
 	// Test:
@@ -56,7 +56,7 @@ func TestVPNDeleteNetwork(t *testing.T) {
 	setupTestCase()
 	SetupDB("sqlite3", ":memory:")
 	defer CeaseDB()
-	Init("localhost", "", UDPProto)
+	Init("localhost", "", UDPProto, "")
 
 	// Prepare:
 	// Test:
@@ -94,7 +94,7 @@ func TestVPNGetNetwork(t *testing.T) {
 	setupTestCase()
 	SetupDB("sqlite3", ":memory:")
 	defer CeaseDB()
-	Init("localhost", "", UDPProto)
+	Init("localhost", "", UDPProto, "")
 
 	// Prepare:
 	// Test:
@@ -129,7 +129,7 @@ func TestVPNGetAllNetworks(t *testing.T) {
 	setupTestCase()
 	SetupDB("sqlite3", ":memory:")
 	defer CeaseDB()
-	Init("localhost", "", UDPProto)
+	Init("localhost", "", UDPProto, "")
 
 	// Prepare:
 	// Test:
@@ -175,7 +175,7 @@ func TestNetAssociate(t *testing.T) {
 	setupTestCase()
 	SetupDB("sqlite3", ":memory:")
 	defer CeaseDB()
-	Init("localhost", "", UDPProto)
+	Init("localhost", "", UDPProto, "")
 
 	// Prepare:
 	// Test:
@@ -213,7 +213,7 @@ func TestNetDissociate(t *testing.T) {
 	setupTestCase()
 	SetupDB("sqlite3", ":memory:")
 	defer CeaseDB()
-	err := Init("localhost", "", UDPProto)
+	err := Init("localhost", "", UDPProto, "")
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -266,7 +266,7 @@ func TestNetGetAssociatedUsers(t *testing.T) {
 	setupTestCase()
 	SetupDB("sqlite3", ":memory:")
 	defer CeaseDB()
-	Init("localhost", "", UDPProto)
+	Init("localhost", "", UDPProto, "")
 
 	// Prepare:
 	// Test:

+ 36 - 27
pb/vpn.pb.go

@@ -54,6 +54,7 @@ type VPNInitRequest struct {
 	Hostname  string   `protobuf:"bytes,1,opt,name=Hostname" json:"Hostname,omitempty"`
 	Port      string   `protobuf:"bytes,2,opt,name=Port" json:"Port,omitempty"`
 	Protopref VPNProto `protobuf:"varint,3,opt,name=Protopref,enum=pb.VPNProto" json:"Protopref,omitempty"`
+	IPBlock   string   `protobuf:"bytes,4,opt,name=IPBlock" json:"IPBlock,omitempty"`
 }
 
 func (m *VPNInitRequest) Reset()                    { *m = VPNInitRequest{} }
@@ -82,6 +83,13 @@ func (m *VPNInitRequest) GetProtopref() VPNProto {
 	return VPNProto_NOPREF
 }
 
+func (m *VPNInitRequest) GetIPBlock() string {
+	if m != nil {
+		return m.IPBlock
+	}
+	return ""
+}
+
 type VPNStatusResponse struct {
 	Name         string `protobuf:"bytes,1,opt,name=Name" json:"Name,omitempty"`
 	SerialNumber string `protobuf:"bytes,2,opt,name=SerialNumber" json:"SerialNumber,omitempty"`
@@ -294,31 +302,32 @@ var _VPNService_serviceDesc = grpc.ServiceDesc{
 func init() { proto.RegisterFile("vpn.proto", fileDescriptor1) }
 
 var fileDescriptor1 = []byte{
-	// 406 bytes of a gzipped FileDescriptorProto
-	0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x74, 0x92, 0xc1, 0x72, 0xd3, 0x30,
-	0x10, 0x86, 0xb1, 0x93, 0xba, 0xf1, 0x4e, 0x26, 0x38, 0xdb, 0x02, 0x22, 0xd3, 0x43, 0x47, 0xa7,
-	0x4c, 0x0e, 0xf1, 0x50, 0x6e, 0xbd, 0x75, 0x0c, 0x4c, 0x39, 0x60, 0x34, 0x2e, 0xe4, 0x2e, 0x83,
-	0xe8, 0x78, 0x48, 0x25, 0x21, 0x29, 0xbe, 0xc3, 0x2b, 0xf0, 0x12, 0xbc, 0x0f, 0xaf, 0xc0, 0x83,
-	0x30, 0x92, 0xdd, 0x26, 0x66, 0x86, 0xdb, 0xbf, 0xdf, 0x66, 0x77, 0xff, 0xfc, 0x16, 0xa4, 0xad,
-	0x96, 0x6b, 0x6d, 0x94, 0x53, 0x18, 0xeb, 0x7a, 0x71, 0x76, 0xab, 0xd4, 0xed, 0x56, 0xe4, 0x5c,
-	0x37, 0x39, 0x97, 0x52, 0x39, 0xee, 0x1a, 0x25, 0x6d, 0xf7, 0x0b, 0x8a, 0x90, 0x6d, 0x58, 0x79,
-	0xe3, 0xb8, 0xdb, 0xd9, 0x4a, 0x7c, 0xdb, 0x09, 0xeb, 0xe8, 0x16, 0x66, 0x1b, 0x56, 0xbe, 0x95,
-	0x8d, 0xeb, 0x09, 0x2e, 0x60, 0x72, 0xad, 0xac, 0x93, 0xfc, 0x4e, 0x90, 0xe8, 0x3c, 0x5a, 0xa6,
-	0xd5, 0x43, 0x8d, 0x08, 0x63, 0xa6, 0x8c, 0x23, 0x71, 0xe0, 0x41, 0xe3, 0x0a, 0x52, 0xe6, 0xd7,
-	0x6b, 0x23, 0xbe, 0x90, 0xd1, 0x79, 0xb4, 0x9c, 0x5d, 0x4c, 0xd7, 0xba, 0x5e, 0x6f, 0x58, 0x19,
-	0x78, 0xb5, 0x6f, 0xd3, 0xef, 0x31, 0xcc, 0x0f, 0x2c, 0x58, 0xad, 0xa4, 0x0d, 0x5b, 0xcb, 0xfd,
-	0xb5, 0xa0, 0x91, 0xc2, 0xf4, 0x46, 0x98, 0x86, 0x6f, 0xcb, 0xdd, 0x5d, 0x2d, 0x4c, 0x7f, 0x71,
-	0xc0, 0x06, 0x4e, 0x47, 0xff, 0x71, 0x3a, 0x3e, 0x70, 0x8a, 0x30, 0x2e, 0x84, 0x71, 0xe4, 0xa8,
-	0x63, 0x5e, 0xe3, 0x53, 0x48, 0x8a, 0xab, 0x40, 0x93, 0x40, 0xfb, 0x0a, 0x33, 0x18, 0x95, 0xc2,
-	0x91, 0xe3, 0x00, 0xbd, 0xf4, 0xd3, 0xef, 0xb8, 0xfd, 0x4a, 0x26, 0xdd, 0xb4, 0xd7, 0x78, 0x06,
-	0x69, 0x61, 0x04, 0x77, 0xe2, 0xf3, 0x95, 0x23, 0x69, 0x68, 0xec, 0x01, 0x9e, 0xc2, 0x51, 0xf8,
-	0xeb, 0x04, 0x42, 0xa7, 0x2b, 0xe8, 0x1c, 0x1e, 0x3f, 0x24, 0xde, 0x05, 0xb0, 0x5a, 0xc2, 0xe4,
-	0x3e, 0x2d, 0x04, 0x48, 0xca, 0xf7, 0xac, 0x7a, 0xfd, 0x26, 0x7b, 0x84, 0xc7, 0x30, 0xfa, 0xf8,
-	0x8a, 0x65, 0x91, 0x17, 0x1f, 0x0a, 0x96, 0xc5, 0x17, 0xbf, 0x22, 0x00, 0x1f, 0xa0, 0x30, 0x6d,
-	0xf3, 0x49, 0x20, 0x83, 0xa4, 0xcb, 0x12, 0x4f, 0xfb, 0xc8, 0x07, 0x5f, 0x77, 0xf1, 0xe4, 0x1f,
-	0xda, 0xdd, 0xa3, 0xcf, 0x7f, 0xfc, 0xfe, 0xf3, 0x33, 0x3e, 0xa1, 0xb3, 0xbc, 0x7d, 0x91, 0xb7,
-	0x5a, 0xe6, 0x36, 0xf4, 0x2f, 0xa3, 0x15, 0x5e, 0xc3, 0xd8, 0x5b, 0x43, 0xec, 0x27, 0x0f, 0x5e,
-	0xc6, 0xe2, 0x64, 0xc0, 0xfa, 0x5d, 0xcf, 0xc2, 0xae, 0x39, 0x9d, 0xde, 0xef, 0x6a, 0x64, 0xe3,
-	0x2e, 0xa3, 0x55, 0x9d, 0x84, 0x47, 0xf7, 0xf2, 0x6f, 0x00, 0x00, 0x00, 0xff, 0xff, 0xfc, 0x94,
-	0xe6, 0xba, 0xa3, 0x02, 0x00, 0x00,
+	// 423 bytes of a gzipped FileDescriptorProto
+	0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x74, 0x92, 0xcf, 0x72, 0xd3, 0x30,
+	0x10, 0xc6, 0xb1, 0xe3, 0x3a, 0xf1, 0x4e, 0x26, 0x38, 0xdb, 0x02, 0x22, 0xd3, 0x43, 0x47, 0xa7,
+	0x4c, 0x0e, 0xf1, 0x50, 0x6e, 0xbd, 0x95, 0x00, 0xd3, 0x1e, 0x30, 0x9a, 0x14, 0x72, 0x57, 0x8a,
+	0xe8, 0x78, 0x9a, 0x4a, 0x42, 0x52, 0x72, 0x87, 0x03, 0x2f, 0xc0, 0x4b, 0xf0, 0x3e, 0xbc, 0x02,
+	0x0f, 0xc2, 0x48, 0x76, 0xfe, 0x98, 0x99, 0xde, 0xbe, 0xfd, 0xad, 0x77, 0xbd, 0xfb, 0xad, 0x20,
+	0xdb, 0x68, 0x39, 0xd5, 0x46, 0x39, 0x85, 0xb1, 0x5e, 0x8e, 0x4e, 0xef, 0x94, 0xba, 0x5b, 0x89,
+	0x82, 0xeb, 0xaa, 0xe0, 0x52, 0x2a, 0xc7, 0x5d, 0xa5, 0xa4, 0xad, 0xbf, 0xa0, 0x08, 0xf9, 0x82,
+	0x95, 0x37, 0x8e, 0xbb, 0xb5, 0x9d, 0x8b, 0x6f, 0x6b, 0x61, 0x1d, 0xfd, 0x19, 0xc1, 0x60, 0xc1,
+	0xca, 0x6b, 0x59, 0xb9, 0x06, 0xe1, 0x08, 0x7a, 0x57, 0xca, 0x3a, 0xc9, 0x1f, 0x04, 0x89, 0xce,
+	0xa2, 0x71, 0x36, 0xdf, 0xc5, 0x88, 0x90, 0x30, 0x65, 0x1c, 0x89, 0x03, 0x0f, 0x1a, 0x27, 0x90,
+	0x31, 0xdf, 0x5f, 0x1b, 0xf1, 0x95, 0x74, 0xce, 0xa2, 0xf1, 0xe0, 0xbc, 0x3f, 0xd5, 0xcb, 0xe9,
+	0x82, 0x95, 0x81, 0xcf, 0xf7, 0x69, 0x24, 0xd0, 0xbd, 0x66, 0x6f, 0x56, 0xea, 0xf6, 0x9e, 0x24,
+	0xa1, 0xc5, 0x36, 0xa4, 0xdf, 0x63, 0x18, 0x1e, 0x4c, 0x67, 0xb5, 0x92, 0x36, 0xfc, 0xaf, 0xdc,
+	0xcf, 0x11, 0x34, 0x52, 0xe8, 0xdf, 0x08, 0x53, 0xf1, 0x55, 0xb9, 0x7e, 0x58, 0x0a, 0xd3, 0xcc,
+	0xd2, 0x62, 0xad, 0x1d, 0x3a, 0x8f, 0xec, 0x90, 0x1c, 0xec, 0x80, 0x90, 0xcc, 0x84, 0x71, 0xe4,
+	0xa8, 0x66, 0x5e, 0xe3, 0x73, 0x48, 0x67, 0x97, 0x81, 0xa6, 0x81, 0x36, 0x11, 0xe6, 0xd0, 0x29,
+	0x85, 0x23, 0xdd, 0x00, 0xbd, 0xf4, 0xd5, 0x1f, 0xb8, 0xbd, 0x27, 0xbd, 0xba, 0xda, 0x6b, 0x3c,
+	0x85, 0x6c, 0x66, 0x04, 0x77, 0xe2, 0xcb, 0xa5, 0x23, 0x59, 0x48, 0xec, 0x01, 0x9e, 0xc0, 0x51,
+	0x30, 0x85, 0x40, 0xc8, 0xd4, 0x01, 0x1d, 0xc2, 0xd3, 0xdd, 0x2d, 0x6a, 0x03, 0x26, 0x63, 0xe8,
+	0x6d, 0x7d, 0x44, 0x80, 0xb4, 0xfc, 0xc8, 0xe6, 0xef, 0xde, 0xe7, 0x4f, 0xb0, 0x0b, 0x9d, 0xcf,
+	0x6f, 0x59, 0x1e, 0x79, 0xf1, 0x69, 0xc6, 0xf2, 0xf8, 0xfc, 0x77, 0x04, 0xe0, 0x0d, 0x14, 0x66,
+	0x53, 0xdd, 0x0a, 0x64, 0x90, 0xd6, 0x5e, 0xe2, 0x49, 0x73, 0x8c, 0xd6, 0xe1, 0x47, 0xcf, 0xfe,
+	0xa3, 0xf5, 0xff, 0xe8, 0xcb, 0x1f, 0x7f, 0xfe, 0xfe, 0x8a, 0x8f, 0xe9, 0xa0, 0xd8, 0xbc, 0x2a,
+	0x36, 0x5a, 0x16, 0x36, 0xe4, 0x2f, 0xa2, 0x09, 0x5e, 0x41, 0xe2, 0x47, 0x43, 0x6c, 0x2a, 0x0f,
+	0xde, 0xcc, 0xe8, 0xb8, 0xc5, 0x9a, 0x5e, 0x2f, 0x42, 0xaf, 0x21, 0xed, 0x6f, 0x7b, 0x55, 0xb2,
+	0x72, 0x17, 0xd1, 0x64, 0x99, 0x86, 0xf7, 0xf8, 0xfa, 0x5f, 0x00, 0x00, 0x00, 0xff, 0xff, 0x93,
+	0x52, 0x74, 0x1b, 0xbe, 0x02, 0x00, 0x00,
 }

+ 1 - 0
pb/vpn.proto

@@ -15,6 +15,7 @@ message VPNInitRequest {
   string Hostname = 1;
   string Port = 2;
   VPNProto Protopref = 3;
+  string IPBlock = 4;
 }
 
 service VPNService {

+ 3 - 0
pb/vpn.swagger.json

@@ -80,6 +80,9 @@
         },
         "Protopref": {
           "$ref": "#/definitions/pbVPNProto"
+        },
+        "IPBlock": {
+          "type": "string"
         }
       }
     },

+ 11 - 3
user.go

@@ -301,8 +301,12 @@ func (u *DBUser) GetCreatedAt() string {
 func (u *DBUser) getIP() net.IP {
 	users := getNonStaticHostUsers()
 	staticHostIDs := getStaticHostIDs()
-	mask := net.IPMask(net.ParseIP(_DefaultServerNetMask).To4())
-	network := net.ParseIP(_DefaultServerNetwork).To4().Mask(mask)
+	server, err := GetServerInstance()
+	if err != nil {
+		logrus.Panicf("can not get server instance: %v", err)
+	}
+	mask := net.IPMask(net.ParseIP(server.Mask).To4())
+	network := net.ParseIP(server.Net).To4().Mask(mask)
 
 	// If the user has static ip address, return it immediately.
 	if u.HostID != 0 {
@@ -333,7 +337,11 @@ func (u *DBUser) getIP() net.IP {
 
 // GetIPNet returns user's vpn ip network. (e.g. 192.168.0.1/24)
 func (u *DBUser) GetIPNet() string {
-	mask := net.IPMask(net.ParseIP(_DefaultServerNetMask).To4())
+	server, err := GetServerInstance()
+	if err != nil {
+		logrus.Panicf("can not get user ipnet: %v", err)
+	}
+	mask := net.IPMask(net.ParseIP(server.Mask).To4())
 
 	ipn := net.IPNet{
 		IP:   u.getIP(),

+ 10 - 10
user_test.go

@@ -13,7 +13,7 @@ func TestCreateNewUser(t *testing.T) {
 	// Initialize:
 	ovpm.SetupDB("sqlite3", ":memory:")
 	defer ovpm.CeaseDB()
-	ovpm.Init("localhost", "", ovpm.UDPProto)
+	ovpm.Init("localhost", "", ovpm.UDPProto, "")
 	server, _ := ovpm.GetServerInstance()
 
 	// Prepare:
@@ -89,7 +89,7 @@ func TestUserUpdate(t *testing.T) {
 	// Initialize:
 	ovpm.SetupDB("sqlite3", ":memory:")
 	defer ovpm.CeaseDB()
-	ovpm.Init("localhost", "", ovpm.UDPProto)
+	ovpm.Init("localhost", "", ovpm.UDPProto, "")
 
 	// Prepare:
 	username := "testUser"
@@ -127,7 +127,7 @@ func TestUserPasswordCorrect(t *testing.T) {
 	// Initialize:
 	ovpm.SetupDB("sqlite3", ":memory:")
 	defer ovpm.CeaseDB()
-	ovpm.Init("localhost", "", ovpm.UDPProto)
+	ovpm.Init("localhost", "", ovpm.UDPProto, "")
 
 	// Prepare:
 	initialPassword := "g00dp@ssW0rd9"
@@ -144,7 +144,7 @@ func TestUserPasswordReset(t *testing.T) {
 	// Initialize:
 	ovpm.SetupDB("sqlite3", ":memory:")
 	defer ovpm.CeaseDB()
-	ovpm.Init("localhost", "", ovpm.UDPProto)
+	ovpm.Init("localhost", "", ovpm.UDPProto, "")
 
 	// Prepare:
 	initialPassword := "g00dp@ssW0rd9"
@@ -171,7 +171,7 @@ func TestUserDelete(t *testing.T) {
 	// Initialize:
 	ovpm.SetupDB("sqlite3", ":memory:")
 	defer ovpm.CeaseDB()
-	ovpm.Init("localhost", "", ovpm.UDPProto)
+	ovpm.Init("localhost", "", ovpm.UDPProto, "")
 
 	// Prepare:
 	username := "testUser"
@@ -209,7 +209,7 @@ func TestUserGet(t *testing.T) {
 	// Initialize:
 	ovpm.SetupDB("sqlite3", ":memory:")
 	defer ovpm.CeaseDB()
-	ovpm.Init("localhost", "", ovpm.UDPProto)
+	ovpm.Init("localhost", "", ovpm.UDPProto, "")
 
 	// Prepare:
 	username := "testUser"
@@ -233,7 +233,7 @@ func TestUserGetAll(t *testing.T) {
 	// Initialize:
 	ovpm.SetupDB("sqlite3", ":memory:")
 	defer ovpm.CeaseDB()
-	ovpm.Init("localhost", "", ovpm.UDPProto)
+	ovpm.Init("localhost", "", ovpm.UDPProto, "")
 	count := 5
 
 	// Prepare:
@@ -271,14 +271,14 @@ func TestUserRenew(t *testing.T) {
 	// Initialize:
 	ovpm.SetupDB("sqlite3", ":memory:")
 	defer ovpm.CeaseDB()
-	ovpm.Init("localhost", "", ovpm.UDPProto)
+	ovpm.Init("localhost", "", ovpm.UDPProto, "")
 
 	// Prepare:
 	user, _ := ovpm.CreateNewUser("user", "1234", false, 0)
 
 	// Test:
 	// Re initialize the server.
-	ovpm.Init("example.com", "3333", ovpm.UDPProto) // This causes implicit Renew() on every user in the system.
+	ovpm.Init("example.com", "3333", ovpm.UDPProto, "") // This causes implicit Renew() on every user in the system.
 
 	// Fetch user back.
 	fetchedUser, _ := ovpm.GetUser(user.GetUsername())
@@ -293,7 +293,7 @@ func TestUserIPAllocator(t *testing.T) {
 	// Initialize:
 	ovpm.SetupDB("sqlite3", ":memory:")
 	defer ovpm.CeaseDB()
-	ovpm.Init("localhost", "", ovpm.UDPProto)
+	ovpm.Init("localhost", "", ovpm.UDPProto, "")
 
 	// Prepare:
 

+ 53 - 8
vpn.go

@@ -69,10 +69,17 @@ type _VPNServerConfig struct {
 	Proto        string
 }
 
-// Init regenerates keys and certs for a Root CA, and saves them in the database.
+// Init regenerates keys and certs for a Root CA, gets initial settings for the VPN server
+// and saves them in the database.
 //
-// proto can be either "udp" or "tcp" and if it's "" it defaults to "udp".
-func Init(hostname string, port string, proto string) error {
+// 'proto' can be either "udp" or "tcp" and if it's "" it defaults to "udp".
+//
+// 'ipblock' is a IP network in the CIDR form. VPN clients get their IP addresses from this network.
+// It defaults to const 'DefaultVPNNetwork'.
+//
+// Please note that, Init is potentially destructive procedure, it will cause invalidation of
+// existing .ovpn profiles of the current users. So it should be used carefully.
+func Init(hostname string, port string, proto string, ipblock string) error {
 	if port == "" {
 		port = DefaultVPNPort
 	}
@@ -88,6 +95,33 @@ func Init(hostname string, port string, proto string) error {
 		return fmt.Errorf("validation error: proto:`%s` should be either 'tcp' or 'udp'", proto)
 	}
 
+	// vpn network to use.
+	var ipnet *net.IPNet
+
+	// If user didn't specify, pick the vpn network from defaults.
+	if ipblock == "" {
+		var err error
+		_, ipnet, err = net.ParseCIDR(DefaultVPNNetwork)
+		if err != nil {
+			return fmt.Errorf("can not parse CIDR %s: %v", DefaultVPNNetwork, err)
+		}
+	}
+
+	// Check if the user specified vpn network is valid.
+	if ipblock != "" && !govalidator.IsCIDR(ipblock) {
+		return fmt.Errorf("validation error: ipblock:`%s` should be a CIDR network", ipblock)
+	}
+
+	// Use user specified vpn network.
+	if ipblock != "" {
+		var err error
+		_, ipnet, err = net.ParseCIDR(ipblock)
+		if err != nil {
+			return fmt.Errorf("can parse ipblock: %s", err)
+
+		}
+	}
+
 	if !govalidator.IsNumeric(port) {
 		return fmt.Errorf("validation error: port:`%s` should be numeric", port)
 	}
@@ -113,6 +147,7 @@ func Init(hostname string, port string, proto string) error {
 	if err != nil {
 		return fmt.Errorf("can not create server cert creds: %s", err)
 	}
+
 	serialNumber := uuid.New().String()
 	serverInstance := DBServer{
 		Name: serverName,
@@ -125,8 +160,8 @@ func Init(hostname string, port string, proto string) error {
 		Key:          srv.Key,
 		CACert:       ca.Cert,
 		CAKey:        ca.Key,
-		Net:          _DefaultServerNetwork,
-		Mask:         _DefaultServerNetMask,
+		Net:          ipnet.IP.To4().String(),
+		Mask:         net.IP(ipnet.Mask).To4().String(),
 	}
 
 	db.Create(&serverInstance)
@@ -375,6 +410,11 @@ func emitToFile(path, content string, mode uint) error {
 }
 
 func emitServerConf() error {
+	dbServer, err := GetServerInstance()
+	if err != nil {
+		return fmt.Errorf("can not get server instance: %v", err)
+	}
+
 	serverInstance, err := GetServerInstance()
 	if err != nil {
 		return fmt.Errorf("can not retrieve server: %v", err)
@@ -399,8 +439,8 @@ func emitServerConf() error {
 		CCDPath:      _DefaultVPNCCDPath,
 		CRLPath:      _DefaultCRLPath,
 		DHParamsPath: _DefaultDHParamsPath,
-		Net:          _DefaultServerNetwork,
-		Mask:         _DefaultServerNetMask,
+		Net:          dbServer.Net,
+		Mask:         dbServer.Mask,
 		Port:         port,
 		Proto:        proto,
 	}
@@ -531,6 +571,11 @@ func emitCCD() error {
 			}
 		}
 	}
+	server, err := GetServerInstance()
+	if err != nil {
+		return fmt.Errorf("can not get server instance: %v", err)
+	}
+
 	// Render ccd templates for the users.
 	for _, user := range users {
 		var associatedRoutes [][3]string
@@ -555,7 +600,7 @@ func emitCCD() error {
 			NetMask    string
 			Routes     [][3]string // [0] is IP, [1] is Netmask, [2] is Via
 			RedirectGW bool
-		}{IP: user.getIP().String(), NetMask: _DefaultServerNetMask, Routes: associatedRoutes, RedirectGW: !user.NoGW}
+		}{IP: user.getIP().String(), NetMask: server.Mask, Routes: associatedRoutes, RedirectGW: !user.NoGW}
 
 		data, err := bindata.Asset("template/ccd.file.tmpl")
 		if err != nil {

+ 12 - 12
vpn_test.go

@@ -35,13 +35,13 @@ func TestVPNInit(t *testing.T) {
 	}
 
 	// Wrongfully initialize server.
-	err := Init("localhost", "asdf", UDPProto)
+	err := Init("localhost", "asdf", UDPProto, "")
 	if err == nil {
 		t.Fatalf("error is expected to be not nil but it's nil instead")
 	}
 
 	// Initialize the server.
-	Init("localhost", "", UDPProto)
+	Init("localhost", "", UDPProto, "")
 
 	// Check database if the database has no server.
 	var server2 DBServer
@@ -61,7 +61,7 @@ func TestVPNDeinit(t *testing.T) {
 
 	// Prepare:
 	// Initialize the server.
-	Init("localhost", "", UDPProto)
+	Init("localhost", "", UDPProto, "")
 	u, err := CreateNewUser("user", "p", false, 0)
 	if err != nil {
 		t.Fatal(err)
@@ -122,7 +122,7 @@ func TestVPNIsInitialized(t *testing.T) {
 	}
 
 	// Initialize the server.
-	Init("localhost", "", UDPProto)
+	Init("localhost", "", UDPProto, "")
 
 	// Isn't initialized?
 	if !IsInitialized() {
@@ -152,7 +152,7 @@ func TestVPNGetServerInstance(t *testing.T) {
 	}
 
 	// Initialize server.
-	Init("localhost", "", UDPProto)
+	Init("localhost", "", UDPProto, "")
 
 	server, err = GetServerInstance()
 
@@ -172,7 +172,7 @@ func TestVPNDumpsClientConfig(t *testing.T) {
 	setupTestCase()
 	SetupDB("sqlite3", ":memory:")
 	defer CeaseDB()
-	Init("localhost", "", UDPProto)
+	Init("localhost", "", UDPProto, "")
 
 	// Prepare:
 	user, _ := CreateNewUser("user", "password", false, 0)
@@ -194,7 +194,7 @@ func TestVPNDumpClientConfig(t *testing.T) {
 	setupTestCase()
 	SetupDB("sqlite3", ":memory:")
 	defer CeaseDB()
-	Init("localhost", "", UDPProto)
+	Init("localhost", "", UDPProto, "")
 
 	// Prepare:
 	noGW := false
@@ -262,7 +262,7 @@ func TestVPNGetSystemCA(t *testing.T) {
 	}
 
 	// Initialize system.
-	Init("localhost", "", UDPProto)
+	Init("localhost", "", UDPProto, "")
 
 	ca, err = GetSystemCA()
 	if err != nil {
@@ -302,7 +302,7 @@ func TestVPNStartVPNProc(t *testing.T) {
 	}
 
 	// Initialize OVPM server.
-	Init("localhost", "", UDPProto)
+	Init("localhost", "", UDPProto, "")
 
 	// Call start again..
 	StartVPNProc()
@@ -318,7 +318,7 @@ func TestVPNStopVPNProc(t *testing.T) {
 	setupTestCase()
 	SetupDB("sqlite3", ":memory:")
 	defer CeaseDB()
-	Init("localhost", "", UDPProto)
+	Init("localhost", "", UDPProto, "")
 
 	// Prepare:
 	vpnProc.Start()
@@ -342,7 +342,7 @@ func TestVPNRestartVPNProc(t *testing.T) {
 	// Init:
 	SetupDB("sqlite3", ":memory:")
 	defer CeaseDB()
-	Init("localhost", "", UDPProto)
+	Init("localhost", "", UDPProto, "")
 
 	// Prepare:
 
@@ -371,7 +371,7 @@ func TestVPNEmit(t *testing.T) {
 	setupTestCase()
 	SetupDB("sqlite3", ":memory:")
 	defer CeaseDB()
-	Init("localhost", "", UDPProto)
+	Init("localhost", "", UDPProto, "")
 
 	// Prepare: