net.go 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409
  1. package ovpm
  2. import (
  3. "encoding/binary"
  4. "fmt"
  5. "net"
  6. "time"
  7. "github.com/Sirupsen/logrus"
  8. "github.com/asaskevich/govalidator"
  9. "github.com/coreos/go-iptables/iptables"
  10. "github.com/jinzhu/gorm"
  11. )
  12. // NetworkType distinguishes different types of networks that is defined in the networks table.
  13. type NetworkType uint
  14. // NetworkTypes
  15. const (
  16. UNDEFINEDNET NetworkType = iota
  17. SERVERNET
  18. ROUTE
  19. )
  20. var networkTypes = [...]struct {
  21. Type NetworkType
  22. String string
  23. }{
  24. {UNDEFINEDNET, "UNDEFINEDNET"},
  25. {SERVERNET, "SERVERNET"},
  26. {ROUTE, "ROUTE"},
  27. }
  28. // NetworkTypeFromString returns string representation of the network type.
  29. func NetworkTypeFromString(typ string) NetworkType {
  30. for _, v := range networkTypes {
  31. if v.String == typ {
  32. return v.Type
  33. }
  34. }
  35. return UNDEFINEDNET
  36. }
  37. // GetAllNetworkTypes returns all network types defined in the system.
  38. func GetAllNetworkTypes() []NetworkType {
  39. var networkTypeList []NetworkType
  40. for _, v := range networkTypes {
  41. networkTypeList = append(networkTypeList, v.Type)
  42. }
  43. return networkTypeList
  44. }
  45. func (nt NetworkType) String() string {
  46. for _, v := range networkTypes {
  47. if v.Type == nt {
  48. return v.String
  49. }
  50. }
  51. return "UNDEFINEDNET"
  52. }
  53. // DBNetwork is database model for external networks on the VPN server.
  54. type DBNetwork struct {
  55. gorm.Model
  56. ServerID uint
  57. Server DBServer
  58. Name string `gorm:"unique_index"`
  59. CIDR string
  60. Type NetworkType
  61. Users []*DBUser `gorm:"many2many:network_users;"`
  62. }
  63. // GetNetwork returns a network specified by its name.
  64. func GetNetwork(name string) (*DBNetwork, error) {
  65. if !IsInitialized() {
  66. return nil, fmt.Errorf("you first need to create server")
  67. }
  68. // Validate user input.
  69. if govalidator.IsNull(name) {
  70. return nil, fmt.Errorf("validation error: %s can not be null", name)
  71. }
  72. if !govalidator.IsAlphanumeric(name) {
  73. return nil, fmt.Errorf("validation error: `%s` can only contain letters and numbers", name)
  74. }
  75. var network DBNetwork
  76. db.Preload("Users").Where(&DBNetwork{Name: name}).First(&network)
  77. if db.NewRecord(&network) {
  78. return nil, fmt.Errorf("network not found %s", name)
  79. }
  80. return &network, nil
  81. }
  82. // GetAllNetworks returns all networks defined in the system.
  83. func GetAllNetworks() ([]*DBNetwork, error) {
  84. var networks []*DBNetwork
  85. db.Preload("Users").Find(&networks)
  86. return networks, nil
  87. }
  88. // CreateNewNetwork creates a new network definition in the system.
  89. func CreateNewNetwork(name, cidr string, nettype NetworkType) (*DBNetwork, error) {
  90. if !IsInitialized() {
  91. return nil, fmt.Errorf("you first need to create server")
  92. }
  93. // Validate user input.
  94. if govalidator.IsNull(name) {
  95. return nil, fmt.Errorf("validation error: %s can not be null", name)
  96. }
  97. if !govalidator.IsAlphanumeric(name) {
  98. return nil, fmt.Errorf("validation error: `%s` can only contain letters and numbers", name)
  99. }
  100. if !govalidator.IsCIDR(cidr) {
  101. return nil, fmt.Errorf("validation error: `%s` must be a network in the CIDR form", cidr)
  102. }
  103. if nettype == UNDEFINEDNET {
  104. return nil, fmt.Errorf("validation error: `%s` must be a valid network type", nettype)
  105. }
  106. _, ipnet, err := net.ParseCIDR(cidr)
  107. if err != nil {
  108. return nil, fmt.Errorf("can not parse CIDR %s: %v", cidr, err)
  109. }
  110. network := DBNetwork{
  111. Name: name,
  112. CIDR: ipnet.String(),
  113. Type: nettype,
  114. Users: []*DBUser{},
  115. }
  116. db.Save(&network)
  117. if db.NewRecord(&network) {
  118. return nil, fmt.Errorf("can not create network in the db")
  119. }
  120. return &network, nil
  121. }
  122. // Delete deletes a network definition in the system.
  123. func (n *DBNetwork) Delete() error {
  124. if !IsInitialized() {
  125. return fmt.Errorf("you first need to create server")
  126. }
  127. db.Unscoped().Delete(n)
  128. logrus.Infof("network deleted: %s", n.Name)
  129. return nil
  130. }
  131. // Associate allows the given user access to this network.
  132. func (n *DBNetwork) Associate(username string) error {
  133. if !IsInitialized() {
  134. return fmt.Errorf("you first need to create server")
  135. }
  136. user, err := GetUser(username)
  137. if err != nil {
  138. return fmt.Errorf("user can not be fetched: %v", err)
  139. }
  140. var users []DBUser
  141. userAssoc := db.Model(&n).Association("Users")
  142. userAssoc.Find(&users)
  143. var found bool
  144. for _, u := range users {
  145. if u.ID == user.ID {
  146. found = true
  147. break
  148. }
  149. }
  150. if found {
  151. return fmt.Errorf("user %s is already associated with the network %s", user.Username, n.Name)
  152. }
  153. userAssoc.Append(user)
  154. if userAssoc.Error != nil {
  155. return fmt.Errorf("association failed: %v", userAssoc.Error)
  156. }
  157. logrus.Infof("user '%s' is associated with the network '%s'", user.GetUsername(), n.Name)
  158. return nil
  159. }
  160. // Dissociate breaks up the given users association to the said network.
  161. func (n *DBNetwork) Dissociate(username string) error {
  162. if !IsInitialized() {
  163. return fmt.Errorf("you first need to create server")
  164. }
  165. user, err := GetUser(username)
  166. if err != nil {
  167. return fmt.Errorf("user can not be fetched: %v", err)
  168. }
  169. var users []DBUser
  170. userAssoc := db.Model(&n).Association("Users")
  171. userAssoc.Find(&users)
  172. var found bool
  173. for _, u := range users {
  174. if u.ID == user.ID {
  175. found = true
  176. break
  177. }
  178. }
  179. if !found {
  180. return fmt.Errorf("user %s is already not associated with the network %s", user.Username, n.Name)
  181. }
  182. userAssoc.Delete(user)
  183. if userAssoc.Error != nil {
  184. return fmt.Errorf("disassociation failed: %v", userAssoc.Error)
  185. }
  186. logrus.Infof("user '%s' is dissociated with the network '%s'", user.GetUsername(), n.Name)
  187. return nil
  188. }
  189. // GetName returns network's name.
  190. func (n *DBNetwork) GetName() string {
  191. return n.Name
  192. }
  193. // GetCIDR returns network's CIDR.
  194. func (n *DBNetwork) GetCIDR() string {
  195. return n.CIDR
  196. }
  197. // GetCreatedAt returns network's name.
  198. func (n *DBNetwork) GetCreatedAt() string {
  199. return n.CreatedAt.Format(time.UnixDate)
  200. }
  201. // GetType returns network's network type.
  202. func (n *DBNetwork) GetType() NetworkType {
  203. return NetworkType(n.Type)
  204. }
  205. // GetAssociatedUsers returns network's associated users.
  206. func (n *DBNetwork) GetAssociatedUsers() []*DBUser {
  207. return n.Users
  208. }
  209. // routedInterface returns a network interface that can route IP
  210. // traffic and satisfies flags. It returns nil when an appropriate
  211. // network interface is not found. Network must be "ip", "ip4" or
  212. // "ip6".
  213. func routedInterface(network string, flags net.Flags) *net.Interface {
  214. switch network {
  215. case "ip", "ip4", "ip6":
  216. default:
  217. return nil
  218. }
  219. ift, err := net.Interfaces()
  220. if err != nil {
  221. return nil
  222. }
  223. for _, ifi := range ift {
  224. if ifi.Flags&flags != flags {
  225. continue
  226. }
  227. if _, ok := hasRoutableIP(network, &ifi); !ok {
  228. continue
  229. }
  230. return &ifi
  231. }
  232. return nil
  233. }
  234. func hasRoutableIP(network string, ifi *net.Interface) (net.IP, bool) {
  235. ifat, err := ifi.Addrs()
  236. if err != nil {
  237. return nil, false
  238. }
  239. for _, ifa := range ifat {
  240. switch ifa := ifa.(type) {
  241. case *net.IPAddr:
  242. if ip := routableIP(network, ifa.IP); ip != nil {
  243. return ip, true
  244. }
  245. case *net.IPNet:
  246. if ip := routableIP(network, ifa.IP); ip != nil {
  247. return ip, true
  248. }
  249. }
  250. }
  251. return nil, false
  252. }
  253. func vpnInterface() *net.Interface {
  254. mask := net.IPMask(net.ParseIP(_DefaultServerNetMask))
  255. prefix := net.ParseIP(_DefaultServerNetwork)
  256. netw := prefix.Mask(mask).To4()
  257. netw[3] = byte(1) // Server is always gets xxx.xxx.xxx.1
  258. ipnet := net.IPNet{IP: netw, Mask: mask}
  259. ifs, err := net.Interfaces()
  260. if err != nil {
  261. logrus.Errorf("can not get system network interfaces: %v", err)
  262. return nil
  263. }
  264. for _, ifc := range ifs {
  265. addrs, err := ifc.Addrs()
  266. if err != nil {
  267. logrus.Errorf("can not get interface addresses: %v", err)
  268. return nil
  269. }
  270. for _, addr := range addrs {
  271. //logrus.Debugf("addr: %s == %s", addr.String(), ipnet.String())
  272. if addr.String() == ipnet.String() {
  273. return &ifc
  274. }
  275. }
  276. }
  277. return nil
  278. }
  279. func routableIP(network string, ip net.IP) net.IP {
  280. if !ip.IsLoopback() && !ip.IsLinkLocalUnicast() && !ip.IsGlobalUnicast() {
  281. return nil
  282. }
  283. switch network {
  284. case "ip4":
  285. if ip := ip.To4(); ip != nil {
  286. return ip
  287. }
  288. case "ip6":
  289. if ip.IsLoopback() { // addressing scope of the loopback address depends on each implementation
  290. return nil
  291. }
  292. if ip := ip.To16(); ip != nil && ip.To4() == nil {
  293. return ip
  294. }
  295. default:
  296. if ip := ip.To4(); ip != nil {
  297. return ip
  298. }
  299. if ip := ip.To16(); ip != nil {
  300. return ip
  301. }
  302. }
  303. return nil
  304. }
  305. // ensureNatEnabled launches a goroutine that constantly tries to enable nat.
  306. func ensureNatEnabled() {
  307. // Nat enablerer
  308. go func() {
  309. for {
  310. err := enableNat()
  311. if err == nil {
  312. logrus.Debug("nat is enabled")
  313. return
  314. }
  315. logrus.Debugf("can not enable nat: %v", err)
  316. // TODO(cad): employ a exponential back-off approach here
  317. // instead of sleeping for the constant duration.
  318. time.Sleep(1 * time.Second)
  319. }
  320. }()
  321. }
  322. // enableNat is an idempotent command that ensures nat is enabled for the vpn server.
  323. func enableNat() error {
  324. rif := routedInterface("ip", net.FlagUp|net.FlagBroadcast)
  325. if rif == nil {
  326. return fmt.Errorf("can not get routable network interface")
  327. }
  328. vpnIfc := vpnInterface()
  329. if vpnIfc == nil {
  330. return fmt.Errorf("can not get vpn network interface on the system")
  331. }
  332. // Enable ip forwarding.
  333. emitToFile("/proc/sys/net/ipv4/ip_forward", "1", 0)
  334. ipt, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
  335. if err != nil {
  336. return fmt.Errorf("can not create new iptables object: %v", err)
  337. }
  338. // Append iptables nat rules.
  339. ipt.AppendUnique("nat", "POSTROUTING", "-o", rif.Name, "-j", "MASQUERADE")
  340. ipt.AppendUnique("filter", "FORWARD", "-i", rif.Name, "-o", vpnIfc.Name, "-m", "state", "--state", "RELATED, ESTABLISHED", "-j", "ACCEPT")
  341. ipt.AppendUnique("filter", "FORWARD", "-i", vpnIfc.Name, "-o", rif.Name, "-j", "ACCEPT")
  342. return nil
  343. }
  344. // HostID2IP converts a host id (32-bit unsigned integer) to an IP address.
  345. func HostID2IP(hostid uint32) net.IP {
  346. ip := make([]byte, 4)
  347. binary.BigEndian.PutUint32(ip, hostid)
  348. return net.IP(ip)
  349. }
  350. //IP2HostID converts an IP address to a host id (32-bit unsigned integer).
  351. func IP2HostID(ip net.IP) uint32 {
  352. hostid := binary.BigEndian.Uint32(ip)
  353. return hostid
  354. }