1
0

net.go 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. package ovpm
  2. import (
  3. "fmt"
  4. "net"
  5. "github.com/Sirupsen/logrus"
  6. "github.com/coreos/go-iptables/iptables"
  7. "time"
  8. )
  9. // routedInterface returns a network interface that can route IP
  10. // traffic and satisfies flags. It returns nil when an appropriate
  11. // network interface is not found. Network must be "ip", "ip4" or
  12. // "ip6".
  13. func routedInterface(network string, flags net.Flags) *net.Interface {
  14. switch network {
  15. case "ip", "ip4", "ip6":
  16. default:
  17. return nil
  18. }
  19. ift, err := net.Interfaces()
  20. if err != nil {
  21. return nil
  22. }
  23. for _, ifi := range ift {
  24. if ifi.Flags&flags != flags {
  25. continue
  26. }
  27. if _, ok := hasRoutableIP(network, &ifi); !ok {
  28. continue
  29. }
  30. return &ifi
  31. }
  32. return nil
  33. }
  34. func hasRoutableIP(network string, ifi *net.Interface) (net.IP, bool) {
  35. ifat, err := ifi.Addrs()
  36. if err != nil {
  37. return nil, false
  38. }
  39. for _, ifa := range ifat {
  40. switch ifa := ifa.(type) {
  41. case *net.IPAddr:
  42. if ip := routableIP(network, ifa.IP); ip != nil {
  43. return ip, true
  44. }
  45. case *net.IPNet:
  46. if ip := routableIP(network, ifa.IP); ip != nil {
  47. return ip, true
  48. }
  49. }
  50. }
  51. return nil, false
  52. }
  53. func vpnInterface() *net.Interface {
  54. mask := net.IPMask(net.ParseIP(_DefaultServerNetMask))
  55. prefix := net.ParseIP(_DefaultServerNetwork)
  56. netw := prefix.Mask(mask).To4()
  57. netw[3] = byte(1) // Server is always gets xxx.xxx.xxx.1
  58. ipnet := net.IPNet{IP: netw, Mask: mask}
  59. ifs, err := net.Interfaces()
  60. if err != nil {
  61. logrus.Errorf("can not get system network interfaces: %v", err)
  62. return nil
  63. }
  64. for _, ifc := range ifs {
  65. addrs, err := ifc.Addrs()
  66. if err != nil {
  67. logrus.Errorf("can not get interface addresses: %v", err)
  68. return nil
  69. }
  70. for _, addr := range addrs {
  71. //logrus.Debugf("addr: %s == %s", addr.String(), ipnet.String())
  72. if addr.String() == ipnet.String() {
  73. return &ifc
  74. }
  75. }
  76. }
  77. return nil
  78. }
  79. func routableIP(network string, ip net.IP) net.IP {
  80. if !ip.IsLoopback() && !ip.IsLinkLocalUnicast() && !ip.IsGlobalUnicast() {
  81. return nil
  82. }
  83. switch network {
  84. case "ip4":
  85. if ip := ip.To4(); ip != nil {
  86. return ip
  87. }
  88. case "ip6":
  89. if ip.IsLoopback() { // addressing scope of the loopback address depends on each implementation
  90. return nil
  91. }
  92. if ip := ip.To16(); ip != nil && ip.To4() == nil {
  93. return ip
  94. }
  95. default:
  96. if ip := ip.To4(); ip != nil {
  97. return ip
  98. }
  99. if ip := ip.To16(); ip != nil {
  100. return ip
  101. }
  102. }
  103. return nil
  104. }
  105. // ensureNatEnabled launches a goroutine that constantly tries to enable nat.
  106. func ensureNatEnabled() {
  107. // Nat enablerer
  108. go func() {
  109. for {
  110. err := enableNat()
  111. if err == nil {
  112. logrus.Debug("nat is enabled")
  113. return
  114. }
  115. logrus.Debugf("can not enable nat: %v", err)
  116. // TODO(cad): employ a exponential back-off approach here
  117. // instead of sleeping for the constant duration.
  118. time.Sleep(1 * time.Second)
  119. }
  120. }()
  121. }
  122. // enableNat is an idempotent command that ensures nat is enabled for the vpn server.
  123. func enableNat() error {
  124. rif := routedInterface("ip", net.FlagUp|net.FlagBroadcast)
  125. if rif == nil {
  126. return fmt.Errorf("can not get routable network interface")
  127. }
  128. vpnIfc := vpnInterface()
  129. if vpnIfc == nil {
  130. return fmt.Errorf("can not get vpn network interface on the system")
  131. }
  132. // Enable ip forwarding.
  133. emitToFile("/proc/sys/net/ipv4/ip_forward", "1", 0)
  134. ipt, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
  135. if err != nil {
  136. return fmt.Errorf("can not create new iptables object: %v", err)
  137. }
  138. // Append iptables nat rules.
  139. ipt.AppendUnique("nat", "POSTROUTING", "-o", rif.Name, "-j", "MASQUERADE")
  140. // TODO(cad): we should use the interface name that we get when we query the system
  141. // with the vpn server's internal ip address, instead of default "tun0".
  142. ipt.AppendUnique("filter", "FORWARD", "-i", rif.Name, "-o", vpnIfc.Name, "-m", "state", "--state", "RELATED, ESTABLISHED", "-j", "ACCEPT")
  143. ipt.AppendUnique("filter", "FORWARD", "-i", vpnIfc.Name, "-o", rif.Name, "-j", "ACCEPT")
  144. return nil
  145. }