diff --git a/call.go b/call.go index fea07998..24af8ab0 100644 --- a/call.go +++ b/call.go @@ -112,7 +112,14 @@ func sendRequest(ctx context.Context, codec Codec, compressor Compressor, callHd // Invoke sends the RPC request on the wire and returns after response is received. // Invoke is called by generated code. Also users can call Invoke directly when it // is really needed in their use cases. -func Invoke(ctx context.Context, method string, args, reply interface{}, cc *ClientConn, opts ...CallOption) (err error) { +func Invoke(ctx context.Context, method string, args, reply interface{}, cc *ClientConn, opts ...CallOption) error { + if cc.dopts.unaryInt != nil { + return cc.dopts.unaryInt(ctx, method, args, reply, cc, invoke, opts...) + } + return invoke(ctx, method, args, reply, cc, opts...) +} + +func invoke(ctx context.Context, method string, args, reply interface{}, cc *ClientConn, opts ...CallOption) (err error) { c := defaultCallInfo for _, o := range opts { if err := o.before(&c); err != nil { diff --git a/clientconn.go b/clientconn.go index ccaff94a..0afcbb31 100644 --- a/clientconn.go +++ b/clientconn.go @@ -83,15 +83,17 @@ var ( // dialOptions configure a Dial call. dialOptions are set by the DialOption // values passed to Dial. type dialOptions struct { - codec Codec - cp Compressor - dc Decompressor - bs backoffStrategy - balancer Balancer - block bool - insecure bool - timeout time.Duration - copts transport.ConnectOptions + unaryInt UnaryClientInterceptor + streamInt StreamClientInterceptor + codec Codec + cp Compressor + dc Decompressor + bs backoffStrategy + balancer Balancer + block bool + insecure bool + timeout time.Duration + copts transport.ConnectOptions } // DialOption configures how we set up the connection. @@ -215,6 +217,20 @@ func WithUserAgent(s string) DialOption { } } +// WithUnaryInterceptor returns a DialOption that specifies the interceptor for unary RPCs. +func WithUnaryInterceptor(f UnaryClientInterceptor) DialOption { + return func(o *dialOptions) { + o.unaryInt = f + } +} + +// WithStreamInterceptor returns a DialOption that specifies the interceptor for streaming RPCs. +func WithStreamInterceptor(f StreamClientInterceptor) DialOption { + return func(o *dialOptions) { + o.streamInt = f + } +} + // Dial creates a client connection to the given target. func Dial(target string, opts ...DialOption) (*ClientConn, error) { return DialContext(context.Background(), target, opts...) diff --git a/interceptor.go b/interceptor.go index 588f59e5..8d932efe 100644 --- a/interceptor.go +++ b/interceptor.go @@ -37,6 +37,22 @@ import ( "golang.org/x/net/context" ) +// UnaryInvoker is called by UnaryClientInterceptor to complete RPCs. +type UnaryInvoker func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, opts ...CallOption) error + +// UnaryClientInterceptor intercepts the execution of a unary RPC on the client. inovker is the handler to complete the RPC +// and it is the responsibility of the interceptor to call it. +// This is the EXPERIMENTAL API. +type UnaryClientInterceptor func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error + +// Streamer is called by StreamClientInterceptor to create a ClientStream. +type Streamer func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (ClientStream, error) + +// StreamClientInterceptor intercepts the creation of ClientStream. It may return a custom ClientStream to intercept all I/O +// operations. streamer is the handlder to create a ClientStream and it is the responsibility of the interceptor to call it. +// This is the EXPERIMENTAL API. +type StreamClientInterceptor func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, streamer Streamer, opts ...CallOption) (ClientStream, error) + // UnaryServerInfo consists of various information about a unary RPC on // server side. All per-rpc information may be mutated by the interceptor. type UnaryServerInfo struct { diff --git a/stream.go b/stream.go index 51df3f01..89d6a198 100644 --- a/stream.go +++ b/stream.go @@ -97,7 +97,14 @@ type ClientStream interface { // NewClientStream creates a new Stream for the client side. This is called // by generated code. -func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (_ ClientStream, err error) { +func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (ClientStream, error) { + if cc.dopts.streamInt != nil { + return cc.dopts.streamInt(ctx, desc, cc, method, newClientStream, opts...) + } + return newClientStream(ctx, desc, cc, method, opts...) +} + +func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (_ ClientStream, err error) { var ( t transport.ClientTransport s *transport.Stream diff --git a/test/end2end_test.go b/test/end2end_test.go index e30830a6..645be6b2 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -369,8 +369,10 @@ type test struct { userAgent string clientCompression bool serverCompression bool - unaryInt grpc.UnaryServerInterceptor - streamInt grpc.StreamServerInterceptor + unaryClientInt grpc.UnaryClientInterceptor + streamClientInt grpc.StreamClientInterceptor + unaryServerInt grpc.UnaryServerInterceptor + streamServerInt grpc.StreamServerInterceptor // srv and srvAddr are set once startServer is called. srv *grpc.Server @@ -425,11 +427,11 @@ func (te *test) startServer(ts testpb.TestServiceServer) { grpc.RPCDecompressor(grpc.NewGZIPDecompressor()), ) } - if te.unaryInt != nil { - sopts = append(sopts, grpc.UnaryInterceptor(te.unaryInt)) + if te.unaryServerInt != nil { + sopts = append(sopts, grpc.UnaryInterceptor(te.unaryServerInt)) } - if te.streamInt != nil { - sopts = append(sopts, grpc.StreamInterceptor(te.streamInt)) + if te.streamServerInt != nil { + sopts = append(sopts, grpc.StreamInterceptor(te.streamServerInt)) } la := "localhost:0" switch e.network { @@ -494,6 +496,12 @@ func (te *test) clientConn() *grpc.ClientConn { grpc.WithDecompressor(grpc.NewGZIPDecompressor()), ) } + if te.unaryClientInt != nil { + opts = append(opts, grpc.WithUnaryInterceptor(te.unaryClientInt)) + } + if te.streamClientInt != nil { + opts = append(opts, grpc.WithStreamInterceptor(te.streamClientInt)) + } switch te.e.security { case "tls": creds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "x.test.youtube.com") @@ -2064,6 +2072,75 @@ func testCompressOK(t *testing.T, e env) { } } +func TestUnaryClientInterceptor(t *testing.T) { + defer leakCheck(t)() + for _, e := range listTestEnv() { + testUnaryClientInterceptor(t, e) + } +} + +func failOkayRPC(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + err := invoker(ctx, method, req, reply, cc, opts...) + if err == nil { + return grpc.Errorf(codes.NotFound, "") + } + return err +} + +func testUnaryClientInterceptor(t *testing.T, e env) { + te := newTest(t, e) + te.userAgent = testAppUA + te.unaryClientInt = failOkayRPC + te.startServer(&testServer{security: e.security}) + defer te.tearDown() + + tc := testpb.NewTestServiceClient(te.clientConn()) + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); grpc.Code(err) != codes.NotFound { + t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, error code %s", tc, err, codes.NotFound) + } +} + +func TestStreamClientInterceptor(t *testing.T) { + defer leakCheck(t)() + for _, e := range listTestEnv() { + testStreamClientInterceptor(t, e) + } +} + +func failOkayStream(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + s, err := streamer(ctx, desc, cc, method, opts...) + if err == nil { + return nil, grpc.Errorf(codes.NotFound, "") + } + return s, nil +} + +func testStreamClientInterceptor(t *testing.T, e env) { + te := newTest(t, e) + te.streamClientInt = failOkayStream + te.startServer(&testServer{security: e.security}) + defer te.tearDown() + + tc := testpb.NewTestServiceClient(te.clientConn()) + respParam := []*testpb.ResponseParameters{ + { + Size: proto.Int32(int32(1)), + }, + } + payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, int32(1)) + if err != nil { + t.Fatal(err) + } + req := &testpb.StreamingOutputCallRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + ResponseParameters: respParam, + Payload: payload, + } + if _, err := tc.StreamingOutputCall(context.Background(), req); grpc.Code(err) != codes.NotFound { + t.Fatalf("%v.StreamingOutputCall(_) = _, %v, want _, error code %s", tc, err, codes.NotFound) + } +} + func TestUnaryServerInterceptor(t *testing.T) { defer leakCheck(t)() for _, e := range listTestEnv() { @@ -2077,7 +2154,7 @@ func errInjector(ctx context.Context, req interface{}, info *grpc.UnaryServerInf func testUnaryServerInterceptor(t *testing.T, e env) { te := newTest(t, e) - te.unaryInt = errInjector + te.unaryServerInt = errInjector te.startServer(&testServer{security: e.security}) defer te.tearDown() @@ -2108,7 +2185,7 @@ func fullDuplexOnly(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServ func testStreamServerInterceptor(t *testing.T, e env) { te := newTest(t, e) - te.streamInt = fullDuplexOnly + te.streamServerInt = fullDuplexOnly te.startServer(&testServer{security: e.security}) defer te.tearDown()