diff --git a/server.go b/server.go index 53a2a92c..d3a8073d 100644 --- a/server.go +++ b/server.go @@ -100,6 +100,7 @@ type options struct { cp Compressor dc Decompressor unaryInt UnaryServerInterceptor + streamInt StreamServerInterceptor maxConcurrentStreams uint32 useHandlerImpl bool // use http.Handler-based server } @@ -142,8 +143,8 @@ func Creds(c credentials.Credentials) ServerOption { } // UnaryInterceptor returns a ServerOption that sets the UnaryServerInterceptor for the -// server. Only one interceptor can be installed. The construction of multiple interceptors -// (e.g., chaining) can be implemented at the caller. +// server. Only one unary interceptor can be installed. The construction of multiple +// interceptors (e.g., chaining) can be implemented at the caller. func UnaryInterceptor(i UnaryServerInterceptor) ServerOption { return func(o *options) { if o.unaryInt != nil { @@ -153,6 +154,17 @@ func UnaryInterceptor(i UnaryServerInterceptor) ServerOption { } } +// StreamInterceptor returns a ServerOption that sets the StreamServerInterceptor for the +// server. Only one stream interceptor can be installed. +func StreamInterceptor(i StreamServerInterceptor) ServerOption { + return func(o *options) { + if o.streamInt != nil { + panic("The stream server interceptor has been set.") + } + o.streamInt = i + } +} + // NewServer creates a gRPC server which has no service registered and has not // started to accept requests yet. func NewServer(opt ...ServerOption) *Server { @@ -585,7 +597,18 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp ss.mu.Unlock() }() } - if appErr := sd.Handler(srv.server, ss); appErr != nil { + var appErr error + if s.opts.streamInt == nil { + appErr = sd.Handler(srv.server, ss) + } else { + info := &StreamServerInfo{ + FullMethod: stream.Method(), + IsClientStream: sd.ClientStreams, + IsServerStream: sd.ServerStreams, + } + appErr = s.opts.streamInt(srv.server, ss, info, sd.Handler) + } + if appErr != nil { if err, ok := appErr.(rpcError); ok { ss.statusCode = err.code ss.statusDesc = err.desc diff --git a/test/end2end_test.go b/test/end2end_test.go index 8bd84f49..d3be1eb0 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -421,6 +421,7 @@ type test struct { clientCompression bool serverCompression bool unaryInt grpc.UnaryServerInterceptor + streamInt grpc.StreamServerInterceptor // srv and srvAddr are set once startServer is called. srv *grpc.Server @@ -472,6 +473,9 @@ func (te *test) startServer() { if te.unaryInt != nil { sopts = append(sopts, grpc.UnaryInterceptor(te.unaryInt)) } + if te.streamInt != nil { + sopts = append(sopts, grpc.StreamInterceptor(te.streamInt)) + } la := "localhost:0" switch e.network { case "unix": @@ -1725,7 +1729,62 @@ func testUnaryServerInterceptor(t *testing.T, e env) { tc := testpb.NewTestServiceClient(te.clientConn()) if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); grpc.Code(err) != codes.PermissionDenied { - t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, error code %d", err, codes.PermissionDenied) + t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, error code %d", tc, err, codes.PermissionDenied) + } +} + +func TestStreamServerInterceptor(t *testing.T) { + defer leakCheck(t)() + for _, e := range listTestEnv() { + testStreamServerInterceptor(t, e) + } +} + +func fullDuplexOnly(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + if info.FullMethod == "/grpc.testing.TestService/FullDuplexCall" { + return handler(srv, ss) + } + // Reject the other methods. + return grpc.Errorf(codes.PermissionDenied, "") +} + +func testStreamServerInterceptor(t *testing.T, e env) { + te := newTest(t, e) + te.streamInt = fullDuplexOnly + te.startServer() + defer te.tearDown() + + tc := testpb.NewTestServiceClient(te.clientConn()) + respParam := []*testpb.ResponseParameters{ + { + Size: proto.Int32(int32(1)), + }, + } + payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, int32(1)) + if err != nil { + t.Fatal(err) + } + req := &testpb.StreamingOutputCallRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + ResponseParameters: respParam, + Payload: payload, + } + s1, err := tc.StreamingOutputCall(context.Background(), req) + if err != nil { + t.Fatalf("%v.StreamingOutputCall(_) = _, %v, want _, ", tc, err) + } + if _, err := s1.Recv(); grpc.Code(err) != codes.PermissionDenied { + t.Fatalf("%v.StreamingInputCall(_) = _, %v, want _, error code %d", tc, err, codes.PermissionDenied) + } + s2, err := tc.FullDuplexCall(context.Background()) + if err != nil { + t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) + } + if err := s2.Send(req); err != nil { + t.Fatalf("%v.Send(_) = %v, want ", s2, err) + } + if _, err := s2.Recv(); err != nil { + t.Fatalf("%v.Recv() = _, %v, want _, ", s2, err) } }