diff --git a/auth/example/main.go b/auth/example/main.go index 4d416b3..251d3b1 100644 --- a/auth/example/main.go +++ b/auth/example/main.go @@ -9,6 +9,9 @@ import ( "github.com/go-gost/plugin/auth/proto" "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" ) var ( @@ -20,6 +23,11 @@ type server struct { } func (s *server) Authenticate(ctx context.Context, in *proto.AuthenticateRequest) (*proto.AuthenticateReply, error) { + token := s.getCredentials(ctx) + if token != "gost" { + return nil, status.Error(codes.Unauthenticated, codes.Unauthenticated.String()) + } + reply := &proto.AuthenticateReply{} if in.GetUsername() == "gost" && in.GetPassword() == "gost" { reply.Ok = true @@ -28,6 +36,14 @@ func (s *server) Authenticate(ctx context.Context, in *proto.AuthenticateRequest return reply, nil } +func (s *server) getCredentials(ctx context.Context) string { + md, ok := metadata.FromIncomingContext(ctx) + if ok && len(md["token"]) > 0 { + return md["token"][0] + } + return "" +} + func main() { flag.Parse() lis, err := net.Listen("tcp", fmt.Sprintf(":%d", *port))