net.go 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. package ovpm
  2. import (
  3. "encoding/binary"
  4. "fmt"
  5. "net"
  6. "time"
  7. "github.com/Sirupsen/logrus"
  8. "github.com/asaskevich/govalidator"
  9. "github.com/coreos/go-iptables/iptables"
  10. "github.com/jinzhu/gorm"
  11. )
  12. // DBNetwork is database model for external networks on the VPN server.
  13. type DBNetwork struct {
  14. gorm.Model
  15. ServerID uint
  16. Server DBServer
  17. Name string `gorm:"unique_index"`
  18. CIDR string
  19. }
  20. // GetNetwork returns a network specified by its name.
  21. func GetNetwork(name string) (*DBNetwork, error) {
  22. if !IsInitialized() {
  23. return nil, fmt.Errorf("you first need to create server")
  24. }
  25. // Validate user input.
  26. if govalidator.IsNull(name) {
  27. return nil, fmt.Errorf("validation error: %s can not be null", name)
  28. }
  29. if !govalidator.IsAlphanumeric(name) {
  30. return nil, fmt.Errorf("validation error: `%s` can only contain letters and numbers", name)
  31. }
  32. var network DBNetwork
  33. db.Where(&DBNetwork{Name: name}).First(&network)
  34. if db.NewRecord(&network) {
  35. return nil, fmt.Errorf("network not found %s", name)
  36. }
  37. return &network, nil
  38. }
  39. // GetAllNetworks returns all networks defined in the system.
  40. func GetAllNetworks() ([]*DBNetwork, error) {
  41. var networks []*DBNetwork
  42. db.Find(&networks)
  43. return networks, nil
  44. }
  45. // CreateNewNetwork creates a new network definition in the system.
  46. func CreateNewNetwork(name, cidr string) (*DBNetwork, error) {
  47. if !IsInitialized() {
  48. return nil, fmt.Errorf("you first need to create server")
  49. }
  50. // Validate user input.
  51. if govalidator.IsNull(name) {
  52. return nil, fmt.Errorf("validation error: %s can not be null", name)
  53. }
  54. if !govalidator.IsAlphanumeric(name) {
  55. return nil, fmt.Errorf("validation error: `%s` can only contain letters and numbers", name)
  56. }
  57. if !govalidator.IsCIDR(cidr) {
  58. return nil, fmt.Errorf("validation error: `%s` must be a network in the CIDR form", name)
  59. }
  60. _, ipnet, err := net.ParseCIDR(cidr)
  61. if err != nil {
  62. return nil, fmt.Errorf("can not parse CIDR %s: %v", cidr, err)
  63. }
  64. network := DBNetwork{
  65. Name: name,
  66. CIDR: ipnet.String(),
  67. }
  68. db.Save(&network)
  69. if db.NewRecord(&network) {
  70. return nil, fmt.Errorf("can not create network in the db")
  71. }
  72. return &network, nil
  73. }
  74. // Delete deletes a network definition in the system.
  75. func (n *DBNetwork) Delete() error {
  76. if !IsInitialized() {
  77. return fmt.Errorf("you first need to create server")
  78. }
  79. db.Unscoped().Delete(n)
  80. logrus.Infof("network deleted: %s", n.Name)
  81. return nil
  82. }
  83. // GetName returns network's name.
  84. func (n *DBNetwork) GetName() string {
  85. return n.Name
  86. }
  87. // GetCIDR returns network's CIDR.
  88. func (n *DBNetwork) GetCIDR() string {
  89. return n.CIDR
  90. }
  91. // GetCreatedAt returns network's name.
  92. func (n *DBNetwork) GetCreatedAt() string {
  93. return n.CreatedAt.Format(time.UnixDate)
  94. }
  95. // routedInterface returns a network interface that can route IP
  96. // traffic and satisfies flags. It returns nil when an appropriate
  97. // network interface is not found. Network must be "ip", "ip4" or
  98. // "ip6".
  99. func routedInterface(network string, flags net.Flags) *net.Interface {
  100. switch network {
  101. case "ip", "ip4", "ip6":
  102. default:
  103. return nil
  104. }
  105. ift, err := net.Interfaces()
  106. if err != nil {
  107. return nil
  108. }
  109. for _, ifi := range ift {
  110. if ifi.Flags&flags != flags {
  111. continue
  112. }
  113. if _, ok := hasRoutableIP(network, &ifi); !ok {
  114. continue
  115. }
  116. return &ifi
  117. }
  118. return nil
  119. }
  120. func hasRoutableIP(network string, ifi *net.Interface) (net.IP, bool) {
  121. ifat, err := ifi.Addrs()
  122. if err != nil {
  123. return nil, false
  124. }
  125. for _, ifa := range ifat {
  126. switch ifa := ifa.(type) {
  127. case *net.IPAddr:
  128. if ip := routableIP(network, ifa.IP); ip != nil {
  129. return ip, true
  130. }
  131. case *net.IPNet:
  132. if ip := routableIP(network, ifa.IP); ip != nil {
  133. return ip, true
  134. }
  135. }
  136. }
  137. return nil, false
  138. }
  139. func vpnInterface() *net.Interface {
  140. mask := net.IPMask(net.ParseIP(_DefaultServerNetMask))
  141. prefix := net.ParseIP(_DefaultServerNetwork)
  142. netw := prefix.Mask(mask).To4()
  143. netw[3] = byte(1) // Server is always gets xxx.xxx.xxx.1
  144. ipnet := net.IPNet{IP: netw, Mask: mask}
  145. ifs, err := net.Interfaces()
  146. if err != nil {
  147. logrus.Errorf("can not get system network interfaces: %v", err)
  148. return nil
  149. }
  150. for _, ifc := range ifs {
  151. addrs, err := ifc.Addrs()
  152. if err != nil {
  153. logrus.Errorf("can not get interface addresses: %v", err)
  154. return nil
  155. }
  156. for _, addr := range addrs {
  157. //logrus.Debugf("addr: %s == %s", addr.String(), ipnet.String())
  158. if addr.String() == ipnet.String() {
  159. return &ifc
  160. }
  161. }
  162. }
  163. return nil
  164. }
  165. func routableIP(network string, ip net.IP) net.IP {
  166. if !ip.IsLoopback() && !ip.IsLinkLocalUnicast() && !ip.IsGlobalUnicast() {
  167. return nil
  168. }
  169. switch network {
  170. case "ip4":
  171. if ip := ip.To4(); ip != nil {
  172. return ip
  173. }
  174. case "ip6":
  175. if ip.IsLoopback() { // addressing scope of the loopback address depends on each implementation
  176. return nil
  177. }
  178. if ip := ip.To16(); ip != nil && ip.To4() == nil {
  179. return ip
  180. }
  181. default:
  182. if ip := ip.To4(); ip != nil {
  183. return ip
  184. }
  185. if ip := ip.To16(); ip != nil {
  186. return ip
  187. }
  188. }
  189. return nil
  190. }
  191. // ensureNatEnabled launches a goroutine that constantly tries to enable nat.
  192. func ensureNatEnabled() {
  193. // Nat enablerer
  194. go func() {
  195. for {
  196. err := enableNat()
  197. if err == nil {
  198. logrus.Debug("nat is enabled")
  199. return
  200. }
  201. logrus.Debugf("can not enable nat: %v", err)
  202. // TODO(cad): employ a exponential back-off approach here
  203. // instead of sleeping for the constant duration.
  204. time.Sleep(1 * time.Second)
  205. }
  206. }()
  207. }
  208. // enableNat is an idempotent command that ensures nat is enabled for the vpn server.
  209. func enableNat() error {
  210. rif := routedInterface("ip", net.FlagUp|net.FlagBroadcast)
  211. if rif == nil {
  212. return fmt.Errorf("can not get routable network interface")
  213. }
  214. vpnIfc := vpnInterface()
  215. if vpnIfc == nil {
  216. return fmt.Errorf("can not get vpn network interface on the system")
  217. }
  218. // Enable ip forwarding.
  219. emitToFile("/proc/sys/net/ipv4/ip_forward", "1", 0)
  220. ipt, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
  221. if err != nil {
  222. return fmt.Errorf("can not create new iptables object: %v", err)
  223. }
  224. // Append iptables nat rules.
  225. ipt.AppendUnique("nat", "POSTROUTING", "-o", rif.Name, "-j", "MASQUERADE")
  226. // TODO(cad): we should use the interface name that we get when we query the system
  227. // with the vpn server's internal ip address, instead of default "tun0".
  228. ipt.AppendUnique("filter", "FORWARD", "-i", rif.Name, "-o", vpnIfc.Name, "-m", "state", "--state", "RELATED, ESTABLISHED", "-j", "ACCEPT")
  229. ipt.AppendUnique("filter", "FORWARD", "-i", vpnIfc.Name, "-o", rif.Name, "-j", "ACCEPT")
  230. return nil
  231. }
  232. // HostID2IP converts a host id (32-bit unsigned integer) to an IP address.
  233. func HostID2IP(hostid uint32) net.IP {
  234. ip := make([]byte, 4)
  235. binary.BigEndian.PutUint32(ip, hostid)
  236. return net.IP(ip)
  237. }
  238. //IP2HostID converts an IP address to a host id (32-bit unsigned integer).
  239. func IP2HostID(ip net.IP) uint32 {
  240. hostid := binary.BigEndian.Uint32(ip)
  241. return hostid
  242. }