client: export types implementing CallOptions for access by interceptors (#1902)
This commit is contained in:
committed by
dfawley
parent
ec9275ba9a
commit
fa28bef939
4
call.go
4
call.go
@ -27,6 +27,10 @@ import (
|
||||
//
|
||||
// All errors returned by Invoke are compatible with the status package.
|
||||
func (cc *ClientConn) Invoke(ctx context.Context, method string, args, reply interface{}, opts ...CallOption) error {
|
||||
// allow interceptor to see all applicable call options, which means those
|
||||
// configured as defaults from dial option as well as per-call options
|
||||
opts = append(cc.dopts.callOptions, opts...)
|
||||
|
||||
if cc.dopts.unaryInt != nil {
|
||||
return cc.dopts.unaryInt(ctx, method, args, reply, cc, invoke, opts...)
|
||||
}
|
||||
|
||||
188
rpc_util.go
188
rpc_util.go
@ -160,46 +160,66 @@ type EmptyCallOption struct{}
|
||||
func (EmptyCallOption) before(*callInfo) error { return nil }
|
||||
func (EmptyCallOption) after(*callInfo) {}
|
||||
|
||||
type beforeCall func(c *callInfo) error
|
||||
|
||||
func (o beforeCall) before(c *callInfo) error { return o(c) }
|
||||
func (o beforeCall) after(c *callInfo) {}
|
||||
|
||||
type afterCall func(c *callInfo)
|
||||
|
||||
func (o afterCall) before(c *callInfo) error { return nil }
|
||||
func (o afterCall) after(c *callInfo) { o(c) }
|
||||
|
||||
// Header returns a CallOptions that retrieves the header metadata
|
||||
// for a unary RPC.
|
||||
func Header(md *metadata.MD) CallOption {
|
||||
return afterCall(func(c *callInfo) {
|
||||
if c.stream != nil {
|
||||
*md, _ = c.stream.Header()
|
||||
return HeaderCallOption{HeaderAddr: md}
|
||||
}
|
||||
|
||||
// HeaderCallOption is a CallOption for collecting response header metadata.
|
||||
// The metadata field will be populated *after* the RPC completes.
|
||||
// This is an EXPERIMENTAL API.
|
||||
type HeaderCallOption struct {
|
||||
HeaderAddr *metadata.MD
|
||||
}
|
||||
|
||||
func (o HeaderCallOption) before(c *callInfo) error { return nil }
|
||||
func (o HeaderCallOption) after(c *callInfo) {
|
||||
if c.stream != nil {
|
||||
*o.HeaderAddr, _ = c.stream.Header()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Trailer returns a CallOptions that retrieves the trailer metadata
|
||||
// for a unary RPC.
|
||||
func Trailer(md *metadata.MD) CallOption {
|
||||
return afterCall(func(c *callInfo) {
|
||||
if c.stream != nil {
|
||||
*md = c.stream.Trailer()
|
||||
return TrailerCallOption{TrailerAddr: md}
|
||||
}
|
||||
|
||||
// TrailerCallOption is a CallOption for collecting response trailer metadata.
|
||||
// The metadata field will be populated *after* the RPC completes.
|
||||
// This is an EXPERIMENTAL API.
|
||||
type TrailerCallOption struct {
|
||||
TrailerAddr *metadata.MD
|
||||
}
|
||||
|
||||
func (o TrailerCallOption) before(c *callInfo) error { return nil }
|
||||
func (o TrailerCallOption) after(c *callInfo) {
|
||||
if c.stream != nil {
|
||||
*o.TrailerAddr = c.stream.Trailer()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Peer returns a CallOption that retrieves peer information for a
|
||||
// unary RPC.
|
||||
func Peer(p *peer.Peer) CallOption {
|
||||
return afterCall(func(c *callInfo) {
|
||||
return PeerCallOption{PeerAddr: p}
|
||||
}
|
||||
|
||||
// PeerCallOption is a CallOption for collecting the identity of the remote
|
||||
// peer. The peer field will be populated *after* the RPC completes.
|
||||
// This is an EXPERIMENTAL API.
|
||||
type PeerCallOption struct {
|
||||
PeerAddr *peer.Peer
|
||||
}
|
||||
|
||||
func (o PeerCallOption) before(c *callInfo) error { return nil }
|
||||
func (o PeerCallOption) after(c *callInfo) {
|
||||
if c.stream != nil {
|
||||
if x, ok := peer.FromContext(c.stream.Context()); ok {
|
||||
*p = *x
|
||||
*o.PeerAddr = *x
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// FailFast configures the action to take when an RPC is attempted on broken
|
||||
@ -213,49 +233,98 @@ func Peer(p *peer.Peer) CallOption {
|
||||
//
|
||||
// By default, RPCs are "Fail Fast".
|
||||
func FailFast(failFast bool) CallOption {
|
||||
return beforeCall(func(c *callInfo) error {
|
||||
c.failFast = failFast
|
||||
return nil
|
||||
})
|
||||
return FailFastCallOption{FailFast: failFast}
|
||||
}
|
||||
|
||||
// FailFastCallOption is a CallOption for indicating whether an RPC should fail
|
||||
// fast or not.
|
||||
// This is an EXPERIMENTAL API.
|
||||
type FailFastCallOption struct {
|
||||
FailFast bool
|
||||
}
|
||||
|
||||
func (o FailFastCallOption) before(c *callInfo) error {
|
||||
c.failFast = o.FailFast
|
||||
return nil
|
||||
}
|
||||
func (o FailFastCallOption) after(c *callInfo) { return }
|
||||
|
||||
// MaxCallRecvMsgSize returns a CallOption which sets the maximum message size the client can receive.
|
||||
func MaxCallRecvMsgSize(s int) CallOption {
|
||||
return beforeCall(func(o *callInfo) error {
|
||||
o.maxReceiveMessageSize = &s
|
||||
return nil
|
||||
})
|
||||
return MaxRecvMsgSizeCallOption{MaxRecvMsgSize: s}
|
||||
}
|
||||
|
||||
// MaxRecvMsgSizeCallOption is a CallOption that indicates the maximum message
|
||||
// size the client can receive.
|
||||
// This is an EXPERIMENTAL API.
|
||||
type MaxRecvMsgSizeCallOption struct {
|
||||
MaxRecvMsgSize int
|
||||
}
|
||||
|
||||
func (o MaxRecvMsgSizeCallOption) before(c *callInfo) error {
|
||||
c.maxReceiveMessageSize = &o.MaxRecvMsgSize
|
||||
return nil
|
||||
}
|
||||
func (o MaxRecvMsgSizeCallOption) after(c *callInfo) { return }
|
||||
|
||||
// MaxCallSendMsgSize returns a CallOption which sets the maximum message size the client can send.
|
||||
func MaxCallSendMsgSize(s int) CallOption {
|
||||
return beforeCall(func(o *callInfo) error {
|
||||
o.maxSendMessageSize = &s
|
||||
return nil
|
||||
})
|
||||
return MaxSendMsgSizeCallOption{MaxSendMsgSize: s}
|
||||
}
|
||||
|
||||
// MaxSendMsgSizeCallOption is a CallOption that indicates the maximum message
|
||||
// size the client can send.
|
||||
// This is an EXPERIMENTAL API.
|
||||
type MaxSendMsgSizeCallOption struct {
|
||||
MaxSendMsgSize int
|
||||
}
|
||||
|
||||
func (o MaxSendMsgSizeCallOption) before(c *callInfo) error {
|
||||
c.maxSendMessageSize = &o.MaxSendMsgSize
|
||||
return nil
|
||||
}
|
||||
func (o MaxSendMsgSizeCallOption) after(c *callInfo) { return }
|
||||
|
||||
// PerRPCCredentials returns a CallOption that sets credentials.PerRPCCredentials
|
||||
// for a call.
|
||||
func PerRPCCredentials(creds credentials.PerRPCCredentials) CallOption {
|
||||
return beforeCall(func(c *callInfo) error {
|
||||
c.creds = creds
|
||||
return nil
|
||||
})
|
||||
return PerRPCCredsCallOption{Creds: creds}
|
||||
}
|
||||
|
||||
// PerRPCCredsCallOption is a CallOption that indicates the the per-RPC
|
||||
// credentials to use for the call.
|
||||
// This is an EXPERIMENTAL API.
|
||||
type PerRPCCredsCallOption struct {
|
||||
Creds credentials.PerRPCCredentials
|
||||
}
|
||||
|
||||
func (o PerRPCCredsCallOption) before(c *callInfo) error {
|
||||
c.creds = o.Creds
|
||||
return nil
|
||||
}
|
||||
func (o PerRPCCredsCallOption) after(c *callInfo) { return }
|
||||
|
||||
// UseCompressor returns a CallOption which sets the compressor used when
|
||||
// sending the request. If WithCompressor is also set, UseCompressor has
|
||||
// higher priority.
|
||||
//
|
||||
// This API is EXPERIMENTAL.
|
||||
func UseCompressor(name string) CallOption {
|
||||
return beforeCall(func(c *callInfo) error {
|
||||
c.compressorType = name
|
||||
return nil
|
||||
})
|
||||
return CompressorCallOption{CompressorType: name}
|
||||
}
|
||||
|
||||
// CompressorCallOption is a CallOption that indicates the compressor to use.
|
||||
// This is an EXPERIMENTAL API.
|
||||
type CompressorCallOption struct {
|
||||
CompressorType string
|
||||
}
|
||||
|
||||
func (o CompressorCallOption) before(c *callInfo) error {
|
||||
c.compressorType = o.CompressorType
|
||||
return nil
|
||||
}
|
||||
func (o CompressorCallOption) after(c *callInfo) { return }
|
||||
|
||||
// CallContentSubtype returns a CallOption that will set the content-subtype
|
||||
// for a call. For example, if content-subtype is "json", the Content-Type over
|
||||
// the wire will be "application/grpc+json". The content-subtype is converted
|
||||
@ -273,13 +342,22 @@ func UseCompressor(name string) CallOption {
|
||||
// response messages, with the content-subtype set to the given contentSubtype
|
||||
// here for requests.
|
||||
func CallContentSubtype(contentSubtype string) CallOption {
|
||||
contentSubtype = strings.ToLower(contentSubtype)
|
||||
return beforeCall(func(c *callInfo) error {
|
||||
c.contentSubtype = contentSubtype
|
||||
return nil
|
||||
})
|
||||
return ContentSubtypeCallOption{ContentSubtype: strings.ToLower(contentSubtype)}
|
||||
}
|
||||
|
||||
// ContentSubtypeCallOption is a CallOption that indicates the content-subtype
|
||||
// used for marshaling messages.
|
||||
// This is an EXPERIMENTAL API.
|
||||
type ContentSubtypeCallOption struct {
|
||||
ContentSubtype string
|
||||
}
|
||||
|
||||
func (o ContentSubtypeCallOption) before(c *callInfo) error {
|
||||
c.contentSubtype = o.ContentSubtype
|
||||
return nil
|
||||
}
|
||||
func (o ContentSubtypeCallOption) after(c *callInfo) { return }
|
||||
|
||||
// CallCustomCodec returns a CallOption that will set the given Codec to be
|
||||
// used for all request and response messages for a call. The result of calling
|
||||
// String() will be used as the content-subtype in a case-insensitive manner.
|
||||
@ -293,12 +371,22 @@ func CallContentSubtype(contentSubtype string) CallOption {
|
||||
// This function is provided for advanced users; prefer to use only
|
||||
// CallContentSubtype to select a registered codec instead.
|
||||
func CallCustomCodec(codec Codec) CallOption {
|
||||
return beforeCall(func(c *callInfo) error {
|
||||
c.codec = codec
|
||||
return nil
|
||||
})
|
||||
return CustomCodecCallOption{Codec: codec}
|
||||
}
|
||||
|
||||
// CustomCodecCallOption is a CallOption that indicates the codec used for
|
||||
// marshaling messages.
|
||||
// This is an EXPERIMENTAL API.
|
||||
type CustomCodecCallOption struct {
|
||||
Codec Codec
|
||||
}
|
||||
|
||||
func (o CustomCodecCallOption) before(c *callInfo) error {
|
||||
c.codec = o.Codec
|
||||
return nil
|
||||
}
|
||||
func (o CustomCodecCallOption) after(c *callInfo) { return }
|
||||
|
||||
// The format of the payload: compressed or not?
|
||||
type payloadFormat uint8
|
||||
|
||||
|
||||
@ -102,6 +102,10 @@ type ClientStream interface {
|
||||
// NewStream creates a new Stream for the client side. This is typically
|
||||
// called by generated code.
|
||||
func (cc *ClientConn) NewStream(ctx context.Context, desc *StreamDesc, method string, opts ...CallOption) (ClientStream, error) {
|
||||
// allow interceptor to see all applicable call options, which means those
|
||||
// configured as defaults from dial option as well as per-call options
|
||||
opts = append(cc.dopts.callOptions, opts...)
|
||||
|
||||
if cc.dopts.streamInt != nil {
|
||||
return cc.dopts.streamInt(ctx, desc, cc, method, newClientStream, opts...)
|
||||
}
|
||||
@ -140,7 +144,6 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
|
||||
}
|
||||
}()
|
||||
|
||||
opts = append(cc.dopts.callOptions, opts...)
|
||||
for _, o := range opts {
|
||||
if err := o.before(c); err != nil {
|
||||
return nil, toRPCErr(err)
|
||||
|
||||
@ -627,14 +627,11 @@ func (d *nopDecompressor) Type() string {
|
||||
return "nop"
|
||||
}
|
||||
|
||||
func (te *test) clientConn() *grpc.ClientConn {
|
||||
func (te *test) clientConn(opts ...grpc.DialOption) *grpc.ClientConn {
|
||||
if te.cc != nil {
|
||||
return te.cc
|
||||
}
|
||||
opts := []grpc.DialOption{
|
||||
grpc.WithDialer(te.e.dialer),
|
||||
grpc.WithUserAgent(te.userAgent),
|
||||
}
|
||||
opts = append(opts, grpc.WithDialer(te.e.dialer), grpc.WithUserAgent(te.userAgent))
|
||||
|
||||
if te.sc != nil {
|
||||
opts = append(opts, grpc.WithServiceConfig(te.sc))
|
||||
@ -5887,6 +5884,115 @@ func TestMethodFromServerStream(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestInterceptorCanAccessCallOptions(t *testing.T) {
|
||||
defer leakcheck.Check(t)
|
||||
e := tcpClearRREnv
|
||||
te := newTest(t, e)
|
||||
te.startServer(&testServer{security: e.security})
|
||||
defer te.tearDown()
|
||||
|
||||
type observedOptions struct {
|
||||
headers []*metadata.MD
|
||||
trailers []*metadata.MD
|
||||
peer []*peer.Peer
|
||||
creds []credentials.PerRPCCredentials
|
||||
failFast []bool
|
||||
maxRecvSize []int
|
||||
maxSendSize []int
|
||||
compressor []string
|
||||
subtype []string
|
||||
codec []grpc.Codec
|
||||
}
|
||||
var observedOpts observedOptions
|
||||
populateOpts := func(opts []grpc.CallOption) {
|
||||
for _, o := range opts {
|
||||
switch o := o.(type) {
|
||||
case grpc.HeaderCallOption:
|
||||
observedOpts.headers = append(observedOpts.headers, o.HeaderAddr)
|
||||
case grpc.TrailerCallOption:
|
||||
observedOpts.trailers = append(observedOpts.trailers, o.TrailerAddr)
|
||||
case grpc.PeerCallOption:
|
||||
observedOpts.peer = append(observedOpts.peer, o.PeerAddr)
|
||||
case grpc.PerRPCCredsCallOption:
|
||||
observedOpts.creds = append(observedOpts.creds, o.Creds)
|
||||
case grpc.FailFastCallOption:
|
||||
observedOpts.failFast = append(observedOpts.failFast, o.FailFast)
|
||||
case grpc.MaxRecvMsgSizeCallOption:
|
||||
observedOpts.maxRecvSize = append(observedOpts.maxRecvSize, o.MaxRecvMsgSize)
|
||||
case grpc.MaxSendMsgSizeCallOption:
|
||||
observedOpts.maxSendSize = append(observedOpts.maxSendSize, o.MaxSendMsgSize)
|
||||
case grpc.CompressorCallOption:
|
||||
observedOpts.compressor = append(observedOpts.compressor, o.CompressorType)
|
||||
case grpc.ContentSubtypeCallOption:
|
||||
observedOpts.subtype = append(observedOpts.subtype, o.ContentSubtype)
|
||||
case grpc.CustomCodecCallOption:
|
||||
observedOpts.codec = append(observedOpts.codec, o.Codec)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
te.unaryClientInt = func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
|
||||
populateOpts(opts)
|
||||
return nil
|
||||
}
|
||||
te.streamClientInt = func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
|
||||
populateOpts(opts)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
defaults := []grpc.CallOption{
|
||||
grpc.FailFast(false),
|
||||
grpc.MaxCallRecvMsgSize(1010),
|
||||
}
|
||||
tc := testpb.NewTestServiceClient(te.clientConn(grpc.WithDefaultCallOptions(defaults...)))
|
||||
|
||||
var headers metadata.MD
|
||||
var trailers metadata.MD
|
||||
var pr peer.Peer
|
||||
tc.UnaryCall(context.Background(), &testpb.SimpleRequest{},
|
||||
grpc.MaxCallRecvMsgSize(100),
|
||||
grpc.MaxCallSendMsgSize(200),
|
||||
grpc.PerRPCCredentials(testPerRPCCredentials{}),
|
||||
grpc.Header(&headers),
|
||||
grpc.Trailer(&trailers),
|
||||
grpc.Peer(&pr))
|
||||
expected := observedOptions{
|
||||
failFast: []bool{false},
|
||||
maxRecvSize: []int{1010, 100},
|
||||
maxSendSize: []int{200},
|
||||
creds: []credentials.PerRPCCredentials{testPerRPCCredentials{}},
|
||||
headers: []*metadata.MD{&headers},
|
||||
trailers: []*metadata.MD{&trailers},
|
||||
peer: []*peer.Peer{&pr},
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(expected, observedOpts) {
|
||||
t.Errorf("unary call did not observe expected options: expected %#v, got %#v", expected, observedOpts)
|
||||
}
|
||||
|
||||
observedOpts = observedOptions{} // reset
|
||||
|
||||
var codec errCodec
|
||||
tc.StreamingInputCall(context.Background(),
|
||||
grpc.FailFast(true),
|
||||
grpc.MaxCallSendMsgSize(2020),
|
||||
grpc.UseCompressor("comp-type"),
|
||||
grpc.CallContentSubtype("json"),
|
||||
grpc.CallCustomCodec(&codec))
|
||||
expected = observedOptions{
|
||||
failFast: []bool{false, true},
|
||||
maxRecvSize: []int{1010},
|
||||
maxSendSize: []int{2020},
|
||||
compressor: []string{"comp-type"},
|
||||
subtype: []string{"json"},
|
||||
codec: []grpc.Codec{&codec},
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(expected, observedOpts) {
|
||||
t.Errorf("streaming call did not observe expected options: expected %#v, got %#v", expected, observedOpts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompressorRegister(t *testing.T) {
|
||||
defer leakcheck.Check(t)
|
||||
for _, e := range listTestEnv() {
|
||||
|
||||
Reference in New Issue
Block a user