Support client side interceptor

This commit is contained in:
iamqizhao
2016-08-26 13:50:38 -07:00
parent d8f4ebe77f
commit 1e47e17230
5 changed files with 142 additions and 19 deletions

View File

@ -112,7 +112,14 @@ func sendRequest(ctx context.Context, codec Codec, compressor Compressor, callHd
// Invoke sends the RPC request on the wire and returns after response is received.
// Invoke is called by generated code. Also users can call Invoke directly when it
// is really needed in their use cases.
func Invoke(ctx context.Context, method string, args, reply interface{}, cc *ClientConn, opts ...CallOption) (err error) {
func Invoke(ctx context.Context, method string, args, reply interface{}, cc *ClientConn, opts ...CallOption) error {
if cc.dopts.unaryInt != nil {
return cc.dopts.unaryInt(ctx, method, args, reply, cc, invoke, opts...)
}
return invoke(ctx, method, args, reply, cc, opts...)
}
func invoke(ctx context.Context, method string, args, reply interface{}, cc *ClientConn, opts ...CallOption) (err error) {
c := defaultCallInfo
for _, o := range opts {
if err := o.before(&c); err != nil {

View File

@ -83,15 +83,17 @@ var (
// dialOptions configure a Dial call. dialOptions are set by the DialOption
// values passed to Dial.
type dialOptions struct {
codec Codec
cp Compressor
dc Decompressor
bs backoffStrategy
balancer Balancer
block bool
insecure bool
timeout time.Duration
copts transport.ConnectOptions
unaryInt UnaryClientInterceptor
streamInt StreamClientInterceptor
codec Codec
cp Compressor
dc Decompressor
bs backoffStrategy
balancer Balancer
block bool
insecure bool
timeout time.Duration
copts transport.ConnectOptions
}
// DialOption configures how we set up the connection.
@ -215,6 +217,20 @@ func WithUserAgent(s string) DialOption {
}
}
// WithUnaryInterceptor returns a DialOption that specifies the interceptor for unary RPCs.
func WithUnaryInterceptor(f UnaryClientInterceptor) DialOption {
return func(o *dialOptions) {
o.unaryInt = f
}
}
// WithStreamInterceptor returns a DialOption that specifies the interceptor for streaming RPCs.
func WithStreamInterceptor(f StreamClientInterceptor) DialOption {
return func(o *dialOptions) {
o.streamInt = f
}
}
// Dial creates a client connection to the given target.
func Dial(target string, opts ...DialOption) (*ClientConn, error) {
return DialContext(context.Background(), target, opts...)

View File

@ -37,6 +37,22 @@ import (
"golang.org/x/net/context"
)
// UnaryInvoker is called by UnaryClientInterceptor to complete RPCs.
type UnaryInvoker func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, opts ...CallOption) error
// UnaryClientInterceptor intercepts the execution of a unary RPC on the client. inovker is the handler to complete the RPC
// and it is the responsibility of the interceptor to call it.
// This is the EXPERIMENTAL API.
type UnaryClientInterceptor func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error
// Streamer is called by StreamClientInterceptor to create a ClientStream.
type Streamer func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (ClientStream, error)
// StreamClientInterceptor intercepts the creation of ClientStream. It may return a custom ClientStream to intercept all I/O
// operations. streamer is the handlder to create a ClientStream and it is the responsibility of the interceptor to call it.
// This is the EXPERIMENTAL API.
type StreamClientInterceptor func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, streamer Streamer, opts ...CallOption) (ClientStream, error)
// UnaryServerInfo consists of various information about a unary RPC on
// server side. All per-rpc information may be mutated by the interceptor.
type UnaryServerInfo struct {

View File

@ -97,7 +97,14 @@ type ClientStream interface {
// NewClientStream creates a new Stream for the client side. This is called
// by generated code.
func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (_ ClientStream, err error) {
func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (ClientStream, error) {
if cc.dopts.streamInt != nil {
return cc.dopts.streamInt(ctx, desc, cc, method, newClientStream, opts...)
}
return newClientStream(ctx, desc, cc, method, opts...)
}
func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (_ ClientStream, err error) {
var (
t transport.ClientTransport
s *transport.Stream

View File

@ -369,8 +369,10 @@ type test struct {
userAgent string
clientCompression bool
serverCompression bool
unaryInt grpc.UnaryServerInterceptor
streamInt grpc.StreamServerInterceptor
unaryClientInt grpc.UnaryClientInterceptor
streamClientInt grpc.StreamClientInterceptor
unaryServerInt grpc.UnaryServerInterceptor
streamServerInt grpc.StreamServerInterceptor
// srv and srvAddr are set once startServer is called.
srv *grpc.Server
@ -425,11 +427,11 @@ func (te *test) startServer(ts testpb.TestServiceServer) {
grpc.RPCDecompressor(grpc.NewGZIPDecompressor()),
)
}
if te.unaryInt != nil {
sopts = append(sopts, grpc.UnaryInterceptor(te.unaryInt))
if te.unaryServerInt != nil {
sopts = append(sopts, grpc.UnaryInterceptor(te.unaryServerInt))
}
if te.streamInt != nil {
sopts = append(sopts, grpc.StreamInterceptor(te.streamInt))
if te.streamServerInt != nil {
sopts = append(sopts, grpc.StreamInterceptor(te.streamServerInt))
}
la := "localhost:0"
switch e.network {
@ -494,6 +496,12 @@ func (te *test) clientConn() *grpc.ClientConn {
grpc.WithDecompressor(grpc.NewGZIPDecompressor()),
)
}
if te.unaryClientInt != nil {
opts = append(opts, grpc.WithUnaryInterceptor(te.unaryClientInt))
}
if te.streamClientInt != nil {
opts = append(opts, grpc.WithStreamInterceptor(te.streamClientInt))
}
switch te.e.security {
case "tls":
creds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "x.test.youtube.com")
@ -2064,6 +2072,75 @@ func testCompressOK(t *testing.T, e env) {
}
}
func TestUnaryClientInterceptor(t *testing.T) {
defer leakCheck(t)()
for _, e := range listTestEnv() {
testUnaryClientInterceptor(t, e)
}
}
func failOkayRPC(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
err := invoker(ctx, method, req, reply, cc, opts...)
if err == nil {
return grpc.Errorf(codes.NotFound, "")
}
return err
}
func testUnaryClientInterceptor(t *testing.T, e env) {
te := newTest(t, e)
te.userAgent = testAppUA
te.unaryClientInt = failOkayRPC
te.startServer(&testServer{security: e.security})
defer te.tearDown()
tc := testpb.NewTestServiceClient(te.clientConn())
if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); grpc.Code(err) != codes.NotFound {
t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, error code %s", tc, err, codes.NotFound)
}
}
func TestStreamClientInterceptor(t *testing.T) {
defer leakCheck(t)()
for _, e := range listTestEnv() {
testStreamClientInterceptor(t, e)
}
}
func failOkayStream(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
s, err := streamer(ctx, desc, cc, method, opts...)
if err == nil {
return nil, grpc.Errorf(codes.NotFound, "")
}
return s, nil
}
func testStreamClientInterceptor(t *testing.T, e env) {
te := newTest(t, e)
te.streamClientInt = failOkayStream
te.startServer(&testServer{security: e.security})
defer te.tearDown()
tc := testpb.NewTestServiceClient(te.clientConn())
respParam := []*testpb.ResponseParameters{
{
Size: proto.Int32(int32(1)),
},
}
payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, int32(1))
if err != nil {
t.Fatal(err)
}
req := &testpb.StreamingOutputCallRequest{
ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(),
ResponseParameters: respParam,
Payload: payload,
}
if _, err := tc.StreamingOutputCall(context.Background(), req); grpc.Code(err) != codes.NotFound {
t.Fatalf("%v.StreamingOutputCall(_) = _, %v, want _, error code %s", tc, err, codes.NotFound)
}
}
func TestUnaryServerInterceptor(t *testing.T) {
defer leakCheck(t)()
for _, e := range listTestEnv() {
@ -2077,7 +2154,7 @@ func errInjector(ctx context.Context, req interface{}, info *grpc.UnaryServerInf
func testUnaryServerInterceptor(t *testing.T, e env) {
te := newTest(t, e)
te.unaryInt = errInjector
te.unaryServerInt = errInjector
te.startServer(&testServer{security: e.security})
defer te.tearDown()
@ -2108,7 +2185,7 @@ func fullDuplexOnly(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServ
func testStreamServerInterceptor(t *testing.T, e env) {
te := newTest(t, e)
te.streamInt = fullDuplexOnly
te.streamServerInt = fullDuplexOnly
te.startServer(&testServer{security: e.security})
defer te.tearDown()