main.go 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. //go:generate go-bindata template/
  2. package main
  3. import (
  4. "context"
  5. "fmt"
  6. "log"
  7. "net"
  8. "net/http"
  9. "os"
  10. "os/signal"
  11. "strconv"
  12. "syscall"
  13. "time"
  14. "google.golang.org/grpc"
  15. "github.com/Sirupsen/logrus"
  16. "github.com/cad/ovpm"
  17. "github.com/cad/ovpm/api"
  18. "github.com/urfave/cli"
  19. )
  20. var action string
  21. var db *ovpm.DB
  22. func main() {
  23. app := cli.NewApp()
  24. app.Name = "ovpmd"
  25. app.Usage = "OpenVPN Manager Daemon"
  26. app.Version = ovpm.Version
  27. app.Flags = []cli.Flag{
  28. cli.BoolFlag{
  29. Name: "verbose",
  30. Usage: "verbose output",
  31. },
  32. cli.StringFlag{
  33. Name: "port",
  34. Usage: "port number for daemon to listen on",
  35. },
  36. }
  37. app.Before = func(c *cli.Context) error {
  38. logrus.SetLevel(logrus.InfoLevel)
  39. if c.GlobalBool("verbose") {
  40. logrus.SetLevel(logrus.DebugLevel)
  41. }
  42. db = ovpm.CreateDB("sqlite3", "")
  43. return nil
  44. }
  45. app.After = func(c *cli.Context) error {
  46. db.Cease()
  47. return nil
  48. }
  49. app.Action = func(c *cli.Context) error {
  50. port := c.String("port")
  51. if port == "" {
  52. port = "9090"
  53. }
  54. s := newServer(port)
  55. s.start()
  56. s.waitForInterrupt()
  57. s.stop()
  58. return nil
  59. }
  60. app.Run(os.Args)
  61. }
  62. type server struct {
  63. grpcPort string
  64. lis net.Listener
  65. grpcServer *grpc.Server
  66. restServer http.Handler
  67. restCancel context.CancelFunc
  68. restPort string
  69. signal chan os.Signal
  70. done chan bool
  71. }
  72. func newServer(port string) *server {
  73. sigs := make(chan os.Signal, 1)
  74. done := make(chan bool, 1)
  75. signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
  76. go func() {
  77. sig := <-sigs
  78. fmt.Println()
  79. fmt.Println(sig)
  80. done <- true
  81. }()
  82. if !ovpm.Testing {
  83. // NOTE(cad): gRPC endpoint listens on localhost. This is important
  84. // because we don't authanticate requests coming from localhost.
  85. // So gRPC endpoint should never listen on something else then
  86. // localhost.
  87. lis, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%s", port))
  88. if err != nil {
  89. logrus.Fatalf("could not listen to port %s: %v", port, err)
  90. }
  91. rpcServer := api.NewRPCServer()
  92. restServer, restCancel, err := api.NewRESTServer(port)
  93. if err != nil {
  94. logrus.Fatalf("could not get new rest server :%v", err)
  95. }
  96. return &server{
  97. lis: lis,
  98. grpcServer: rpcServer,
  99. restServer: restServer,
  100. restCancel: context.CancelFunc(restCancel),
  101. restPort: increasePort(port),
  102. signal: sigs,
  103. done: done,
  104. grpcPort: port,
  105. }
  106. }
  107. return &server{}
  108. }
  109. func (s *server) start() {
  110. logrus.Infof("OVPM is running gRPC:%s, REST:%s ...", s.grpcPort, s.restPort)
  111. go s.grpcServer.Serve(s.lis)
  112. go http.ListenAndServe(":"+s.restPort, s.restServer)
  113. ovpm.StartVPNProc()
  114. }
  115. func (s *server) stop() {
  116. logrus.Info("OVPM is shutting down ...")
  117. s.grpcServer.Stop()
  118. s.restCancel()
  119. ovpm.StopVPNProc()
  120. }
  121. func (s *server) waitForInterrupt() {
  122. <-s.done
  123. go timeout(8 * time.Second)
  124. }
  125. func timeout(interval time.Duration) {
  126. time.Sleep(interval)
  127. log.Println("Timeout! Killing the main thread...")
  128. os.Exit(-1)
  129. }
  130. func stringInSlice(a string, list []string) bool {
  131. for _, b := range list {
  132. if b == a {
  133. return true
  134. }
  135. }
  136. return false
  137. }
  138. func increasePort(p string) string {
  139. i, err := strconv.Atoi(p)
  140. if err != nil {
  141. logrus.Panicf(fmt.Sprintf("can't convert %s to int: %v", p, err))
  142. }
  143. i++
  144. return fmt.Sprintf("%d", i)
  145. }