1
0

interceptor.go 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. package api
  2. import (
  3. "fmt"
  4. "net"
  5. "github.com/Sirupsen/logrus"
  6. "github.com/asaskevich/govalidator"
  7. gcontext "golang.org/x/net/context"
  8. "google.golang.org/grpc"
  9. "google.golang.org/grpc/metadata"
  10. )
  11. // AuthUnaryInterceptor is a interceptor function.
  12. //
  13. // See https://godoc.org/google.golang.org/grpc#UnaryServerInterceptor.
  14. func AuthUnaryInterceptor(ctx gcontext.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
  15. var enableAuthCheck bool
  16. md, ok := metadata.FromIncomingContext(ctx)
  17. if !ok {
  18. return nil, fmt.Errorf("Expected 2 metadata items in context; got %v", md)
  19. }
  20. // We enable auth check if we find a non-loopback
  21. // or invalid IP in the headers coming from the grpc-gateway.
  22. for _, userAgentIP := range md["x-forwarded-for"] {
  23. // Check if the remote user IP addr is a proper IP addr.
  24. if !govalidator.IsIP(userAgentIP) {
  25. enableAuthCheck = true
  26. logrus.Debugf("grpc request user agent ip can not be fetched from x-forwarded-for metadata, enabling auth check module '%s'", userAgentIP)
  27. break
  28. }
  29. // Check if the remote user IP addr is a loopback IP addr.
  30. if ip := net.ParseIP(userAgentIP); !ip.IsLoopback() {
  31. enableAuthCheck = true
  32. logrus.Debugf("grpc request user agent ips include non-loopback ip, enabling auth check module '%s'", userAgentIP)
  33. break
  34. }
  35. // TODO(cad): We assume gRPC endpoints are for cli only therefore
  36. // we are listening only on looback IP.
  37. //
  38. // But if we decide use gRPC endpoints publicly, we need to add
  39. // extra checks against gRPC remote peer IP to test if the request
  40. // is coming from a remote peer IP or also from a loopback ip.
  41. }
  42. if !enableAuthCheck {
  43. logrus.Debugf("rpc: auth-check not enabled: %s", md["x-forwarded-for"])
  44. ctx = NewUsernameContext(ctx, "root")
  45. }
  46. if enableAuthCheck {
  47. switch info.FullMethod {
  48. // AuthService methods
  49. case "/pb.AuthService/Status":
  50. return authRequired(ctx, req, handler)
  51. // UserService methods
  52. case "/pb.UserService/List":
  53. return authRequired(ctx, req, handler)
  54. case "/pb.UserService/Create":
  55. return authRequired(ctx, req, handler)
  56. case "/pb.UserService/Update":
  57. return authRequired(ctx, req, handler)
  58. case "/pb.UserService/Delete":
  59. return authRequired(ctx, req, handler)
  60. case "/pb.UserService/Renew":
  61. return authRequired(ctx, req, handler)
  62. case "/pb.UserService/GenConfig":
  63. return authRequired(ctx, req, handler)
  64. // VPNService methods
  65. case "/pb.VPNService/Status":
  66. return authRequired(ctx, req, handler)
  67. case "/pb.VPNService/Init":
  68. return authRequired(ctx, req, handler)
  69. case "/pb.VPNService/Update":
  70. return authRequired(ctx, req, handler)
  71. // NetworkService methods
  72. case "/pb.NetworkService/Create":
  73. return authRequired(ctx, req, handler)
  74. case "/pb.NetworkService/List":
  75. return authRequired(ctx, req, handler)
  76. case "/pb.NetworkService/Delete":
  77. return authRequired(ctx, req, handler)
  78. case "/pb.NetworkService/GetAllTypes":
  79. return authRequired(ctx, req, handler)
  80. case "/pb.NetworkService/GetAssociatedUsers":
  81. return authRequired(ctx, req, handler)
  82. case "/pb.NetworkService/Associate":
  83. return authRequired(ctx, req, handler)
  84. case "/pb.NetworkService/Dissociate":
  85. return authRequired(ctx, req, handler)
  86. default:
  87. logrus.Debugln("rpc: auth is not required for this endpoint: '%s'", info.FullMethod)
  88. }
  89. }
  90. return handler(ctx, req)
  91. }