main.go 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. //go:generate go-bindata template/
  2. package main
  3. import (
  4. "context"
  5. "fmt"
  6. "log"
  7. "net"
  8. "net/http"
  9. "os"
  10. "os/signal"
  11. "strconv"
  12. "syscall"
  13. "time"
  14. "google.golang.org/grpc"
  15. "github.com/Sirupsen/logrus"
  16. "github.com/cad/ovpm"
  17. "github.com/cad/ovpm/api"
  18. "github.com/urfave/cli"
  19. )
  20. var action string
  21. var db *ovpm.DB
  22. func main() {
  23. app := cli.NewApp()
  24. app.Name = "ovpmd"
  25. app.Usage = "OpenVPN Manager Daemon"
  26. app.Version = ovpm.Version
  27. app.Flags = []cli.Flag{
  28. cli.BoolFlag{
  29. Name: "verbose",
  30. Usage: "verbose output",
  31. },
  32. cli.StringFlag{
  33. Name: "port",
  34. Usage: "port number for gRPC API daemon",
  35. },
  36. cli.StringFlag{
  37. Name: "web-port",
  38. Usage: "port number for the REST API daemon",
  39. },
  40. }
  41. app.Before = func(c *cli.Context) error {
  42. logrus.SetLevel(logrus.InfoLevel)
  43. if c.GlobalBool("verbose") {
  44. logrus.SetLevel(logrus.DebugLevel)
  45. }
  46. db = ovpm.CreateDB("sqlite3", "")
  47. return nil
  48. }
  49. app.After = func(c *cli.Context) error {
  50. db.Cease()
  51. return nil
  52. }
  53. app.Action = func(c *cli.Context) error {
  54. port := c.String("port")
  55. if port == "" {
  56. port = "9090"
  57. }
  58. webPort := c.String("web-port")
  59. if webPort == "" {
  60. webPort = "8080"
  61. }
  62. s := newServer(port, webPort)
  63. s.start()
  64. s.waitForInterrupt()
  65. s.stop()
  66. return nil
  67. }
  68. app.Run(os.Args)
  69. }
  70. type server struct {
  71. grpcPort string
  72. lis net.Listener
  73. grpcServer *grpc.Server
  74. restServer http.Handler
  75. restCancel context.CancelFunc
  76. restPort string
  77. signal chan os.Signal
  78. done chan bool
  79. }
  80. func newServer(port, webPort string) *server {
  81. sigs := make(chan os.Signal, 1)
  82. done := make(chan bool, 1)
  83. signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
  84. go func() {
  85. sig := <-sigs
  86. fmt.Println()
  87. fmt.Println(sig)
  88. done <- true
  89. }()
  90. if !ovpm.Testing {
  91. // NOTE(cad): gRPC endpoint listens on localhost. This is important
  92. // because we don't authanticate requests coming from localhost.
  93. // So gRPC endpoint should never listen on something else then
  94. // localhost.
  95. lis, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%s", port))
  96. if err != nil {
  97. logrus.Fatalf("could not listen to port %s: %v", port, err)
  98. }
  99. rpcServer := api.NewRPCServer()
  100. restServer, restCancel, err := api.NewRESTServer(port)
  101. if err != nil {
  102. logrus.Fatalf("could not get new rest server :%v", err)
  103. }
  104. return &server{
  105. lis: lis,
  106. grpcServer: rpcServer,
  107. restServer: restServer,
  108. restCancel: context.CancelFunc(restCancel),
  109. restPort: webPort,
  110. signal: sigs,
  111. done: done,
  112. grpcPort: port,
  113. }
  114. }
  115. return &server{}
  116. }
  117. func (s *server) start() {
  118. logrus.Infof("OVPM %s is running gRPC:%s, REST:%s ...", ovpm.Version, s.grpcPort, s.restPort)
  119. go s.grpcServer.Serve(s.lis)
  120. go http.ListenAndServe(":"+s.restPort, s.restServer)
  121. ovpm.TheServer().StartVPNProc()
  122. }
  123. func (s *server) stop() {
  124. logrus.Info("OVPM is shutting down ...")
  125. s.grpcServer.Stop()
  126. s.restCancel()
  127. ovpm.TheServer().StopVPNProc()
  128. }
  129. func (s *server) waitForInterrupt() {
  130. <-s.done
  131. go timeout(8 * time.Second)
  132. }
  133. func timeout(interval time.Duration) {
  134. time.Sleep(interval)
  135. log.Println("Timeout! Killing the main thread...")
  136. os.Exit(-1)
  137. }
  138. func stringInSlice(a string, list []string) bool {
  139. for _, b := range list {
  140. if b == a {
  141. return true
  142. }
  143. }
  144. return false
  145. }
  146. func increasePort(p string) string {
  147. i, err := strconv.Atoi(p)
  148. if err != nil {
  149. logrus.Panicf(fmt.Sprintf("can't convert %s to int: %v", p, err))
  150. }
  151. i++
  152. return fmt.Sprintf("%d", i)
  153. }