add support for user-provided UnknownStreamHandler
This commit is contained in:
35
server.go
35
server.go
@ -116,6 +116,7 @@ type options struct {
|
|||||||
statsHandler stats.Handler
|
statsHandler stats.Handler
|
||||||
maxConcurrentStreams uint32
|
maxConcurrentStreams uint32
|
||||||
useHandlerImpl bool // use http.Handler-based server
|
useHandlerImpl bool // use http.Handler-based server
|
||||||
|
unknownStreamDesc *StreamDesc
|
||||||
}
|
}
|
||||||
|
|
||||||
var defaultMaxMsgSize = 1024 * 1024 * 4 // use 4MB as the default message size limit
|
var defaultMaxMsgSize = 1024 * 1024 * 4 // use 4MB as the default message size limit
|
||||||
@ -208,6 +209,24 @@ func StatsHandler(h stats.Handler) ServerOption {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UnknownServiceHandler returns a ServerOption that allows for adding a custom
|
||||||
|
// unknown service handler. The provided method is a bidi-streaming RPC service
|
||||||
|
// handler that will be invoked instead of returning the the "unimplemented" gRPC
|
||||||
|
// error whenever a request is received for an unregistered service or method.
|
||||||
|
// The handling function has full access to the Context of the request and the
|
||||||
|
// stream, and the invocation passes through interceptors.
|
||||||
|
func UnknownServiceHandler(streamHandler StreamHandler) ServerOption {
|
||||||
|
return func(o *options) {
|
||||||
|
o.unknownStreamDesc = &StreamDesc{
|
||||||
|
StreamName: "unknown_service_handler",
|
||||||
|
Handler: streamHandler,
|
||||||
|
// We need to assume that the users of the streamHandler will want to use both.
|
||||||
|
ClientStreams: true,
|
||||||
|
ServerStreams: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// NewServer creates a gRPC server which has no service registered and has not
|
// NewServer creates a gRPC server which has no service registered and has not
|
||||||
// started to accept requests yet.
|
// started to accept requests yet.
|
||||||
func NewServer(opt ...ServerOption) *Server {
|
func NewServer(opt ...ServerOption) *Server {
|
||||||
@ -815,15 +834,19 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
var appErr error
|
var appErr error
|
||||||
|
var server interface{}
|
||||||
|
if srv != nil {
|
||||||
|
server = srv.server
|
||||||
|
}
|
||||||
if s.opts.streamInt == nil {
|
if s.opts.streamInt == nil {
|
||||||
appErr = sd.Handler(srv.server, ss)
|
appErr = sd.Handler(server, ss)
|
||||||
} else {
|
} else {
|
||||||
info := &StreamServerInfo{
|
info := &StreamServerInfo{
|
||||||
FullMethod: stream.Method(),
|
FullMethod: stream.Method(),
|
||||||
IsClientStream: sd.ClientStreams,
|
IsClientStream: sd.ClientStreams,
|
||||||
IsServerStream: sd.ServerStreams,
|
IsServerStream: sd.ServerStreams,
|
||||||
}
|
}
|
||||||
appErr = s.opts.streamInt(srv.server, ss, info, sd.Handler)
|
appErr = s.opts.streamInt(server, ss, info, sd.Handler)
|
||||||
}
|
}
|
||||||
if appErr != nil {
|
if appErr != nil {
|
||||||
if err, ok := appErr.(*rpcError); ok {
|
if err, ok := appErr.(*rpcError); ok {
|
||||||
@ -883,6 +906,10 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str
|
|||||||
method := sm[pos+1:]
|
method := sm[pos+1:]
|
||||||
srv, ok := s.m[service]
|
srv, ok := s.m[service]
|
||||||
if !ok {
|
if !ok {
|
||||||
|
if unknownDesc := s.opts.unknownStreamDesc; unknownDesc != nil {
|
||||||
|
s.processStreamingRPC(t, stream, nil, unknownDesc, trInfo)
|
||||||
|
return
|
||||||
|
}
|
||||||
if trInfo != nil {
|
if trInfo != nil {
|
||||||
trInfo.tr.LazyLog(&fmtStringer{"Unknown service %v", []interface{}{service}}, true)
|
trInfo.tr.LazyLog(&fmtStringer{"Unknown service %v", []interface{}{service}}, true)
|
||||||
trInfo.tr.SetError()
|
trInfo.tr.SetError()
|
||||||
@ -913,6 +940,10 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str
|
|||||||
trInfo.tr.LazyLog(&fmtStringer{"Unknown method %v", []interface{}{method}}, true)
|
trInfo.tr.LazyLog(&fmtStringer{"Unknown method %v", []interface{}{method}}, true)
|
||||||
trInfo.tr.SetError()
|
trInfo.tr.SetError()
|
||||||
}
|
}
|
||||||
|
if unknownDesc := s.opts.unknownStreamDesc; unknownDesc != nil {
|
||||||
|
s.processStreamingRPC(t, stream, nil, unknownDesc, trInfo)
|
||||||
|
return
|
||||||
|
}
|
||||||
errDesc := fmt.Sprintf("unknown method %v", method)
|
errDesc := fmt.Sprintf("unknown method %v", method)
|
||||||
if err := t.WriteStatus(stream, codes.Unimplemented, errDesc); err != nil {
|
if err := t.WriteStatus(stream, codes.Unimplemented, errDesc); err != nil {
|
||||||
if trInfo != nil {
|
if trInfo != nil {
|
||||||
|
@ -428,6 +428,7 @@ type test struct {
|
|||||||
streamClientInt grpc.StreamClientInterceptor
|
streamClientInt grpc.StreamClientInterceptor
|
||||||
unaryServerInt grpc.UnaryServerInterceptor
|
unaryServerInt grpc.UnaryServerInterceptor
|
||||||
streamServerInt grpc.StreamServerInterceptor
|
streamServerInt grpc.StreamServerInterceptor
|
||||||
|
unknownHandler grpc.StreamHandler
|
||||||
sc <-chan grpc.ServiceConfig
|
sc <-chan grpc.ServiceConfig
|
||||||
|
|
||||||
// srv and srvAddr are set once startServer is called.
|
// srv and srvAddr are set once startServer is called.
|
||||||
@ -493,6 +494,9 @@ func (te *test) startServer(ts testpb.TestServiceServer) {
|
|||||||
if te.streamServerInt != nil {
|
if te.streamServerInt != nil {
|
||||||
sopts = append(sopts, grpc.StreamInterceptor(te.streamServerInt))
|
sopts = append(sopts, grpc.StreamInterceptor(te.streamServerInt))
|
||||||
}
|
}
|
||||||
|
if te.unknownHandler != nil {
|
||||||
|
sopts = append(sopts, grpc.UnknownServiceHandler(te.unknownHandler))
|
||||||
|
}
|
||||||
la := "localhost:0"
|
la := "localhost:0"
|
||||||
switch te.e.network {
|
switch te.e.network {
|
||||||
case "unix":
|
case "unix":
|
||||||
@ -1234,6 +1238,33 @@ func testHealthCheckOff(t *testing.T, e env) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUnknownHandler(t *testing.T) {
|
||||||
|
defer leakCheck(t)()
|
||||||
|
// An example unknownHandler that returns a different code and a different method, making sure that we do not
|
||||||
|
// expose what methods are implemented to a client that is not authenticated.
|
||||||
|
unknownHandler := func(srv interface{}, stream grpc.ServerStream) error {
|
||||||
|
return grpc.Errorf(codes.Unauthenticated, "user unauthenticated")
|
||||||
|
}
|
||||||
|
for _, e := range listTestEnv() {
|
||||||
|
// TODO(bradfitz): Temporarily skip this env due to #619.
|
||||||
|
if e.name == "handler-tls" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
testUnknownHandler(t, e, unknownHandler)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testUnknownHandler(t *testing.T, e env, unknownHandler grpc.StreamHandler) {
|
||||||
|
te := newTest(t, e)
|
||||||
|
te.unknownHandler = unknownHandler
|
||||||
|
te.startServer(&testServer{security: e.security})
|
||||||
|
defer te.tearDown()
|
||||||
|
want := grpc.Errorf(codes.Unauthenticated, "user unauthenticated")
|
||||||
|
if _, err := healthCheck(1*time.Second, te.clientConn(), ""); !equalErrors(err, want) {
|
||||||
|
t.Fatalf("Health/Check(_, _) = _, %v, want _, %v", err, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestHealthCheckServingStatus(t *testing.T) {
|
func TestHealthCheckServingStatus(t *testing.T) {
|
||||||
defer leakCheck(t)()
|
defer leakCheck(t)()
|
||||||
for _, e := range listTestEnv() {
|
for _, e := range listTestEnv() {
|
||||||
|
Reference in New Issue
Block a user