Support the stream interceptor on server side.
This commit is contained in:
29
server.go
29
server.go
@ -100,6 +100,7 @@ type options struct {
|
|||||||
cp Compressor
|
cp Compressor
|
||||||
dc Decompressor
|
dc Decompressor
|
||||||
unaryInt UnaryServerInterceptor
|
unaryInt UnaryServerInterceptor
|
||||||
|
streamInt StreamServerInterceptor
|
||||||
maxConcurrentStreams uint32
|
maxConcurrentStreams uint32
|
||||||
useHandlerImpl bool // use http.Handler-based server
|
useHandlerImpl bool // use http.Handler-based server
|
||||||
}
|
}
|
||||||
@ -142,8 +143,8 @@ func Creds(c credentials.Credentials) ServerOption {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// UnaryInterceptor returns a ServerOption that sets the UnaryServerInterceptor for the
|
// UnaryInterceptor returns a ServerOption that sets the UnaryServerInterceptor for the
|
||||||
// server. Only one interceptor can be installed. The construction of multiple interceptors
|
// server. Only one unary interceptor can be installed. The construction of multiple
|
||||||
// (e.g., chaining) can be implemented at the caller.
|
// interceptors (e.g., chaining) can be implemented at the caller.
|
||||||
func UnaryInterceptor(i UnaryServerInterceptor) ServerOption {
|
func UnaryInterceptor(i UnaryServerInterceptor) ServerOption {
|
||||||
return func(o *options) {
|
return func(o *options) {
|
||||||
if o.unaryInt != nil {
|
if o.unaryInt != nil {
|
||||||
@ -153,6 +154,17 @@ func UnaryInterceptor(i UnaryServerInterceptor) ServerOption {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// StreamInterceptor returns a ServerOption that sets the StreamServerInterceptor for the
|
||||||
|
// server. Only one stream interceptor can be installed.
|
||||||
|
func StreamInterceptor(i StreamServerInterceptor) ServerOption {
|
||||||
|
return func(o *options) {
|
||||||
|
if o.streamInt != nil {
|
||||||
|
panic("The stream server interceptor has been set.")
|
||||||
|
}
|
||||||
|
o.streamInt = i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// NewServer creates a gRPC server which has no service registered and has not
|
// NewServer creates a gRPC server which has no service registered and has not
|
||||||
// started to accept requests yet.
|
// started to accept requests yet.
|
||||||
func NewServer(opt ...ServerOption) *Server {
|
func NewServer(opt ...ServerOption) *Server {
|
||||||
@ -585,7 +597,18 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
|
|||||||
ss.mu.Unlock()
|
ss.mu.Unlock()
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
if appErr := sd.Handler(srv.server, ss); appErr != nil {
|
var appErr error
|
||||||
|
if s.opts.streamInt == nil {
|
||||||
|
appErr = sd.Handler(srv.server, ss)
|
||||||
|
} else {
|
||||||
|
info := &StreamServerInfo{
|
||||||
|
FullMethod: stream.Method(),
|
||||||
|
IsClientStream: sd.ClientStreams,
|
||||||
|
IsServerStream: sd.ServerStreams,
|
||||||
|
}
|
||||||
|
appErr = s.opts.streamInt(srv.server, ss, info, sd.Handler)
|
||||||
|
}
|
||||||
|
if appErr != nil {
|
||||||
if err, ok := appErr.(rpcError); ok {
|
if err, ok := appErr.(rpcError); ok {
|
||||||
ss.statusCode = err.code
|
ss.statusCode = err.code
|
||||||
ss.statusDesc = err.desc
|
ss.statusDesc = err.desc
|
||||||
|
@ -421,6 +421,7 @@ type test struct {
|
|||||||
clientCompression bool
|
clientCompression bool
|
||||||
serverCompression bool
|
serverCompression bool
|
||||||
unaryInt grpc.UnaryServerInterceptor
|
unaryInt grpc.UnaryServerInterceptor
|
||||||
|
streamInt grpc.StreamServerInterceptor
|
||||||
|
|
||||||
// srv and srvAddr are set once startServer is called.
|
// srv and srvAddr are set once startServer is called.
|
||||||
srv *grpc.Server
|
srv *grpc.Server
|
||||||
@ -472,6 +473,9 @@ func (te *test) startServer() {
|
|||||||
if te.unaryInt != nil {
|
if te.unaryInt != nil {
|
||||||
sopts = append(sopts, grpc.UnaryInterceptor(te.unaryInt))
|
sopts = append(sopts, grpc.UnaryInterceptor(te.unaryInt))
|
||||||
}
|
}
|
||||||
|
if te.streamInt != nil {
|
||||||
|
sopts = append(sopts, grpc.StreamInterceptor(te.streamInt))
|
||||||
|
}
|
||||||
la := "localhost:0"
|
la := "localhost:0"
|
||||||
switch e.network {
|
switch e.network {
|
||||||
case "unix":
|
case "unix":
|
||||||
@ -1725,7 +1729,62 @@ func testUnaryServerInterceptor(t *testing.T, e env) {
|
|||||||
|
|
||||||
tc := testpb.NewTestServiceClient(te.clientConn())
|
tc := testpb.NewTestServiceClient(te.clientConn())
|
||||||
if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); grpc.Code(err) != codes.PermissionDenied {
|
if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); grpc.Code(err) != codes.PermissionDenied {
|
||||||
t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, error code %d", err, codes.PermissionDenied)
|
t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, error code %d", tc, err, codes.PermissionDenied)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStreamServerInterceptor(t *testing.T) {
|
||||||
|
defer leakCheck(t)()
|
||||||
|
for _, e := range listTestEnv() {
|
||||||
|
testStreamServerInterceptor(t, e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func fullDuplexOnly(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
||||||
|
if info.FullMethod == "/grpc.testing.TestService/FullDuplexCall" {
|
||||||
|
return handler(srv, ss)
|
||||||
|
}
|
||||||
|
// Reject the other methods.
|
||||||
|
return grpc.Errorf(codes.PermissionDenied, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
func testStreamServerInterceptor(t *testing.T, e env) {
|
||||||
|
te := newTest(t, e)
|
||||||
|
te.streamInt = fullDuplexOnly
|
||||||
|
te.startServer()
|
||||||
|
defer te.tearDown()
|
||||||
|
|
||||||
|
tc := testpb.NewTestServiceClient(te.clientConn())
|
||||||
|
respParam := []*testpb.ResponseParameters{
|
||||||
|
{
|
||||||
|
Size: proto.Int32(int32(1)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, int32(1))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
req := &testpb.StreamingOutputCallRequest{
|
||||||
|
ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(),
|
||||||
|
ResponseParameters: respParam,
|
||||||
|
Payload: payload,
|
||||||
|
}
|
||||||
|
s1, err := tc.StreamingOutputCall(context.Background(), req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("%v.StreamingOutputCall(_) = _, %v, want _, <nil>", tc, err)
|
||||||
|
}
|
||||||
|
if _, err := s1.Recv(); grpc.Code(err) != codes.PermissionDenied {
|
||||||
|
t.Fatalf("%v.StreamingInputCall(_) = _, %v, want _, error code %d", tc, err, codes.PermissionDenied)
|
||||||
|
}
|
||||||
|
s2, err := tc.FullDuplexCall(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
|
||||||
|
}
|
||||||
|
if err := s2.Send(req); err != nil {
|
||||||
|
t.Fatalf("%v.Send(_) = %v, want <nil>", s2, err)
|
||||||
|
}
|
||||||
|
if _, err := s2.Recv(); err != nil {
|
||||||
|
t.Fatalf("%v.Recv() = _, %v, want _, <nil>", s2, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user