net.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551
  1. package ovpm
  2. import (
  3. "encoding/binary"
  4. "errors"
  5. "fmt"
  6. "log"
  7. "net"
  8. "time"
  9. "github.com/Sirupsen/logrus"
  10. "github.com/asaskevich/govalidator"
  11. "github.com/coreos/go-iptables/iptables"
  12. "github.com/jinzhu/gorm"
  13. )
  14. // NetworkType distinguishes different types of networks that is defined in the networks table.
  15. type NetworkType uint
  16. // NetworkTypes
  17. const (
  18. UNDEFINEDNET NetworkType = iota
  19. SERVERNET
  20. ROUTE
  21. )
  22. var networkTypes = [...]struct {
  23. Type NetworkType
  24. String string
  25. Description string
  26. }{
  27. {UNDEFINEDNET, "UNDEFINEDNET", "unknown network type"},
  28. {SERVERNET, "SERVERNET", "network behind vpn server"},
  29. {ROUTE, "ROUTE", "network to be pushed as route"},
  30. }
  31. // NetworkTypeFromString returns string representation of the network type.
  32. func NetworkTypeFromString(typ string) NetworkType {
  33. for _, v := range networkTypes {
  34. if v.String == typ {
  35. return v.Type
  36. }
  37. }
  38. return UNDEFINEDNET
  39. }
  40. // GetAllNetworkTypes returns all network types defined in the system.
  41. func GetAllNetworkTypes() []NetworkType {
  42. var networkTypeList []NetworkType
  43. for _, v := range networkTypes {
  44. networkTypeList = append(networkTypeList, v.Type)
  45. }
  46. return networkTypeList
  47. }
  48. func (nt NetworkType) String() string {
  49. for _, v := range networkTypes {
  50. if v.Type == nt {
  51. return v.String
  52. }
  53. }
  54. return "UNDEFINEDNET"
  55. }
  56. // Description gives description about the network type.
  57. func (nt NetworkType) Description() string {
  58. for _, v := range networkTypes {
  59. if v.Type == nt {
  60. return v.Description
  61. }
  62. }
  63. return "UNDEFINEDNET"
  64. }
  65. // dbNetworkModel is database model for external networks on the VPN server.
  66. type dbNetworkModel struct {
  67. gorm.Model
  68. ServerID uint
  69. Server dbServerModel
  70. Name string `gorm:"unique_index"`
  71. CIDR string
  72. Type NetworkType
  73. Via string
  74. Users []*dbUserModel `gorm:"many2many:network_users;"`
  75. }
  76. // Network represents a VPN related network.
  77. type Network struct {
  78. dbNetworkModel
  79. }
  80. // GetNetwork returns a network specified by its name.
  81. func GetNetwork(name string) (*Network, error) {
  82. if !IsInitialized() {
  83. return nil, fmt.Errorf("you first need to create server")
  84. }
  85. // Validate user input.
  86. if govalidator.IsNull(name) {
  87. return nil, fmt.Errorf("validation error: %s can not be null", name)
  88. }
  89. if !govalidator.IsAlphanumeric(name) {
  90. return nil, fmt.Errorf("validation error: `%s` can only contain letters and numbers", name)
  91. }
  92. var network dbNetworkModel
  93. db.Preload("Users").Where(&dbNetworkModel{Name: name}).First(&network)
  94. if db.NewRecord(&network) {
  95. return nil, fmt.Errorf("network not found %s", name)
  96. }
  97. return &Network{dbNetworkModel: network}, nil
  98. }
  99. // GetAllNetworks returns all networks defined in the system.
  100. func GetAllNetworks() []*Network {
  101. var networks []*Network
  102. var dbNetworks []*dbNetworkModel
  103. db.Preload("Users").Find(&dbNetworks)
  104. for _, n := range dbNetworks {
  105. networks = append(networks, &Network{dbNetworkModel: *n})
  106. }
  107. return networks
  108. }
  109. // CreateNewNetwork creates a new network definition in the system.
  110. func CreateNewNetwork(name, cidr string, nettype NetworkType, via string) (*Network, error) {
  111. if !IsInitialized() {
  112. return nil, fmt.Errorf("you first need to create server")
  113. }
  114. // Validate user input.
  115. if govalidator.IsNull(name) {
  116. return nil, fmt.Errorf("validation error: %s can not be null", name)
  117. }
  118. if !govalidator.IsAlphanumeric(name) {
  119. return nil, fmt.Errorf("validation error: `%s` can only contain letters and numbers", name)
  120. }
  121. if !govalidator.IsCIDR(cidr) {
  122. return nil, fmt.Errorf("validation error: `%s` must be a network in the CIDR form", cidr)
  123. }
  124. if via != "" && !govalidator.IsIPv4(via) {
  125. return nil, fmt.Errorf("validation error: `%s` must be a network in the IPv4 form", via)
  126. }
  127. if nettype == UNDEFINEDNET {
  128. return nil, fmt.Errorf("validation error: `%s` must be a valid network type", nettype)
  129. }
  130. _, ipnet, err := net.ParseCIDR(cidr)
  131. if err != nil {
  132. return nil, fmt.Errorf("can not parse CIDR %s: %v", cidr, err)
  133. }
  134. // Overwrite via with the parsed IPv4 string.
  135. if nettype == ROUTE && via != "" {
  136. viaIP := net.ParseIP(via).To4()
  137. if err != nil {
  138. return nil, fmt.Errorf("can not parse IPv4 %s: %v", via, err)
  139. }
  140. via = viaIP.String()
  141. } else {
  142. via = ""
  143. }
  144. network := dbNetworkModel{
  145. Name: name,
  146. CIDR: ipnet.String(),
  147. Type: nettype,
  148. Users: []*dbUserModel{},
  149. Via: via,
  150. }
  151. db.Save(&network)
  152. if db.NewRecord(&network) {
  153. return nil, fmt.Errorf("can not create network in the db")
  154. }
  155. Emit()
  156. logrus.Infof("network defined: %s (%s)", network.Name, network.CIDR)
  157. return &Network{dbNetworkModel: network}, nil
  158. }
  159. // Delete deletes a network definition in the system.
  160. func (n *Network) Delete() error {
  161. if !IsInitialized() {
  162. return fmt.Errorf("you first need to create server")
  163. }
  164. db.Unscoped().Delete(n.dbNetworkModel)
  165. Emit()
  166. logrus.Infof("network deleted: %s", n.Name)
  167. return nil
  168. }
  169. // Associate allows the given user access to this network.
  170. func (n *Network) Associate(username string) error {
  171. if !IsInitialized() {
  172. return fmt.Errorf("you first need to create server")
  173. }
  174. user, err := GetUser(username)
  175. if err != nil {
  176. return fmt.Errorf("user can not be fetched: %v", err)
  177. }
  178. var users []dbUserModel
  179. userAssoc := db.Model(&n.dbNetworkModel).Association("Users")
  180. userAssoc.Find(&users)
  181. var found bool
  182. for _, u := range users {
  183. if u.ID == user.ID {
  184. found = true
  185. break
  186. }
  187. }
  188. if found {
  189. return fmt.Errorf("user %s is already associated with the network %s", user.Username, n.Name)
  190. }
  191. userAssoc.Append(user.dbUserModel)
  192. if userAssoc.Error != nil {
  193. return fmt.Errorf("association failed: %v", userAssoc.Error)
  194. }
  195. Emit()
  196. logrus.Infof("user '%s' is associated with the network '%s'", user.GetUsername(), n.Name)
  197. return nil
  198. }
  199. // Dissociate breaks up the given users association to the said network.
  200. func (n *Network) Dissociate(username string) error {
  201. if !IsInitialized() {
  202. return fmt.Errorf("you first need to create server")
  203. }
  204. user, err := GetUser(username)
  205. if err != nil {
  206. return fmt.Errorf("user can not be fetched: %v", err)
  207. }
  208. var users []dbUserModel
  209. userAssoc := db.Model(&n.dbNetworkModel).Association("Users")
  210. userAssoc.Find(&users)
  211. var found bool
  212. for _, u := range users {
  213. if u.ID == user.ID {
  214. found = true
  215. break
  216. }
  217. }
  218. if !found {
  219. return fmt.Errorf("user %s is already not associated with the network %s", user.Username, n.Name)
  220. }
  221. userAssoc.Delete(user.dbUserModel)
  222. if userAssoc.Error != nil {
  223. return fmt.Errorf("disassociation failed: %v", userAssoc.Error)
  224. }
  225. Emit()
  226. logrus.Infof("user '%s' is dissociated with the network '%s'", user.GetUsername(), n.Name)
  227. return nil
  228. }
  229. // GetName returns network's name.
  230. func (n *Network) GetName() string {
  231. return n.Name
  232. }
  233. // GetCIDR returns network's CIDR.
  234. func (n *Network) GetCIDR() string {
  235. return n.CIDR
  236. }
  237. // GetCreatedAt returns network's name.
  238. func (n *Network) GetCreatedAt() string {
  239. return n.CreatedAt.Format(time.UnixDate)
  240. }
  241. // GetType returns network's network type.
  242. func (n *Network) GetType() NetworkType {
  243. return NetworkType(n.Type)
  244. }
  245. // GetAssociatedUsers returns network's associated users.
  246. func (n *Network) GetAssociatedUsers() []*User {
  247. var users []*User
  248. for _, u := range n.Users {
  249. users = append(users, &User{dbUserModel: *u})
  250. }
  251. return users
  252. }
  253. // GetAssociatedUsernames returns network's associated user names.
  254. func (n *Network) GetAssociatedUsernames() []string {
  255. var usernames []string
  256. for _, user := range n.GetAssociatedUsers() {
  257. usernames = append(usernames, user.Username)
  258. }
  259. return usernames
  260. }
  261. // GetVia returns network' via.
  262. func (n *Network) GetVia() string {
  263. return n.Via
  264. }
  265. // interfaceOfIP returns a network interface that has the given IP.
  266. func interfaceOfIP(ipnet *net.IPNet) *net.Interface {
  267. ifaces, err := net.Interfaces()
  268. if err != nil {
  269. return nil
  270. }
  271. for _, iface := range ifaces {
  272. addrs, err := iface.Addrs()
  273. if err != nil {
  274. logrus.Error(err)
  275. return nil
  276. }
  277. for _, addr := range addrs {
  278. switch addr := addr.(type) {
  279. case *net.IPAddr:
  280. if ip := addr.IP; ip != nil {
  281. if ipnet.Contains(ip) {
  282. return &iface
  283. }
  284. }
  285. case *net.IPNet:
  286. if ip := addr.IP; ip != nil {
  287. if ipnet.Contains(ip) {
  288. return &iface
  289. }
  290. }
  291. }
  292. }
  293. }
  294. return nil
  295. }
  296. func routedInterface(network string, flags net.Flags) *net.Interface {
  297. switch network {
  298. case "ip", "ip4", "ip6":
  299. default:
  300. return nil
  301. }
  302. ift, err := net.Interfaces()
  303. if err != nil {
  304. return nil
  305. }
  306. for _, ifi := range ift {
  307. if ifi.Flags&flags != flags {
  308. continue
  309. }
  310. if _, ok := hasRoutableIP(network, &ifi); !ok {
  311. continue
  312. }
  313. return &ifi
  314. }
  315. return nil
  316. }
  317. func getOutboundInterface() *net.Interface {
  318. conn, err := net.Dial("udp", "8.8.8.8:80")
  319. if err != nil {
  320. log.Fatal(err)
  321. }
  322. defer conn.Close()
  323. localAddr := conn.LocalAddr().(*net.UDPAddr)
  324. ipnet := net.IPNet{
  325. IP: localAddr.IP.To4(),
  326. Mask: localAddr.IP.To4().DefaultMask(),
  327. }
  328. return interfaceOfIP(&ipnet)
  329. }
  330. func hasRoutableIP(network string, ifi *net.Interface) (net.IP, bool) {
  331. ifat, err := ifi.Addrs()
  332. if err != nil {
  333. return nil, false
  334. }
  335. for _, ifa := range ifat {
  336. switch ifa := ifa.(type) {
  337. case *net.IPAddr:
  338. if ip := routableIP(network, ifa.IP); ip != nil {
  339. return ip, true
  340. }
  341. case *net.IPNet:
  342. if ip := routableIP(network, ifa.IP); ip != nil {
  343. return ip, true
  344. }
  345. }
  346. }
  347. return nil, false
  348. }
  349. func vpnInterface() *net.Interface {
  350. server, err := GetServerInstance()
  351. if err != nil {
  352. logrus.Errorf("can't get server instance: %v", err)
  353. return nil
  354. }
  355. mask := net.IPMask(net.ParseIP(server.Mask))
  356. prefix := net.ParseIP(server.Net)
  357. netw := prefix.Mask(mask).To4()
  358. netw[3] = byte(1) // Server is always gets xxx.xxx.xxx.1
  359. ipnet := net.IPNet{IP: netw, Mask: mask}
  360. return interfaceOfIP(&ipnet)
  361. }
  362. func routableIP(network string, ip net.IP) net.IP {
  363. if !ip.IsLoopback() && !ip.IsLinkLocalUnicast() && !ip.IsGlobalUnicast() {
  364. return nil
  365. }
  366. switch network {
  367. case "ip4":
  368. if ip := ip.To4(); ip != nil {
  369. return ip
  370. }
  371. case "ip6":
  372. if ip.IsLoopback() { // addressing scope of the loopback address depends on each implementation
  373. return nil
  374. }
  375. if ip := ip.To16(); ip != nil && ip.To4() == nil {
  376. return ip
  377. }
  378. default:
  379. if ip := ip.To4(); ip != nil {
  380. return ip
  381. }
  382. if ip := ip.To16(); ip != nil {
  383. return ip
  384. }
  385. }
  386. return nil
  387. }
  388. // ensureNatEnabled launches a goroutine that constantly tries to enable nat.
  389. func ensureNatEnabled() {
  390. // Nat enablerer
  391. go func() {
  392. for {
  393. err := enableNat()
  394. if err == nil {
  395. logrus.Debug("nat is enabled")
  396. return
  397. }
  398. logrus.Debugf("can not enable nat: %v", err)
  399. // TODO(cad): employ a exponential back-off approach here
  400. // instead of sleeping for the constant duration.
  401. time.Sleep(1 * time.Second)
  402. }
  403. }()
  404. }
  405. // enableNat is an idempotent command that ensures nat is enabled for the vpn server.
  406. func enableNat() error {
  407. if Testing {
  408. return nil
  409. }
  410. // rif := routedInterface("ip", net.FlagUp|net.FlagBroadcast)
  411. // if rif == nil {
  412. // return fmt.Errorf("can not get routable network interface")
  413. // }
  414. rif := getOutboundInterface()
  415. if rif == nil {
  416. return fmt.Errorf("can not get default gw interface")
  417. }
  418. vpnIfc := vpnInterface()
  419. if vpnIfc == nil {
  420. return fmt.Errorf("can not get vpn network interface on the system")
  421. }
  422. // Enable ip forwarding.
  423. emitToFile("/proc/sys/net/ipv4/ip_forward", "1", 0)
  424. ipt, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
  425. if err != nil {
  426. return fmt.Errorf("can not create new iptables object: %v", err)
  427. }
  428. server, err := GetServerInstance()
  429. if err != nil {
  430. logrus.Errorf("can't get server instance: %v", err)
  431. return nil
  432. }
  433. mask := net.IPMask(net.ParseIP(server.Mask))
  434. prefix := net.ParseIP(server.Net)
  435. netw := prefix.Mask(mask).To4()
  436. netw[3] = byte(1) // Server is always gets xxx.xxx.xxx.1
  437. ipnet := net.IPNet{IP: netw, Mask: mask}
  438. // Append iptables nat rules.
  439. if err := ipt.AppendUnique("nat", "POSTROUTING", "-s", ipnet.String(), "-o", rif.Name, "-j", "MASQUERADE"); err != nil {
  440. return err
  441. }
  442. if err := ipt.AppendUnique("filter", "FORWARD", "-i", rif.Name, "-o", vpnIfc.Name, "-m", "state", "--state", "RELATED,ESTABLISHED", "-j", "ACCEPT"); err != nil {
  443. return err
  444. }
  445. if err := ipt.AppendUnique("filter", "FORWARD", "-i", vpnIfc.Name, "-o", rif.Name, "-j", "ACCEPT"); err != nil {
  446. return err
  447. }
  448. return nil
  449. }
  450. // HostID2IP converts a host id (32-bit unsigned integer) to an IP address.
  451. func HostID2IP(hostid uint32) net.IP {
  452. ip := make([]byte, 4)
  453. binary.BigEndian.PutUint32(ip, hostid)
  454. return net.IP(ip)
  455. }
  456. // IP2HostID converts an IP address to a host id (32-bit unsigned integer).
  457. func IP2HostID(ip net.IP) uint32 {
  458. hostid := binary.BigEndian.Uint32(ip)
  459. return hostid
  460. }
  461. // IncrementIP will return next ip address within the network.
  462. func IncrementIP(ip, mask string) (string, error) {
  463. if !govalidator.IsIPv4(ip) {
  464. return "", fmt.Errorf("'ip' is expected to be a valid IPv4 %s", ip)
  465. }
  466. if !govalidator.IsIPv4(ip) {
  467. return "", fmt.Errorf("'mask' is expected to be a valid IPv4 %s", mask)
  468. }
  469. ipAddr := net.ParseIP(ip).To4()
  470. netMask := net.IPMask(net.ParseIP(mask).To4())
  471. ipNet := net.IPNet{IP: ipAddr, Mask: netMask}
  472. for i := len(ipAddr) - 1; i >= 0; i-- {
  473. ipAddr[i]++
  474. if ip[i] != 0 {
  475. break
  476. }
  477. }
  478. if !ipNet.Contains(ipAddr) {
  479. return ip, errors.New("overflowed CIDR while incrementing IP")
  480. }
  481. return ipAddr.String(), nil
  482. }