net.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547
  1. package ovpm
  2. import (
  3. "encoding/binary"
  4. "errors"
  5. "fmt"
  6. "log"
  7. "net"
  8. "time"
  9. "github.com/Sirupsen/logrus"
  10. "github.com/asaskevich/govalidator"
  11. "github.com/coreos/go-iptables/iptables"
  12. "github.com/jinzhu/gorm"
  13. )
  14. // NetworkType distinguishes different types of networks that is defined in the networks table.
  15. type NetworkType uint
  16. // NetworkTypes
  17. const (
  18. UNDEFINEDNET NetworkType = iota
  19. SERVERNET
  20. ROUTE
  21. )
  22. var networkTypes = [...]struct {
  23. Type NetworkType
  24. String string
  25. Description string
  26. }{
  27. {UNDEFINEDNET, "UNDEFINEDNET", "unknown network type"},
  28. {SERVERNET, "SERVERNET", "network behind vpn server"},
  29. {ROUTE, "ROUTE", "network to be pushed as route"},
  30. }
  31. // NetworkTypeFromString returns string representation of the network type.
  32. func NetworkTypeFromString(typ string) NetworkType {
  33. for _, v := range networkTypes {
  34. if v.String == typ {
  35. return v.Type
  36. }
  37. }
  38. return UNDEFINEDNET
  39. }
  40. // GetAllNetworkTypes returns all network types defined in the system.
  41. func GetAllNetworkTypes() []NetworkType {
  42. var networkTypeList []NetworkType
  43. for _, v := range networkTypes {
  44. networkTypeList = append(networkTypeList, v.Type)
  45. }
  46. return networkTypeList
  47. }
  48. func (nt NetworkType) String() string {
  49. for _, v := range networkTypes {
  50. if v.Type == nt {
  51. return v.String
  52. }
  53. }
  54. return "UNDEFINEDNET"
  55. }
  56. // Description gives description about the network type.
  57. func (nt NetworkType) Description() string {
  58. for _, v := range networkTypes {
  59. if v.Type == nt {
  60. return v.Description
  61. }
  62. }
  63. return "UNDEFINEDNET"
  64. }
  65. // dbNetworkModel is database model for external networks on the VPN server.
  66. type dbNetworkModel struct {
  67. gorm.Model
  68. ServerID uint
  69. Server dbServerModel
  70. Name string `gorm:"unique_index"`
  71. CIDR string
  72. Type NetworkType
  73. Via string
  74. Users []*dbUserModel `gorm:"many2many:network_users;"`
  75. }
  76. // Network represents a VPN related network.
  77. type Network struct {
  78. dbNetworkModel
  79. }
  80. // GetNetwork returns a network specified by its name.
  81. func GetNetwork(name string) (*Network, error) {
  82. if !IsInitialized() {
  83. return nil, fmt.Errorf("you first need to create server")
  84. }
  85. // Validate user input.
  86. if govalidator.IsNull(name) {
  87. return nil, fmt.Errorf("validation error: %s can not be null", name)
  88. }
  89. var network dbNetworkModel
  90. db.Preload("Users").Where(&dbNetworkModel{Name: name}).First(&network)
  91. if db.NewRecord(&network) {
  92. return nil, fmt.Errorf("network not found %s", name)
  93. }
  94. return &Network{dbNetworkModel: network}, nil
  95. }
  96. // GetAllNetworks returns all networks defined in the system.
  97. func GetAllNetworks() []*Network {
  98. var networks []*Network
  99. var dbNetworks []*dbNetworkModel
  100. db.Preload("Users").Find(&dbNetworks)
  101. for _, n := range dbNetworks {
  102. networks = append(networks, &Network{dbNetworkModel: *n})
  103. }
  104. return networks
  105. }
  106. // CreateNewNetwork creates a new network definition in the system.
  107. func CreateNewNetwork(name, cidr string, nettype NetworkType, via string) (*Network, error) {
  108. if !IsInitialized() {
  109. return nil, fmt.Errorf("you first need to create server")
  110. }
  111. // Validate user input.
  112. if govalidator.IsNull(name) {
  113. return nil, fmt.Errorf("validation error: %s can not be null", name)
  114. }
  115. if !govalidator.Matches(name, "^([\\w\\.]+)$") { // allow alphanumeric, underscore and dot
  116. return nil, fmt.Errorf("validation error: `%s` can only contain letters, numbers, underscores and dots", name)
  117. }
  118. if !govalidator.IsCIDR(cidr) {
  119. return nil, fmt.Errorf("validation error: `%s` must be a network in the CIDR form", cidr)
  120. }
  121. if via != "" && !govalidator.IsIPv4(via) {
  122. return nil, fmt.Errorf("validation error: `%s` must be a network in the IPv4 form", via)
  123. }
  124. if nettype == UNDEFINEDNET {
  125. return nil, fmt.Errorf("validation error: `%s` must be a valid network type", nettype)
  126. }
  127. _, ipnet, err := net.ParseCIDR(cidr)
  128. if err != nil {
  129. return nil, fmt.Errorf("can not parse CIDR %s: %v", cidr, err)
  130. }
  131. // Overwrite via with the parsed IPv4 string.
  132. if nettype == ROUTE && via != "" {
  133. viaIP := net.ParseIP(via).To4()
  134. if err != nil {
  135. return nil, fmt.Errorf("can not parse IPv4 %s: %v", via, err)
  136. }
  137. via = viaIP.String()
  138. } else {
  139. via = ""
  140. }
  141. network := dbNetworkModel{
  142. Name: name,
  143. CIDR: ipnet.String(),
  144. Type: nettype,
  145. Users: []*dbUserModel{},
  146. Via: via,
  147. }
  148. db.Save(&network)
  149. if db.NewRecord(&network) {
  150. return nil, fmt.Errorf("can not create network in the db")
  151. }
  152. Emit()
  153. logrus.Infof("network defined: %s (%s)", network.Name, network.CIDR)
  154. return &Network{dbNetworkModel: network}, nil
  155. }
  156. // Delete deletes a network definition in the system.
  157. func (n *Network) Delete() error {
  158. if !IsInitialized() {
  159. return fmt.Errorf("you first need to create server")
  160. }
  161. db.Unscoped().Delete(n.dbNetworkModel)
  162. Emit()
  163. logrus.Infof("network deleted: %s", n.Name)
  164. return nil
  165. }
  166. // Associate allows the given user access to this network.
  167. func (n *Network) Associate(username string) error {
  168. if !IsInitialized() {
  169. return fmt.Errorf("you first need to create server")
  170. }
  171. user, err := GetUser(username)
  172. if err != nil {
  173. return fmt.Errorf("user can not be fetched: %v", err)
  174. }
  175. var users []dbUserModel
  176. userAssoc := db.Model(&n.dbNetworkModel).Association("Users")
  177. userAssoc.Find(&users)
  178. var found bool
  179. for _, u := range users {
  180. if u.ID == user.ID {
  181. found = true
  182. break
  183. }
  184. }
  185. if found {
  186. return fmt.Errorf("user %s is already associated with the network %s", user.Username, n.Name)
  187. }
  188. userAssoc.Append(user.dbUserModel)
  189. if userAssoc.Error != nil {
  190. return fmt.Errorf("association failed: %v", userAssoc.Error)
  191. }
  192. Emit()
  193. logrus.Infof("user '%s' is associated with the network '%s'", user.GetUsername(), n.Name)
  194. return nil
  195. }
  196. // Dissociate breaks up the given users association to the said network.
  197. func (n *Network) Dissociate(username string) error {
  198. if !IsInitialized() {
  199. return fmt.Errorf("you first need to create server")
  200. }
  201. user, err := GetUser(username)
  202. if err != nil {
  203. return fmt.Errorf("user can not be fetched: %v", err)
  204. }
  205. var users []dbUserModel
  206. userAssoc := db.Model(&n.dbNetworkModel).Association("Users")
  207. userAssoc.Find(&users)
  208. var found bool
  209. for _, u := range users {
  210. if u.ID == user.ID {
  211. found = true
  212. break
  213. }
  214. }
  215. if !found {
  216. return fmt.Errorf("user %s is already not associated with the network %s", user.Username, n.Name)
  217. }
  218. userAssoc.Delete(user.dbUserModel)
  219. if userAssoc.Error != nil {
  220. return fmt.Errorf("disassociation failed: %v", userAssoc.Error)
  221. }
  222. Emit()
  223. logrus.Infof("user '%s' is dissociated with the network '%s'", user.GetUsername(), n.Name)
  224. return nil
  225. }
  226. // GetName returns network's name.
  227. func (n *Network) GetName() string {
  228. return n.Name
  229. }
  230. // GetCIDR returns network's CIDR.
  231. func (n *Network) GetCIDR() string {
  232. return n.CIDR
  233. }
  234. // GetCreatedAt returns network's name.
  235. func (n *Network) GetCreatedAt() string {
  236. return n.CreatedAt.Format(time.UnixDate)
  237. }
  238. // GetType returns network's network type.
  239. func (n *Network) GetType() NetworkType {
  240. return NetworkType(n.Type)
  241. }
  242. // GetAssociatedUsers returns network's associated users.
  243. func (n *Network) GetAssociatedUsers() []*User {
  244. var users []*User
  245. for _, u := range n.Users {
  246. users = append(users, &User{dbUserModel: *u})
  247. }
  248. return users
  249. }
  250. // GetAssociatedUsernames returns network's associated user names.
  251. func (n *Network) GetAssociatedUsernames() []string {
  252. var usernames []string
  253. for _, user := range n.GetAssociatedUsers() {
  254. usernames = append(usernames, user.Username)
  255. }
  256. return usernames
  257. }
  258. // GetVia returns network' via.
  259. func (n *Network) GetVia() string {
  260. return n.Via
  261. }
  262. // interfaceOfIP returns a network interface that has the given IP.
  263. func interfaceOfIP(ipnet *net.IPNet) *net.Interface {
  264. ifaces, err := net.Interfaces()
  265. if err != nil {
  266. return nil
  267. }
  268. for _, iface := range ifaces {
  269. addrs, err := iface.Addrs()
  270. if err != nil {
  271. logrus.Error(err)
  272. return nil
  273. }
  274. for _, addr := range addrs {
  275. switch addr := addr.(type) {
  276. case *net.IPAddr:
  277. if ip := addr.IP; ip != nil {
  278. if ipnet.Contains(ip) {
  279. return &iface
  280. }
  281. }
  282. case *net.IPNet:
  283. if ip := addr.IP; ip != nil {
  284. if ipnet.Contains(ip) {
  285. return &iface
  286. }
  287. }
  288. }
  289. }
  290. }
  291. return nil
  292. }
  293. func routedInterface(network string, flags net.Flags) *net.Interface {
  294. switch network {
  295. case "ip", "ip4", "ip6":
  296. default:
  297. return nil
  298. }
  299. ift, err := net.Interfaces()
  300. if err != nil {
  301. return nil
  302. }
  303. for _, ifi := range ift {
  304. if ifi.Flags&flags != flags {
  305. continue
  306. }
  307. if _, ok := hasRoutableIP(network, &ifi); !ok {
  308. continue
  309. }
  310. return &ifi
  311. }
  312. return nil
  313. }
  314. func getOutboundInterface() *net.Interface {
  315. conn, err := net.Dial("udp", "8.8.8.8:80")
  316. if err != nil {
  317. log.Fatal(err)
  318. }
  319. defer conn.Close()
  320. localAddr := conn.LocalAddr().(*net.UDPAddr)
  321. ipnet := net.IPNet{
  322. IP: localAddr.IP.To4(),
  323. Mask: localAddr.IP.To4().DefaultMask(),
  324. }
  325. return interfaceOfIP(&ipnet)
  326. }
  327. func hasRoutableIP(network string, ifi *net.Interface) (net.IP, bool) {
  328. ifat, err := ifi.Addrs()
  329. if err != nil {
  330. return nil, false
  331. }
  332. for _, ifa := range ifat {
  333. switch ifa := ifa.(type) {
  334. case *net.IPAddr:
  335. if ip := routableIP(network, ifa.IP); ip != nil {
  336. return ip, true
  337. }
  338. case *net.IPNet:
  339. if ip := routableIP(network, ifa.IP); ip != nil {
  340. return ip, true
  341. }
  342. }
  343. }
  344. return nil, false
  345. }
  346. func vpnInterface() *net.Interface {
  347. server, err := GetServerInstance()
  348. if err != nil {
  349. logrus.Errorf("can't get server instance: %v", err)
  350. return nil
  351. }
  352. mask := net.IPMask(net.ParseIP(server.Mask))
  353. prefix := net.ParseIP(server.Net)
  354. netw := prefix.Mask(mask).To4()
  355. netw[3] = byte(1) // Server is always gets xxx.xxx.xxx.1
  356. ipnet := net.IPNet{IP: netw, Mask: mask}
  357. return interfaceOfIP(&ipnet)
  358. }
  359. func routableIP(network string, ip net.IP) net.IP {
  360. if !ip.IsLoopback() && !ip.IsLinkLocalUnicast() && !ip.IsGlobalUnicast() {
  361. return nil
  362. }
  363. switch network {
  364. case "ip4":
  365. if ip := ip.To4(); ip != nil {
  366. return ip
  367. }
  368. case "ip6":
  369. if ip.IsLoopback() { // addressing scope of the loopback address depends on each implementation
  370. return nil
  371. }
  372. if ip := ip.To16(); ip != nil && ip.To4() == nil {
  373. return ip
  374. }
  375. default:
  376. if ip := ip.To4(); ip != nil {
  377. return ip
  378. }
  379. if ip := ip.To16(); ip != nil {
  380. return ip
  381. }
  382. }
  383. return nil
  384. }
  385. // ensureNatEnabled launches a goroutine that constantly tries to enable nat.
  386. func ensureNatEnabled() {
  387. // Nat enablerer
  388. go func() {
  389. for {
  390. err := enableNat()
  391. if err == nil {
  392. logrus.Debug("nat is enabled")
  393. return
  394. }
  395. logrus.Debugf("can not enable nat: %v", err)
  396. // TODO(cad): employ a exponential back-off approach here
  397. // instead of sleeping for the constant duration.
  398. time.Sleep(1 * time.Second)
  399. }
  400. }()
  401. }
  402. // enableNat is an idempotent command that ensures nat is enabled for the vpn server.
  403. func enableNat() error {
  404. if Testing {
  405. return nil
  406. }
  407. // rif := routedInterface("ip", net.FlagUp|net.FlagBroadcast)
  408. // if rif == nil {
  409. // return fmt.Errorf("can not get routable network interface")
  410. // }
  411. rif := getOutboundInterface()
  412. if rif == nil {
  413. return fmt.Errorf("can not get default gw interface")
  414. }
  415. vpnIfc := vpnInterface()
  416. if vpnIfc == nil {
  417. return fmt.Errorf("can not get vpn network interface on the system")
  418. }
  419. // Enable ip forwarding.
  420. emitToFile("/proc/sys/net/ipv4/ip_forward", "1", 0)
  421. ipt, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
  422. if err != nil {
  423. return fmt.Errorf("can not create new iptables object: %v", err)
  424. }
  425. server, err := GetServerInstance()
  426. if err != nil {
  427. logrus.Errorf("can't get server instance: %v", err)
  428. return nil
  429. }
  430. mask := net.IPMask(net.ParseIP(server.Mask))
  431. prefix := net.ParseIP(server.Net)
  432. netw := prefix.Mask(mask).To4()
  433. netw[3] = byte(1) // Server is always gets xxx.xxx.xxx.1
  434. ipnet := net.IPNet{IP: netw, Mask: mask}
  435. // Append iptables nat rules.
  436. if err := ipt.AppendUnique("nat", "POSTROUTING", "-s", ipnet.String(), "-o", rif.Name, "-j", "MASQUERADE"); err != nil {
  437. return err
  438. }
  439. if err := ipt.AppendUnique("filter", "FORWARD", "-i", rif.Name, "-o", vpnIfc.Name, "-m", "state", "--state", "RELATED,ESTABLISHED", "-j", "ACCEPT"); err != nil {
  440. return err
  441. }
  442. if err := ipt.AppendUnique("filter", "FORWARD", "-i", vpnIfc.Name, "-o", rif.Name, "-j", "ACCEPT"); err != nil {
  443. return err
  444. }
  445. return nil
  446. }
  447. // HostID2IP converts a host id (32-bit unsigned integer) to an IP address.
  448. func HostID2IP(hostid uint32) net.IP {
  449. ip := make([]byte, 4)
  450. binary.BigEndian.PutUint32(ip, hostid)
  451. return net.IP(ip)
  452. }
  453. // IP2HostID converts an IP address to a host id (32-bit unsigned integer).
  454. func IP2HostID(ip net.IP) uint32 {
  455. hostid := binary.BigEndian.Uint32(ip)
  456. return hostid
  457. }
  458. // IncrementIP will return next ip address within the network.
  459. func IncrementIP(ip, mask string) (string, error) {
  460. if !govalidator.IsIPv4(ip) {
  461. return "", fmt.Errorf("'ip' is expected to be a valid IPv4 %s", ip)
  462. }
  463. if !govalidator.IsIPv4(ip) {
  464. return "", fmt.Errorf("'mask' is expected to be a valid IPv4 %s", mask)
  465. }
  466. ipAddr := net.ParseIP(ip).To4()
  467. netMask := net.IPMask(net.ParseIP(mask).To4())
  468. ipNet := net.IPNet{IP: ipAddr, Mask: netMask}
  469. for i := len(ipAddr) - 1; i >= 0; i-- {
  470. ipAddr[i]++
  471. if ip[i] != 0 {
  472. break
  473. }
  474. }
  475. if !ipNet.Contains(ipAddr) {
  476. return ip, errors.New("overflowed CIDR while incrementing IP")
  477. }
  478. return ipAddr.String(), nil
  479. }