net.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480
  1. package ovpm
  2. import (
  3. "encoding/binary"
  4. "fmt"
  5. "log"
  6. "net"
  7. "time"
  8. "github.com/Sirupsen/logrus"
  9. "github.com/asaskevich/govalidator"
  10. "github.com/coreos/go-iptables/iptables"
  11. "github.com/jinzhu/gorm"
  12. )
  13. // NetworkType distinguishes different types of networks that is defined in the networks table.
  14. type NetworkType uint
  15. // NetworkTypes
  16. const (
  17. UNDEFINEDNET NetworkType = iota
  18. SERVERNET
  19. ROUTE
  20. )
  21. var networkTypes = [...]struct {
  22. Type NetworkType
  23. String string
  24. }{
  25. {UNDEFINEDNET, "UNDEFINEDNET"},
  26. {SERVERNET, "SERVERNET"},
  27. {ROUTE, "ROUTE"},
  28. }
  29. // NetworkTypeFromString returns string representation of the network type.
  30. func NetworkTypeFromString(typ string) NetworkType {
  31. for _, v := range networkTypes {
  32. if v.String == typ {
  33. return v.Type
  34. }
  35. }
  36. return UNDEFINEDNET
  37. }
  38. // GetAllNetworkTypes returns all network types defined in the system.
  39. func GetAllNetworkTypes() []NetworkType {
  40. var networkTypeList []NetworkType
  41. for _, v := range networkTypes {
  42. networkTypeList = append(networkTypeList, v.Type)
  43. }
  44. return networkTypeList
  45. }
  46. func (nt NetworkType) String() string {
  47. for _, v := range networkTypes {
  48. if v.Type == nt {
  49. return v.String
  50. }
  51. }
  52. return "UNDEFINEDNET"
  53. }
  54. // DBNetwork is database model for external networks on the VPN server.
  55. type DBNetwork struct {
  56. gorm.Model
  57. ServerID uint
  58. Server DBServer
  59. Name string `gorm:"unique_index"`
  60. CIDR string
  61. Type NetworkType
  62. Users []*DBUser `gorm:"many2many:network_users;"`
  63. }
  64. // GetNetwork returns a network specified by its name.
  65. func GetNetwork(name string) (*DBNetwork, error) {
  66. if !IsInitialized() {
  67. return nil, fmt.Errorf("you first need to create server")
  68. }
  69. // Validate user input.
  70. if govalidator.IsNull(name) {
  71. return nil, fmt.Errorf("validation error: %s can not be null", name)
  72. }
  73. if !govalidator.IsAlphanumeric(name) {
  74. return nil, fmt.Errorf("validation error: `%s` can only contain letters and numbers", name)
  75. }
  76. var network DBNetwork
  77. db.Preload("Users").Where(&DBNetwork{Name: name}).First(&network)
  78. if db.NewRecord(&network) {
  79. return nil, fmt.Errorf("network not found %s", name)
  80. }
  81. return &network, nil
  82. }
  83. // GetAllNetworks returns all networks defined in the system.
  84. func GetAllNetworks() []*DBNetwork {
  85. var networks []*DBNetwork
  86. db.Preload("Users").Find(&networks)
  87. return networks
  88. }
  89. // CreateNewNetwork creates a new network definition in the system.
  90. func CreateNewNetwork(name, cidr string, nettype NetworkType) (*DBNetwork, error) {
  91. if !IsInitialized() {
  92. return nil, fmt.Errorf("you first need to create server")
  93. }
  94. // Validate user input.
  95. if govalidator.IsNull(name) {
  96. return nil, fmt.Errorf("validation error: %s can not be null", name)
  97. }
  98. if !govalidator.IsAlphanumeric(name) {
  99. return nil, fmt.Errorf("validation error: `%s` can only contain letters and numbers", name)
  100. }
  101. if !govalidator.IsCIDR(cidr) {
  102. return nil, fmt.Errorf("validation error: `%s` must be a network in the CIDR form", cidr)
  103. }
  104. if nettype == UNDEFINEDNET {
  105. return nil, fmt.Errorf("validation error: `%s` must be a valid network type", nettype)
  106. }
  107. _, ipnet, err := net.ParseCIDR(cidr)
  108. if err != nil {
  109. return nil, fmt.Errorf("can not parse CIDR %s: %v", cidr, err)
  110. }
  111. network := DBNetwork{
  112. Name: name,
  113. CIDR: ipnet.String(),
  114. Type: nettype,
  115. Users: []*DBUser{},
  116. }
  117. db.Save(&network)
  118. if db.NewRecord(&network) {
  119. return nil, fmt.Errorf("can not create network in the db")
  120. }
  121. Emit()
  122. logrus.Infof("network defined: %s (%s)", network.Name, network.CIDR)
  123. return &network, nil
  124. }
  125. // Delete deletes a network definition in the system.
  126. func (n *DBNetwork) Delete() error {
  127. if !IsInitialized() {
  128. return fmt.Errorf("you first need to create server")
  129. }
  130. db.Unscoped().Delete(n)
  131. Emit()
  132. logrus.Infof("network deleted: %s", n.Name)
  133. return nil
  134. }
  135. // Associate allows the given user access to this network.
  136. func (n *DBNetwork) Associate(username string) error {
  137. if !IsInitialized() {
  138. return fmt.Errorf("you first need to create server")
  139. }
  140. user, err := GetUser(username)
  141. if err != nil {
  142. return fmt.Errorf("user can not be fetched: %v", err)
  143. }
  144. var users []DBUser
  145. userAssoc := db.Model(&n).Association("Users")
  146. userAssoc.Find(&users)
  147. var found bool
  148. for _, u := range users {
  149. if u.ID == user.ID {
  150. found = true
  151. break
  152. }
  153. }
  154. if found {
  155. return fmt.Errorf("user %s is already associated with the network %s", user.Username, n.Name)
  156. }
  157. userAssoc.Append(user)
  158. if userAssoc.Error != nil {
  159. return fmt.Errorf("association failed: %v", userAssoc.Error)
  160. }
  161. Emit()
  162. logrus.Infof("user '%s' is associated with the network '%s'", user.GetUsername(), n.Name)
  163. return nil
  164. }
  165. // Dissociate breaks up the given users association to the said network.
  166. func (n *DBNetwork) Dissociate(username string) error {
  167. if !IsInitialized() {
  168. return fmt.Errorf("you first need to create server")
  169. }
  170. user, err := GetUser(username)
  171. if err != nil {
  172. return fmt.Errorf("user can not be fetched: %v", err)
  173. }
  174. var users []DBUser
  175. userAssoc := db.Model(&n).Association("Users")
  176. userAssoc.Find(&users)
  177. var found bool
  178. for _, u := range users {
  179. if u.ID == user.ID {
  180. found = true
  181. break
  182. }
  183. }
  184. if !found {
  185. return fmt.Errorf("user %s is already not associated with the network %s", user.Username, n.Name)
  186. }
  187. userAssoc.Delete(user)
  188. if userAssoc.Error != nil {
  189. return fmt.Errorf("disassociation failed: %v", userAssoc.Error)
  190. }
  191. Emit()
  192. logrus.Infof("user '%s' is dissociated with the network '%s'", user.GetUsername(), n.Name)
  193. return nil
  194. }
  195. // GetName returns network's name.
  196. func (n *DBNetwork) GetName() string {
  197. return n.Name
  198. }
  199. // GetCIDR returns network's CIDR.
  200. func (n *DBNetwork) GetCIDR() string {
  201. return n.CIDR
  202. }
  203. // GetCreatedAt returns network's name.
  204. func (n *DBNetwork) GetCreatedAt() string {
  205. return n.CreatedAt.Format(time.UnixDate)
  206. }
  207. // GetType returns network's network type.
  208. func (n *DBNetwork) GetType() NetworkType {
  209. return NetworkType(n.Type)
  210. }
  211. // GetAssociatedUsers returns network's associated users.
  212. func (n *DBNetwork) GetAssociatedUsers() []*DBUser {
  213. return n.Users
  214. }
  215. // GetAssociatedUsernames returns network's associated user names.
  216. func (n *DBNetwork) GetAssociatedUsernames() []string {
  217. var usernames []string
  218. for _, user := range n.GetAssociatedUsers() {
  219. usernames = append(usernames, user.Username)
  220. }
  221. return usernames
  222. }
  223. // interfaceOfIP returns a network interface that has the given IP.
  224. func interfaceOfIP(ipnet *net.IPNet) *net.Interface {
  225. ifaces, err := net.Interfaces()
  226. if err != nil {
  227. return nil
  228. }
  229. for _, iface := range ifaces {
  230. addrs, err := iface.Addrs()
  231. if err != nil {
  232. logrus.Error(err)
  233. return nil
  234. }
  235. for _, addr := range addrs {
  236. switch addr := addr.(type) {
  237. case *net.IPAddr:
  238. if ip := addr.IP; ip != nil {
  239. if ipnet.Contains(ip) {
  240. return &iface
  241. }
  242. }
  243. case *net.IPNet:
  244. if ip := addr.IP; ip != nil {
  245. if ipnet.Contains(ip) {
  246. return &iface
  247. }
  248. }
  249. }
  250. }
  251. }
  252. return nil
  253. }
  254. func routedInterface(network string, flags net.Flags) *net.Interface {
  255. switch network {
  256. case "ip", "ip4", "ip6":
  257. default:
  258. return nil
  259. }
  260. ift, err := net.Interfaces()
  261. if err != nil {
  262. return nil
  263. }
  264. for _, ifi := range ift {
  265. if ifi.Flags&flags != flags {
  266. continue
  267. }
  268. if _, ok := hasRoutableIP(network, &ifi); !ok {
  269. continue
  270. }
  271. return &ifi
  272. }
  273. return nil
  274. }
  275. func getOutboundInterface() *net.Interface {
  276. conn, err := net.Dial("udp", "8.8.8.8:80")
  277. if err != nil {
  278. log.Fatal(err)
  279. }
  280. defer conn.Close()
  281. localAddr := conn.LocalAddr().(*net.UDPAddr)
  282. ipnet := net.IPNet{
  283. IP: localAddr.IP.To4(),
  284. Mask: localAddr.IP.To4().DefaultMask(),
  285. }
  286. return interfaceOfIP(&ipnet)
  287. }
  288. func hasRoutableIP(network string, ifi *net.Interface) (net.IP, bool) {
  289. ifat, err := ifi.Addrs()
  290. if err != nil {
  291. return nil, false
  292. }
  293. for _, ifa := range ifat {
  294. switch ifa := ifa.(type) {
  295. case *net.IPAddr:
  296. if ip := routableIP(network, ifa.IP); ip != nil {
  297. return ip, true
  298. }
  299. case *net.IPNet:
  300. if ip := routableIP(network, ifa.IP); ip != nil {
  301. return ip, true
  302. }
  303. }
  304. }
  305. return nil, false
  306. }
  307. func vpnInterface() *net.Interface {
  308. server, err := GetServerInstance()
  309. if err != nil {
  310. logrus.Errorf("can't get server instance: %v", err)
  311. return nil
  312. }
  313. mask := net.IPMask(net.ParseIP(server.Mask))
  314. prefix := net.ParseIP(server.Net)
  315. netw := prefix.Mask(mask).To4()
  316. netw[3] = byte(1) // Server is always gets xxx.xxx.xxx.1
  317. ipnet := net.IPNet{IP: netw, Mask: mask}
  318. return interfaceOfIP(&ipnet)
  319. }
  320. func routableIP(network string, ip net.IP) net.IP {
  321. if !ip.IsLoopback() && !ip.IsLinkLocalUnicast() && !ip.IsGlobalUnicast() {
  322. return nil
  323. }
  324. switch network {
  325. case "ip4":
  326. if ip := ip.To4(); ip != nil {
  327. return ip
  328. }
  329. case "ip6":
  330. if ip.IsLoopback() { // addressing scope of the loopback address depends on each implementation
  331. return nil
  332. }
  333. if ip := ip.To16(); ip != nil && ip.To4() == nil {
  334. return ip
  335. }
  336. default:
  337. if ip := ip.To4(); ip != nil {
  338. return ip
  339. }
  340. if ip := ip.To16(); ip != nil {
  341. return ip
  342. }
  343. }
  344. return nil
  345. }
  346. // ensureNatEnabled launches a goroutine that constantly tries to enable nat.
  347. func ensureNatEnabled() {
  348. // Nat enablerer
  349. go func() {
  350. for {
  351. err := enableNat()
  352. if err == nil {
  353. logrus.Debug("nat is enabled")
  354. return
  355. }
  356. logrus.Debugf("can not enable nat: %v", err)
  357. // TODO(cad): employ a exponential back-off approach here
  358. // instead of sleeping for the constant duration.
  359. time.Sleep(1 * time.Second)
  360. }
  361. }()
  362. }
  363. // enableNat is an idempotent command that ensures nat is enabled for the vpn server.
  364. func enableNat() error {
  365. if Testing {
  366. return nil
  367. }
  368. // rif := routedInterface("ip", net.FlagUp|net.FlagBroadcast)
  369. // if rif == nil {
  370. // return fmt.Errorf("can not get routable network interface")
  371. // }
  372. rif := getOutboundInterface()
  373. if rif == nil {
  374. return fmt.Errorf("can not get default gw interface")
  375. }
  376. vpnIfc := vpnInterface()
  377. if vpnIfc == nil {
  378. return fmt.Errorf("can not get vpn network interface on the system")
  379. }
  380. // Enable ip forwarding.
  381. emitToFile("/proc/sys/net/ipv4/ip_forward", "1", 0)
  382. ipt, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
  383. if err != nil {
  384. return fmt.Errorf("can not create new iptables object: %v", err)
  385. }
  386. server, err := GetServerInstance()
  387. if err != nil {
  388. logrus.Errorf("can't get server instance: %v", err)
  389. return nil
  390. }
  391. mask := net.IPMask(net.ParseIP(server.Mask))
  392. prefix := net.ParseIP(server.Net)
  393. netw := prefix.Mask(mask).To4()
  394. netw[3] = byte(1) // Server is always gets xxx.xxx.xxx.1
  395. ipnet := net.IPNet{IP: netw, Mask: mask}
  396. // Append iptables nat rules.
  397. if err := ipt.AppendUnique("nat", "POSTROUTING", "-s", ipnet.String(), "-o", rif.Name, "-j", "MASQUERADE"); err != nil {
  398. return err
  399. }
  400. if err := ipt.AppendUnique("filter", "FORWARD", "-i", rif.Name, "-o", vpnIfc.Name, "-m", "state", "--state", "RELATED,ESTABLISHED", "-j", "ACCEPT"); err != nil {
  401. return err
  402. }
  403. if err := ipt.AppendUnique("filter", "FORWARD", "-i", vpnIfc.Name, "-o", rif.Name, "-j", "ACCEPT"); err != nil {
  404. return err
  405. }
  406. return nil
  407. }
  408. // HostID2IP converts a host id (32-bit unsigned integer) to an IP address.
  409. func HostID2IP(hostid uint32) net.IP {
  410. ip := make([]byte, 4)
  411. binary.BigEndian.PutUint32(ip, hostid)
  412. return net.IP(ip)
  413. }
  414. //IP2HostID converts an IP address to a host id (32-bit unsigned integer).
  415. func IP2HostID(ip net.IP) uint32 {
  416. hostid := binary.BigEndian.Uint32(ip)
  417. return hostid
  418. }