main.go 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. //go:generate go-bindata template/
  2. package main
  3. import (
  4. "context"
  5. "fmt"
  6. "log"
  7. "net"
  8. "net/http"
  9. "os"
  10. "os/signal"
  11. "strconv"
  12. "syscall"
  13. "time"
  14. "google.golang.org/grpc"
  15. "github.com/Sirupsen/logrus"
  16. "github.com/cad/ovpm"
  17. "github.com/cad/ovpm/api"
  18. "github.com/grpc-ecosystem/grpc-gateway/runtime"
  19. "github.com/urfave/cli"
  20. )
  21. var action string
  22. var db *ovpm.DB
  23. func main() {
  24. app := cli.NewApp()
  25. app.Name = "ovpmd"
  26. app.Usage = "OpenVPN Manager Daemon"
  27. app.Version = ovpm.Version
  28. app.Flags = []cli.Flag{
  29. cli.BoolFlag{
  30. Name: "verbose",
  31. Usage: "verbose output",
  32. },
  33. cli.StringFlag{
  34. Name: "port",
  35. Usage: "port number for daemon to listen on",
  36. },
  37. }
  38. app.Before = func(c *cli.Context) error {
  39. logrus.SetLevel(logrus.InfoLevel)
  40. if c.GlobalBool("verbose") {
  41. logrus.SetLevel(logrus.DebugLevel)
  42. }
  43. db = ovpm.CreateDB("sqlite3", "")
  44. return nil
  45. }
  46. app.After = func(c *cli.Context) error {
  47. db.Cease()
  48. return nil
  49. }
  50. app.Action = func(c *cli.Context) error {
  51. port := c.String("port")
  52. if port == "" {
  53. port = "9090"
  54. }
  55. s := newServer(port)
  56. s.start()
  57. s.waitForInterrupt()
  58. s.stop()
  59. return nil
  60. }
  61. app.Run(os.Args)
  62. }
  63. type server struct {
  64. grpcPort string
  65. lis net.Listener
  66. grpcServer *grpc.Server
  67. restServer *runtime.ServeMux
  68. restCancel context.CancelFunc
  69. restPort string
  70. signal chan os.Signal
  71. done chan bool
  72. }
  73. func newServer(port string) *server {
  74. sigs := make(chan os.Signal, 1)
  75. done := make(chan bool, 1)
  76. signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
  77. go func() {
  78. sig := <-sigs
  79. fmt.Println()
  80. fmt.Println(sig)
  81. done <- true
  82. }()
  83. if !ovpm.Testing {
  84. lis, err := net.Listen("tcp", fmt.Sprintf(":%s", port))
  85. if err != nil {
  86. logrus.Fatalf("could not listen to port %s: %v", port, err)
  87. }
  88. rpcServer := api.NewRPCServer()
  89. restServer, restCancel, err := api.NewRESTServer()
  90. if err != nil {
  91. logrus.Fatalf("could not get new rest server :%v", err)
  92. }
  93. return &server{
  94. lis: lis,
  95. grpcServer: rpcServer,
  96. restServer: restServer,
  97. restCancel: context.CancelFunc(restCancel),
  98. restPort: increasePort(port),
  99. signal: sigs,
  100. done: done,
  101. grpcPort: port,
  102. }
  103. }
  104. return &server{}
  105. }
  106. func (s *server) start() {
  107. logrus.Infof("OVPM is running gRPC:%s, REST:%s ...", s.grpcPort, s.restPort)
  108. go s.grpcServer.Serve(s.lis)
  109. go http.ListenAndServe(":"+s.restPort, s.restServer)
  110. ovpm.StartVPNProc()
  111. }
  112. func (s *server) stop() {
  113. logrus.Info("OVPM is shutting down ...")
  114. s.grpcServer.Stop()
  115. s.restCancel()
  116. ovpm.StopVPNProc()
  117. }
  118. func (s *server) waitForInterrupt() {
  119. <-s.done
  120. go timeout(8 * time.Second)
  121. }
  122. func timeout(interval time.Duration) {
  123. time.Sleep(interval)
  124. log.Println("Timeout! Killing the main thread...")
  125. os.Exit(-1)
  126. }
  127. func stringInSlice(a string, list []string) bool {
  128. for _, b := range list {
  129. if b == a {
  130. return true
  131. }
  132. }
  133. return false
  134. }
  135. func increasePort(p string) string {
  136. i, err := strconv.Atoi(p)
  137. if err != nil {
  138. logrus.Panicf(fmt.Sprintf("can't convert %s to int: %v", p, err))
  139. }
  140. i++
  141. return fmt.Sprintf("%d", i)
  142. }