1
0

net.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558
  1. package ovpm
  2. import (
  3. "encoding/binary"
  4. "errors"
  5. "fmt"
  6. "log"
  7. "net"
  8. "time"
  9. "github.com/Sirupsen/logrus"
  10. "github.com/asaskevich/govalidator"
  11. "github.com/coreos/go-iptables/iptables"
  12. "github.com/jinzhu/gorm"
  13. )
  14. // NetworkType distinguishes different types of networks that is defined in the networks table.
  15. type NetworkType uint
  16. // NetworkTypes
  17. const (
  18. UNDEFINEDNET NetworkType = iota
  19. SERVERNET
  20. ROUTE
  21. )
  22. var networkTypes = [...]struct {
  23. Type NetworkType
  24. String string
  25. Description string
  26. }{
  27. {UNDEFINEDNET, "UNDEFINEDNET", "unknown network type"},
  28. {SERVERNET, "SERVERNET", "network behind vpn server"},
  29. {ROUTE, "ROUTE", "network to be pushed as route"},
  30. }
  31. // NetworkTypeFromString returns string representation of the network type.
  32. func NetworkTypeFromString(typ string) NetworkType {
  33. for _, v := range networkTypes {
  34. if v.String == typ {
  35. return v.Type
  36. }
  37. }
  38. return UNDEFINEDNET
  39. }
  40. // GetAllNetworkTypes returns all network types defined in the system.
  41. func GetAllNetworkTypes() []NetworkType {
  42. var networkTypeList []NetworkType
  43. for _, v := range networkTypes {
  44. networkTypeList = append(networkTypeList, v.Type)
  45. }
  46. return networkTypeList
  47. }
  48. func (nt NetworkType) String() string {
  49. for _, v := range networkTypes {
  50. if v.Type == nt {
  51. return v.String
  52. }
  53. }
  54. return "UNDEFINEDNET"
  55. }
  56. // Description gives description about the network type.
  57. func (nt NetworkType) Description() string {
  58. for _, v := range networkTypes {
  59. if v.Type == nt {
  60. return v.Description
  61. }
  62. }
  63. return "UNDEFINEDNET"
  64. }
  65. // IsNetworkType returns if the s is a valid network type or not.
  66. func IsNetworkType(s string) bool {
  67. for _, v := range networkTypes {
  68. if s == v.String {
  69. return true
  70. }
  71. }
  72. return false
  73. }
  74. // dbNetworkModel is database model for external networks on the VPN server.
  75. type dbNetworkModel struct {
  76. gorm.Model
  77. ServerID uint
  78. Server dbServerModel
  79. Name string `gorm:"unique_index"`
  80. CIDR string
  81. Type NetworkType
  82. Via string
  83. Users []*dbUserModel `gorm:"many2many:network_users;"`
  84. }
  85. // Network represents a VPN related network.
  86. type Network struct {
  87. dbNetworkModel
  88. }
  89. // GetNetwork returns a network specified by its name.
  90. func GetNetwork(name string) (*Network, error) {
  91. if !IsInitialized() {
  92. return nil, fmt.Errorf("you first need to create server")
  93. }
  94. // Validate user input.
  95. if govalidator.IsNull(name) {
  96. return nil, fmt.Errorf("validation error: %s can not be null", name)
  97. }
  98. var network dbNetworkModel
  99. db.Preload("Users").Where(&dbNetworkModel{Name: name}).First(&network)
  100. if db.NewRecord(&network) {
  101. return nil, fmt.Errorf("network not found %s", name)
  102. }
  103. return &Network{dbNetworkModel: network}, nil
  104. }
  105. // GetAllNetworks returns all networks defined in the system.
  106. func GetAllNetworks() []*Network {
  107. var networks []*Network
  108. var dbNetworks []*dbNetworkModel
  109. db.Preload("Users").Find(&dbNetworks)
  110. for _, n := range dbNetworks {
  111. networks = append(networks, &Network{dbNetworkModel: *n})
  112. }
  113. return networks
  114. }
  115. // CreateNewNetwork creates a new network definition in the system.
  116. func CreateNewNetwork(name, cidr string, nettype NetworkType, via string) (*Network, error) {
  117. if !IsInitialized() {
  118. return nil, fmt.Errorf("you first need to create server")
  119. }
  120. // Validate user input.
  121. if govalidator.IsNull(name) {
  122. return nil, fmt.Errorf("validation error: %s can not be null", name)
  123. }
  124. if !govalidator.Matches(name, "^([\\w\\.]+)$") { // allow alphanumeric, underscore and dot
  125. return nil, fmt.Errorf("validation error: `%s` can only contain letters, numbers, underscores and dots", name)
  126. }
  127. if !govalidator.IsCIDR(cidr) {
  128. return nil, fmt.Errorf("validation error: `%s` must be a network in the CIDR form", cidr)
  129. }
  130. if via != "" && !govalidator.IsIPv4(via) {
  131. return nil, fmt.Errorf("validation error: `%s` must be a network in the IPv4 form", via)
  132. }
  133. if nettype == UNDEFINEDNET {
  134. return nil, fmt.Errorf("validation error: `%s` must be a valid network type", nettype)
  135. }
  136. _, ipnet, err := net.ParseCIDR(cidr)
  137. if err != nil {
  138. return nil, fmt.Errorf("can not parse CIDR %s: %v", cidr, err)
  139. }
  140. // Overwrite via with the parsed IPv4 string.
  141. if nettype == ROUTE && via != "" {
  142. viaIP := net.ParseIP(via).To4()
  143. if err != nil {
  144. return nil, fmt.Errorf("can not parse IPv4 %s: %v", via, err)
  145. }
  146. via = viaIP.String()
  147. } else {
  148. via = ""
  149. }
  150. network := dbNetworkModel{
  151. Name: name,
  152. CIDR: ipnet.String(),
  153. Type: nettype,
  154. Users: []*dbUserModel{},
  155. Via: via,
  156. }
  157. db.Save(&network)
  158. if db.NewRecord(&network) {
  159. return nil, fmt.Errorf("can not create network in the db")
  160. }
  161. EmitWithRestart()
  162. logrus.Infof("network defined: %s (%s)", network.Name, network.CIDR)
  163. return &Network{dbNetworkModel: network}, nil
  164. }
  165. // Delete deletes a network definition in the system.
  166. func (n *Network) Delete() error {
  167. if !IsInitialized() {
  168. return fmt.Errorf("you first need to create server")
  169. }
  170. db.Unscoped().Delete(n.dbNetworkModel)
  171. EmitWithRestart()
  172. logrus.Infof("network deleted: %s", n.Name)
  173. return nil
  174. }
  175. // Associate allows the given user access to this network.
  176. func (n *Network) Associate(username string) error {
  177. if !IsInitialized() {
  178. return fmt.Errorf("you first need to create server")
  179. }
  180. user, err := GetUser(username)
  181. if err != nil {
  182. return fmt.Errorf("user can not be fetched: %v", err)
  183. }
  184. var users []dbUserModel
  185. userAssoc := db.Model(&n.dbNetworkModel).Association("Users")
  186. userAssoc.Find(&users)
  187. var found bool
  188. for _, u := range users {
  189. if u.ID == user.ID {
  190. found = true
  191. break
  192. }
  193. }
  194. if found {
  195. return fmt.Errorf("user %s is already associated with the network %s", user.Username, n.Name)
  196. }
  197. userAssoc.Append(user.dbUserModel)
  198. if userAssoc.Error != nil {
  199. return fmt.Errorf("association failed: %v", userAssoc.Error)
  200. }
  201. EmitWithRestart()
  202. logrus.Infof("user '%s' is associated with the network '%s'", user.GetUsername(), n.Name)
  203. return nil
  204. }
  205. // Dissociate breaks up the given users association to the said network.
  206. func (n *Network) Dissociate(username string) error {
  207. if !IsInitialized() {
  208. return fmt.Errorf("you first need to create server")
  209. }
  210. user, err := GetUser(username)
  211. if err != nil {
  212. return fmt.Errorf("user can not be fetched: %v", err)
  213. }
  214. var users []dbUserModel
  215. userAssoc := db.Model(&n.dbNetworkModel).Association("Users")
  216. userAssoc.Find(&users)
  217. var found bool
  218. for _, u := range users {
  219. if u.ID == user.ID {
  220. found = true
  221. break
  222. }
  223. }
  224. if !found {
  225. return fmt.Errorf("user %s is already not associated with the network %s", user.Username, n.Name)
  226. }
  227. userAssoc.Delete(user.dbUserModel)
  228. if userAssoc.Error != nil {
  229. return fmt.Errorf("disassociation failed: %v", userAssoc.Error)
  230. }
  231. EmitWithRestart()
  232. logrus.Infof("user '%s' is dissociated with the network '%s'", user.GetUsername(), n.Name)
  233. return nil
  234. }
  235. // GetName returns network's name.
  236. func (n *Network) GetName() string {
  237. return n.Name
  238. }
  239. // GetCIDR returns network's CIDR.
  240. func (n *Network) GetCIDR() string {
  241. return n.CIDR
  242. }
  243. // GetCreatedAt returns network's name.
  244. func (n *Network) GetCreatedAt() string {
  245. return n.CreatedAt.Format(time.UnixDate)
  246. }
  247. // GetType returns network's network type.
  248. func (n *Network) GetType() NetworkType {
  249. return NetworkType(n.Type)
  250. }
  251. // GetAssociatedUsers returns network's associated users.
  252. func (n *Network) GetAssociatedUsers() []*User {
  253. var users []*User
  254. for _, u := range n.Users {
  255. users = append(users, &User{dbUserModel: *u})
  256. }
  257. return users
  258. }
  259. // GetAssociatedUsernames returns network's associated user names.
  260. func (n *Network) GetAssociatedUsernames() []string {
  261. var usernames []string
  262. for _, user := range n.GetAssociatedUsers() {
  263. usernames = append(usernames, user.Username)
  264. }
  265. return usernames
  266. }
  267. // GetVia returns network' via.
  268. func (n *Network) GetVia() string {
  269. return n.Via
  270. }
  271. // interfaceOfIP returns a network interface that has the given IP.
  272. func interfaceOfIP(ipnet *net.IPNet) *net.Interface {
  273. ifaces, err := net.Interfaces()
  274. if err != nil {
  275. return nil
  276. }
  277. for _, iface := range ifaces {
  278. addrs, err := iface.Addrs()
  279. if err != nil {
  280. logrus.Error(err)
  281. return nil
  282. }
  283. for _, addr := range addrs {
  284. switch addr := addr.(type) {
  285. case *net.IPAddr:
  286. if ip := addr.IP; ip != nil {
  287. if ipnet.Contains(ip) {
  288. return &iface
  289. }
  290. }
  291. case *net.IPNet:
  292. if ip := addr.IP; ip != nil {
  293. if ipnet.Contains(ip) {
  294. return &iface
  295. }
  296. }
  297. }
  298. }
  299. }
  300. return nil
  301. }
  302. func routedInterface(network string, flags net.Flags) *net.Interface {
  303. switch network {
  304. case "ip", "ip4", "ip6":
  305. default:
  306. return nil
  307. }
  308. ift, err := net.Interfaces()
  309. if err != nil {
  310. return nil
  311. }
  312. for _, ifi := range ift {
  313. if ifi.Flags&flags != flags {
  314. continue
  315. }
  316. if _, ok := hasRoutableIP(network, &ifi); !ok {
  317. continue
  318. }
  319. return &ifi
  320. }
  321. return nil
  322. }
  323. func getOutboundInterface() *net.Interface {
  324. conn, err := net.Dial("udp", "8.8.8.8:80")
  325. if err != nil {
  326. log.Fatal(err)
  327. }
  328. defer conn.Close()
  329. localAddr := conn.LocalAddr().(*net.UDPAddr)
  330. ipnet := net.IPNet{
  331. IP: localAddr.IP.To4(),
  332. Mask: localAddr.IP.To4().DefaultMask(),
  333. }
  334. return interfaceOfIP(&ipnet)
  335. }
  336. func hasRoutableIP(network string, ifi *net.Interface) (net.IP, bool) {
  337. ifat, err := ifi.Addrs()
  338. if err != nil {
  339. return nil, false
  340. }
  341. for _, ifa := range ifat {
  342. switch ifa := ifa.(type) {
  343. case *net.IPAddr:
  344. if ip := routableIP(network, ifa.IP); ip != nil {
  345. return ip, true
  346. }
  347. case *net.IPNet:
  348. if ip := routableIP(network, ifa.IP); ip != nil {
  349. return ip, true
  350. }
  351. }
  352. }
  353. return nil, false
  354. }
  355. func vpnInterface() *net.Interface {
  356. server, err := GetServerInstance()
  357. if err != nil {
  358. logrus.Errorf("can't get server instance: %v", err)
  359. return nil
  360. }
  361. mask := net.IPMask(net.ParseIP(server.Mask))
  362. prefix := net.ParseIP(server.Net)
  363. netw := prefix.Mask(mask).To4()
  364. netw[3] = byte(1) // Server is always gets xxx.xxx.xxx.1
  365. ipnet := net.IPNet{IP: netw, Mask: mask}
  366. return interfaceOfIP(&ipnet)
  367. }
  368. func routableIP(network string, ip net.IP) net.IP {
  369. if !ip.IsLoopback() && !ip.IsLinkLocalUnicast() && !ip.IsGlobalUnicast() {
  370. return nil
  371. }
  372. switch network {
  373. case "ip4":
  374. if ip := ip.To4(); ip != nil {
  375. return ip
  376. }
  377. case "ip6":
  378. if ip.IsLoopback() { // addressing scope of the loopback address depends on each implementation
  379. return nil
  380. }
  381. if ip := ip.To16(); ip != nil && ip.To4() == nil {
  382. return ip
  383. }
  384. default:
  385. if ip := ip.To4(); ip != nil {
  386. return ip
  387. }
  388. if ip := ip.To16(); ip != nil {
  389. return ip
  390. }
  391. }
  392. return nil
  393. }
  394. // ensureNatEnabled launches a goroutine that constantly tries to enable nat.
  395. func ensureNatEnabled() {
  396. // Nat enablerer
  397. go func() {
  398. for {
  399. err := enableNat()
  400. if err == nil {
  401. logrus.Debug("nat is enabled")
  402. return
  403. }
  404. logrus.Debugf("can not enable nat: %v", err)
  405. // TODO(cad): employ a exponential back-off approach here
  406. // instead of sleeping for the constant duration.
  407. time.Sleep(1 * time.Second)
  408. }
  409. }()
  410. }
  411. // enableNat is an idempotent command that ensures nat is enabled for the vpn server.
  412. func enableNat() error {
  413. if Testing {
  414. return nil
  415. }
  416. // rif := routedInterface("ip", net.FlagUp|net.FlagBroadcast)
  417. // if rif == nil {
  418. // return fmt.Errorf("can not get routable network interface")
  419. // }
  420. rif := getOutboundInterface()
  421. if rif == nil {
  422. return fmt.Errorf("can not get default gw interface")
  423. }
  424. vpnIfc := vpnInterface()
  425. if vpnIfc == nil {
  426. return fmt.Errorf("can not get vpn network interface on the system")
  427. }
  428. // Enable ip forwarding.
  429. emitToFile("/proc/sys/net/ipv4/ip_forward", "1", 0)
  430. ipt, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
  431. if err != nil {
  432. return fmt.Errorf("can not create new iptables object: %v", err)
  433. }
  434. server, err := GetServerInstance()
  435. if err != nil {
  436. logrus.Errorf("can't get server instance: %v", err)
  437. return nil
  438. }
  439. mask := net.IPMask(net.ParseIP(server.Mask))
  440. prefix := net.ParseIP(server.Net)
  441. netw := prefix.Mask(mask).To4()
  442. netw[3] = byte(1) // Server is always gets xxx.xxx.xxx.1
  443. ipnet := net.IPNet{IP: netw, Mask: mask}
  444. // Append iptables nat rules.
  445. if err := ipt.AppendUnique("nat", "POSTROUTING", "-s", ipnet.String(), "-o", rif.Name, "-j", "MASQUERADE"); err != nil {
  446. return err
  447. }
  448. if err := ipt.AppendUnique("filter", "FORWARD", "-i", rif.Name, "-o", vpnIfc.Name, "-m", "state", "--state", "RELATED,ESTABLISHED", "-j", "ACCEPT"); err != nil {
  449. return err
  450. }
  451. if err := ipt.AppendUnique("filter", "FORWARD", "-i", vpnIfc.Name, "-o", rif.Name, "-j", "ACCEPT"); err != nil {
  452. return err
  453. }
  454. return nil
  455. }
  456. // HostID2IP converts a host id (32-bit unsigned integer) to an IP address.
  457. func HostID2IP(hostid uint32) net.IP {
  458. ip := make([]byte, 4)
  459. binary.BigEndian.PutUint32(ip, hostid)
  460. return net.IP(ip).To4()
  461. }
  462. // IP2HostID converts an IP address to a host id (32-bit unsigned integer).
  463. func IP2HostID(ip net.IP) uint32 {
  464. hostid := binary.BigEndian.Uint32(ip.To4())
  465. return hostid
  466. }
  467. // IncrementIP will return next ip address within the network.
  468. func IncrementIP(ip, mask string) (string, error) {
  469. if !govalidator.IsIPv4(ip) {
  470. return "", fmt.Errorf("'ip' is expected to be a valid IPv4 %s", ip)
  471. }
  472. if !govalidator.IsIPv4(ip) {
  473. return "", fmt.Errorf("'mask' is expected to be a valid IPv4 %s", mask)
  474. }
  475. ipAddr := net.ParseIP(ip).To4()
  476. netMask := net.IPMask(net.ParseIP(mask).To4())
  477. ipNet := net.IPNet{IP: ipAddr, Mask: netMask}
  478. for i := len(ipAddr) - 1; i >= 0; i-- {
  479. ipAddr[i]++
  480. if ip[i] != 0 {
  481. break
  482. }
  483. }
  484. if !ipNet.Contains(ipAddr) {
  485. return ip, errors.New("overflowed CIDR while incrementing IP")
  486. }
  487. return ipAddr.String(), nil
  488. }