浏览代码

refactor: split pki into it's own package

Mustafa Arici 8 年之前
父节点
当前提交
c2891189f9
共有 9 个文件被更改,包括 127 次插入108 次删除
  1. 1 1
      api/rpc.go
  2. 1 1
      cmd/ovpm/main.go
  3. 6 1
      cmd/ovpmd/main.go
  4. 0 26
      conf.go
  5. 25 0
      const.go
  6. 12 8
      db.go
  7. 31 23
      pki/pki.go
  8. 8 7
      user.go
  9. 43 41
      vpn.go

+ 1 - 1
api/rpc.go

@@ -128,7 +128,7 @@ func (s *VPNService) Status(ctx context.Context, req *pb.VPNStatusRequest) (*pb.
 }
 
 func (s *VPNService) Init(ctx context.Context, req *pb.VPNInitRequest) (*pb.VPNInitResponse, error) {
-	if err := ovpm.Initialize("default", req.Hostname, req.Port); err != nil {
+	if err := ovpm.Init(req.Hostname, req.Port); err != nil {
 		logrus.Errorf("server can not be created: %v", err)
 	}
 	return &pb.VPNInitResponse{}, nil

+ 1 - 1
cmd/ovpm/main.go

@@ -1,4 +1,5 @@
 //go:generate go-bindata template/
+
 package main
 
 import (
@@ -359,7 +360,6 @@ func main() {
 		},
 	}
 	app.Run(os.Args)
-	ovpm.CloseDB()
 }
 
 func stringInSlice(a string, list []string) bool {

+ 6 - 1
cmd/ovpmd/main.go

@@ -1,4 +1,5 @@
 //go:generate go-bindata template/
+
 package main
 
 import (
@@ -37,6 +38,11 @@ func main() {
 		if c.GlobalBool("verbose") {
 			logrus.SetLevel(logrus.DebugLevel)
 		}
+		ovpm.SetupDB()
+		return nil
+	}
+	app.After = func(c *cli.Context) error {
+		ovpm.CeaseDB()
 		return nil
 	}
 	app.Action = func(c *cli.Context) error {
@@ -56,7 +62,6 @@ func main() {
 		return nil
 	}
 	app.Run(os.Args)
-	ovpm.CloseDB()
 }
 
 func stringInSlice(a string, list []string) bool {

+ 0 - 26
conf.go

@@ -1,26 +0,0 @@
-package ovpm
-
-const (
-	// Version defines the version of ovpm.
-	Version = "0.0.0"
-
-	etcBasePath         = "/etc/ovpm/"
-	varBasePath         = "/var/db/ovpm/"
-	DefaultConfigPath   = etcBasePath + "ovpm.ini"
-	DefaultDBPath       = varBasePath + "db.sqlite3"
-	DefaultVPNConfPath  = varBasePath + "server.conf"
-	DefaultVPNPort      = "1197"
-	DefaultVPNCCDPath   = varBasePath + "ccd/"
-	DefaultCertPath     = varBasePath + "server.crt"
-	DefaultKeyPath      = varBasePath + "server.key"
-	DefaultCACertPath   = varBasePath + "ca.crt"
-	DefaultCAKeyPath    = varBasePath + "ca.key"
-	DefaultDHParamsPath = varBasePath + "dh4096.pem"
-	DefaultCRLPath      = varBasePath + "crl.pem"
-
-	CrtExpireYears = 10
-	CrtKeyLength   = 2024
-
-	DefaultServerNetwork = "10.9.0.0"
-	DefaultServerNetMask = "255.255.255.0"
-)

+ 25 - 0
const.go

@@ -0,0 +1,25 @@
+package ovpm
+
+const (
+	// Version defines the version of ovpm.
+	Version = "0.0.0"
+
+	// DefaultVPNPort is the default OpenVPN port to listen.
+	DefaultVPNPort = "1197"
+	etcBasePath    = "/etc/ovpm/"
+	varBasePath    = "/var/db/ovpm/"
+
+	_DefaultConfigPath   = etcBasePath + "ovpm.ini"
+	_DefaultDBPath       = varBasePath + "db.sqlite3"
+	_DefaultVPNConfPath  = varBasePath + "server.conf"
+	_DefaultVPNCCDPath   = varBasePath + "ccd/"
+	_DefaultCertPath     = varBasePath + "server.crt"
+	_DefaultKeyPath      = varBasePath + "server.key"
+	_DefaultCACertPath   = varBasePath + "ca.crt"
+	_DefaultCAKeyPath    = varBasePath + "ca.key"
+	_DefaultDHParamsPath = varBasePath + "dh4096.pem"
+	_DefaultCRLPath      = varBasePath + "crl.pem"
+
+	_DefaultServerNetwork = "10.9.0.0"
+	_DefaultServerNetMask = "255.255.255.0"
+)

+ 12 - 8
db.go

@@ -10,21 +10,25 @@ import (
 
 var db *gorm.DB
 
-// CloseDB closes the database.
-func CloseDB() {
-	db.Close()
-}
-
-func init() {
+// SetupDB prepares database for use.
+//
+// It should be run at the start of the program.
+func SetupDB() {
 	var err error
-	db, err = gorm.Open("sqlite3", DefaultDBPath)
+	db, err = gorm.Open("sqlite3", _DefaultDBPath)
 	if err != nil {
-		logrus.Fatalf("couldn't open sqlite database %s: %v", DefaultDBPath, err)
+		logrus.Fatalf("couldn't open sqlite database %s: %v", _DefaultDBPath, err)
 	}
 
 	db.AutoMigrate(&DBUser{})
 	db.AutoMigrate(&DBNetwork{})
 	db.AutoMigrate(&DBServer{})
 	db.AutoMigrate(&DBRevoked{})
+}
 
+// CeaseDB closes the database.
+//
+// It should be run at the exit of the program.
+func CeaseDB() {
+	db.Close()
 }

+ 31 - 23
pki.go → pki/pki.go

@@ -1,4 +1,5 @@
-package ovpm
+// Package pki contains bits and pieces to work with OpenVPN PKI related operations.
+package pki
 
 import (
 	"crypto/rand"
@@ -14,10 +15,15 @@ import (
 	"time"
 )
 
+const (
+	_CrtExpireYears = 10
+	_CrtKeyLength   = 2024
+)
+
 // CertHolder encapsulates a public certificate and the corresponding private key.
 type CertHolder struct {
-	Cert string
-	Key  string // Private Key
+	Cert string // PEM Encoded Certificate
+	Key  string // PEM Encoded Private Key
 }
 
 // CA is a special type of CertHolder that also has a CSR in it.
@@ -30,7 +36,12 @@ type CA struct {
 //
 // This will generate a public/private RSA keypair and a authority certificate signed by itself.
 func NewCA() (*CA, error) {
-	key, err := rsa.GenerateKey(rand.Reader, CrtKeyLength)
+	type basicConstraints struct {
+		IsCA       bool `asn1:"optional"`
+		MaxPathLen int  `asn1:"optional,default:-1"`
+	}
+
+	key, err := rsa.GenerateKey(rand.Reader, _CrtKeyLength)
 	if err != nil {
 		return nil, fmt.Errorf("private key cannot be created: %s", err)
 	}
@@ -116,9 +127,9 @@ func NewClientCertHolder(username string, ca *CA) (*CertHolder, error) {
 	return newCert(username, ca, false)
 }
 
-// NewCRL takes in a list of certificate serial numbers and a CA then makes a PEM encoded CRL and returns it as a string.
-func NewCRL(revokedCertificateSerials []*big.Int, ca *CA) (string, error) {
-	caCrt, err := readCertFromPEM(ca.Cert)
+// NewCRL takes in a list of certificate serial numbers to-be-revoked and a CA then makes a PEM encoded CRL and returns it as a string.
+func NewCRL(serials []*big.Int, ca *CA) (string, error) {
+	caCrt, err := ReadCertFromPEM(ca.Cert)
 	if err != nil {
 		return "", err
 	}
@@ -133,7 +144,7 @@ func NewCRL(revokedCertificateSerials []*big.Int, ca *CA) (string, error) {
 		return "", fmt.Errorf("failed to parse ca private key: %s", err)
 	}
 	var revokedCertList []pkix.RevokedCertificate
-	for _, serial := range revokedCertificateSerials {
+	for _, serial := range serials {
 		revokedCert := pkix.RevokedCertificate{
 			SerialNumber:   serial,
 			RevocationTime: time.Now().UTC(),
@@ -154,7 +165,16 @@ func NewCRL(revokedCertificateSerials []*big.Int, ca *CA) (string, error) {
 
 }
 
-func newCert(commonName string, ca *CA, server bool) (*CertHolder, error) {
+// ReadCertFromPEM decodes a PEM encoded string into a x509.Certificate.
+func ReadCertFromPEM(s string) (*x509.Certificate, error) {
+	block, _ := pem.Decode([]byte(s))
+	var cert *x509.Certificate
+	cert, _ = x509.ParseCertificate(block.Bytes)
+	return cert, nil
+}
+
+// newCert generates a private key and a certificate, that is signed by the given CA.
+func newCert(cn string, ca *CA, server bool) (*CertHolder, error) {
 	// Get CA private key
 	block, _ := pem.Decode([]byte(ca.Key))
 	if block == nil {
@@ -166,7 +186,7 @@ func newCert(commonName string, ca *CA, server bool) (*CertHolder, error) {
 		return nil, fmt.Errorf("failed to parse ca private key: %s", err)
 	}
 
-	caCert, err := readCertFromPEM(ca.Cert)
+	caCert, err := ReadCertFromPEM(ca.Cert)
 	if err != nil {
 		return nil, fmt.Errorf("failed to parse ca cert: %v", err)
 	}
@@ -187,7 +207,7 @@ func newCert(commonName string, ca *CA, server bool) (*CertHolder, error) {
 		NotAfter:     time.Now().AddDate(5, 0, 0),
 		SerialNumber: serial,
 		Subject: pkix.Name{
-			CommonName:   commonName,
+			CommonName:   cn,
 			Organization: []string{"Innovation"},
 		},
 		ExtKeyUsage:           []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
@@ -219,15 +239,3 @@ func newCert(commonName string, ca *CA, server bool) (*CertHolder, error) {
 		Cert: string(certPem[:]),
 	}, nil
 }
-
-type basicConstraints struct {
-	IsCA       bool `asn1:"optional"`
-	MaxPathLen int  `asn1:"optional,default:-1"`
-}
-
-func readCertFromPEM(pemCert string) (*x509.Certificate, error) {
-	block, _ := pem.Decode([]byte(pemCert))
-	var cert *x509.Certificate
-	cert, _ = x509.ParseCertificate(block.Bytes)
-	return cert, nil
-}

+ 8 - 7
user.go

@@ -6,6 +6,7 @@ import (
 
 	"github.com/Sirupsen/logrus"
 	"github.com/asaskevich/govalidator"
+	"github.com/cad/ovpm/pki"
 	"github.com/jinzhu/gorm"
 )
 
@@ -35,9 +36,9 @@ type DBRevoked struct {
 	SerialNumber string
 }
 
-func (u *DBUser) setPassword(newPassword string) error {
+func (u *DBUser) setPassword(password string) error {
 	// TODO(cad): Use a proper password hashing algorithm here.
-	u.Password = newPassword
+	u.Password = password
 	return nil
 }
 
@@ -80,7 +81,7 @@ func CreateNewUser(username, password string) (*DBUser, error) {
 		return nil, err
 	}
 
-	clientCert, err := NewClientCertHolder(username, ca)
+	clientCert, err := pki.NewClientCertHolder(username, ca)
 	if err != nil {
 		return nil, fmt.Errorf("can not create client cert %s: %v", username, err)
 	}
@@ -118,7 +119,7 @@ func (u *DBUser) Delete() error {
 		// user is not found
 		return fmt.Errorf("user is not initialized: %s", u.Username)
 	}
-	crt, err := readCertFromPEM(u.Cert)
+	crt, err := pki.ReadCertFromPEM(u.Cert)
 	if err != nil {
 		return fmt.Errorf("can not get user's certificate: %v", err)
 	}
@@ -136,8 +137,8 @@ func (u *DBUser) Delete() error {
 }
 
 // ResetPassword resets the users password into the provided password.
-func (u *DBUser) ResetPassword(newPassword string) error {
-	err := u.setPassword(newPassword)
+func (u *DBUser) ResetPassword(password string) error {
+	err := u.setPassword(password)
 	if err != nil {
 		// user password can not be updated
 		return fmt.Errorf("user password can not be updated %s: %v", u.Username, err)
@@ -159,7 +160,7 @@ func (u *DBUser) Sign() error {
 		return err
 	}
 
-	clientCert, err := NewClientCertHolder(u.Username, ca)
+	clientCert, err := pki.NewClientCertHolder(u.Username, ca)
 	if err != nil {
 		return fmt.Errorf("can not create client cert %s: %v", u.Username, err)
 	}

+ 43 - 41
vpn.go

@@ -15,6 +15,7 @@ import (
 	"github.com/Sirupsen/logrus"
 	"github.com/asaskevich/govalidator"
 	"github.com/cad/ovpm/bindata"
+	"github.com/cad/ovpm/pki"
 	"github.com/google/uuid"
 	"github.com/jinzhu/gorm"
 )
@@ -47,8 +48,8 @@ type DBServer struct {
 }
 
 // CheckSerial takes a serial number and checks it against the current server's serial number.
-func (s *DBServer) CheckSerial(serialNo string) bool {
-	return serialNo == s.SerialNumber
+func (s *DBServer) CheckSerial(serial string) bool {
+	return serial == s.SerialNumber
 }
 
 type _VPNServerConfig struct {
@@ -64,10 +65,11 @@ type _VPNServerConfig struct {
 	Port         string
 }
 
-// Initialize regenerates keys and certs for a Root CA, and saves them in the database.
-func Initialize(serverName string, hostname string, port string) error {
+// Init regenerates keys and certs for a Root CA, and saves them in the database.
+func Init(hostname string, port string) error {
+	serverName := "default"
 	if IsInitialized() {
-		if err := DeInitialize("default"); err != nil {
+		if err := Deinit(); err != nil {
 			logrus.Errorf("server can not be deleted: %v", err)
 			return err
 		}
@@ -77,12 +79,12 @@ func Initialize(serverName string, hostname string, port string) error {
 		return fmt.Errorf("validation error: hostname:`%s` should be either an ip address or a FQDN", hostname)
 	}
 
-	ca, err := NewCA()
+	ca, err := pki.NewCA()
 	if err != nil {
 		return fmt.Errorf("can not create ca creds: %s", err)
 	}
 
-	srv, err := NewServerCertHolder(ca)
+	srv, err := pki.NewServerCertHolder(ca)
 	if err != nil {
 		return fmt.Errorf("can not create server cert creds: %s", err)
 	}
@@ -97,8 +99,8 @@ func Initialize(serverName string, hostname string, port string) error {
 		Key:          srv.Key,
 		CACert:       ca.Cert,
 		CAKey:        ca.Key,
-		Net:          DefaultServerNetwork,
-		Mask:         DefaultServerNetMask,
+		Net:          _DefaultServerNetwork,
+		Mask:         _DefaultServerNetMask,
 	}
 
 	db.Create(&serverInstance)
@@ -123,8 +125,8 @@ func Initialize(serverName string, hostname string, port string) error {
 	return nil
 }
 
-// DeInitialize deletes the server with the given serverName from the database and frees the allocated resources.
-func DeInitialize(serverName string) error {
+// Deinit deletes the server with the given serverName from the database and frees the allocated resources.
+func Deinit() error {
 	if !IsInitialized() {
 		return fmt.Errorf("server not found")
 	}
@@ -179,25 +181,25 @@ func DumpsClientConfig(username string) (string, error) {
 }
 
 // DumpClientConfig generates .ovpn file for the given vpn user and dumps it to outPath.
-func DumpClientConfig(username, outPath string) error {
+func DumpClientConfig(username, path string) error {
 	result, err := DumpsClientConfig(username)
 	if err != nil {
 		return err
 	}
 	// Wite rendered content into openvpn server conf.
-	return emitToFile(outPath, result, 0)
+	return emitToFile(path, result, 0)
 
 }
 
 // GetSystemCA returns the system CA from the database if available.
-func GetSystemCA() (*CA, error) {
+func GetSystemCA() (*pki.CA, error) {
 	server := DBServer{}
 	db.First(&server)
 	if db.NewRecord(&server) {
 		return nil, fmt.Errorf("server record does not exists in db")
 	}
-	return &CA{
-		CertHolder: CertHolder{
+	return &pki.CA{
+		CertHolder: pki.CertHolder{
 			Cert: server.CACert,
 			Key:  server.CAKey,
 		},
@@ -266,10 +268,10 @@ func Emit() error {
 	return nil
 }
 
-func emitToFile(filePath, content string, mode uint) error {
-	file, err := os.Create(filePath)
+func emitToFile(path, content string, mode uint) error {
+	file, err := os.Create(path)
 	if err != nil {
-		return fmt.Errorf("Cannot create file %s: %v", filePath, err)
+		return fmt.Errorf("Cannot create file %s: %v", path, err)
 
 	}
 	if mode != 0 {
@@ -293,15 +295,15 @@ func emitServerConf() error {
 	var result bytes.Buffer
 
 	server := _VPNServerConfig{
-		CertPath:     DefaultCertPath,
-		KeyPath:      DefaultKeyPath,
-		CACertPath:   DefaultCACertPath,
-		CAKeyPath:    DefaultCAKeyPath,
-		CCDPath:      DefaultVPNCCDPath,
-		CRLPath:      DefaultCRLPath,
-		DHParamsPath: DefaultDHParamsPath,
-		Net:          DefaultServerNetwork,
-		Mask:         DefaultServerNetMask,
+		CertPath:     _DefaultCertPath,
+		KeyPath:      _DefaultKeyPath,
+		CACertPath:   _DefaultCACertPath,
+		CAKeyPath:    _DefaultCAKeyPath,
+		CCDPath:      _DefaultVPNCCDPath,
+		CRLPath:      _DefaultCRLPath,
+		DHParamsPath: _DefaultDHParamsPath,
+		Net:          _DefaultServerNetwork,
+		Mask:         _DefaultServerNetMask,
 		Port:         port,
 	}
 	data, err := bindata.Asset("template/server.conf.tmpl")
@@ -320,7 +322,7 @@ func emitServerConf() error {
 	}
 
 	// Wite rendered content into openvpn server conf.
-	return emitToFile(DefaultVPNConfPath, result.String(), 0)
+	return emitToFile(_DefaultVPNConfPath, result.String(), 0)
 }
 
 // GetServerInstance returns the default server from the database.
@@ -350,7 +352,7 @@ func emitServerKey() error {
 	}
 
 	// Write rendered content into key file.
-	return emitToFile(DefaultKeyPath, server.Key, 0600)
+	return emitToFile(_DefaultKeyPath, server.Key, 0600)
 }
 
 func emitServerCert() error {
@@ -360,7 +362,7 @@ func emitServerCert() error {
 	}
 
 	// Write rendered content into the cert file.
-	return emitToFile(DefaultCertPath, server.Cert, 0)
+	return emitToFile(_DefaultCertPath, server.Cert, 0)
 }
 
 func emitCRL() error {
@@ -376,12 +378,12 @@ func emitCRL() error {
 	if err != nil {
 		return fmt.Errorf("can not emit CRL: %v", err)
 	}
-	crl, err := NewCRL(revokedCertSerials, systemCA)
+	crl, err := pki.NewCRL(revokedCertSerials, systemCA)
 	if err != nil {
 		return fmt.Errorf("can not emit crl: %v", err)
 	}
 
-	return emitToFile(DefaultCRLPath, crl, 0)
+	return emitToFile(_DefaultCRLPath, crl, 0)
 }
 
 func emitCACert() error {
@@ -391,7 +393,7 @@ func emitCACert() error {
 	}
 
 	// Write rendered content into the ca cert file.
-	return emitToFile(DefaultCACertPath, server.CACert, 0)
+	return emitToFile(_DefaultCACertPath, server.CACert, 0)
 }
 
 func emitCAKey() error {
@@ -401,7 +403,7 @@ func emitCAKey() error {
 	}
 
 	// Write rendered content into the ca key file.
-	return emitToFile(DefaultCAKeyPath, server.CAKey, 0600)
+	return emitToFile(_DefaultCAKeyPath, server.CAKey, 0600)
 }
 
 func emitCCD() error {
@@ -411,9 +413,9 @@ func emitCCD() error {
 	}
 
 	// Create and write rendered ccd data.
-	os.Mkdir(DefaultVPNCCDPath, 0755)
-	clientsNetMask := net.IPMask(net.ParseIP(DefaultServerNetMask))
-	clientsNetPrefix := net.ParseIP(DefaultServerNetwork)
+	os.Mkdir(_DefaultVPNCCDPath, 0755)
+	clientsNetMask := net.IPMask(net.ParseIP(_DefaultServerNetMask))
+	clientsNetPrefix := net.ParseIP(_DefaultServerNetwork)
 	clientNet := clientsNetPrefix.Mask(clientsNetMask).To4()
 
 	counter := 2
@@ -423,7 +425,7 @@ func emitCCD() error {
 		params := struct {
 			IP      string
 			NetMask string
-		}{IP: clientNet.String(), NetMask: DefaultServerNetMask}
+		}{IP: clientNet.String(), NetMask: _DefaultServerNetMask}
 
 		data, err := bindata.Asset("template/ccd.file.tmpl")
 		if err != nil {
@@ -439,7 +441,7 @@ func emitCCD() error {
 			return fmt.Errorf("can not render ccd file %s: %s", user.Username, err)
 		}
 
-		err = emitToFile(DefaultVPNCCDPath+user.Username, result.String(), 0)
+		err = emitToFile(_DefaultVPNCCDPath+user.Username, result.String(), 0)
 		if err != nil {
 			return err
 		}
@@ -465,7 +467,7 @@ func emitDHParams() error {
 		return fmt.Errorf("can not render dh4096.pem file: %s", err)
 	}
 
-	err = emitToFile(DefaultDHParamsPath, result.String(), 0)
+	err = emitToFile(_DefaultDHParamsPath, result.String(), 0)
 	if err != nil {
 		return err
 	}