add support for user-provided UnknownStreamHandler
This commit is contained in:
35
server.go
35
server.go
@ -116,6 +116,7 @@ type options struct {
|
||||
statsHandler stats.Handler
|
||||
maxConcurrentStreams uint32
|
||||
useHandlerImpl bool // use http.Handler-based server
|
||||
unknownStreamDesc *StreamDesc
|
||||
}
|
||||
|
||||
var defaultMaxMsgSize = 1024 * 1024 * 4 // use 4MB as the default message size limit
|
||||
@ -208,6 +209,24 @@ func StatsHandler(h stats.Handler) ServerOption {
|
||||
}
|
||||
}
|
||||
|
||||
// UnknownServiceHandler returns a ServerOption that allows for adding a custom
|
||||
// unknown service handler. The provided method is a bidi-streaming RPC service
|
||||
// handler that will be invoked instead of returning the the "unimplemented" gRPC
|
||||
// error whenever a request is received for an unregistered service or method.
|
||||
// The handling function has full access to the Context of the request and the
|
||||
// stream, and the invocation passes through interceptors.
|
||||
func UnknownServiceHandler(streamHandler StreamHandler) ServerOption {
|
||||
return func(o *options) {
|
||||
o.unknownStreamDesc = &StreamDesc{
|
||||
StreamName: "unknown_service_handler",
|
||||
Handler: streamHandler,
|
||||
// We need to assume that the users of the streamHandler will want to use both.
|
||||
ClientStreams: true,
|
||||
ServerStreams: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// NewServer creates a gRPC server which has no service registered and has not
|
||||
// started to accept requests yet.
|
||||
func NewServer(opt ...ServerOption) *Server {
|
||||
@ -815,15 +834,19 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
|
||||
}()
|
||||
}
|
||||
var appErr error
|
||||
var server interface{}
|
||||
if srv != nil {
|
||||
server = srv.server
|
||||
}
|
||||
if s.opts.streamInt == nil {
|
||||
appErr = sd.Handler(srv.server, ss)
|
||||
appErr = sd.Handler(server, ss)
|
||||
} else {
|
||||
info := &StreamServerInfo{
|
||||
FullMethod: stream.Method(),
|
||||
IsClientStream: sd.ClientStreams,
|
||||
IsServerStream: sd.ServerStreams,
|
||||
}
|
||||
appErr = s.opts.streamInt(srv.server, ss, info, sd.Handler)
|
||||
appErr = s.opts.streamInt(server, ss, info, sd.Handler)
|
||||
}
|
||||
if appErr != nil {
|
||||
if err, ok := appErr.(*rpcError); ok {
|
||||
@ -883,6 +906,10 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str
|
||||
method := sm[pos+1:]
|
||||
srv, ok := s.m[service]
|
||||
if !ok {
|
||||
if unknownDesc := s.opts.unknownStreamDesc; unknownDesc != nil {
|
||||
s.processStreamingRPC(t, stream, nil, unknownDesc, trInfo)
|
||||
return
|
||||
}
|
||||
if trInfo != nil {
|
||||
trInfo.tr.LazyLog(&fmtStringer{"Unknown service %v", []interface{}{service}}, true)
|
||||
trInfo.tr.SetError()
|
||||
@ -913,6 +940,10 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str
|
||||
trInfo.tr.LazyLog(&fmtStringer{"Unknown method %v", []interface{}{method}}, true)
|
||||
trInfo.tr.SetError()
|
||||
}
|
||||
if unknownDesc := s.opts.unknownStreamDesc; unknownDesc != nil {
|
||||
s.processStreamingRPC(t, stream, nil, unknownDesc, trInfo)
|
||||
return
|
||||
}
|
||||
errDesc := fmt.Sprintf("unknown method %v", method)
|
||||
if err := t.WriteStatus(stream, codes.Unimplemented, errDesc); err != nil {
|
||||
if trInfo != nil {
|
||||
|
@ -428,6 +428,7 @@ type test struct {
|
||||
streamClientInt grpc.StreamClientInterceptor
|
||||
unaryServerInt grpc.UnaryServerInterceptor
|
||||
streamServerInt grpc.StreamServerInterceptor
|
||||
unknownHandler grpc.StreamHandler
|
||||
sc <-chan grpc.ServiceConfig
|
||||
|
||||
// srv and srvAddr are set once startServer is called.
|
||||
@ -493,6 +494,9 @@ func (te *test) startServer(ts testpb.TestServiceServer) {
|
||||
if te.streamServerInt != nil {
|
||||
sopts = append(sopts, grpc.StreamInterceptor(te.streamServerInt))
|
||||
}
|
||||
if te.unknownHandler != nil {
|
||||
sopts = append(sopts, grpc.UnknownServiceHandler(te.unknownHandler))
|
||||
}
|
||||
la := "localhost:0"
|
||||
switch te.e.network {
|
||||
case "unix":
|
||||
@ -1234,6 +1238,33 @@ func testHealthCheckOff(t *testing.T, e env) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnknownHandler(t *testing.T) {
|
||||
defer leakCheck(t)()
|
||||
// An example unknownHandler that returns a different code and a different method, making sure that we do not
|
||||
// expose what methods are implemented to a client that is not authenticated.
|
||||
unknownHandler := func(srv interface{}, stream grpc.ServerStream) error {
|
||||
return grpc.Errorf(codes.Unauthenticated, "user unauthenticated")
|
||||
}
|
||||
for _, e := range listTestEnv() {
|
||||
// TODO(bradfitz): Temporarily skip this env due to #619.
|
||||
if e.name == "handler-tls" {
|
||||
continue
|
||||
}
|
||||
testUnknownHandler(t, e, unknownHandler)
|
||||
}
|
||||
}
|
||||
|
||||
func testUnknownHandler(t *testing.T, e env, unknownHandler grpc.StreamHandler) {
|
||||
te := newTest(t, e)
|
||||
te.unknownHandler = unknownHandler
|
||||
te.startServer(&testServer{security: e.security})
|
||||
defer te.tearDown()
|
||||
want := grpc.Errorf(codes.Unauthenticated, "user unauthenticated")
|
||||
if _, err := healthCheck(1*time.Second, te.clientConn(), ""); !equalErrors(err, want) {
|
||||
t.Fatalf("Health/Check(_, _) = _, %v, want _, %v", err, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealthCheckServingStatus(t *testing.T) {
|
||||
defer leakCheck(t)()
|
||||
for _, e := range listTestEnv() {
|
||||
|
Reference in New Issue
Block a user