Support client side interceptor
This commit is contained in:
9
call.go
9
call.go
@ -112,7 +112,14 @@ func sendRequest(ctx context.Context, codec Codec, compressor Compressor, callHd
|
||||
// Invoke sends the RPC request on the wire and returns after response is received.
|
||||
// Invoke is called by generated code. Also users can call Invoke directly when it
|
||||
// is really needed in their use cases.
|
||||
func Invoke(ctx context.Context, method string, args, reply interface{}, cc *ClientConn, opts ...CallOption) (err error) {
|
||||
func Invoke(ctx context.Context, method string, args, reply interface{}, cc *ClientConn, opts ...CallOption) error {
|
||||
if cc.dopts.unaryInt != nil {
|
||||
return cc.dopts.unaryInt(ctx, method, args, reply, cc, invoke, opts...)
|
||||
}
|
||||
return invoke(ctx, method, args, reply, cc, opts...)
|
||||
}
|
||||
|
||||
func invoke(ctx context.Context, method string, args, reply interface{}, cc *ClientConn, opts ...CallOption) (err error) {
|
||||
c := defaultCallInfo
|
||||
for _, o := range opts {
|
||||
if err := o.before(&c); err != nil {
|
||||
|
@ -83,15 +83,17 @@ var (
|
||||
// dialOptions configure a Dial call. dialOptions are set by the DialOption
|
||||
// values passed to Dial.
|
||||
type dialOptions struct {
|
||||
codec Codec
|
||||
cp Compressor
|
||||
dc Decompressor
|
||||
bs backoffStrategy
|
||||
balancer Balancer
|
||||
block bool
|
||||
insecure bool
|
||||
timeout time.Duration
|
||||
copts transport.ConnectOptions
|
||||
unaryInt UnaryClientInterceptor
|
||||
streamInt StreamClientInterceptor
|
||||
codec Codec
|
||||
cp Compressor
|
||||
dc Decompressor
|
||||
bs backoffStrategy
|
||||
balancer Balancer
|
||||
block bool
|
||||
insecure bool
|
||||
timeout time.Duration
|
||||
copts transport.ConnectOptions
|
||||
}
|
||||
|
||||
// DialOption configures how we set up the connection.
|
||||
@ -215,6 +217,20 @@ func WithUserAgent(s string) DialOption {
|
||||
}
|
||||
}
|
||||
|
||||
// WithUnaryInterceptor returns a DialOption that specifies the interceptor for unary RPCs.
|
||||
func WithUnaryInterceptor(f UnaryClientInterceptor) DialOption {
|
||||
return func(o *dialOptions) {
|
||||
o.unaryInt = f
|
||||
}
|
||||
}
|
||||
|
||||
// WithStreamInterceptor returns a DialOption that specifies the interceptor for streaming RPCs.
|
||||
func WithStreamInterceptor(f StreamClientInterceptor) DialOption {
|
||||
return func(o *dialOptions) {
|
||||
o.streamInt = f
|
||||
}
|
||||
}
|
||||
|
||||
// Dial creates a client connection to the given target.
|
||||
func Dial(target string, opts ...DialOption) (*ClientConn, error) {
|
||||
return DialContext(context.Background(), target, opts...)
|
||||
|
@ -37,6 +37,22 @@ import (
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
// UnaryInvoker is called by UnaryClientInterceptor to complete RPCs.
|
||||
type UnaryInvoker func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, opts ...CallOption) error
|
||||
|
||||
// UnaryClientInterceptor intercepts the execution of a unary RPC on the client. inovker is the handler to complete the RPC
|
||||
// and it is the responsibility of the interceptor to call it.
|
||||
// This is the EXPERIMENTAL API.
|
||||
type UnaryClientInterceptor func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error
|
||||
|
||||
// Streamer is called by StreamClientInterceptor to create a ClientStream.
|
||||
type Streamer func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (ClientStream, error)
|
||||
|
||||
// StreamClientInterceptor intercepts the creation of ClientStream. It may return a custom ClientStream to intercept all I/O
|
||||
// operations. streamer is the handlder to create a ClientStream and it is the responsibility of the interceptor to call it.
|
||||
// This is the EXPERIMENTAL API.
|
||||
type StreamClientInterceptor func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, streamer Streamer, opts ...CallOption) (ClientStream, error)
|
||||
|
||||
// UnaryServerInfo consists of various information about a unary RPC on
|
||||
// server side. All per-rpc information may be mutated by the interceptor.
|
||||
type UnaryServerInfo struct {
|
||||
|
@ -97,7 +97,14 @@ type ClientStream interface {
|
||||
|
||||
// NewClientStream creates a new Stream for the client side. This is called
|
||||
// by generated code.
|
||||
func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (_ ClientStream, err error) {
|
||||
func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (ClientStream, error) {
|
||||
if cc.dopts.streamInt != nil {
|
||||
return cc.dopts.streamInt(ctx, desc, cc, method, newClientStream, opts...)
|
||||
}
|
||||
return newClientStream(ctx, desc, cc, method, opts...)
|
||||
}
|
||||
|
||||
func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (_ ClientStream, err error) {
|
||||
var (
|
||||
t transport.ClientTransport
|
||||
s *transport.Stream
|
||||
|
@ -369,8 +369,10 @@ type test struct {
|
||||
userAgent string
|
||||
clientCompression bool
|
||||
serverCompression bool
|
||||
unaryInt grpc.UnaryServerInterceptor
|
||||
streamInt grpc.StreamServerInterceptor
|
||||
unaryClientInt grpc.UnaryClientInterceptor
|
||||
streamClientInt grpc.StreamClientInterceptor
|
||||
unaryServerInt grpc.UnaryServerInterceptor
|
||||
streamServerInt grpc.StreamServerInterceptor
|
||||
|
||||
// srv and srvAddr are set once startServer is called.
|
||||
srv *grpc.Server
|
||||
@ -425,11 +427,11 @@ func (te *test) startServer(ts testpb.TestServiceServer) {
|
||||
grpc.RPCDecompressor(grpc.NewGZIPDecompressor()),
|
||||
)
|
||||
}
|
||||
if te.unaryInt != nil {
|
||||
sopts = append(sopts, grpc.UnaryInterceptor(te.unaryInt))
|
||||
if te.unaryServerInt != nil {
|
||||
sopts = append(sopts, grpc.UnaryInterceptor(te.unaryServerInt))
|
||||
}
|
||||
if te.streamInt != nil {
|
||||
sopts = append(sopts, grpc.StreamInterceptor(te.streamInt))
|
||||
if te.streamServerInt != nil {
|
||||
sopts = append(sopts, grpc.StreamInterceptor(te.streamServerInt))
|
||||
}
|
||||
la := "localhost:0"
|
||||
switch e.network {
|
||||
@ -494,6 +496,12 @@ func (te *test) clientConn() *grpc.ClientConn {
|
||||
grpc.WithDecompressor(grpc.NewGZIPDecompressor()),
|
||||
)
|
||||
}
|
||||
if te.unaryClientInt != nil {
|
||||
opts = append(opts, grpc.WithUnaryInterceptor(te.unaryClientInt))
|
||||
}
|
||||
if te.streamClientInt != nil {
|
||||
opts = append(opts, grpc.WithStreamInterceptor(te.streamClientInt))
|
||||
}
|
||||
switch te.e.security {
|
||||
case "tls":
|
||||
creds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "x.test.youtube.com")
|
||||
@ -2064,6 +2072,75 @@ func testCompressOK(t *testing.T, e env) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnaryClientInterceptor(t *testing.T) {
|
||||
defer leakCheck(t)()
|
||||
for _, e := range listTestEnv() {
|
||||
testUnaryClientInterceptor(t, e)
|
||||
}
|
||||
}
|
||||
|
||||
func failOkayRPC(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
|
||||
err := invoker(ctx, method, req, reply, cc, opts...)
|
||||
if err == nil {
|
||||
return grpc.Errorf(codes.NotFound, "")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func testUnaryClientInterceptor(t *testing.T, e env) {
|
||||
te := newTest(t, e)
|
||||
te.userAgent = testAppUA
|
||||
te.unaryClientInt = failOkayRPC
|
||||
te.startServer(&testServer{security: e.security})
|
||||
defer te.tearDown()
|
||||
|
||||
tc := testpb.NewTestServiceClient(te.clientConn())
|
||||
if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); grpc.Code(err) != codes.NotFound {
|
||||
t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, error code %s", tc, err, codes.NotFound)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamClientInterceptor(t *testing.T) {
|
||||
defer leakCheck(t)()
|
||||
for _, e := range listTestEnv() {
|
||||
testStreamClientInterceptor(t, e)
|
||||
}
|
||||
}
|
||||
|
||||
func failOkayStream(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
|
||||
s, err := streamer(ctx, desc, cc, method, opts...)
|
||||
if err == nil {
|
||||
return nil, grpc.Errorf(codes.NotFound, "")
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func testStreamClientInterceptor(t *testing.T, e env) {
|
||||
te := newTest(t, e)
|
||||
te.streamClientInt = failOkayStream
|
||||
te.startServer(&testServer{security: e.security})
|
||||
defer te.tearDown()
|
||||
|
||||
tc := testpb.NewTestServiceClient(te.clientConn())
|
||||
respParam := []*testpb.ResponseParameters{
|
||||
{
|
||||
Size: proto.Int32(int32(1)),
|
||||
},
|
||||
}
|
||||
payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, int32(1))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req := &testpb.StreamingOutputCallRequest{
|
||||
ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(),
|
||||
ResponseParameters: respParam,
|
||||
Payload: payload,
|
||||
}
|
||||
if _, err := tc.StreamingOutputCall(context.Background(), req); grpc.Code(err) != codes.NotFound {
|
||||
t.Fatalf("%v.StreamingOutputCall(_) = _, %v, want _, error code %s", tc, err, codes.NotFound)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnaryServerInterceptor(t *testing.T) {
|
||||
defer leakCheck(t)()
|
||||
for _, e := range listTestEnv() {
|
||||
@ -2077,7 +2154,7 @@ func errInjector(ctx context.Context, req interface{}, info *grpc.UnaryServerInf
|
||||
|
||||
func testUnaryServerInterceptor(t *testing.T, e env) {
|
||||
te := newTest(t, e)
|
||||
te.unaryInt = errInjector
|
||||
te.unaryServerInt = errInjector
|
||||
te.startServer(&testServer{security: e.security})
|
||||
defer te.tearDown()
|
||||
|
||||
@ -2108,7 +2185,7 @@ func fullDuplexOnly(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServ
|
||||
|
||||
func testStreamServerInterceptor(t *testing.T, e env) {
|
||||
te := newTest(t, e)
|
||||
te.streamInt = fullDuplexOnly
|
||||
te.streamServerInt = fullDuplexOnly
|
||||
te.startServer(&testServer{security: e.security})
|
||||
defer te.tearDown()
|
||||
|
||||
|
Reference in New Issue
Block a user