Browse Source

feat(vpn): ensure nat is enabled when starting vpn

Closes #13
Mustafa Arici 8 years ago
parent
commit
b6e48777d0
4 changed files with 165 additions and 5 deletions
  1. 5 5
      bindata/bindata.go
  2. 16 0
      cmd/ovpmd/main.go
  3. 139 0
      net.go
  4. 5 0
      supervisor/supervisor.go

+ 5 - 5
bindata/bindata.go

@@ -87,7 +87,7 @@ func templateCcdFileTmpl() (*asset, error) {
 		return nil, err
 	}
 
-	info := bindataFileInfo{name: "template/ccd.file.tmpl", size: 74, mode: os.FileMode(436), modTime: time.Unix(1500555693, 0)}
+	info := bindataFileInfo{name: "template/ccd.file.tmpl", size: 74, mode: os.FileMode(420), modTime: time.Unix(1501822328, 0)}
 	a := &asset{bytes: bytes, info: info}
 	return a, nil
 }
@@ -107,7 +107,7 @@ func templateClientOvpnTmpl() (*asset, error) {
 		return nil, err
 	}
 
-	info := bindataFileInfo{name: "template/client.ovpn.tmpl", size: 306, mode: os.FileMode(436), modTime: time.Unix(1502657766, 0)}
+	info := bindataFileInfo{name: "template/client.ovpn.tmpl", size: 306, mode: os.FileMode(420), modTime: time.Unix(1502656204, 0)}
 	a := &asset{bytes: bytes, info: info}
 	return a, nil
 }
@@ -127,7 +127,7 @@ func templateDh4096PemTmpl() (*asset, error) {
 		return nil, err
 	}
 
-	info := bindataFileInfo{name: "template/dh4096.pem.tmpl", size: 1468, mode: os.FileMode(436), modTime: time.Unix(1500555693, 0)}
+	info := bindataFileInfo{name: "template/dh4096.pem.tmpl", size: 1468, mode: os.FileMode(420), modTime: time.Unix(1501822328, 0)}
 	a := &asset{bytes: bytes, info: info}
 	return a, nil
 }
@@ -147,7 +147,7 @@ func templateIptablesTmpl() (*asset, error) {
 		return nil, err
 	}
 
-	info := bindataFileInfo{name: "template/iptables.tmpl", size: 0, mode: os.FileMode(436), modTime: time.Unix(1500555693, 0)}
+	info := bindataFileInfo{name: "template/iptables.tmpl", size: 0, mode: os.FileMode(420), modTime: time.Unix(1501822328, 0)}
 	a := &asset{bytes: bytes, info: info}
 	return a, nil
 }
@@ -167,7 +167,7 @@ func templateServerConfTmpl() (*asset, error) {
 		return nil, err
 	}
 
-	info := bindataFileInfo{name: "template/server.conf.tmpl", size: 9585, mode: os.FileMode(436), modTime: time.Unix(1502270416, 0)}
+	info := bindataFileInfo{name: "template/server.conf.tmpl", size: 9585, mode: os.FileMode(420), modTime: time.Unix(1502173553, 0)}
 	a := &asset{bytes: bytes, info: info}
 	return a, nil
 }

+ 16 - 0
cmd/ovpmd/main.go

