client: export types implementing CallOptions for access by interceptors (#1902)

This commit is contained in:
Joshua Humphries
2018-03-16 18:57:34 -04:00
committed by dfawley
parent ec9275ba9a
commit fa28bef939
4 changed files with 262 additions and 61 deletions

View File

@ -27,6 +27,10 @@ import (
// //
// All errors returned by Invoke are compatible with the status package. // All errors returned by Invoke are compatible with the status package.
func (cc *ClientConn) Invoke(ctx context.Context, method string, args, reply interface{}, opts ...CallOption) error { func (cc *ClientConn) Invoke(ctx context.Context, method string, args, reply interface{}, opts ...CallOption) error {
// allow interceptor to see all applicable call options, which means those
// configured as defaults from dial option as well as per-call options
opts = append(cc.dopts.callOptions, opts...)
if cc.dopts.unaryInt != nil { if cc.dopts.unaryInt != nil {
return cc.dopts.unaryInt(ctx, method, args, reply, cc, invoke, opts...) return cc.dopts.unaryInt(ctx, method, args, reply, cc, invoke, opts...)
} }

View File

@ -160,46 +160,66 @@ type EmptyCallOption struct{}
func (EmptyCallOption) before(*callInfo) error { return nil } func (EmptyCallOption) before(*callInfo) error { return nil }
func (EmptyCallOption) after(*callInfo) {} func (EmptyCallOption) after(*callInfo) {}
type beforeCall func(c *callInfo) error
func (o beforeCall) before(c *callInfo) error { return o(c) }
func (o beforeCall) after(c *callInfo) {}
type afterCall func(c *callInfo)
func (o afterCall) before(c *callInfo) error { return nil }
func (o afterCall) after(c *callInfo) { o(c) }
// Header returns a CallOptions that retrieves the header metadata // Header returns a CallOptions that retrieves the header metadata
// for a unary RPC. // for a unary RPC.
func Header(md *metadata.MD) CallOption { func Header(md *metadata.MD) CallOption {
return afterCall(func(c *callInfo) { return HeaderCallOption{HeaderAddr: md}
if c.stream != nil { }
*md, _ = c.stream.Header()
} // HeaderCallOption is a CallOption for collecting response header metadata.
}) // The metadata field will be populated *after* the RPC completes.
// This is an EXPERIMENTAL API.
type HeaderCallOption struct {
HeaderAddr *metadata.MD
}
func (o HeaderCallOption) before(c *callInfo) error { return nil }
func (o HeaderCallOption) after(c *callInfo) {
if c.stream != nil {
*o.HeaderAddr, _ = c.stream.Header()
}
} }
// Trailer returns a CallOptions that retrieves the trailer metadata // Trailer returns a CallOptions that retrieves the trailer metadata
// for a unary RPC. // for a unary RPC.
func Trailer(md *metadata.MD) CallOption { func Trailer(md *metadata.MD) CallOption {
return afterCall(func(c *callInfo) { return TrailerCallOption{TrailerAddr: md}
if c.stream != nil { }
*md = c.stream.Trailer()
} // TrailerCallOption is a CallOption for collecting response trailer metadata.
}) // The metadata field will be populated *after* the RPC completes.
// This is an EXPERIMENTAL API.
type TrailerCallOption struct {
TrailerAddr *metadata.MD
}
func (o TrailerCallOption) before(c *callInfo) error { return nil }
func (o TrailerCallOption) after(c *callInfo) {
if c.stream != nil {
*o.TrailerAddr = c.stream.Trailer()
}
} }
// Peer returns a CallOption that retrieves peer information for a // Peer returns a CallOption that retrieves peer information for a
// unary RPC. // unary RPC.
func Peer(p *peer.Peer) CallOption { func Peer(p *peer.Peer) CallOption {
return afterCall(func(c *callInfo) { return PeerCallOption{PeerAddr: p}
if c.stream != nil { }
if x, ok := peer.FromContext(c.stream.Context()); ok {
*p = *x // PeerCallOption is a CallOption for collecting the identity of the remote
} // peer. The peer field will be populated *after* the RPC completes.
// This is an EXPERIMENTAL API.
type PeerCallOption struct {
PeerAddr *peer.Peer
}
func (o PeerCallOption) before(c *callInfo) error { return nil }
func (o PeerCallOption) after(c *callInfo) {
if c.stream != nil {
if x, ok := peer.FromContext(c.stream.Context()); ok {
*o.PeerAddr = *x
} }
}) }
} }
// FailFast configures the action to take when an RPC is attempted on broken // FailFast configures the action to take when an RPC is attempted on broken
@ -213,49 +233,98 @@ func Peer(p *peer.Peer) CallOption {
// //
// By default, RPCs are "Fail Fast". // By default, RPCs are "Fail Fast".
func FailFast(failFast bool) CallOption { func FailFast(failFast bool) CallOption {
return beforeCall(func(c *callInfo) error { return FailFastCallOption{FailFast: failFast}
c.failFast = failFast
return nil
})
} }
// FailFastCallOption is a CallOption for indicating whether an RPC should fail
// fast or not.
// This is an EXPERIMENTAL API.
type FailFastCallOption struct {
FailFast bool
}
func (o FailFastCallOption) before(c *callInfo) error {
c.failFast = o.FailFast
return nil
}
func (o FailFastCallOption) after(c *callInfo) { return }
// MaxCallRecvMsgSize returns a CallOption which sets the maximum message size the client can receive. // MaxCallRecvMsgSize returns a CallOption which sets the maximum message size the client can receive.
func MaxCallRecvMsgSize(s int) CallOption { func MaxCallRecvMsgSize(s int) CallOption {
return beforeCall(func(o *callInfo) error { return MaxRecvMsgSizeCallOption{MaxRecvMsgSize: s}
o.maxReceiveMessageSize = &s
return nil
})
} }
// MaxRecvMsgSizeCallOption is a CallOption that indicates the maximum message
// size the client can receive.
// This is an EXPERIMENTAL API.
type MaxRecvMsgSizeCallOption struct {
MaxRecvMsgSize int
}
func (o MaxRecvMsgSizeCallOption) before(c *callInfo) error {
c.maxReceiveMessageSize = &o.MaxRecvMsgSize
return nil
}
func (o MaxRecvMsgSizeCallOption) after(c *callInfo) { return }
// MaxCallSendMsgSize returns a CallOption which sets the maximum message size the client can send. // MaxCallSendMsgSize returns a CallOption which sets the maximum message size the client can send.
func MaxCallSendMsgSize(s int) CallOption { func MaxCallSendMsgSize(s int) CallOption {
return beforeCall(func(o *callInfo) error { return MaxSendMsgSizeCallOption{MaxSendMsgSize: s}
o.maxSendMessageSize = &s
return nil
})
} }
// MaxSendMsgSizeCallOption is a CallOption that indicates the maximum message
// size the client can send.
// This is an EXPERIMENTAL API.
type MaxSendMsgSizeCallOption struct {
MaxSendMsgSize int
}
func (o MaxSendMsgSizeCallOption) before(c *callInfo) error {
c.maxSendMessageSize = &o.MaxSendMsgSize
return nil
}
func (o MaxSendMsgSizeCallOption) after(c *callInfo) { return }
// PerRPCCredentials returns a CallOption that sets credentials.PerRPCCredentials // PerRPCCredentials returns a CallOption that sets credentials.PerRPCCredentials
// for a call. // for a call.
func PerRPCCredentials(creds credentials.PerRPCCredentials) CallOption { func PerRPCCredentials(creds credentials.PerRPCCredentials) CallOption {
return beforeCall(func(c *callInfo) error { return PerRPCCredsCallOption{Creds: creds}
c.creds = creds
return nil
})
} }
// PerRPCCredsCallOption is a CallOption that indicates the the per-RPC
// credentials to use for the call.
// This is an EXPERIMENTAL API.
type PerRPCCredsCallOption struct {
Creds credentials.PerRPCCredentials
}
func (o PerRPCCredsCallOption) before(c *callInfo) error {
c.creds = o.Creds
return nil
}
func (o PerRPCCredsCallOption) after(c *callInfo) { return }
// UseCompressor returns a CallOption which sets the compressor used when // UseCompressor returns a CallOption which sets the compressor used when
// sending the request. If WithCompressor is also set, UseCompressor has // sending the request. If WithCompressor is also set, UseCompressor has
// higher priority. // higher priority.
// //
// This API is EXPERIMENTAL. // This API is EXPERIMENTAL.
func UseCompressor(name string) CallOption { func UseCompressor(name string) CallOption {
return beforeCall(func(c *callInfo) error { return CompressorCallOption{CompressorType: name}
c.compressorType = name
return nil
})
} }
// CompressorCallOption is a CallOption that indicates the compressor to use.
// This is an EXPERIMENTAL API.
type CompressorCallOption struct {
CompressorType string
}
func (o CompressorCallOption) before(c *callInfo) error {
c.compressorType = o.CompressorType
return nil
}
func (o CompressorCallOption) after(c *callInfo) { return }
// CallContentSubtype returns a CallOption that will set the content-subtype // CallContentSubtype returns a CallOption that will set the content-subtype
// for a call. For example, if content-subtype is "json", the Content-Type over // for a call. For example, if content-subtype is "json", the Content-Type over
// the wire will be "application/grpc+json". The content-subtype is converted // the wire will be "application/grpc+json". The content-subtype is converted
@ -273,13 +342,22 @@ func UseCompressor(name string) CallOption {
// response messages, with the content-subtype set to the given contentSubtype // response messages, with the content-subtype set to the given contentSubtype
// here for requests. // here for requests.
func CallContentSubtype(contentSubtype string) CallOption { func CallContentSubtype(contentSubtype string) CallOption {
contentSubtype = strings.ToLower(contentSubtype) return ContentSubtypeCallOption{ContentSubtype: strings.ToLower(contentSubtype)}
return beforeCall(func(c *callInfo) error {
c.contentSubtype = contentSubtype
return nil
})
} }
// ContentSubtypeCallOption is a CallOption that indicates the content-subtype
// used for marshaling messages.
// This is an EXPERIMENTAL API.
type ContentSubtypeCallOption struct {
ContentSubtype string
}
func (o ContentSubtypeCallOption) before(c *callInfo) error {
c.contentSubtype = o.ContentSubtype
return nil
}
func (o ContentSubtypeCallOption) after(c *callInfo) { return }
// CallCustomCodec returns a CallOption that will set the given Codec to be // CallCustomCodec returns a CallOption that will set the given Codec to be
// used for all request and response messages for a call. The result of calling // used for all request and response messages for a call. The result of calling
// String() will be used as the content-subtype in a case-insensitive manner. // String() will be used as the content-subtype in a case-insensitive manner.
@ -293,12 +371,22 @@ func CallContentSubtype(contentSubtype string) CallOption {
// This function is provided for advanced users; prefer to use only // This function is provided for advanced users; prefer to use only
// CallContentSubtype to select a registered codec instead. // CallContentSubtype to select a registered codec instead.
func CallCustomCodec(codec Codec) CallOption { func CallCustomCodec(codec Codec) CallOption {
return beforeCall(func(c *callInfo) error { return CustomCodecCallOption{Codec: codec}
c.codec = codec
return nil
})
} }
// CustomCodecCallOption is a CallOption that indicates the codec used for
// marshaling messages.
// This is an EXPERIMENTAL API.
type CustomCodecCallOption struct {
Codec Codec
}
func (o CustomCodecCallOption) before(c *callInfo) error {
c.codec = o.Codec
return nil
}
func (o CustomCodecCallOption) after(c *callInfo) { return }
// The format of the payload: compressed or not? // The format of the payload: compressed or not?
type payloadFormat uint8 type payloadFormat uint8

View File

@ -102,6 +102,10 @@ type ClientStream interface {
// NewStream creates a new Stream for the client side. This is typically // NewStream creates a new Stream for the client side. This is typically
// called by generated code. // called by generated code.
func (cc *ClientConn) NewStream(ctx context.Context, desc *StreamDesc, method string, opts ...CallOption) (ClientStream, error) { func (cc *ClientConn) NewStream(ctx context.Context, desc *StreamDesc, method string, opts ...CallOption) (ClientStream, error) {
// allow interceptor to see all applicable call options, which means those
// configured as defaults from dial option as well as per-call options
opts = append(cc.dopts.callOptions, opts...)
if cc.dopts.streamInt != nil { if cc.dopts.streamInt != nil {
return cc.dopts.streamInt(ctx, desc, cc, method, newClientStream, opts...) return cc.dopts.streamInt(ctx, desc, cc, method, newClientStream, opts...)
} }
@ -140,7 +144,6 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
} }
}() }()
opts = append(cc.dopts.callOptions, opts...)
for _, o := range opts { for _, o := range opts {
if err := o.before(c); err != nil { if err := o.before(c); err != nil {
return nil, toRPCErr(err) return nil, toRPCErr(err)

View File

@ -627,14 +627,11 @@ func (d *nopDecompressor) Type() string {
return "nop" return "nop"
} }
func (te *test) clientConn() *grpc.ClientConn { func (te *test) clientConn(opts ...grpc.DialOption) *grpc.ClientConn {
if te.cc != nil { if te.cc != nil {
return te.cc return te.cc
} }
opts := []grpc.DialOption{ opts = append(opts, grpc.WithDialer(te.e.dialer), grpc.WithUserAgent(te.userAgent))
grpc.WithDialer(te.e.dialer),
grpc.WithUserAgent(te.userAgent),
}
if te.sc != nil { if te.sc != nil {
opts = append(opts, grpc.WithServiceConfig(te.sc)) opts = append(opts, grpc.WithServiceConfig(te.sc))
@ -5887,6 +5884,115 @@ func TestMethodFromServerStream(t *testing.T) {
} }
} }
func TestInterceptorCanAccessCallOptions(t *testing.T) {
defer leakcheck.Check(t)
e := tcpClearRREnv
te := newTest(t, e)
te.startServer(&testServer{security: e.security})
defer te.tearDown()
type observedOptions struct {
headers []*metadata.MD
trailers []*metadata.MD
peer []*peer.Peer
creds []credentials.PerRPCCredentials
failFast []bool
maxRecvSize []int
maxSendSize []int
compressor []string
subtype []string
codec []grpc.Codec
}
var observedOpts observedOptions
populateOpts := func(opts []grpc.CallOption) {
for _, o := range opts {
switch o := o.(type) {
case grpc.HeaderCallOption:
observedOpts.headers = append(observedOpts.headers, o.HeaderAddr)
case grpc.TrailerCallOption:
observedOpts.trailers = append(observedOpts.trailers, o.TrailerAddr)
case grpc.PeerCallOption:
observedOpts.peer = append(observedOpts.peer, o.PeerAddr)
case grpc.PerRPCCredsCallOption:
observedOpts.creds = append(observedOpts.creds, o.Creds)
case grpc.FailFastCallOption:
observedOpts.failFast = append(observedOpts.failFast, o.FailFast)
case grpc.MaxRecvMsgSizeCallOption:
observedOpts.maxRecvSize = append(observedOpts.maxRecvSize, o.MaxRecvMsgSize)
case grpc.MaxSendMsgSizeCallOption:
observedOpts.maxSendSize = append(observedOpts.maxSendSize, o.MaxSendMsgSize)
case grpc.CompressorCallOption:
observedOpts.compressor = append(observedOpts.compressor, o.CompressorType)
case grpc.ContentSubtypeCallOption:
observedOpts.subtype = append(observedOpts.subtype, o.ContentSubtype)
case grpc.CustomCodecCallOption:
observedOpts.codec = append(observedOpts.codec, o.Codec)
}
}
}
te.unaryClientInt = func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
populateOpts(opts)
return nil
}
te.streamClientInt = func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
populateOpts(opts)
return nil, nil
}
defaults := []grpc.CallOption{
grpc.FailFast(false),
grpc.MaxCallRecvMsgSize(1010),
}
tc := testpb.NewTestServiceClient(te.clientConn(grpc.WithDefaultCallOptions(defaults...)))
var headers metadata.MD
var trailers metadata.MD
var pr peer.Peer
tc.UnaryCall(context.Background(), &testpb.SimpleRequest{},
grpc.MaxCallRecvMsgSize(100),
grpc.MaxCallSendMsgSize(200),
grpc.PerRPCCredentials(testPerRPCCredentials{}),
grpc.Header(&headers),
grpc.Trailer(&trailers),
grpc.Peer(&pr))
expected := observedOptions{
failFast: []bool{false},
maxRecvSize: []int{1010, 100},
maxSendSize: []int{200},
creds: []credentials.PerRPCCredentials{testPerRPCCredentials{}},
headers: []*metadata.MD{&headers},
trailers: []*metadata.MD{&trailers},
peer: []*peer.Peer{&pr},
}
if !reflect.DeepEqual(expected, observedOpts) {
t.Errorf("unary call did not observe expected options: expected %#v, got %#v", expected, observedOpts)
}
observedOpts = observedOptions{} // reset
var codec errCodec
tc.StreamingInputCall(context.Background(),
grpc.FailFast(true),
grpc.MaxCallSendMsgSize(2020),
grpc.UseCompressor("comp-type"),
grpc.CallContentSubtype("json"),
grpc.CallCustomCodec(&codec))
expected = observedOptions{
failFast: []bool{false, true},
maxRecvSize: []int{1010},
maxSendSize: []int{2020},
compressor: []string{"comp-type"},
subtype: []string{"json"},
codec: []grpc.Codec{&codec},
}
if !reflect.DeepEqual(expected, observedOpts) {
t.Errorf("streaming call did not observe expected options: expected %#v, got %#v", expected, observedOpts)
}
}
func TestCompressorRegister(t *testing.T) { func TestCompressorRegister(t *testing.T) {
defer leakcheck.Check(t) defer leakcheck.Check(t)
for _, e := range listTestEnv() { for _, e := range listTestEnv() {