diff --git a/server.go b/server.go index 985226d6..157f35ee 100644 --- a/server.go +++ b/server.go @@ -116,6 +116,7 @@ type options struct { statsHandler stats.Handler maxConcurrentStreams uint32 useHandlerImpl bool // use http.Handler-based server + unknownStreamDesc *StreamDesc } var defaultMaxMsgSize = 1024 * 1024 * 4 // use 4MB as the default message size limit @@ -208,6 +209,24 @@ func StatsHandler(h stats.Handler) ServerOption { } } +// UnknownServiceHandler returns a ServerOption that allows for adding a custom +// unknown service handler. The provided method is a bidi-streaming RPC service +// handler that will be invoked instead of returning the the "unimplemented" gRPC +// error whenever a request is received for an unregistered service or method. +// The handling function has full access to the Context of the request and the +// stream, and the invocation passes through interceptors. +func UnknownServiceHandler(streamHandler StreamHandler) ServerOption { + return func(o *options) { + o.unknownStreamDesc = &StreamDesc{ + StreamName: "unknown_service_handler", + Handler: streamHandler, + // We need to assume that the users of the streamHandler will want to use both. + ClientStreams: true, + ServerStreams: true, + } + } +} + // NewServer creates a gRPC server which has no service registered and has not // started to accept requests yet. func NewServer(opt ...ServerOption) *Server { @@ -815,15 +834,19 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp }() } var appErr error + var server interface{} + if srv != nil { + server = srv.server + } if s.opts.streamInt == nil { - appErr = sd.Handler(srv.server, ss) + appErr = sd.Handler(server, ss) } else { info := &StreamServerInfo{ FullMethod: stream.Method(), IsClientStream: sd.ClientStreams, IsServerStream: sd.ServerStreams, } - appErr = s.opts.streamInt(srv.server, ss, info, sd.Handler) + appErr = s.opts.streamInt(server, ss, info, sd.Handler) } if appErr != nil { if err, ok := appErr.(*rpcError); ok { @@ -883,6 +906,10 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str method := sm[pos+1:] srv, ok := s.m[service] if !ok { + if unknownDesc := s.opts.unknownStreamDesc; unknownDesc != nil { + s.processStreamingRPC(t, stream, nil, unknownDesc, trInfo) + return + } if trInfo != nil { trInfo.tr.LazyLog(&fmtStringer{"Unknown service %v", []interface{}{service}}, true) trInfo.tr.SetError() @@ -913,6 +940,10 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str trInfo.tr.LazyLog(&fmtStringer{"Unknown method %v", []interface{}{method}}, true) trInfo.tr.SetError() } + if unknownDesc := s.opts.unknownStreamDesc; unknownDesc != nil { + s.processStreamingRPC(t, stream, nil, unknownDesc, trInfo) + return + } errDesc := fmt.Sprintf("unknown method %v", method) if err := t.WriteStatus(stream, codes.Unimplemented, errDesc); err != nil { if trInfo != nil { diff --git a/test/end2end_test.go b/test/end2end_test.go index 4a7311c7..9bcea032 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -428,6 +428,7 @@ type test struct { streamClientInt grpc.StreamClientInterceptor unaryServerInt grpc.UnaryServerInterceptor streamServerInt grpc.StreamServerInterceptor + unknownHandler grpc.StreamHandler sc <-chan grpc.ServiceConfig // srv and srvAddr are set once startServer is called. @@ -493,6 +494,9 @@ func (te *test) startServer(ts testpb.TestServiceServer) { if te.streamServerInt != nil { sopts = append(sopts, grpc.StreamInterceptor(te.streamServerInt)) } + if te.unknownHandler != nil { + sopts = append(sopts, grpc.UnknownServiceHandler(te.unknownHandler)) + } la := "localhost:0" switch te.e.network { case "unix": @@ -1234,6 +1238,33 @@ func testHealthCheckOff(t *testing.T, e env) { } } +func TestUnknownHandler(t *testing.T) { + defer leakCheck(t)() + // An example unknownHandler that returns a different code and a different method, making sure that we do not + // expose what methods are implemented to a client that is not authenticated. + unknownHandler := func(srv interface{}, stream grpc.ServerStream) error { + return grpc.Errorf(codes.Unauthenticated, "user unauthenticated") + } + for _, e := range listTestEnv() { + // TODO(bradfitz): Temporarily skip this env due to #619. + if e.name == "handler-tls" { + continue + } + testUnknownHandler(t, e, unknownHandler) + } +} + +func testUnknownHandler(t *testing.T, e env, unknownHandler grpc.StreamHandler) { + te := newTest(t, e) + te.unknownHandler = unknownHandler + te.startServer(&testServer{security: e.security}) + defer te.tearDown() + want := grpc.Errorf(codes.Unauthenticated, "user unauthenticated") + if _, err := healthCheck(1*time.Second, te.clientConn(), ""); !equalErrors(err, want) { + t.Fatalf("Health/Check(_, _) = _, %v, want _, %v", err, want) + } +} + func TestHealthCheckServingStatus(t *testing.T) { defer leakCheck(t)() for _, e := range listTestEnv() {