@@ -98,6 +98,22 @@ func (s *server) start() {
 	logrus.Infof("OVPM is running :%s ...", s.port)
 	go s.grpcServer.Serve(s.lis)
 	ovpm.StartVPNProc()
+
+	// Nat enablerer
+	go func() {
+		for {
+			err := ovpm.EnsureNatEnabled()
+			if err == nil {
+				logrus.Debug("nat is enabled")
+				return
+			}
+			logrus.Debugf("can not enable nat: %v", err)
+			// TODO(cad): employ a exponential back-off approach here
+			// instead of sleeping for the constant duration.
+			time.Sleep(1 * time.Second)
+		}
+
+	}()
 }
 
 func (s *server) stop() {

+ 139 - 0
net.go

@@ -0,0 +1,139 @@
+package ovpm
+
+import (
+	"fmt"
+	"net"
+
+	"github.com/Sirupsen/logrus"
+	"github.com/coreos/go-iptables/iptables"
+)
+
+// routedInterface returns a network interface that can route IP
+// traffic and satisfies flags. It returns nil when an appropriate
+// network interface is not found. Network must be "ip", "ip4" or
+// "ip6".
+func routedInterface(network string, flags net.Flags) *net.Interface {
+	switch network {
+	case "ip", "ip4", "ip6":
+	default:
+		return nil
+	}
+	ift, err := net.Interfaces()
+	if err != nil {
+		return nil
+	}
+	for _, ifi := range ift {
+		if ifi.Flags&flags != flags {
+			continue
+		}
+		if _, ok := hasRoutableIP(network, &ifi); !ok {
+			continue
+		}
+		return &ifi
+	}
+	return nil
+}
+
+func hasRoutableIP(network string, ifi *net.Interface) (net.IP, bool) {
+	ifat, err := ifi.Addrs()
+	if err != nil {
+		return nil, false
+	}
+	for _, ifa := range ifat {
+		switch ifa := ifa.(type) {
+		case *net.IPAddr:
+			if ip := routableIP(network, ifa.IP); ip != nil {
+				return ip, true
+			}
+		case *net.IPNet:
+			if ip := routableIP(network, ifa.IP); ip != nil {
+				return ip, true
+			}
+		}
+	}
+	return nil, false
+}
+
+func vpnInterface() *net.Interface {
+	mask := net.IPMask(net.ParseIP(_DefaultServerNetMask))
+	prefix := net.ParseIP(_DefaultServerNetwork)
+	netw := prefix.Mask(mask).To4()
+	netw[3] = byte(1) // Server is always gets xxx.xxx.xxx.1
+	ipnet := net.IPNet{IP: netw, Mask: mask}
+
+	ifs, err := net.Interfaces()
+	if err != nil {
+		logrus.Errorf("can not get system network interfaces: %v", err)
+		return nil
+	}
+
+	for _, ifc := range ifs {
+		addrs, err := ifc.Addrs()
+		if err != nil {
+			logrus.Errorf("can not get interface addresses: %v", err)
+			return nil
+		}
+		for _, addr := range addrs {
+			//logrus.Debugf("addr: %s == %s", addr.String(), ipnet.String())
+			if addr.String() == ipnet.String() {
+				return &ifc
+			}
+		}
+	}
+	return nil
+}
+
+func routableIP(network string, ip net.IP) net.IP {
+	if !ip.IsLoopback() && !ip.IsLinkLocalUnicast() && !ip.IsGlobalUnicast() {
+		return nil
+	}
+	switch network {
+	case "ip4":
+		if ip := ip.To4(); ip != nil {
+			return ip
+		}
+	case "ip6":
+		if ip.IsLoopback() { // addressing scope of the loopback address depends on each implementation
+			return nil
+		}
+		if ip := ip.To16(); ip != nil && ip.To4() == nil {
+			return ip
+		}
+	default:
+		if ip := ip.To4(); ip != nil {
+			return ip
+		}
+		if ip := ip.To16(); ip != nil {
+			return ip
+		}
+	}
+	return nil
+}
+
+// EnsureNatEnabled is an idempotent command that ensures nat is enabled for the vpn server.
+func EnsureNatEnabled() error {
+	rif := routedInterface("ip", net.FlagUp|net.FlagBroadcast)
+	if rif == nil {
+		return fmt.Errorf("can not get routable network interface")
+	}
+
+	vpnIfc := vpnInterface()
+	if vpnIfc == nil {
+		return fmt.Errorf("can not get vpn network interface on the system")
+	}
+
+	// Enable ip forwarding.
+	emitToFile("/proc/sys/net/ipv4/ip_forward", "1", 0)
+	ipt, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
+	if err != nil {
+		return fmt.Errorf("can not create new iptables object: %v", err)
+	}
+
+	// Append iptables nat rules.
+	ipt.AppendUnique("nat", "POSTROUTING", "-o", rif.Name, "-j", "MASQUERADE")
+	// TODO(cad): we should use the interface name that we get when we query the system
+	// with the vpn server's internal ip address, instead of default "tun0".
+	ipt.AppendUnique("filter", "FORWARD", "-i", rif.Name, "-o", vpnIfc.Name, "-m", "state", "--state", "RELATED, ESTABLISHED", "-j", "ACCEPT")
+	ipt.AppendUnique("filter", "FORWARD", "-i", vpnIfc.Name, "-o", rif.Name, "-j", "ACCEPT")
+	return nil
+}

+ 5 - 0
supervisor/supervisor.go

@@ -131,6 +131,11 @@ func (p *Process) waitFor(state State) {
 	}
 }
 
+// WaitFor blocks until the FSM transitions to the given state.
+func WaitFor(process *Process, state State) {
+	process.waitFor(state)
+}
+
 // Start will run the process.
 func (p *Process) Start() {
 	p.transitionTo(STARTING)