Browse Source

Merge branch 'feat/handle-sigint-10' into dev

closes #10
Mustafa Arici 8 years ago
parent
commit
8084483cd6
2 changed files with 74 additions and 10 deletions
  1. 62 10
      cmd/ovpmd/main.go
  2. 12 0
      vpn.go

+ 62 - 10
cmd/ovpmd/main.go

@@ -4,8 +4,12 @@ package main
 
 import (
 	"fmt"
+	"log"
 	"net"
 	"os"
+	"os/signal"
+	"syscall"
+	"time"
 
 	"google.golang.org/grpc"
 
@@ -50,21 +54,69 @@ func main() {
 		if port == "" {
 			port = "9090"
 		}
-		lis, err := net.Listen("tcp", fmt.Sprintf(":%s", port))
-		if err != nil {
-			logrus.Fatalf("could not listen to port %s: %v", port, err)
-		}
-		s := grpc.NewServer()
-		pb.RegisterUserServiceServer(s, &api.UserService{})
-		pb.RegisterVPNServiceServer(s, &api.VPNService{})
-		logrus.Infof("OVPM is running :%s ...", port)
-		ovpm.RestartVPNProc()
-		s.Serve(lis)
+		s := newServer(port)
+		s.start()
+		s.waitForInterrupt()
+		s.stop()
 		return nil
 	}
 	app.Run(os.Args)
 }
 
+type server struct {
+	port       string
+	lis        net.Listener
+	grpcServer *grpc.Server
+	signal     chan os.Signal
+	done       chan bool
+}
+
+func newServer(port string) *server {
+	sigs := make(chan os.Signal, 1)
+	done := make(chan bool, 1)
+
+	signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
+
+	go func() {
+		sig := <-sigs
+		fmt.Println()
+		fmt.Println(sig)
+		done <- true
+	}()
+
+	lis, err := net.Listen("tcp", fmt.Sprintf(":%s", port))
+	if err != nil {
+		logrus.Fatalf("could not listen to port %s: %v", port, err)
+	}
+	s := grpc.NewServer()
+	pb.RegisterUserServiceServer(s, &api.UserService{})
+	pb.RegisterVPNServiceServer(s, &api.VPNService{})
+	return &server{lis: lis, grpcServer: s, signal: sigs, done: done, port: port}
+}
+
+func (s *server) start() {
+	logrus.Infof("OVPM is running :%s ...", s.port)
+	go s.grpcServer.Serve(s.lis)
+	ovpm.RestartVPNProc()
+}
+
+func (s *server) stop() {
+	logrus.Info("OVPM is shutting down ...")
+	s.grpcServer.Stop()
+	ovpm.StopVPNProc()
+}
+
+func (s *server) waitForInterrupt() {
+	<-s.done
+	go timeout(8 * time.Second)
+}
+
+func timeout(interval time.Duration) {
+	time.Sleep(interval)
+	log.Println("Timeout! Killing the main thread...")
+	os.Exit(-1)
+}
+
 func stringInSlice(a string, list []string) bool {
 	for _, b := range list {
 		if b == a {

+ 12 - 0
vpn.go

@@ -224,6 +224,18 @@ func RestartVPNProc() {
 	vpnProc.Restart()
 }
 
+// StopVPNProc stops the OpenVPN process.
+func StopVPNProc() {
+	if !vpnProc.IsRunning() {
+		logrus.Error("OpenVPN is already stopped")
+		return
+	}
+	if vpnProc == nil {
+		panic(fmt.Sprintf("vpnProc is not initialized!"))
+	}
+	vpnProc.Stop()
+}
+
 // Emit generates all needed files for the OpenVPN server and dumps them to their corresponding paths defined in the config.
 func Emit() error {
 	// Check dependencies