net.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531
  1. package ovpm
  2. import (
  3. "encoding/binary"
  4. "errors"
  5. "fmt"
  6. "log"
  7. "net"
  8. "time"
  9. "github.com/Sirupsen/logrus"
  10. "github.com/asaskevich/govalidator"
  11. "github.com/coreos/go-iptables/iptables"
  12. "github.com/jinzhu/gorm"
  13. )
  14. // NetworkType distinguishes different types of networks that is defined in the networks table.
  15. type NetworkType uint
  16. // NetworkTypes
  17. const (
  18. UNDEFINEDNET NetworkType = iota
  19. SERVERNET
  20. ROUTE
  21. )
  22. var networkTypes = [...]struct {
  23. Type NetworkType
  24. String string
  25. Description string
  26. }{
  27. {UNDEFINEDNET, "UNDEFINEDNET", "unknown network type"},
  28. {SERVERNET, "SERVERNET", "network behind vpn server"},
  29. {ROUTE, "ROUTE", "network to be pushed as route"},
  30. }
  31. // NetworkTypeFromString returns string representation of the network type.
  32. func NetworkTypeFromString(typ string) NetworkType {
  33. for _, v := range networkTypes {
  34. if v.String == typ {
  35. return v.Type
  36. }
  37. }
  38. return UNDEFINEDNET
  39. }
  40. // GetAllNetworkTypes returns all network types defined in the system.
  41. func GetAllNetworkTypes() []NetworkType {
  42. var networkTypeList []NetworkType
  43. for _, v := range networkTypes {
  44. networkTypeList = append(networkTypeList, v.Type)
  45. }
  46. return networkTypeList
  47. }
  48. func (nt NetworkType) String() string {
  49. for _, v := range networkTypes {
  50. if v.Type == nt {
  51. return v.String
  52. }
  53. }
  54. return "UNDEFINEDNET"
  55. }
  56. func (nt NetworkType) Description() string {
  57. for _, v := range networkTypes {
  58. if v.Type == nt {
  59. return v.Description
  60. }
  61. }
  62. return "UNDEFINEDNET"
  63. }
  64. // DBNetwork is database model for external networks on the VPN server.
  65. type DBNetwork struct {
  66. gorm.Model
  67. ServerID uint
  68. Server DBServer
  69. Name string `gorm:"unique_index"`
  70. CIDR string
  71. Type NetworkType
  72. Via string
  73. Users []*DBUser `gorm:"many2many:network_users;"`
  74. }
  75. // GetNetwork returns a network specified by its name.
  76. func GetNetwork(name string) (*DBNetwork, error) {
  77. if !IsInitialized() {
  78. return nil, fmt.Errorf("you first need to create server")
  79. }
  80. // Validate user input.
  81. if govalidator.IsNull(name) {
  82. return nil, fmt.Errorf("validation error: %s can not be null", name)
  83. }
  84. if !govalidator.IsAlphanumeric(name) {
  85. return nil, fmt.Errorf("validation error: `%s` can only contain letters and numbers", name)
  86. }
  87. var network DBNetwork
  88. db.Preload("Users").Where(&DBNetwork{Name: name}).First(&network)
  89. if db.NewRecord(&network) {
  90. return nil, fmt.Errorf("network not found %s", name)
  91. }
  92. return &network, nil
  93. }
  94. // GetAllNetworks returns all networks defined in the system.
  95. func GetAllNetworks() []*DBNetwork {
  96. var networks []*DBNetwork
  97. db.Preload("Users").Find(&networks)
  98. return networks
  99. }
  100. // CreateNewNetwork creates a new network definition in the system.
  101. func CreateNewNetwork(name, cidr string, nettype NetworkType, via string) (*DBNetwork, error) {
  102. if !IsInitialized() {
  103. return nil, fmt.Errorf("you first need to create server")
  104. }
  105. // Validate user input.
  106. if govalidator.IsNull(name) {
  107. return nil, fmt.Errorf("validation error: %s can not be null", name)
  108. }
  109. if !govalidator.IsAlphanumeric(name) {
  110. return nil, fmt.Errorf("validation error: `%s` can only contain letters and numbers", name)
  111. }
  112. if !govalidator.IsCIDR(cidr) {
  113. return nil, fmt.Errorf("validation error: `%s` must be a network in the CIDR form", cidr)
  114. }
  115. if !govalidator.IsCIDR(via) && via != "" {
  116. return nil, fmt.Errorf("validation error: `%s` must be a network in the CIDR form", via)
  117. }
  118. if nettype == UNDEFINEDNET {
  119. return nil, fmt.Errorf("validation error: `%s` must be a valid network type", nettype)
  120. }
  121. _, ipnet, err := net.ParseCIDR(cidr)
  122. if err != nil {
  123. return nil, fmt.Errorf("can not parse CIDR %s: %v", cidr, err)
  124. }
  125. // Overwrite via with the parsed CIDR string.
  126. if nettype == ROUTE && via != "" {
  127. _, viaNet, err := net.ParseCIDR(via)
  128. if err != nil {
  129. return nil, fmt.Errorf("can not parse CIDR %s: %v", via, err)
  130. }
  131. via = viaNet.String()
  132. } else {
  133. via = ""
  134. }
  135. network := DBNetwork{
  136. Name: name,
  137. CIDR: ipnet.String(),
  138. Type: nettype,
  139. Users: []*DBUser{},
  140. Via: via,
  141. }
  142. db.Save(&network)
  143. if db.NewRecord(&network) {
  144. return nil, fmt.Errorf("can not create network in the db")
  145. }
  146. Emit()
  147. logrus.Infof("network defined: %s (%s)", network.Name, network.CIDR)
  148. return &network, nil
  149. }
  150. // Delete deletes a network definition in the system.
  151. func (n *DBNetwork) Delete() error {
  152. if !IsInitialized() {
  153. return fmt.Errorf("you first need to create server")
  154. }
  155. db.Unscoped().Delete(n)
  156. Emit()
  157. logrus.Infof("network deleted: %s", n.Name)
  158. return nil
  159. }
  160. // Associate allows the given user access to this network.
  161. func (n *DBNetwork) Associate(username string) error {
  162. if !IsInitialized() {
  163. return fmt.Errorf("you first need to create server")
  164. }
  165. user, err := GetUser(username)
  166. if err != nil {
  167. return fmt.Errorf("user can not be fetched: %v", err)
  168. }
  169. var users []DBUser
  170. userAssoc := db.Model(&n).Association("Users")
  171. userAssoc.Find(&users)
  172. var found bool
  173. for _, u := range users {
  174. if u.ID == user.ID {
  175. found = true
  176. break
  177. }
  178. }
  179. if found {
  180. return fmt.Errorf("user %s is already associated with the network %s", user.Username, n.Name)
  181. }
  182. userAssoc.Append(user)
  183. if userAssoc.Error != nil {
  184. return fmt.Errorf("association failed: %v", userAssoc.Error)
  185. }
  186. Emit()
  187. logrus.Infof("user '%s' is associated with the network '%s'", user.GetUsername(), n.Name)
  188. return nil
  189. }
  190. // Dissociate breaks up the given users association to the said network.
  191. func (n *DBNetwork) Dissociate(username string) error {
  192. if !IsInitialized() {
  193. return fmt.Errorf("you first need to create server")
  194. }
  195. user, err := GetUser(username)
  196. if err != nil {
  197. return fmt.Errorf("user can not be fetched: %v", err)
  198. }
  199. var users []DBUser
  200. userAssoc := db.Model(&n).Association("Users")
  201. userAssoc.Find(&users)
  202. var found bool
  203. for _, u := range users {
  204. if u.ID == user.ID {
  205. found = true
  206. break
  207. }
  208. }
  209. if !found {
  210. return fmt.Errorf("user %s is already not associated with the network %s", user.Username, n.Name)
  211. }
  212. userAssoc.Delete(user)
  213. if userAssoc.Error != nil {
  214. return fmt.Errorf("disassociation failed: %v", userAssoc.Error)
  215. }
  216. Emit()
  217. logrus.Infof("user '%s' is dissociated with the network '%s'", user.GetUsername(), n.Name)
  218. return nil
  219. }
  220. // GetName returns network's name.
  221. func (n *DBNetwork) GetName() string {
  222. return n.Name
  223. }
  224. // GetCIDR returns network's CIDR.
  225. func (n *DBNetwork) GetCIDR() string {
  226. return n.CIDR
  227. }
  228. // GetCreatedAt returns network's name.
  229. func (n *DBNetwork) GetCreatedAt() string {
  230. return n.CreatedAt.Format(time.UnixDate)
  231. }
  232. // GetType returns network's network type.
  233. func (n *DBNetwork) GetType() NetworkType {
  234. return NetworkType(n.Type)
  235. }
  236. // GetAssociatedUsers returns network's associated users.
  237. func (n *DBNetwork) GetAssociatedUsers() []*DBUser {
  238. return n.Users
  239. }
  240. // GetAssociatedUsernames returns network's associated user names.
  241. func (n *DBNetwork) GetAssociatedUsernames() []string {
  242. var usernames []string
  243. for _, user := range n.GetAssociatedUsers() {
  244. usernames = append(usernames, user.Username)
  245. }
  246. return usernames
  247. }
  248. // GetVia returns network' via.
  249. func (n *DBNetwork) GetVia() string {
  250. return n.Via
  251. }
  252. // interfaceOfIP returns a network interface that has the given IP.
  253. func interfaceOfIP(ipnet *net.IPNet) *net.Interface {
  254. ifaces, err := net.Interfaces()
  255. if err != nil {
  256. return nil
  257. }
  258. for _, iface := range ifaces {
  259. addrs, err := iface.Addrs()
  260. if err != nil {
  261. logrus.Error(err)
  262. return nil
  263. }
  264. for _, addr := range addrs {
  265. switch addr := addr.(type) {
  266. case *net.IPAddr:
  267. if ip := addr.IP; ip != nil {
  268. if ipnet.Contains(ip) {
  269. return &iface
  270. }
  271. }
  272. case *net.IPNet:
  273. if ip := addr.IP; ip != nil {
  274. if ipnet.Contains(ip) {
  275. return &iface
  276. }
  277. }
  278. }
  279. }
  280. }
  281. return nil
  282. }
  283. func routedInterface(network string, flags net.Flags) *net.Interface {
  284. switch network {
  285. case "ip", "ip4", "ip6":
  286. default:
  287. return nil
  288. }
  289. ift, err := net.Interfaces()
  290. if err != nil {
  291. return nil
  292. }
  293. for _, ifi := range ift {
  294. if ifi.Flags&flags != flags {
  295. continue
  296. }
  297. if _, ok := hasRoutableIP(network, &ifi); !ok {
  298. continue
  299. }
  300. return &ifi
  301. }
  302. return nil
  303. }
  304. func getOutboundInterface() *net.Interface {
  305. conn, err := net.Dial("udp", "8.8.8.8:80")
  306. if err != nil {
  307. log.Fatal(err)
  308. }
  309. defer conn.Close()
  310. localAddr := conn.LocalAddr().(*net.UDPAddr)
  311. ipnet := net.IPNet{
  312. IP: localAddr.IP.To4(),
  313. Mask: localAddr.IP.To4().DefaultMask(),
  314. }
  315. return interfaceOfIP(&ipnet)
  316. }
  317. func hasRoutableIP(network string, ifi *net.Interface) (net.IP, bool) {
  318. ifat, err := ifi.Addrs()
  319. if err != nil {
  320. return nil, false
  321. }
  322. for _, ifa := range ifat {
  323. switch ifa := ifa.(type) {
  324. case *net.IPAddr:
  325. if ip := routableIP(network, ifa.IP); ip != nil {
  326. return ip, true
  327. }
  328. case *net.IPNet:
  329. if ip := routableIP(network, ifa.IP); ip != nil {
  330. return ip, true
  331. }
  332. }
  333. }
  334. return nil, false
  335. }
  336. func vpnInterface() *net.Interface {
  337. server, err := GetServerInstance()
  338. if err != nil {
  339. logrus.Errorf("can't get server instance: %v", err)
  340. return nil
  341. }
  342. mask := net.IPMask(net.ParseIP(server.Mask))
  343. prefix := net.ParseIP(server.Net)
  344. netw := prefix.Mask(mask).To4()
  345. netw[3] = byte(1) // Server is always gets xxx.xxx.xxx.1
  346. ipnet := net.IPNet{IP: netw, Mask: mask}
  347. return interfaceOfIP(&ipnet)
  348. }
  349. func routableIP(network string, ip net.IP) net.IP {
  350. if !ip.IsLoopback() && !ip.IsLinkLocalUnicast() && !ip.IsGlobalUnicast() {
  351. return nil
  352. }
  353. switch network {
  354. case "ip4":
  355. if ip := ip.To4(); ip != nil {
  356. return ip
  357. }
  358. case "ip6":
  359. if ip.IsLoopback() { // addressing scope of the loopback address depends on each implementation
  360. return nil
  361. }
  362. if ip := ip.To16(); ip != nil && ip.To4() == nil {
  363. return ip
  364. }
  365. default:
  366. if ip := ip.To4(); ip != nil {
  367. return ip
  368. }
  369. if ip := ip.To16(); ip != nil {
  370. return ip
  371. }
  372. }
  373. return nil
  374. }
  375. // ensureNatEnabled launches a goroutine that constantly tries to enable nat.
  376. func ensureNatEnabled() {
  377. // Nat enablerer
  378. go func() {
  379. for {
  380. err := enableNat()
  381. if err == nil {
  382. logrus.Debug("nat is enabled")
  383. return
  384. }
  385. logrus.Debugf("can not enable nat: %v", err)
  386. // TODO(cad): employ a exponential back-off approach here
  387. // instead of sleeping for the constant duration.
  388. time.Sleep(1 * time.Second)
  389. }
  390. }()
  391. }
  392. // enableNat is an idempotent command that ensures nat is enabled for the vpn server.
  393. func enableNat() error {
  394. if Testing {
  395. return nil
  396. }
  397. // rif := routedInterface("ip", net.FlagUp|net.FlagBroadcast)
  398. // if rif == nil {
  399. // return fmt.Errorf("can not get routable network interface")
  400. // }
  401. rif := getOutboundInterface()
  402. if rif == nil {
  403. return fmt.Errorf("can not get default gw interface")
  404. }
  405. vpnIfc := vpnInterface()
  406. if vpnIfc == nil {
  407. return fmt.Errorf("can not get vpn network interface on the system")
  408. }
  409. // Enable ip forwarding.
  410. emitToFile("/proc/sys/net/ipv4/ip_forward", "1", 0)
  411. ipt, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
  412. if err != nil {
  413. return fmt.Errorf("can not create new iptables object: %v", err)
  414. }
  415. server, err := GetServerInstance()
  416. if err != nil {
  417. logrus.Errorf("can't get server instance: %v", err)
  418. return nil
  419. }
  420. mask := net.IPMask(net.ParseIP(server.Mask))
  421. prefix := net.ParseIP(server.Net)
  422. netw := prefix.Mask(mask).To4()
  423. netw[3] = byte(1) // Server is always gets xxx.xxx.xxx.1
  424. ipnet := net.IPNet{IP: netw, Mask: mask}
  425. // Append iptables nat rules.
  426. if err := ipt.AppendUnique("nat", "POSTROUTING", "-s", ipnet.String(), "-o", rif.Name, "-j", "MASQUERADE"); err != nil {
  427. return err
  428. }
  429. if err := ipt.AppendUnique("filter", "FORWARD", "-i", rif.Name, "-o", vpnIfc.Name, "-m", "state", "--state", "RELATED,ESTABLISHED", "-j", "ACCEPT"); err != nil {
  430. return err
  431. }
  432. if err := ipt.AppendUnique("filter", "FORWARD", "-i", vpnIfc.Name, "-o", rif.Name, "-j", "ACCEPT"); err != nil {
  433. return err
  434. }
  435. return nil
  436. }
  437. // HostID2IP converts a host id (32-bit unsigned integer) to an IP address.
  438. func HostID2IP(hostid uint32) net.IP {
  439. ip := make([]byte, 4)
  440. binary.BigEndian.PutUint32(ip, hostid)
  441. return net.IP(ip)
  442. }
  443. //IP2HostID converts an IP address to a host id (32-bit unsigned integer).
  444. func IP2HostID(ip net.IP) uint32 {
  445. hostid := binary.BigEndian.Uint32(ip)
  446. return hostid
  447. }
  448. // IncrementIP will return next ip address within the network.
  449. func IncrementIP(ip, mask string) (string, error) {
  450. ipAddr := net.ParseIP(ip).To4()
  451. netMask := net.IPMask(net.ParseIP(mask).To4())
  452. ipNet := net.IPNet{IP: ipAddr, Mask: netMask}
  453. for i := len(ipAddr) - 1; i >= 0; i-- {
  454. ipAddr[i]++
  455. if ip[i] != 0 {
  456. break
  457. }
  458. }
  459. if !ipNet.Contains(ipAddr) {
  460. return ip, errors.New("overflowed CIDR while incrementing IP")
  461. }
  462. return ipAddr.String(), nil
  463. }