Support client side interceptor
This commit is contained in:
9
call.go
9
call.go
@ -112,7 +112,14 @@ func sendRequest(ctx context.Context, codec Codec, compressor Compressor, callHd
|
|||||||
// Invoke sends the RPC request on the wire and returns after response is received.
|
// Invoke sends the RPC request on the wire and returns after response is received.
|
||||||
// Invoke is called by generated code. Also users can call Invoke directly when it
|
// Invoke is called by generated code. Also users can call Invoke directly when it
|
||||||
// is really needed in their use cases.
|
// is really needed in their use cases.
|
||||||
func Invoke(ctx context.Context, method string, args, reply interface{}, cc *ClientConn, opts ...CallOption) (err error) {
|
func Invoke(ctx context.Context, method string, args, reply interface{}, cc *ClientConn, opts ...CallOption) error {
|
||||||
|
if cc.dopts.unaryInt != nil {
|
||||||
|
return cc.dopts.unaryInt(ctx, method, args, reply, cc, invoke, opts...)
|
||||||
|
}
|
||||||
|
return invoke(ctx, method, args, reply, cc, opts...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func invoke(ctx context.Context, method string, args, reply interface{}, cc *ClientConn, opts ...CallOption) (err error) {
|
||||||
c := defaultCallInfo
|
c := defaultCallInfo
|
||||||
for _, o := range opts {
|
for _, o := range opts {
|
||||||
if err := o.before(&c); err != nil {
|
if err := o.before(&c); err != nil {
|
||||||
|
@ -83,6 +83,8 @@ var (
|
|||||||
// dialOptions configure a Dial call. dialOptions are set by the DialOption
|
// dialOptions configure a Dial call. dialOptions are set by the DialOption
|
||||||
// values passed to Dial.
|
// values passed to Dial.
|
||||||
type dialOptions struct {
|
type dialOptions struct {
|
||||||
|
unaryInt UnaryClientInterceptor
|
||||||
|
streamInt StreamClientInterceptor
|
||||||
codec Codec
|
codec Codec
|
||||||
cp Compressor
|
cp Compressor
|
||||||
dc Decompressor
|
dc Decompressor
|
||||||
@ -215,6 +217,20 @@ func WithUserAgent(s string) DialOption {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithUnaryInterceptor returns a DialOption that specifies the interceptor for unary RPCs.
|
||||||
|
func WithUnaryInterceptor(f UnaryClientInterceptor) DialOption {
|
||||||
|
return func(o *dialOptions) {
|
||||||
|
o.unaryInt = f
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithStreamInterceptor returns a DialOption that specifies the interceptor for streaming RPCs.
|
||||||
|
func WithStreamInterceptor(f StreamClientInterceptor) DialOption {
|
||||||
|
return func(o *dialOptions) {
|
||||||
|
o.streamInt = f
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Dial creates a client connection to the given target.
|
// Dial creates a client connection to the given target.
|
||||||
func Dial(target string, opts ...DialOption) (*ClientConn, error) {
|
func Dial(target string, opts ...DialOption) (*ClientConn, error) {
|
||||||
return DialContext(context.Background(), target, opts...)
|
return DialContext(context.Background(), target, opts...)
|
||||||
|
@ -37,6 +37,22 @@ import (
|
|||||||
"golang.org/x/net/context"
|
"golang.org/x/net/context"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// UnaryInvoker is called by UnaryClientInterceptor to complete RPCs.
|
||||||
|
type UnaryInvoker func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, opts ...CallOption) error
|
||||||
|
|
||||||
|
// UnaryClientInterceptor intercepts the execution of a unary RPC on the client. inovker is the handler to complete the RPC
|
||||||
|
// and it is the responsibility of the interceptor to call it.
|
||||||
|
// This is the EXPERIMENTAL API.
|
||||||
|
type UnaryClientInterceptor func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error
|
||||||
|
|
||||||
|
// Streamer is called by StreamClientInterceptor to create a ClientStream.
|
||||||
|
type Streamer func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (ClientStream, error)
|
||||||
|
|
||||||
|
// StreamClientInterceptor intercepts the creation of ClientStream. It may return a custom ClientStream to intercept all I/O
|
||||||
|
// operations. streamer is the handlder to create a ClientStream and it is the responsibility of the interceptor to call it.
|
||||||
|
// This is the EXPERIMENTAL API.
|
||||||
|
type StreamClientInterceptor func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, streamer Streamer, opts ...CallOption) (ClientStream, error)
|
||||||
|
|
||||||
// UnaryServerInfo consists of various information about a unary RPC on
|
// UnaryServerInfo consists of various information about a unary RPC on
|
||||||
// server side. All per-rpc information may be mutated by the interceptor.
|
// server side. All per-rpc information may be mutated by the interceptor.
|
||||||
type UnaryServerInfo struct {
|
type UnaryServerInfo struct {
|
||||||
|
@ -97,7 +97,14 @@ type ClientStream interface {
|
|||||||
|
|
||||||
// NewClientStream creates a new Stream for the client side. This is called
|
// NewClientStream creates a new Stream for the client side. This is called
|
||||||
// by generated code.
|
// by generated code.
|
||||||
func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (_ ClientStream, err error) {
|
func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (ClientStream, error) {
|
||||||
|
if cc.dopts.streamInt != nil {
|
||||||
|
return cc.dopts.streamInt(ctx, desc, cc, method, newClientStream, opts...)
|
||||||
|
}
|
||||||
|
return newClientStream(ctx, desc, cc, method, opts...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (_ ClientStream, err error) {
|
||||||
var (
|
var (
|
||||||
t transport.ClientTransport
|
t transport.ClientTransport
|
||||||
s *transport.Stream
|
s *transport.Stream
|
||||||
|
@ -369,8 +369,10 @@ type test struct {
|
|||||||
userAgent string
|
userAgent string
|
||||||
clientCompression bool
|
clientCompression bool
|
||||||
serverCompression bool
|
serverCompression bool
|
||||||
unaryInt grpc.UnaryServerInterceptor
|
unaryClientInt grpc.UnaryClientInterceptor
|
||||||
streamInt grpc.StreamServerInterceptor
|
streamClientInt grpc.StreamClientInterceptor
|
||||||
|
unaryServerInt grpc.UnaryServerInterceptor
|
||||||
|
streamServerInt grpc.StreamServerInterceptor
|
||||||
|
|
||||||
// srv and srvAddr are set once startServer is called.
|
// srv and srvAddr are set once startServer is called.
|
||||||
srv *grpc.Server
|
srv *grpc.Server
|
||||||
@ -425,11 +427,11 @@ func (te *test) startServer(ts testpb.TestServiceServer) {
|
|||||||
grpc.RPCDecompressor(grpc.NewGZIPDecompressor()),
|
grpc.RPCDecompressor(grpc.NewGZIPDecompressor()),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
if te.unaryInt != nil {
|
if te.unaryServerInt != nil {
|
||||||
sopts = append(sopts, grpc.UnaryInterceptor(te.unaryInt))
|
sopts = append(sopts, grpc.UnaryInterceptor(te.unaryServerInt))
|
||||||
}
|
}
|
||||||
if te.streamInt != nil {
|
if te.streamServerInt != nil {
|
||||||
sopts = append(sopts, grpc.StreamInterceptor(te.streamInt))
|
sopts = append(sopts, grpc.StreamInterceptor(te.streamServerInt))
|
||||||
}
|
}
|
||||||
la := "localhost:0"
|
la := "localhost:0"
|
||||||
switch e.network {
|
switch e.network {
|
||||||
@ -494,6 +496,12 @@ func (te *test) clientConn() *grpc.ClientConn {
|
|||||||
grpc.WithDecompressor(grpc.NewGZIPDecompressor()),
|
grpc.WithDecompressor(grpc.NewGZIPDecompressor()),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
if te.unaryClientInt != nil {
|
||||||
|
opts = append(opts, grpc.WithUnaryInterceptor(te.unaryClientInt))
|
||||||
|
}
|
||||||
|
if te.streamClientInt != nil {
|
||||||
|
opts = append(opts, grpc.WithStreamInterceptor(te.streamClientInt))
|
||||||
|
}
|
||||||
switch te.e.security {
|
switch te.e.security {
|
||||||
case "tls":
|
case "tls":
|
||||||
creds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "x.test.youtube.com")
|
creds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "x.test.youtube.com")
|
||||||
@ -2064,6 +2072,75 @@ func testCompressOK(t *testing.T, e env) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUnaryClientInterceptor(t *testing.T) {
|
||||||
|
defer leakCheck(t)()
|
||||||
|
for _, e := range listTestEnv() {
|
||||||
|
testUnaryClientInterceptor(t, e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func failOkayRPC(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
|
||||||
|
err := invoker(ctx, method, req, reply, cc, opts...)
|
||||||
|
if err == nil {
|
||||||
|
return grpc.Errorf(codes.NotFound, "")
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func testUnaryClientInterceptor(t *testing.T, e env) {
|
||||||
|
te := newTest(t, e)
|
||||||
|
te.userAgent = testAppUA
|
||||||
|
te.unaryClientInt = failOkayRPC
|
||||||
|
te.startServer(&testServer{security: e.security})
|
||||||
|
defer te.tearDown()
|
||||||
|
|
||||||
|
tc := testpb.NewTestServiceClient(te.clientConn())
|
||||||
|
if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); grpc.Code(err) != codes.NotFound {
|
||||||
|
t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, error code %s", tc, err, codes.NotFound)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStreamClientInterceptor(t *testing.T) {
|
||||||
|
defer leakCheck(t)()
|
||||||
|
for _, e := range listTestEnv() {
|
||||||
|
testStreamClientInterceptor(t, e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func failOkayStream(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
|
||||||
|
s, err := streamer(ctx, desc, cc, method, opts...)
|
||||||
|
if err == nil {
|
||||||
|
return nil, grpc.Errorf(codes.NotFound, "")
|
||||||
|
}
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func testStreamClientInterceptor(t *testing.T, e env) {
|
||||||
|
te := newTest(t, e)
|
||||||
|
te.streamClientInt = failOkayStream
|
||||||
|
te.startServer(&testServer{security: e.security})
|
||||||
|
defer te.tearDown()
|
||||||
|
|
||||||
|
tc := testpb.NewTestServiceClient(te.clientConn())
|
||||||
|
respParam := []*testpb.ResponseParameters{
|
||||||
|
{
|
||||||
|
Size: proto.Int32(int32(1)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, int32(1))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
req := &testpb.StreamingOutputCallRequest{
|
||||||
|
ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(),
|
||||||
|
ResponseParameters: respParam,
|
||||||
|
Payload: payload,
|
||||||
|
}
|
||||||
|
if _, err := tc.StreamingOutputCall(context.Background(), req); grpc.Code(err) != codes.NotFound {
|
||||||
|
t.Fatalf("%v.StreamingOutputCall(_) = _, %v, want _, error code %s", tc, err, codes.NotFound)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestUnaryServerInterceptor(t *testing.T) {
|
func TestUnaryServerInterceptor(t *testing.T) {
|
||||||
defer leakCheck(t)()
|
defer leakCheck(t)()
|
||||||
for _, e := range listTestEnv() {
|
for _, e := range listTestEnv() {
|
||||||
@ -2077,7 +2154,7 @@ func errInjector(ctx context.Context, req interface{}, info *grpc.UnaryServerInf
|
|||||||
|
|
||||||
func testUnaryServerInterceptor(t *testing.T, e env) {
|
func testUnaryServerInterceptor(t *testing.T, e env) {
|
||||||
te := newTest(t, e)
|
te := newTest(t, e)
|
||||||
te.unaryInt = errInjector
|
te.unaryServerInt = errInjector
|
||||||
te.startServer(&testServer{security: e.security})
|
te.startServer(&testServer{security: e.security})
|
||||||
defer te.tearDown()
|
defer te.tearDown()
|
||||||
|
|
||||||
@ -2108,7 +2185,7 @@ func fullDuplexOnly(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServ
|
|||||||
|
|
||||||
func testStreamServerInterceptor(t *testing.T, e env) {
|
func testStreamServerInterceptor(t *testing.T, e env) {
|
||||||
te := newTest(t, e)
|
te := newTest(t, e)
|
||||||
te.streamInt = fullDuplexOnly
|
te.streamServerInt = fullDuplexOnly
|
||||||
te.startServer(&testServer{security: e.security})
|
te.startServer(&testServer{security: e.security})
|
||||||
defer te.tearDown()
|
defer te.tearDown()
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user