Support client side interceptor

This commit is contained in:
iamqizhao
2016-08-26 13:50:38 -07:00
parent d8f4ebe77f
commit 1e47e17230
5 changed files with 142 additions and 19 deletions

View File

@ -112,7 +112,14 @@ func sendRequest(ctx context.Context, codec Codec, compressor Compressor, callHd
// Invoke sends the RPC request on the wire and returns after response is received.
// Invoke is called by generated code. Also users can call Invoke directly when it
// is really needed in their use cases.
func Invoke(ctx context.Context, method string, args, reply interface{}, cc *ClientConn, opts ...CallOption) (err error) {
func Invoke(ctx context.Context, method string, args, reply interface{}, cc *ClientConn, opts ...CallOption) error {
if cc.dopts.unaryInt != nil {
return cc.dopts.unaryInt(ctx, method, args, reply, cc, invoke, opts...)
}
return invoke(ctx, method, args, reply, cc, opts...)
}
func invoke(ctx context.Context, method string, args, reply interface{}, cc *ClientConn, opts ...CallOption) (err error) {
c := defaultCallInfo
for _, o := range opts {
if err := o.before(&c); err != nil {

View File

@ -83,6 +83,8 @@ var (
// dialOptions configure a Dial call. dialOptions are set by the DialOption
// values passed to Dial.
type dialOptions struct {
unaryInt UnaryClientInterceptor
streamInt StreamClientInterceptor
codec Codec
cp Compressor
dc Decompressor
@ -215,6 +217,20 @@ func WithUserAgent(s string) DialOption {
}
}
// WithUnaryInterceptor returns a DialOption that specifies the interceptor for unary RPCs.
func WithUnaryInterceptor(f UnaryClientInterceptor) DialOption {
return func(o *dialOptions) {
o.unaryInt = f
}
}
// WithStreamInterceptor returns a DialOption that specifies the interceptor for streaming RPCs.
func WithStreamInterceptor(f StreamClientInterceptor) DialOption {
return func(o *dialOptions) {
o.streamInt = f
}
}
// Dial creates a client connection to the given target.
func Dial(target string, opts ...DialOption) (*ClientConn, error) {
return DialContext(context.Background(), target, opts...)

View File

@ -37,6 +37,22 @@ import (
"golang.org/x/net/context"
)
// UnaryInvoker is called by UnaryClientInterceptor to complete RPCs.
type UnaryInvoker func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, opts ...CallOption) error
// UnaryClientInterceptor intercepts the execution of a unary RPC on the client. inovker is the handler to complete the RPC
// and it is the responsibility of the interceptor to call it.
// This is the EXPERIMENTAL API.
type UnaryClientInterceptor func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error
// Streamer is called by StreamClientInterceptor to create a ClientStream.
type Streamer func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (ClientStream, error)
// StreamClientInterceptor intercepts the creation of ClientStream. It may return a custom ClientStream to intercept all I/O
// operations. streamer is the handlder to create a ClientStream and it is the responsibility of the interceptor to call it.
// This is the EXPERIMENTAL API.
type StreamClientInterceptor func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, streamer Streamer, opts ...CallOption) (ClientStream, error)
// UnaryServerInfo consists of various information about a unary RPC on
// server side. All per-rpc information may be mutated by the interceptor.
type UnaryServerInfo struct {

View File

@ -97,7 +97,14 @@ type ClientStream interface {
// NewClientStream creates a new Stream for the client side. This is called
// by generated code.
func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (_ ClientStream, err error) {
func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (ClientStream, error) {
if cc.dopts.streamInt != nil {
return cc.dopts.streamInt(ctx, desc, cc, method, newClientStream, opts...)
}
return newClientStream(ctx, desc, cc, method, opts...)
}
func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (_ ClientStream, err error) {
var (
t transport.ClientTransport
s *transport.Stream

View File

@ -369,8 +369,10 @@ type test struct {
userAgent string
clientCompression bool
serverCompression bool
unaryInt grpc.UnaryServerInterceptor
streamInt grpc.StreamServerInterceptor
unaryClientInt grpc.UnaryClientInterceptor
streamClientInt grpc.StreamClientInterceptor
unaryServerInt grpc.UnaryServerInterceptor
streamServerInt grpc.StreamServerInterceptor
// srv and srvAddr are set once startServer is called.
srv *grpc.Server
@ -425,11 +427,11 @@ func (te *test) startServer(ts testpb.TestServiceServer) {
grpc.RPCDecompressor(grpc.NewGZIPDecompressor()),
)
}
if te.unaryInt != nil {
sopts = append(sopts, grpc.UnaryInterceptor(te.unaryInt))
if te.unaryServerInt != nil {
sopts = append(sopts, grpc.UnaryInterceptor(te.unaryServerInt))
}
if te.streamInt != nil {
sopts = append(sopts, grpc.StreamInterceptor(te.streamInt))
if te.streamServerInt != nil {
sopts = append(sopts, grpc.StreamInterceptor(te.streamServerInt))
}
la := "localhost:0"
switch e.network {
@ -494,6 +496,12 @@ func (te *test) clientConn() *grpc.ClientConn {
grpc.WithDecompressor(grpc.NewGZIPDecompressor()),
)
}
if te.unaryClientInt != nil {
opts = append(opts, grpc.WithUnaryInterceptor(te.unaryClientInt))
}
if te.streamClientInt != nil {
opts = append(opts, grpc.WithStreamInterceptor(te.streamClientInt))
}
switch te.e.security {
case "tls":
creds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "x.test.youtube.com")
@ -2064,6 +2072,75 @@ func testCompressOK(t *testing.T, e env) {
}
}
func TestUnaryClientInterceptor(t *testing.T) {
defer leakCheck(t)()
for _, e := range listTestEnv() {
testUnaryClientInterceptor(t, e)
}
}
func failOkayRPC(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
err := invoker(ctx, method, req, reply, cc, opts...)
if err == nil {
return grpc.Errorf(codes.NotFound, "")
}
return err
}
func testUnaryClientInterceptor(t *testing.T, e env) {
te := newTest(t, e)
te.userAgent = testAppUA
te.unaryClientInt = failOkayRPC
te.startServer(&testServer{security: e.security})
defer te.tearDown()
tc := testpb.NewTestServiceClient(te.clientConn())
if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); grpc.Code(err) != codes.NotFound {
t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, error code %s", tc, err, codes.NotFound)
}
}
func TestStreamClientInterceptor(t *testing.T) {
defer leakCheck(t)()
for _, e := range listTestEnv() {
testStreamClientInterceptor(t, e)
}
}
func failOkayStream(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
s, err := streamer(ctx, desc, cc, method, opts...)
if err == nil {
return nil, grpc.Errorf(codes.NotFound, "")
}
return s, nil
}
func testStreamClientInterceptor(t *testing.T, e env) {
te := newTest(t, e)
te.streamClientInt = failOkayStream
te.startServer(&testServer{security: e.security})
defer te.tearDown()
tc := testpb.NewTestServiceClient(te.clientConn())
respParam := []*testpb.ResponseParameters{
{
Size: proto.Int32(int32(1)),
},
}
payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, int32(1))
if err != nil {
t.Fatal(err)
}
req := &testpb.StreamingOutputCallRequest{
ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(),
ResponseParameters: respParam,
Payload: payload,
}
if _, err := tc.StreamingOutputCall(context.Background(), req); grpc.Code(err) != codes.NotFound {
t.Fatalf("%v.StreamingOutputCall(_) = _, %v, want _, error code %s", tc, err, codes.NotFound)
}
}
func TestUnaryServerInterceptor(t *testing.T) {
defer leakCheck(t)()
for _, e := range listTestEnv() {
@ -2077,7 +2154,7 @@ func errInjector(ctx context.Context, req interface{}, info *grpc.UnaryServerInf
func testUnaryServerInterceptor(t *testing.T, e env) {
te := newTest(t, e)
te.unaryInt = errInjector
te.unaryServerInt = errInjector
te.startServer(&testServer{security: e.security})
defer te.tearDown()
@ -2108,7 +2185,7 @@ func fullDuplexOnly(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServ
func testStreamServerInterceptor(t *testing.T, e env) {
te := newTest(t, e)
te.streamInt = fullDuplexOnly
te.streamServerInt = fullDuplexOnly
te.startServer(&testServer{security: e.security})
defer te.tearDown()