interceptor: new APIs for chaining client interceptors. (#2696)
This commit is contained in:
204
call_test.go
204
call_test.go
@ -123,6 +123,8 @@ type server struct {
|
||||
conns map[transport.ServerTransport]bool
|
||||
}
|
||||
|
||||
type ctxKey string
|
||||
|
||||
func newTestServer() *server {
|
||||
return &server{startedErr: make(chan error, 1)}
|
||||
}
|
||||
@ -202,17 +204,217 @@ func (s *server) stop() {
|
||||
}
|
||||
|
||||
func setUp(t *testing.T, port int, maxStreams uint32) (*server, *ClientConn) {
|
||||
return setUpWithOptions(t, port, maxStreams)
|
||||
}
|
||||
|
||||
func setUpWithOptions(t *testing.T, port int, maxStreams uint32, dopts ...DialOption) (*server, *ClientConn) {
|
||||
server := newTestServer()
|
||||
go server.start(t, port, maxStreams)
|
||||
server.wait(t, 2*time.Second)
|
||||
addr := "localhost:" + server.port
|
||||
cc, err := Dial(addr, WithBlock(), WithInsecure(), WithCodec(testCodec{}))
|
||||
dopts = append(dopts, WithBlock(), WithInsecure(), WithCodec(testCodec{}))
|
||||
cc, err := Dial(addr, dopts...)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create ClientConn: %v", err)
|
||||
}
|
||||
return server, cc
|
||||
}
|
||||
|
||||
func (s) TestUnaryClientInterceptor(t *testing.T) {
|
||||
parentKey := ctxKey("parentKey")
|
||||
|
||||
interceptor := func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error {
|
||||
if ctx.Value(parentKey) == nil {
|
||||
t.Fatalf("interceptor should have %v in context", parentKey)
|
||||
}
|
||||
return invoker(ctx, method, req, reply, cc, opts...)
|
||||
}
|
||||
|
||||
server, cc := setUpWithOptions(t, 0, math.MaxUint32, WithUnaryInterceptor(interceptor))
|
||||
defer func() {
|
||||
cc.Close()
|
||||
server.stop()
|
||||
}()
|
||||
|
||||
var reply string
|
||||
ctx := context.Background()
|
||||
parentCtx := context.WithValue(ctx, ctxKey("parentKey"), 0)
|
||||
if err := cc.Invoke(parentCtx, "/foo/bar", &expectedRequest, &reply); err != nil || reply != expectedResponse {
|
||||
t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want <nil>", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s) TestChainUnaryClientInterceptor(t *testing.T) {
|
||||
var (
|
||||
parentKey = ctxKey("parentKey")
|
||||
firstIntKey = ctxKey("firstIntKey")
|
||||
secondIntKey = ctxKey("secondIntKey")
|
||||
)
|
||||
|
||||
firstInt := func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error {
|
||||
if ctx.Value(parentKey) == nil {
|
||||
t.Fatalf("first interceptor should have %v in context", parentKey)
|
||||
}
|
||||
if ctx.Value(firstIntKey) != nil {
|
||||
t.Fatalf("first interceptor should not have %v in context", firstIntKey)
|
||||
}
|
||||
if ctx.Value(secondIntKey) != nil {
|
||||
t.Fatalf("first interceptor should not have %v in context", secondIntKey)
|
||||
}
|
||||
firstCtx := context.WithValue(ctx, firstIntKey, 1)
|
||||
err := invoker(firstCtx, method, req, reply, cc, opts...)
|
||||
*(reply.(*string)) += "1"
|
||||
return err
|
||||
}
|
||||
|
||||
secondInt := func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error {
|
||||
if ctx.Value(parentKey) == nil {
|
||||
t.Fatalf("second interceptor should have %v in context", parentKey)
|
||||
}
|
||||
if ctx.Value(firstIntKey) == nil {
|
||||
t.Fatalf("second interceptor should have %v in context", firstIntKey)
|
||||
}
|
||||
if ctx.Value(secondIntKey) != nil {
|
||||
t.Fatalf("second interceptor should not have %v in context", secondIntKey)
|
||||
}
|
||||
secondCtx := context.WithValue(ctx, secondIntKey, 2)
|
||||
err := invoker(secondCtx, method, req, reply, cc, opts...)
|
||||
*(reply.(*string)) += "2"
|
||||
return err
|
||||
}
|
||||
|
||||
lastInt := func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error {
|
||||
if ctx.Value(parentKey) == nil {
|
||||
t.Fatalf("last interceptor should have %v in context", parentKey)
|
||||
}
|
||||
if ctx.Value(firstIntKey) == nil {
|
||||
t.Fatalf("last interceptor should have %v in context", firstIntKey)
|
||||
}
|
||||
if ctx.Value(secondIntKey) == nil {
|
||||
t.Fatalf("last interceptor should have %v in context", secondIntKey)
|
||||
}
|
||||
err := invoker(ctx, method, req, reply, cc, opts...)
|
||||
*(reply.(*string)) += "3"
|
||||
return err
|
||||
}
|
||||
|
||||
server, cc := setUpWithOptions(t, 0, math.MaxUint32, WithChainUnaryInterceptor(firstInt, secondInt, lastInt))
|
||||
defer func() {
|
||||
cc.Close()
|
||||
server.stop()
|
||||
}()
|
||||
|
||||
var reply string
|
||||
ctx := context.Background()
|
||||
parentCtx := context.WithValue(ctx, ctxKey("parentKey"), 0)
|
||||
if err := cc.Invoke(parentCtx, "/foo/bar", &expectedRequest, &reply); err != nil || reply != expectedResponse+"321" {
|
||||
t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want <nil>", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s) TestChainOnBaseUnaryClientInterceptor(t *testing.T) {
|
||||
var (
|
||||
parentKey = ctxKey("parentKey")
|
||||
baseIntKey = ctxKey("baseIntKey")
|
||||
)
|
||||
|
||||
baseInt := func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error {
|
||||
if ctx.Value(parentKey) == nil {
|
||||
t.Fatalf("base interceptor should have %v in context", parentKey)
|
||||
}
|
||||
if ctx.Value(baseIntKey) != nil {
|
||||
t.Fatalf("base interceptor should not have %v in context", baseIntKey)
|
||||
}
|
||||
baseCtx := context.WithValue(ctx, baseIntKey, 1)
|
||||
return invoker(baseCtx, method, req, reply, cc, opts...)
|
||||
}
|
||||
|
||||
chainInt := func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error {
|
||||
if ctx.Value(parentKey) == nil {
|
||||
t.Fatalf("chain interceptor should have %v in context", parentKey)
|
||||
}
|
||||
if ctx.Value(baseIntKey) == nil {
|
||||
t.Fatalf("chain interceptor should have %v in context", baseIntKey)
|
||||
}
|
||||
return invoker(ctx, method, req, reply, cc, opts...)
|
||||
}
|
||||
|
||||
server, cc := setUpWithOptions(t, 0, math.MaxUint32, WithUnaryInterceptor(baseInt), WithChainUnaryInterceptor(chainInt))
|
||||
defer func() {
|
||||
cc.Close()
|
||||
server.stop()
|
||||
}()
|
||||
|
||||
var reply string
|
||||
ctx := context.Background()
|
||||
parentCtx := context.WithValue(ctx, ctxKey("parentKey"), 0)
|
||||
if err := cc.Invoke(parentCtx, "/foo/bar", &expectedRequest, &reply); err != nil || reply != expectedResponse {
|
||||
t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want <nil>", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s) TestChainStreamClientInterceptor(t *testing.T) {
|
||||
var (
|
||||
parentKey = ctxKey("parentKey")
|
||||
firstIntKey = ctxKey("firstIntKey")
|
||||
secondIntKey = ctxKey("secondIntKey")
|
||||
)
|
||||
|
||||
firstInt := func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, streamer Streamer, opts ...CallOption) (ClientStream, error) {
|
||||
if ctx.Value(parentKey) == nil {
|
||||
t.Fatalf("first interceptor should have %v in context", parentKey)
|
||||
}
|
||||
if ctx.Value(firstIntKey) != nil {
|
||||
t.Fatalf("first interceptor should not have %v in context", firstIntKey)
|
||||
}
|
||||
if ctx.Value(secondIntKey) != nil {
|
||||
t.Fatalf("first interceptor should not have %v in context", secondIntKey)
|
||||
}
|
||||
firstCtx := context.WithValue(ctx, firstIntKey, 1)
|
||||
return streamer(firstCtx, desc, cc, method, opts...)
|
||||
}
|
||||
|
||||
secondInt := func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, streamer Streamer, opts ...CallOption) (ClientStream, error) {
|
||||
if ctx.Value(parentKey) == nil {
|
||||
t.Fatalf("second interceptor should have %v in context", parentKey)
|
||||
}
|
||||
if ctx.Value(firstIntKey) == nil {
|
||||
t.Fatalf("second interceptor should have %v in context", firstIntKey)
|
||||
}
|
||||
if ctx.Value(secondIntKey) != nil {
|
||||
t.Fatalf("second interceptor should not have %v in context", secondIntKey)
|
||||
}
|
||||
secondCtx := context.WithValue(ctx, secondIntKey, 2)
|
||||
return streamer(secondCtx, desc, cc, method, opts...)
|
||||
}
|
||||
|
||||
lastInt := func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, streamer Streamer, opts ...CallOption) (ClientStream, error) {
|
||||
if ctx.Value(parentKey) == nil {
|
||||
t.Fatalf("last interceptor should have %v in context", parentKey)
|
||||
}
|
||||
if ctx.Value(firstIntKey) == nil {
|
||||
t.Fatalf("last interceptor should have %v in context", firstIntKey)
|
||||
}
|
||||
if ctx.Value(secondIntKey) == nil {
|
||||
t.Fatalf("last interceptor should have %v in context", secondIntKey)
|
||||
}
|
||||
return streamer(ctx, desc, cc, method, opts...)
|
||||
}
|
||||
|
||||
server, cc := setUpWithOptions(t, 0, math.MaxUint32, WithChainStreamInterceptor(firstInt, secondInt, lastInt))
|
||||
defer func() {
|
||||
cc.Close()
|
||||
server.stop()
|
||||
}()
|
||||
|
||||
ctx := context.Background()
|
||||
parentCtx := context.WithValue(ctx, ctxKey("parentKey"), 0)
|
||||
_, err := cc.NewStream(parentCtx, &StreamDesc{}, "/foo/bar")
|
||||
if err != nil {
|
||||
t.Fatalf("grpc.NewStream(_, _, _) = %v, want <nil>", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s) TestInvoke(t *testing.T) {
|
||||
server, cc := setUp(t, 0, math.MaxUint32)
|
||||
var reply string
|
||||
|
@ -137,6 +137,9 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
|
||||
opt.apply(&cc.dopts)
|
||||
}
|
||||
|
||||
chainUnaryClientInterceptors(cc)
|
||||
chainStreamClientInterceptors(cc)
|
||||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
cc.Close()
|
||||
@ -327,6 +330,68 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
|
||||
return cc, nil
|
||||
}
|
||||
|
||||
// chainUnaryClientInterceptors chains all unary client interceptors into one.
|
||||
func chainUnaryClientInterceptors(cc *ClientConn) {
|
||||
interceptors := cc.dopts.chainUnaryInts
|
||||
// Prepend dopts.unaryInt to the chaining interceptors if it exists, since unaryInt will
|
||||
// be executed before any other chained interceptors.
|
||||
if cc.dopts.unaryInt != nil {
|
||||
interceptors = append([]UnaryClientInterceptor{cc.dopts.unaryInt}, interceptors...)
|
||||
}
|
||||
var chainedInt UnaryClientInterceptor
|
||||
if len(interceptors) == 0 {
|
||||
chainedInt = nil
|
||||
} else if len(interceptors) == 1 {
|
||||
chainedInt = interceptors[0]
|
||||
} else {
|
||||
chainedInt = func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error {
|
||||
return interceptors[0](ctx, method, req, reply, cc, getChainUnaryInvoker(interceptors, 0, invoker), opts...)
|
||||
}
|
||||
}
|
||||
cc.dopts.unaryInt = chainedInt
|
||||
}
|
||||
|
||||
// getChainUnaryInvoker recursively generate the chained unary invoker.
|
||||
func getChainUnaryInvoker(interceptors []UnaryClientInterceptor, curr int, finalInvoker UnaryInvoker) UnaryInvoker {
|
||||
if curr == len(interceptors)-1 {
|
||||
return finalInvoker
|
||||
}
|
||||
return func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, opts ...CallOption) error {
|
||||
return interceptors[curr+1](ctx, method, req, reply, cc, getChainUnaryInvoker(interceptors, curr+1, finalInvoker), opts...)
|
||||
}
|
||||
}
|
||||
|
||||
// chainStreamClientInterceptors chains all stream client interceptors into one.
|
||||
func chainStreamClientInterceptors(cc *ClientConn) {
|
||||
interceptors := cc.dopts.chainStreamInts
|
||||
// Prepend dopts.streamInt to the chaining interceptors if it exists, since streamInt will
|
||||
// be executed before any other chained interceptors.
|
||||
if cc.dopts.streamInt != nil {
|
||||
interceptors = append([]StreamClientInterceptor{cc.dopts.streamInt}, interceptors...)
|
||||
}
|
||||
var chainedInt StreamClientInterceptor
|
||||
if len(interceptors) == 0 {
|
||||
chainedInt = nil
|
||||
} else if len(interceptors) == 1 {
|
||||
chainedInt = interceptors[0]
|
||||
} else {
|
||||
chainedInt = func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, streamer Streamer, opts ...CallOption) (ClientStream, error) {
|
||||
return interceptors[0](ctx, desc, cc, method, getChainStreamer(interceptors, 0, streamer), opts...)
|
||||
}
|
||||
}
|
||||
cc.dopts.streamInt = chainedInt
|
||||
}
|
||||
|
||||
// getChainStreamer recursively generate the chained client stream constructor.
|
||||
func getChainStreamer(interceptors []StreamClientInterceptor, curr int, finalStreamer Streamer) Streamer {
|
||||
if curr == len(interceptors)-1 {
|
||||
return finalStreamer
|
||||
}
|
||||
return func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (ClientStream, error) {
|
||||
return interceptors[curr+1](ctx, desc, cc, method, getChainStreamer(interceptors, curr+1, finalStreamer), opts...)
|
||||
}
|
||||
}
|
||||
|
||||
// connectivityStateManager keeps the connectivity.State of ClientConn.
|
||||
// This struct will eventually be exported so the balancers can access it.
|
||||
type connectivityStateManager struct {
|
||||
|
@ -39,8 +39,12 @@ import (
|
||||
// dialOptions configure a Dial call. dialOptions are set by the DialOption
|
||||
// values passed to Dial.
|
||||
type dialOptions struct {
|
||||
unaryInt UnaryClientInterceptor
|
||||
streamInt StreamClientInterceptor
|
||||
unaryInt UnaryClientInterceptor
|
||||
streamInt StreamClientInterceptor
|
||||
|
||||
chainUnaryInts []UnaryClientInterceptor
|
||||
chainStreamInts []StreamClientInterceptor
|
||||
|
||||
cp Compressor
|
||||
dc Decompressor
|
||||
bs backoff.Strategy
|
||||
@ -414,6 +418,17 @@ func WithUnaryInterceptor(f UnaryClientInterceptor) DialOption {
|
||||
})
|
||||
}
|
||||
|
||||
// WithChainUnaryInterceptor returns a DialOption that specifies the chained
|
||||
// interceptor for unary RPCs. The first interceptor will be the outer most,
|
||||
// while the last interceptor will be the inner most wrapper around the real call.
|
||||
// All interceptors added by this method will be chained, and the interceptor
|
||||
// defined by WithUnaryInterceptor will always be prepended to the chain.
|
||||
func WithChainUnaryInterceptor(interceptors ...UnaryClientInterceptor) DialOption {
|
||||
return newFuncDialOption(func(o *dialOptions) {
|
||||
o.chainUnaryInts = append(o.chainUnaryInts, interceptors...)
|
||||
})
|
||||
}
|
||||
|
||||
// WithStreamInterceptor returns a DialOption that specifies the interceptor for
|
||||
// streaming RPCs.
|
||||
func WithStreamInterceptor(f StreamClientInterceptor) DialOption {
|
||||
@ -422,6 +437,17 @@ func WithStreamInterceptor(f StreamClientInterceptor) DialOption {
|
||||
})
|
||||
}
|
||||
|
||||
// WithChainStreamInterceptor returns a DialOption that specifies the chained
|
||||
// interceptor for unary RPCs. The first interceptor will be the outer most,
|
||||
// while the last interceptor will be the inner most wrapper around the real call.
|
||||
// All interceptors added by this method will be chained, and the interceptor
|
||||
// defined by WithStreamInterceptor will always be prepended to the chain.
|
||||
func WithChainStreamInterceptor(interceptors ...StreamClientInterceptor) DialOption {
|
||||
return newFuncDialOption(func(o *dialOptions) {
|
||||
o.chainStreamInts = append(o.chainStreamInts, interceptors...)
|
||||
})
|
||||
}
|
||||
|
||||
// WithAuthority returns a DialOption that specifies the value to be used as the
|
||||
// :authority pseudo-header. This value only works with WithInsecure and has no
|
||||
// effect if TransportCredentials are present.
|
||||
|
Reference in New Issue
Block a user