net.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558
  1. package ovpm
  2. import (
  3. "encoding/binary"
  4. "errors"
  5. "fmt"
  6. "log"
  7. "net"
  8. "time"
  9. "github.com/Sirupsen/logrus"
  10. "github.com/asaskevich/govalidator"
  11. "github.com/coreos/go-iptables/iptables"
  12. "github.com/jinzhu/gorm"
  13. )
  14. // NetworkType distinguishes different types of networks that is defined in the networks table.
  15. type NetworkType uint
  16. // NetworkTypes
  17. const (
  18. UNDEFINEDNET NetworkType = iota
  19. SERVERNET
  20. ROUTE
  21. )
  22. var networkTypes = [...]struct {
  23. Type NetworkType
  24. String string
  25. Description string
  26. }{
  27. {UNDEFINEDNET, "UNDEFINEDNET", "unknown network type"},
  28. {SERVERNET, "SERVERNET", "network behind vpn server"},
  29. {ROUTE, "ROUTE", "network to be pushed as route"},
  30. }
  31. // NetworkTypeFromString returns string representation of the network type.
  32. func NetworkTypeFromString(typ string) NetworkType {
  33. for _, v := range networkTypes {
  34. if v.String == typ {
  35. return v.Type
  36. }
  37. }
  38. return UNDEFINEDNET
  39. }
  40. // GetAllNetworkTypes returns all network types defined in the system.
  41. func GetAllNetworkTypes() []NetworkType {
  42. var networkTypeList []NetworkType
  43. for _, v := range networkTypes {
  44. networkTypeList = append(networkTypeList, v.Type)
  45. }
  46. return networkTypeList
  47. }
  48. func (nt NetworkType) String() string {
  49. for _, v := range networkTypes {
  50. if v.Type == nt {
  51. return v.String
  52. }
  53. }
  54. return "UNDEFINEDNET"
  55. }
  56. // Description gives description about the network type.
  57. func (nt NetworkType) Description() string {
  58. for _, v := range networkTypes {
  59. if v.Type == nt {
  60. return v.Description
  61. }
  62. }
  63. return "UNDEFINEDNET"
  64. }
  65. // IsNetworkType returns if the s is a valid network type or not.
  66. func IsNetworkType(s string) bool {
  67. for _, v := range networkTypes {
  68. if s == v.String {
  69. return true
  70. }
  71. }
  72. return false
  73. }
  74. // dbNetworkModel is database model for external networks on the VPN server.
  75. type dbNetworkModel struct {
  76. gorm.Model
  77. ServerID uint
  78. Server dbServerModel
  79. Name string `gorm:"unique_index"`
  80. CIDR string
  81. Type NetworkType
  82. Via string
  83. Users []*dbUserModel `gorm:"many2many:network_users;"`
  84. }
  85. // Network represents a VPN related network.
  86. type Network struct {
  87. dbNetworkModel
  88. }
  89. // GetNetwork returns a network specified by its name.
  90. func GetNetwork(name string) (*Network, error) {
  91. if svr := TheServer(); !svr.IsInitialized() {
  92. return nil, fmt.Errorf("you first need to create server")
  93. }
  94. // Validate user input.
  95. if govalidator.IsNull(name) {
  96. return nil, fmt.Errorf("validation error: %s can not be null", name)
  97. }
  98. var network dbNetworkModel
  99. db.Preload("Users").Where(&dbNetworkModel{Name: name}).First(&network)
  100. if db.NewRecord(&network) {
  101. return nil, fmt.Errorf("network not found %s", name)
  102. }
  103. return &Network{dbNetworkModel: network}, nil
  104. }
  105. // GetAllNetworks returns all networks defined in the system.
  106. func GetAllNetworks() []*Network {
  107. var networks []*Network
  108. var dbNetworks []*dbNetworkModel
  109. db.Preload("Users").Find(&dbNetworks)
  110. for _, n := range dbNetworks {
  111. networks = append(networks, &Network{dbNetworkModel: *n})
  112. }
  113. return networks
  114. }
  115. // CreateNewNetwork creates a new network definition in the system.
  116. func CreateNewNetwork(name, cidr string, nettype NetworkType, via string) (*Network, error) {
  117. if svr := TheServer(); !svr.IsInitialized() {
  118. return nil, fmt.Errorf("you first need to create server")
  119. }
  120. // Validate user input.
  121. if govalidator.IsNull(name) {
  122. return nil, fmt.Errorf("validation error: %s can not be null", name)
  123. }
  124. if !govalidator.Matches(name, "^([\\w\\.]+)$") { // allow alphanumeric, underscore and dot
  125. return nil, fmt.Errorf("validation error: `%s` can only contain letters, numbers, underscores and dots", name)
  126. }
  127. if !govalidator.IsCIDR(cidr) {
  128. return nil, fmt.Errorf("validation error: `%s` must be a network in the CIDR form", cidr)
  129. }
  130. if via != "" && !govalidator.IsIPv4(via) {
  131. return nil, fmt.Errorf("validation error: `%s` must be a network in the IPv4 form", via)
  132. }
  133. if nettype == UNDEFINEDNET {
  134. return nil, fmt.Errorf("validation error: `%s` must be a valid network type", nettype)
  135. }
  136. _, ipnet, err := net.ParseCIDR(cidr)
  137. if err != nil {
  138. return nil, fmt.Errorf("can not parse CIDR %s: %v", cidr, err)
  139. }
  140. // Overwrite via with the parsed IPv4 string.
  141. if nettype == ROUTE && via != "" {
  142. viaIP := net.ParseIP(via).To4()
  143. if err != nil {
  144. return nil, fmt.Errorf("can not parse IPv4 %s: %v", via, err)
  145. }
  146. via = viaIP.String()
  147. } else {
  148. via = ""
  149. }
  150. network := dbNetworkModel{
  151. Name: name,
  152. CIDR: ipnet.String(),
  153. Type: nettype,
  154. Users: []*dbUserModel{},
  155. Via: via,
  156. }
  157. db.Save(&network)
  158. if db.NewRecord(&network) {
  159. return nil, fmt.Errorf("can not create network in the db")
  160. }
  161. TheServer().EmitWithRestart()
  162. logrus.Infof("network defined: %s (%s)", network.Name, network.CIDR)
  163. return &Network{dbNetworkModel: network}, nil
  164. }
  165. // Delete deletes a network definition in the system.
  166. func (n *Network) Delete() error {
  167. svr := TheServer()
  168. if !svr.IsInitialized() {
  169. return fmt.Errorf("you first need to create server")
  170. }
  171. db.Unscoped().Delete(n.dbNetworkModel)
  172. svr.EmitWithRestart()
  173. logrus.Infof("network deleted: %s", n.Name)
  174. return nil
  175. }
  176. // Associate allows the given user access to this network.
  177. func (n *Network) Associate(username string) error {
  178. if svr := TheServer(); !svr.IsInitialized() {
  179. return fmt.Errorf("you first need to create server")
  180. }
  181. user, err := GetUser(username)
  182. if err != nil {
  183. return fmt.Errorf("user can not be fetched: %v", err)
  184. }
  185. var users []dbUserModel
  186. userAssoc := db.Model(&n.dbNetworkModel).Association("Users")
  187. userAssoc.Find(&users)
  188. var found bool
  189. for _, u := range users {
  190. if u.ID == user.ID {
  191. found = true
  192. break
  193. }
  194. }
  195. if found {
  196. return fmt.Errorf("user %s is already associated with the network %s", user.Username, n.Name)
  197. }
  198. userAssoc.Append(user.dbUserModel)
  199. if userAssoc.Error != nil {
  200. return fmt.Errorf("association failed: %v", userAssoc.Error)
  201. }
  202. TheServer().EmitWithRestart()
  203. logrus.Infof("user '%s' is associated with the network '%s'", user.GetUsername(), n.Name)
  204. return nil
  205. }
  206. // Dissociate breaks up the given users association to the said network.
  207. func (n *Network) Dissociate(username string) error {
  208. svr := TheServer()
  209. if !svr.IsInitialized() {
  210. return fmt.Errorf("you first need to create server")
  211. }
  212. user, err := GetUser(username)
  213. if err != nil {
  214. return fmt.Errorf("user can not be fetched: %v", err)
  215. }
  216. var users []dbUserModel
  217. userAssoc := db.Model(&n.dbNetworkModel).Association("Users")
  218. userAssoc.Find(&users)
  219. var found bool
  220. for _, u := range users {
  221. if u.ID == user.ID {
  222. found = true
  223. break
  224. }
  225. }
  226. if !found {
  227. return fmt.Errorf("user %s is already not associated with the network %s", user.Username, n.Name)
  228. }
  229. userAssoc.Delete(user.dbUserModel)
  230. if userAssoc.Error != nil {
  231. return fmt.Errorf("disassociation failed: %v", userAssoc.Error)
  232. }
  233. svr.EmitWithRestart()
  234. logrus.Infof("user '%s' is dissociated with the network '%s'", user.GetUsername(), n.Name)
  235. return nil
  236. }
  237. // GetName returns network's name.
  238. func (n *Network) GetName() string {
  239. return n.Name
  240. }
  241. // GetCIDR returns network's CIDR.
  242. func (n *Network) GetCIDR() string {
  243. return n.CIDR
  244. }
  245. // GetCreatedAt returns network's name.
  246. func (n *Network) GetCreatedAt() string {
  247. return n.CreatedAt.Format(time.UnixDate)
  248. }
  249. // GetType returns network's network type.
  250. func (n *Network) GetType() NetworkType {
  251. return NetworkType(n.Type)
  252. }
  253. // GetAssociatedUsers returns network's associated users.
  254. func (n *Network) GetAssociatedUsers() []*User {
  255. var users []*User
  256. for _, u := range n.Users {
  257. users = append(users, &User{dbUserModel: *u})
  258. }
  259. return users
  260. }
  261. // GetAssociatedUsernames returns network's associated user names.
  262. func (n *Network) GetAssociatedUsernames() []string {
  263. var usernames []string
  264. for _, user := range n.GetAssociatedUsers() {
  265. usernames = append(usernames, user.Username)
  266. }
  267. return usernames
  268. }
  269. // GetVia returns network' via.
  270. func (n *Network) GetVia() string {
  271. return n.Via
  272. }
  273. // interfaceOfIP returns a network interface that has the given IP.
  274. func interfaceOfIP(ipnet *net.IPNet) *net.Interface {
  275. ifaces, err := net.Interfaces()
  276. if err != nil {
  277. return nil
  278. }
  279. for _, iface := range ifaces {
  280. addrs, err := iface.Addrs()
  281. if err != nil {
  282. logrus.Error(err)
  283. return nil
  284. }
  285. for _, addr := range addrs {
  286. switch addr := addr.(type) {
  287. case *net.IPAddr:
  288. if ip := addr.IP; ip != nil {
  289. if ipnet.Contains(ip) {
  290. return &iface
  291. }
  292. }
  293. case *net.IPNet:
  294. if ip := addr.IP; ip != nil {
  295. if ipnet.Contains(ip) {
  296. return &iface
  297. }
  298. }
  299. }
  300. }
  301. }
  302. return nil
  303. }
  304. // routedInterface returns the interface who has a routable IP address on it.
  305. func routedInterface(network string, flags net.Flags) *net.Interface {
  306. switch network {
  307. case "ip", "ip4", "ip6":
  308. default:
  309. return nil
  310. }
  311. ift, err := net.Interfaces()
  312. if err != nil {
  313. return nil
  314. }
  315. for _, ifi := range ift {
  316. if ifi.Flags&flags != flags {
  317. continue
  318. }
  319. if _, ok := hasRoutableIP(network, &ifi); !ok {
  320. continue
  321. }
  322. return &ifi
  323. }
  324. return nil
  325. }
  326. // getOutboundInterface will return the outbound interface if there is one.
  327. func getOutboundInterface() *net.Interface {
  328. conn, err := net.Dial("udp", "8.8.8.8:80")
  329. if err != nil {
  330. log.Fatal(err)
  331. }
  332. defer conn.Close()
  333. localAddr := conn.LocalAddr().(*net.UDPAddr)
  334. ipnet := net.IPNet{
  335. IP: localAddr.IP.To4(),
  336. Mask: localAddr.IP.To4().DefaultMask(),
  337. }
  338. return interfaceOfIP(&ipnet)
  339. }
  340. // hasRoutableIP returns if the received interface has a routable IP
  341. // address attached on it.
  342. func hasRoutableIP(network string, ifi *net.Interface) (net.IP, bool) {
  343. ifat, err := ifi.Addrs()
  344. if err != nil {
  345. return nil, false
  346. }
  347. for _, ifa := range ifat {
  348. switch ifa := ifa.(type) {
  349. case *net.IPAddr:
  350. if ip := routableIP(network, ifa.IP); ip != nil {
  351. return ip, true
  352. }
  353. case *net.IPNet:
  354. if ip := routableIP(network, ifa.IP); ip != nil {
  355. return ip, true
  356. }
  357. }
  358. }
  359. return nil, false
  360. }
  361. // vpnInterface returns the interface which belongs to the VPN server.
  362. func vpnInterface() *net.Interface {
  363. svr := TheServer()
  364. mask := net.IPMask(net.ParseIP(svr.Mask))
  365. prefix := net.ParseIP(svr.Net)
  366. netw := prefix.Mask(mask).To4()
  367. netw[3] = byte(1) // Server is always gets xxx.xxx.xxx.1
  368. ipnet := net.IPNet{IP: netw, Mask: mask}
  369. return interfaceOfIP(&ipnet)
  370. }
  371. // routableIP returns if the received IP is routable.
  372. func routableIP(network string, ip net.IP) net.IP {
  373. if !ip.IsLoopback() && !ip.IsLinkLocalUnicast() && !ip.IsGlobalUnicast() {
  374. return nil
  375. }
  376. switch network {
  377. case "ip4":
  378. if ip := ip.To4(); ip != nil {
  379. return ip
  380. }
  381. case "ip6":
  382. if ip.IsLoopback() { // addressing scope of the loopback address depends on each implementation
  383. return nil
  384. }
  385. if ip := ip.To16(); ip != nil && ip.To4() == nil {
  386. return ip
  387. }
  388. default:
  389. if ip := ip.To4(); ip != nil {
  390. return ip
  391. }
  392. if ip := ip.To16(); ip != nil {
  393. return ip
  394. }
  395. }
  396. return nil
  397. }
  398. // ensureNatEnabled launches a goroutine that constantly tries to enable nat.
  399. func ensureNatEnabled() {
  400. // Nat enablerer
  401. go func() {
  402. for {
  403. err := enableNat()
  404. if err == nil {
  405. logrus.Debug("nat is enabled")
  406. return
  407. }
  408. logrus.Debugf("can not enable nat: %v", err)
  409. // TODO(cad): employ a exponential back-off approach here
  410. // instead of sleeping for the constant duration.
  411. time.Sleep(1 * time.Second)
  412. }
  413. }()
  414. }
  415. // enableNat is an idempotent command that ensures nat is enabled for the vpn server.
  416. func enableNat() error {
  417. if Testing {
  418. return nil
  419. }
  420. // rif := routedInterface("ip", net.FlagUp|net.FlagBroadcast)
  421. // if rif == nil {
  422. // return fmt.Errorf("can not get routable network interface")
  423. // }
  424. rif := getOutboundInterface()
  425. if rif == nil {
  426. return fmt.Errorf("can not get default gw interface")
  427. }
  428. vpnIfc := vpnInterface()
  429. if vpnIfc == nil {
  430. return fmt.Errorf("can not get vpn network interface on the system")
  431. }
  432. // Enable ip forwarding.
  433. TheServer().emitToFile("/proc/sys/net/ipv4/ip_forward", "1", 0)
  434. ipt, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
  435. if err != nil {
  436. return fmt.Errorf("can not create new iptables object: %v", err)
  437. }
  438. svr := TheServer()
  439. mask := net.IPMask(net.ParseIP(svr.Mask))
  440. prefix := net.ParseIP(svr.Net)
  441. netw := prefix.Mask(mask).To4()
  442. netw[3] = byte(1) // Server is always gets xxx.xxx.xxx.1
  443. ipnet := net.IPNet{IP: netw, Mask: mask}
  444. // Append iptables nat rules.
  445. if err := ipt.AppendUnique("nat", "POSTROUTING", "-s", ipnet.String(), "-o", rif.Name, "-j", "MASQUERADE"); err != nil {
  446. return err
  447. }
  448. if err := ipt.AppendUnique("filter", "FORWARD", "-i", rif.Name, "-o", vpnIfc.Name, "-m", "state", "--state", "RELATED,ESTABLISHED", "-j", "ACCEPT"); err != nil {
  449. return err
  450. }
  451. if err := ipt.AppendUnique("filter", "FORWARD", "-i", vpnIfc.Name, "-o", rif.Name, "-j", "ACCEPT"); err != nil {
  452. return err
  453. }
  454. return nil
  455. }
  456. // HostID2IP converts a host id (32-bit unsigned integer) to an IP address.
  457. func HostID2IP(hostid uint32) net.IP {
  458. ip := make([]byte, 4)
  459. binary.BigEndian.PutUint32(ip, hostid)
  460. return net.IP(ip).To4()
  461. }
  462. // IP2HostID converts an IP address to a host id (32-bit unsigned integer).
  463. func IP2HostID(ip net.IP) uint32 {
  464. hostid := binary.BigEndian.Uint32(ip.To4())
  465. return hostid
  466. }
  467. // IncrementIP will return next ip address within the network.
  468. func IncrementIP(ip, mask string) (string, error) {
  469. if !govalidator.IsIPv4(ip) {
  470. return "", fmt.Errorf("'ip' is expected to be a valid IPv4 %s", ip)
  471. }
  472. if !govalidator.IsIPv4(ip) {
  473. return "", fmt.Errorf("'mask' is expected to be a valid IPv4 %s", mask)
  474. }
  475. ipAddr := net.ParseIP(ip).To4()
  476. netMask := net.IPMask(net.ParseIP(mask).To4())
  477. ipNet := net.IPNet{IP: ipAddr, Mask: netMask}
  478. for i := len(ipAddr) - 1; i >= 0; i-- {
  479. ipAddr[i]++
  480. if ip[i] != 0 {
  481. break
  482. }
  483. }
  484. if !ipNet.Contains(ipAddr) {
  485. return ip, errors.New("overflowed CIDR while incrementing IP")
  486. }
  487. return ipAddr.String(), nil
  488. }