diff --git a/call.go b/call.go index c1588c63..e92a4bc9 100644 --- a/call.go +++ b/call.go @@ -51,7 +51,7 @@ import ( // // TODO(zhaoq): Check whether the received message sequence is valid. // TODO ctx is used for stats collection and processing. It is the context passed from the application. -func recvResponse(ctx context.Context, dopts dialOptions, t transport.ClientTransport, c *callInfo, stream *transport.Stream, reply interface{}) (err error) { +func recvResponse(ctx context.Context, dopts dialOptions, msgSizeLimit int, t transport.ClientTransport, c *callInfo, stream *transport.Stream, reply interface{}) (err error) { // Try to acquire header metadata from the server if there is any. defer func() { if err != nil { @@ -72,7 +72,7 @@ func recvResponse(ctx context.Context, dopts dialOptions, t transport.ClientTran } } for { - if err = recv(p, dopts.codec, stream, dopts.dc, reply, dopts.maxMsgSize, inPayload); err != nil { + if err = recv(p, dopts.codec, stream, dopts.dc, reply, msgSizeLimit, inPayload); err != nil { if err == io.EOF { break } @@ -92,7 +92,7 @@ func recvResponse(ctx context.Context, dopts dialOptions, t transport.ClientTran } // sendRequest writes out various information of an RPC such as Context and Message. -func sendRequest(ctx context.Context, dopts dialOptions, compressor Compressor, callHdr *transport.CallHdr, t transport.ClientTransport, args interface{}, opts *transport.Options) (_ *transport.Stream, err error) { +func sendRequest(ctx context.Context, dopts dialOptions, compressor Compressor, msgSizeLimit int, callHdr *transport.CallHdr, t transport.ClientTransport, args interface{}, opts *transport.Options) (_ *transport.Stream, err error) { stream, err := t.NewStream(ctx, callHdr) if err != nil { return nil, err @@ -121,6 +121,9 @@ func sendRequest(ctx context.Context, dopts dialOptions, compressor Compressor, if err != nil { return nil, Errorf(codes.Internal, "grpc: %v", err) } + if len(outBuf) > msgSizeLimit { + return nil, Errorf(codes.InvalidArgument, "Sent message larger than max (%d vs. %d)", len(outBuf), msgSizeLimit) + } err = t.Write(stream, outBuf, opts) if err == nil && outPayload != nil { outPayload.SentTime = time.Now() @@ -146,15 +149,49 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli return invoke(ctx, method, args, reply, cc, opts...) } +const defaultClientMaxReceiveMessageSize = 1024 * 1024 * 4 +const defaultClientMaxSendMessageSize = 1024 * 1024 * 4 + +func min(a, b int) int { + if a < b { + return a + } + return b +} + func invoke(ctx context.Context, method string, args, reply interface{}, cc *ClientConn, opts ...CallOption) (e error) { c := defaultCallInfo - if mc, ok := cc.getMethodConfig(method); ok { - c.failFast = !mc.WaitForReady - if mc.Timeout > 0 { + maxReceiveMessageSize := defaultClientMaxReceiveMessageSize + maxSendMessageSize := defaultClientMaxSendMessageSize + if mc, ok := cc.GetMethodConfig(method); ok { + if mc.WaitForReady != nil { + c.failFast = !*mc.WaitForReady + } + + if mc.Timeout != nil && *mc.Timeout >= 0 { var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, mc.Timeout) + ctx, cancel = context.WithTimeout(ctx, *mc.Timeout) defer cancel() } + + if mc.MaxReqSize != nil && cc.dopts.maxSendMessageSize >= 0 { + maxSendMessageSize = min(*mc.MaxReqSize, cc.dopts.maxSendMessageSize) + } else if mc.MaxReqSize != nil { + maxSendMessageSize = *mc.MaxReqSize + } + + if mc.MaxRespSize != nil && cc.dopts.maxReceiveMessageSize >= 0 { + maxReceiveMessageSize = min(*mc.MaxRespSize, cc.dopts.maxReceiveMessageSize) + } else if mc.MaxRespSize != nil { + maxReceiveMessageSize = *mc.MaxRespSize + } + } else { + if cc.dopts.maxSendMessageSize >= 0 { + maxSendMessageSize = cc.dopts.maxSendMessageSize + } + if cc.dopts.maxReceiveMessageSize >= 0 { + maxReceiveMessageSize = cc.dopts.maxReceiveMessageSize + } } for _, o := range opts { if err := o.before(&c); err != nil { @@ -245,7 +282,7 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli if c.traceInfo.tr != nil { c.traceInfo.tr.LazyLog(&payload{sent: true, msg: args}, true) } - stream, err = sendRequest(ctx, cc.dopts, cc.dopts.cp, callHdr, t, args, topts) + stream, err = sendRequest(ctx, cc.dopts, cc.dopts.cp, maxSendMessageSize, callHdr, t, args, topts) if err != nil { if put != nil { put() @@ -262,7 +299,7 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli } return toRPCErr(err) } - err = recvResponse(ctx, cc.dopts, t, &c, stream, reply) + err = recvResponse(ctx, cc.dopts, maxReceiveMessageSize, t, &c, stream, reply) if err != nil { if put != nil { put() diff --git a/clientconn.go b/clientconn.go index 1ba592c5..c257284f 100644 --- a/clientconn.go +++ b/clientconn.go @@ -88,19 +88,20 @@ var ( // dialOptions configure a Dial call. dialOptions are set by the DialOption // values passed to Dial. type dialOptions struct { - unaryInt UnaryClientInterceptor - streamInt StreamClientInterceptor - codec Codec - cp Compressor - dc Decompressor - bs backoffStrategy - balancer Balancer - block bool - insecure bool - timeout time.Duration - scChan <-chan ServiceConfig - copts transport.ConnectOptions - maxMsgSize int + unaryInt UnaryClientInterceptor + streamInt StreamClientInterceptor + codec Codec + cp Compressor + dc Decompressor + bs backoffStrategy + balancer Balancer + block bool + insecure bool + timeout time.Duration + scChan <-chan ServiceConfig + copts transport.ConnectOptions + maxReceiveMessageSize int + maxSendMessageSize int } const defaultClientMaxMsgSize = math.MaxInt32 @@ -108,10 +109,24 @@ const defaultClientMaxMsgSize = math.MaxInt32 // DialOption configures how we set up the connection. type DialOption func(*dialOptions) -// WithMaxMsgSize returns a DialOption which sets the maximum message size the client can receive. +// WithMaxMsgSize returns a DialOption which sets the maximum message size the client can receive. This function is for backward API compatibility. It has essentially the same functionality as WithMaxReceiveMessageSize. func WithMaxMsgSize(s int) DialOption { return func(o *dialOptions) { - o.maxMsgSize = s + o.maxReceiveMessageSize = s + } +} + +// WithMaxReceiveMessageSize returns a DialOption which sets the maximum message size the client can receive. Negative input is invalid and has the same effect as not setting the field. +func WithMaxReceiveMessageSize(s int) DialOption { + return func(o *dialOptions) { + o.maxReceiveMessageSize = s + } +} + +// WithMaxSendMessageSize returns a DialOption which sets the maximum message size the client can send. Negative input is invalid and has the same effect as not seeting the field. +func WithMaxSendMessageSize(s int) DialOption { + return func(o *dialOptions) { + o.maxSendMessageSize = s } } @@ -307,7 +322,11 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * conns: make(map[Address]*addrConn), } cc.ctx, cc.cancel = context.WithCancel(context.Background()) - cc.dopts.maxMsgSize = defaultClientMaxMsgSize + + // initialize maxReceiveMessageSize and maxSendMessageSize to -1 before applying DialOption functions to distinguish whether the user set the message limit or not. + cc.dopts.maxReceiveMessageSize = -1 + cc.dopts.maxSendMessageSize = -1 + for _, opt := range opts { opt(&cc.dopts) } @@ -609,11 +628,16 @@ func (cc *ClientConn) resetAddrConn(addr Address, skipWait bool, tearDownErr err return nil } +// GetMethodConfig gets the method config of the input method. If there's no exact match for the input method (i.e. /service/method), we will return the default config for all methods under the service (/service/). // TODO: Avoid the locking here. -func (cc *ClientConn) getMethodConfig(method string) (m MethodConfig, ok bool) { +func (cc *ClientConn) GetMethodConfig(method string) (m MethodConfig, ok bool) { cc.mu.RLock() defer cc.mu.RUnlock() m, ok = cc.sc.Methods[method] + if !ok { + i := strings.LastIndex(method, "/") + m, ok = cc.sc.Methods[method[:i+1]] + } return } diff --git a/rpc_util.go b/rpc_util.go index 73c3a966..df232af1 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -239,7 +239,7 @@ type parser struct { // No other error values or types must be returned, which also means // that the underlying io.Reader must not return an incompatible // error. -func (p *parser) recvMsg(maxMsgSize int) (pf payloadFormat, msg []byte, err error) { +func (p *parser) recvMsg(maxReceiveMessageSize int) (pf payloadFormat, msg []byte, err error) { if _, err := io.ReadFull(p.r, p.header[:]); err != nil { return 0, nil, err } @@ -250,8 +250,8 @@ func (p *parser) recvMsg(maxMsgSize int) (pf payloadFormat, msg []byte, err erro if length == 0 { return pf, nil, nil } - if length > uint32(maxMsgSize) { - return 0, nil, Errorf(codes.Internal, "grpc: received message length %d exceeding the max size %d", length, maxMsgSize) + if length > uint32(maxReceiveMessageSize) { + return 0, nil, Errorf(codes.InvalidArgument, "grpc: Received message larger than max (%d vs. %d)", length, maxReceiveMessageSize) } // TODO(bradfitz,zhaoq): garbage. reuse buffer after proto decoding instead // of making it for each message: @@ -335,8 +335,8 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, dc Decompressor) er return nil } -func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{}, maxMsgSize int, inPayload *stats.InPayload) error { - pf, d, err := p.recvMsg(maxMsgSize) +func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{}, maxReceiveMessageSize int, inPayload *stats.InPayload) error { + pf, d, err := p.recvMsg(maxReceiveMessageSize) if err != nil { return err } @@ -352,10 +352,10 @@ func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{ return Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err) } } - if len(d) > maxMsgSize { + if len(d) > maxReceiveMessageSize { // TODO: Revisit the error code. Currently keep it consistent with java // implementation. - return Errorf(codes.Internal, "grpc: received a message of %d bytes exceeding %d limit", len(d), maxMsgSize) + return Errorf(codes.InvalidArgument, "grpc: Received message larger than max (%d vs. %d)", len(d), maxReceiveMessageSize) } if err := c.Unmarshal(d, m); err != nil { return Errorf(codes.Internal, "grpc: failed to unmarshal the received message %v", err) @@ -489,24 +489,22 @@ type MethodConfig struct { // WaitForReady indicates whether RPCs sent to this method should wait until // the connection is ready by default (!failfast). The value specified via the // gRPC client API will override the value set here. - WaitForReady bool + WaitForReady *bool // Timeout is the default timeout for RPCs sent to this method. The actual // deadline used will be the minimum of the value specified here and the value // set by the application via the gRPC client API. If either one is not set, // then the other will be used. If neither is set, then the RPC has no deadline. - Timeout time.Duration + Timeout *time.Duration // MaxReqSize is the maximum allowed payload size for an individual request in a // stream (client->server) in bytes. The size which is measured is the serialized // payload after per-message compression (but before stream compression) in bytes. // The actual value used is the minumum of the value specified here and the value set // by the application via the gRPC client API. If either one is not set, then the other // will be used. If neither is set, then the built-in default is used. - // TODO: support this. - MaxReqSize uint32 + MaxReqSize *int // MaxRespSize is the maximum allowed payload size for an individual response in a // stream (server->client) in bytes. - // TODO: support this. - MaxRespSize uint32 + MaxRespSize *int } // ServiceConfig is provided by the service provider and contains parameters for how diff --git a/server.go b/server.go index 157f35ee..af0b9318 100644 --- a/server.go +++ b/server.go @@ -105,21 +105,23 @@ type Server struct { } type options struct { - creds credentials.TransportCredentials - codec Codec - cp Compressor - dc Decompressor - maxMsgSize int - unaryInt UnaryServerInterceptor - streamInt StreamServerInterceptor - inTapHandle tap.ServerInHandle - statsHandler stats.Handler - maxConcurrentStreams uint32 - useHandlerImpl bool // use http.Handler-based server - unknownStreamDesc *StreamDesc + creds credentials.TransportCredentials + codec Codec + cp Compressor + dc Decompressor + maxReceiveMessageSize int + maxSendMessageSize int + unaryInt UnaryServerInterceptor + streamInt StreamServerInterceptor + inTapHandle tap.ServerInHandle + statsHandler stats.Handler + maxConcurrentStreams uint32 + useHandlerImpl bool // use http.Handler-based server + unknownStreamDesc *StreamDesc } -var defaultMaxMsgSize = 1024 * 1024 * 4 // use 4MB as the default message size limit +const defaultServerMaxReceiveMessageSize = 1024 * 1024 * 4 // Use 4MB as the default receive message size limit. +const defaultServerMaxSendMessageSize = 1024 * 1024 * 4 // Use 4MB as the default send message size limit. // A ServerOption sets options. type ServerOption func(*options) @@ -146,10 +148,26 @@ func RPCDecompressor(dc Decompressor) ServerOption { } // MaxMsgSize returns a ServerOption to set the max message size in bytes for inbound mesages. -// If this is not set, gRPC uses the default 4MB. +// If this is not set, gRPC uses the default 4MB. This function is for backward compatability. It has essentially the same functionality as MaxReceiveMessageSize. func MaxMsgSize(m int) ServerOption { return func(o *options) { - o.maxMsgSize = m + o.maxReceiveMessageSize = m + } +} + +// MaxReceiveMessageSize returns a ServerOption to set the max message size in bytes for inbound mesages. +// If this is not set, gRPC uses the default 4MB. +func MaxReceiveMessageSize(m int) ServerOption { + return func(o *options) { + o.maxReceiveMessageSize = m + } +} + +// MaxSendMessageSize returns a ServerOption to set the max message size in bytes for outbound mesages. +// If this is not set, gRPC uses the default 4MB. +func MaxSendMessageSize(m int) ServerOption { + return func(o *options) { + o.maxSendMessageSize = m } } @@ -211,7 +229,7 @@ func StatsHandler(h stats.Handler) ServerOption { // UnknownServiceHandler returns a ServerOption that allows for adding a custom // unknown service handler. The provided method is a bidi-streaming RPC service -// handler that will be invoked instead of returning the the "unimplemented" gRPC +// handler that will be invoked instead of returning the "unimplemented" gRPC // error whenever a request is received for an unregistered service or method. // The handling function has full access to the Context of the request and the // stream, and the invocation passes through interceptors. @@ -231,7 +249,8 @@ func UnknownServiceHandler(streamHandler StreamHandler) ServerOption { // started to accept requests yet. func NewServer(opt ...ServerOption) *Server { var opts options - opts.maxMsgSize = defaultMaxMsgSize + opts.maxReceiveMessageSize = defaultServerMaxReceiveMessageSize + opts.maxSendMessageSize = defaultServerMaxSendMessageSize for _, o := range opt { o(&opts) } @@ -609,6 +628,9 @@ func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Str // the optimal option. grpclog.Fatalf("grpc: Server failed to encode response %v", err) } + if len(p) > s.opts.maxSendMessageSize { + return Errorf(codes.InvalidArgument, "Sent message larger than max (%d vs. %d)", len(p), s.opts.maxSendMessageSize) + } err = t.Write(stream, p, opts) if err == nil && outPayload != nil { outPayload.SentTime = time.Now() @@ -653,7 +675,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } p := &parser{r: stream} for { - pf, req, err := p.recvMsg(s.opts.maxMsgSize) + pf, req, err := p.recvMsg(s.opts.maxReceiveMessageSize) if err == io.EOF { // The entire stream is done (for unary RPC only). return err @@ -715,11 +737,11 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. return Errorf(codes.Internal, err.Error()) } } - if len(req) > s.opts.maxMsgSize { + if len(req) > s.opts.maxReceiveMessageSize { // TODO: Revisit the error code. Currently keep it consistent with // java implementation. - statusCode = codes.Internal - statusDesc = fmt.Sprintf("grpc: server received a message of %d bytes exceeding %d limit", len(req), s.opts.maxMsgSize) + statusCode = codes.InvalidArgument + statusDesc = fmt.Sprintf("grpc: server received a message of %d bytes exceeding %d limit", len(req), s.opts.maxReceiveMessageSize) } if err := s.opts.codec.Unmarshal(req, v); err != nil { return err @@ -771,7 +793,6 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. statusCode = codes.Unknown statusDesc = err.Error() } - return err } if trInfo != nil { trInfo.tr.LazyLog(&payload{sent: true, msg: reply}, true) @@ -807,15 +828,16 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp stream.SetSendCompress(s.opts.cp.Type()) } ss := &serverStream{ - t: t, - s: stream, - p: &parser{r: stream}, - codec: s.opts.codec, - cp: s.opts.cp, - dc: s.opts.dc, - maxMsgSize: s.opts.maxMsgSize, - trInfo: trInfo, - statsHandler: sh, + t: t, + s: stream, + p: &parser{r: stream}, + codec: s.opts.codec, + cp: s.opts.cp, + dc: s.opts.dc, + maxReceiveMessageSize: s.opts.maxReceiveMessageSize, + maxSendMessageSize: s.opts.maxSendMessageSize, + trInfo: trInfo, + statsHandler: sh, } if ss.cp != nil { ss.cbuf = new(bytes.Buffer) diff --git a/stream.go b/stream.go index 0ef2077c..034bcf7e 100644 --- a/stream.go +++ b/stream.go @@ -112,12 +112,40 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth cancel context.CancelFunc ) c := defaultCallInfo - if mc, ok := cc.getMethodConfig(method); ok { - c.failFast = !mc.WaitForReady - if mc.Timeout > 0 { - ctx, cancel = context.WithTimeout(ctx, mc.Timeout) + maxReceiveMessageSize := defaultClientMaxReceiveMessageSize + maxSendMessageSize := defaultClientMaxSendMessageSize + + if mc, ok := cc.GetMethodConfig(method); ok { + if mc.WaitForReady != nil { + c.failFast = !*mc.WaitForReady + } + + if mc.Timeout != nil && *mc.Timeout >= 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, *mc.Timeout) + defer cancel() + } + + if mc.MaxReqSize != nil && cc.dopts.maxSendMessageSize >= 0 { + maxSendMessageSize = min(*mc.MaxReqSize, cc.dopts.maxSendMessageSize) + } else if mc.MaxReqSize != nil { + maxSendMessageSize = *mc.MaxReqSize + } + + if mc.MaxRespSize != nil && cc.dopts.maxReceiveMessageSize >= 0 { + maxReceiveMessageSize = min(*mc.MaxRespSize, cc.dopts.maxReceiveMessageSize) + } else if mc.MaxRespSize != nil { + maxReceiveMessageSize = *mc.MaxRespSize + } + } else { + if cc.dopts.maxSendMessageSize >= 0 { + maxSendMessageSize = cc.dopts.maxSendMessageSize + } + if cc.dopts.maxReceiveMessageSize >= 0 { + maxReceiveMessageSize = cc.dopts.maxReceiveMessageSize } } + for _, o := range opts { if err := o.before(&c); err != nil { return nil, toRPCErr(err) @@ -207,14 +235,15 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth break } cs := &clientStream{ - opts: opts, - c: c, - desc: desc, - codec: cc.dopts.codec, - cp: cc.dopts.cp, - dc: cc.dopts.dc, - maxMsgSize: cc.dopts.maxMsgSize, - cancel: cancel, + opts: opts, + c: c, + desc: desc, + codec: cc.dopts.codec, + cp: cc.dopts.cp, + dc: cc.dopts.dc, + maxReceiveMessageSize: maxReceiveMessageSize, + maxSendMessageSize: maxSendMessageSize, + cancel: cancel, put: put, t: t, @@ -259,18 +288,19 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth // clientStream implements a client side Stream. type clientStream struct { - opts []CallOption - c callInfo - t transport.ClientTransport - s *transport.Stream - p *parser - desc *StreamDesc - codec Codec - cp Compressor - cbuf *bytes.Buffer - dc Decompressor - maxMsgSize int - cancel context.CancelFunc + opts []CallOption + c callInfo + t transport.ClientTransport + s *transport.Stream + p *parser + desc *StreamDesc + codec Codec + cp Compressor + cbuf *bytes.Buffer + dc Decompressor + maxReceiveMessageSize int + maxSendMessageSize int + cancel context.CancelFunc tracing bool // set to EnableTracing when the clientStream is created. @@ -353,6 +383,9 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) { if err != nil { return Errorf(codes.Internal, "grpc: %v", err) } + if len(out) > cs.maxSendMessageSize { + return Errorf(codes.InvalidArgument, "Sent message larger than max (%d vs. %d)", len(out), cs.maxSendMessageSize) + } err = cs.t.Write(cs.s, out, &transport.Options{Last: false}) if err == nil && outPayload != nil { outPayload.SentTime = time.Now() @@ -383,7 +416,7 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) { Client: true, } } - err = recv(cs.p, cs.codec, cs.s, cs.dc, m, cs.maxMsgSize, inPayload) + err = recv(cs.p, cs.codec, cs.s, cs.dc, m, cs.maxReceiveMessageSize, inPayload) defer func() { // err != nil indicates the termination of the stream. if err != nil { @@ -406,7 +439,7 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) { } // Special handling for client streaming rpc. // This recv expects EOF or errors, so we don't collect inPayload. - err = recv(cs.p, cs.codec, cs.s, cs.dc, m, cs.maxMsgSize, nil) + err = recv(cs.p, cs.codec, cs.s, cs.dc, m, cs.maxReceiveMessageSize, nil) cs.closeTransportStream(err) if err == nil { return toRPCErr(errors.New("grpc: client streaming protocol violation: get , want ")) @@ -512,17 +545,18 @@ type ServerStream interface { // serverStream implements a server side Stream. type serverStream struct { - t transport.ServerTransport - s *transport.Stream - p *parser - codec Codec - cp Compressor - dc Decompressor - cbuf *bytes.Buffer - maxMsgSize int - statusCode codes.Code - statusDesc string - trInfo *traceInfo + t transport.ServerTransport + s *transport.Stream + p *parser + codec Codec + cp Compressor + dc Decompressor + cbuf *bytes.Buffer + maxReceiveMessageSize int + maxSendMessageSize int + statusCode codes.Code + statusDesc string + trInfo *traceInfo statsHandler stats.Handler @@ -581,6 +615,9 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) { err = Errorf(codes.Internal, "grpc: %v", err) return err } + if len(out) > ss.maxSendMessageSize { + return Errorf(codes.InvalidArgument, "Sent message larger than max (%d vs. %d)", len(out), ss.maxSendMessageSize) + } if err := ss.t.Write(ss.s, out, &transport.Options{Last: false}); err != nil { return toRPCErr(err) } @@ -610,7 +647,7 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) { if ss.statsHandler != nil { inPayload = &stats.InPayload{} } - if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxMsgSize, inPayload); err != nil { + if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxReceiveMessageSize, inPayload); err != nil { if err == io.EOF { return err } diff --git a/test/end2end_test.go b/test/end2end_test.go index 98d590e9..910a3113 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -416,20 +416,25 @@ type test struct { cancel context.CancelFunc // Configurable knobs, after newTest returns: - testServer testpb.TestServiceServer // nil means none - healthServer *health.Server // nil means disabled - maxStream uint32 - tapHandle tap.ServerInHandle - maxMsgSize int - userAgent string - clientCompression bool - serverCompression bool - unaryClientInt grpc.UnaryClientInterceptor - streamClientInt grpc.StreamClientInterceptor - unaryServerInt grpc.UnaryServerInterceptor - streamServerInt grpc.StreamServerInterceptor - unknownHandler grpc.StreamHandler - sc <-chan grpc.ServiceConfig + testServer testpb.TestServiceServer // nil means none + healthServer *health.Server // nil means disabled + maxStream uint32 + tapHandle tap.ServerInHandle + maxMsgSize int + maxClientReceiveMsgSize int + maxClientSendMsgSize int + maxServerReceiveMsgSize int + maxServerSendMsgSize int + userAgent string + clientCompression bool + serverCompression bool + timeout time.Duration + unaryClientInt grpc.UnaryClientInterceptor + streamClientInt grpc.StreamClientInterceptor + unaryServerInt grpc.UnaryServerInterceptor + streamServerInt grpc.StreamServerInterceptor + unknownHandler grpc.StreamHandler + sc <-chan grpc.ServiceConfig // srv and srvAddr are set once startServer is called. srv *grpc.Server @@ -465,6 +470,12 @@ func newTest(t *testing.T, e env) *test { t: t, e: e, maxStream: math.MaxUint32, + // Default value 0 is meaningful (0 byte msg size limit), thus using -1 to indiciate the field is unset. + maxClientReceiveMsgSize: -1, + maxClientSendMsgSize: -1, + maxServerReceiveMsgSize: -1, + maxServerSendMsgSize: -1, + maxMsgSize: -1, } te.ctx, te.cancel = context.WithCancel(context.Background()) return te @@ -476,9 +487,15 @@ func (te *test) startServer(ts testpb.TestServiceServer) { te.testServer = ts te.t.Logf("Running test in %s environment...", te.e.name) sopts := []grpc.ServerOption{grpc.MaxConcurrentStreams(te.maxStream)} - if te.maxMsgSize > 0 { + if te.maxMsgSize >= 0 { sopts = append(sopts, grpc.MaxMsgSize(te.maxMsgSize)) } + if te.maxServerReceiveMsgSize >= 0 { + sopts = append(sopts, grpc.MaxReceiveMessageSize(te.maxServerReceiveMsgSize)) + } + if te.maxServerSendMsgSize >= 0 { + sopts = append(sopts, grpc.MaxSendMessageSize(te.maxServerSendMsgSize)) + } if te.tapHandle != nil { sopts = append(sopts, grpc.InTapHandle(te.tapHandle)) } @@ -570,9 +587,18 @@ func (te *test) clientConn() *grpc.ClientConn { if te.streamClientInt != nil { opts = append(opts, grpc.WithStreamInterceptor(te.streamClientInt)) } - if te.maxMsgSize > 0 { + if te.maxMsgSize >= 0 { opts = append(opts, grpc.WithMaxMsgSize(te.maxMsgSize)) } + if te.maxClientReceiveMsgSize >= 0 { + opts = append(opts, grpc.WithMaxReceiveMessageSize(te.maxClientReceiveMsgSize)) + } + if te.maxClientSendMsgSize >= 0 { + opts = append(opts, grpc.WithMaxSendMessageSize(te.maxClientSendMsgSize)) + } + if te.timeout > 0 { + opts = append(opts, grpc.WithTimeout(te.timeout)) + } switch te.e.security { case "tls": creds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "x.test.youtube.com") @@ -1030,11 +1056,14 @@ func testFailFast(t *testing.T, e env) { func TestServiceConfig(t *testing.T) { defer leakCheck(t)() for _, e := range listTestEnv() { - testServiceConfig(t, e) + testGetMethodConfig(t, e) + testServiceConfigWaitForReady(t, e) + // Timeout logic (min of service config and client API) is implemented implicitly in context. WithTimeout(). No need to test here. + testServiceConfigMaxMsgSize(t, e) } } -func testServiceConfig(t *testing.T, e env) { +func testServiceConfigSetup(t *testing.T, e env) (*test, chan grpc.ServiceConfig) { te := newTest(t, e) ch := make(chan grpc.ServiceConfig) te.sc = ch @@ -1045,15 +1074,93 @@ func testServiceConfig(t *testing.T, e env) { "grpc: addrConn.resetTransport failed to create client transport: connection error", "Failed to dial : context canceled; please retry.", ) + return te, ch +} + +func testGetMethodConfig(t *testing.T, e env) { + te, ch := testServiceConfigSetup(t, e) defer te.tearDown() var wg sync.WaitGroup wg.Add(1) go func() { defer wg.Done() + w1 := new(bool) + *w1 = true + to := new(time.Duration) + *to = time.Millisecond + mc1 := grpc.MethodConfig{ + WaitForReady: w1, + Timeout: to, + } + w2 := new(bool) + *w2 = false + mc2 := grpc.MethodConfig{ + WaitForReady: w2, + } + m := make(map[string]grpc.MethodConfig) + m["/grpc.testing.TestService/EmptyCall"] = mc1 + m["/grpc.testing.TestService"] = mc2 + sc := grpc.ServiceConfig{ + Methods: m, + } + ch <- sc + }() + cc := te.clientConn() + tc := testpb.NewTestServiceClient(cc) + // The following RPCs are expected to become non-fail-fast ones with 1ms deadline. + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); grpc.Code(err) != codes.DeadlineExceeded { + t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %s", err, codes.DeadlineExceeded) + } + wg.Wait() + w1 := new(bool) + *w1 = true + to := new(time.Duration) + *to = time.Millisecond + mc1 := grpc.MethodConfig{ + WaitForReady: w1, + Timeout: to, + } + w2 := new(bool) + *w2 = false + mc2 := grpc.MethodConfig{ + WaitForReady: w2, + } + m := make(map[string]grpc.MethodConfig) + m["/grpc.testing.TestService/UnaryCall"] = mc1 + m["/grpc.testing.TestService/"] = mc2 + sc := grpc.ServiceConfig{ + Methods: m, + } + ch <- sc + // Wait for the new service config to propagate. + for { + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); grpc.Code(err) == codes.DeadlineExceeded { + continue + } + break + } + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); grpc.Code(err) != codes.Unavailable { + t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %s", err, codes.Unavailable) + } +} + +func testServiceConfigWaitForReady(t *testing.T, e env) { + te, ch := testServiceConfigSetup(t, e) + defer te.tearDown() + + // Case1: Client API set failfast to be false, and service config set wait_for_ready to be false, Client API should win, and the rpc will wait until deadline exceeds. + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + w := new(bool) + *w = false + to := new(time.Duration) + *to = time.Millisecond mc := grpc.MethodConfig{ - WaitForReady: true, - Timeout: time.Millisecond, + WaitForReady: w, + Timeout: to, } m := make(map[string]grpc.MethodConfig) m["/grpc.testing.TestService/EmptyCall"] = mc @@ -1066,16 +1173,23 @@ func testServiceConfig(t *testing.T, e env) { cc := te.clientConn() tc := testpb.NewTestServiceClient(cc) // The following RPCs are expected to become non-fail-fast ones with 1ms deadline. - if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); grpc.Code(err) != codes.DeadlineExceeded { + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); grpc.Code(err) != codes.DeadlineExceeded { t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %s", err, codes.DeadlineExceeded) } - if _, err := tc.FullDuplexCall(context.Background()); grpc.Code(err) != codes.DeadlineExceeded { + if _, err := tc.FullDuplexCall(context.Background(), grpc.FailFast(false)); grpc.Code(err) != codes.DeadlineExceeded { t.Fatalf("TestService/FullDuplexCall(_) = _, %v, want %s", err, codes.DeadlineExceeded) } wg.Wait() + // Generate a service config update. + // Case2: Client API does not set failfast, and service config set wait_for_ready to be true, and the rpc will wait until deadline exceeds. + w := new(bool) + *w = true + to := new(time.Duration) + *to = time.Millisecond mc := grpc.MethodConfig{ - WaitForReady: false, + WaitForReady: w, + Timeout: to, } m := make(map[string]grpc.MethodConfig) m["/grpc.testing.TestService/EmptyCall"] = mc @@ -1084,19 +1198,504 @@ func testServiceConfig(t *testing.T, e env) { Methods: m, } ch <- sc - // Loop until the new update becomes effective. + + // Wait for the new service config to take effect. + mc, ok := cc.GetMethodConfig("/grpc.testing.TestService/EmptyCall") for { - if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); grpc.Code(err) != codes.Unavailable { + if ok && !*mc.WaitForReady { + time.Sleep(100 * time.Millisecond) + mc, ok = cc.GetMethodConfig("/grpc.testing.TestService/EmptyCall") continue } break } - // The following RPCs are expected to become fail-fast. - if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); grpc.Code(err) != codes.Unavailable { - t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %s", err, codes.Unavailable) + // The following RPCs are expected to become non-fail-fast ones with 1ms deadline. + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); grpc.Code(err) != codes.DeadlineExceeded { + t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %s", err, codes.DeadlineExceeded) } - if _, err := tc.FullDuplexCall(context.Background()); grpc.Code(err) != codes.Unavailable { - t.Fatalf("TestService/FullDuplexCall(_) = _, %v, want %s", err, codes.Unavailable) + if _, err := tc.FullDuplexCall(context.Background()); grpc.Code(err) != codes.DeadlineExceeded { + t.Fatalf("TestService/FullDuplexCall(_) = _, %v, want %s", err, codes.DeadlineExceeded) + } +} + +func testServiceConfigMaxMsgSize(t *testing.T, e env) { + // Setting up values and objects shared across all test cases. + const smallSize = 1 + const largeSize = 1024 + const extraLargeSize = 2048 + + smallPayload, err := newPayload(testpb.PayloadType_COMPRESSABLE, smallSize) + if err != nil { + t.Fatal(err) + } + largePayload, err := newPayload(testpb.PayloadType_COMPRESSABLE, largeSize) + if err != nil { + t.Fatal(err) + } + extraLargePayload, err := newPayload(testpb.PayloadType_COMPRESSABLE, extraLargeSize) + if err != nil { + t.Fatal(err) + } + + mreq := new(int) + *mreq = extraLargeSize + mresp := new(int) + *mresp = extraLargeSize + mc := grpc.MethodConfig{ + MaxReqSize: mreq, + MaxRespSize: mresp, + } + m := make(map[string]grpc.MethodConfig) + m["/grpc.testing.TestService/UnaryCall"] = mc + m["/grpc.testing.TestService/FullDuplexCall"] = mc + sc := grpc.ServiceConfig{ + Methods: m, + } + // Case1: sc set maxReqSize to 2048 (send), maxRespSize to 2048 (recv). + te1, ch1 := testServiceConfigSetup(t, e) + te1.startServer(&testServer{security: e.security}) + defer te1.tearDown() + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + ch1 <- sc + }() + tc := testpb.NewTestServiceClient(te1.clientConn()) + + req := &testpb.SimpleRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + ResponseSize: proto.Int32(int32(extraLargeSize)), + Payload: smallPayload, + } + // test for unary RPC recv + if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.InvalidArgument { + t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.InvalidArgument) + } + + // test for unary RPC send + req.Payload = extraLargePayload + req.ResponseSize = proto.Int32(int32(smallSize)) + if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.InvalidArgument { + t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.InvalidArgument) + } + + // test for streaming RPC recv + respParam := []*testpb.ResponseParameters{ + { + Size: proto.Int32(int32(extraLargeSize)), + }, + } + sreq := &testpb.StreamingOutputCallRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + ResponseParameters: respParam, + Payload: smallPayload, + } + stream, err := tc.FullDuplexCall(te1.ctx) + if err != nil { + t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) + } + if err := stream.Send(sreq); err != nil { + t.Fatalf("%v.Send(%v) = %v, want ", stream, sreq, err) + } + if _, err := stream.Recv(); err == nil || grpc.Code(err) != codes.InvalidArgument { + t.Fatalf("%v.Recv() = _, %v, want _, error code: %s", stream, err, codes.InvalidArgument) + } + + // test for streaming RPC send + respParam[0].Size = proto.Int32(int32(smallSize)) + sreq.Payload = extraLargePayload + stream, err = tc.FullDuplexCall(te1.ctx) + if err != nil { + t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) + } + if err := stream.Send(sreq); err == nil || grpc.Code(err) != codes.InvalidArgument { + t.Fatalf("%v.Send(%v) = %v, want _, error code: %s", stream, sreq, err, codes.InvalidArgument) + } + + wg.Wait() + // Case2: Client API set maxReqSize to 1024 (send), maxRespSize to 1024 (recv). Sc sets maxReqSize to 2048 (send), maxRespSize to 2048 (recv). + te2, ch2 := testServiceConfigSetup(t, e) + te2.maxClientReceiveMsgSize = 1024 + te2.maxClientSendMsgSize = 1024 + te2.startServer(&testServer{security: e.security}) + defer te2.tearDown() + wg.Add(1) + go func() { + defer wg.Done() + ch2 <- sc + }() + tc = testpb.NewTestServiceClient(te2.clientConn()) + + // Test for unary RPC recv. + req.Payload = smallPayload + req.ResponseSize = proto.Int32(int32(largeSize)) + + if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.InvalidArgument { + t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.InvalidArgument) + } + + // Test for unary RPC send. + req.Payload = largePayload + req.ResponseSize = proto.Int32(int32(smallSize)) + if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.InvalidArgument { + t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.InvalidArgument) + } + + // Test for streaming RPC recv. + stream, err = tc.FullDuplexCall(te2.ctx) + respParam[0].Size = proto.Int32(int32(largeSize)) + sreq.Payload = smallPayload + if err != nil { + t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) + } + if err := stream.Send(sreq); err != nil { + t.Fatalf("%v.Send(%v) = %v, want ", stream, sreq, err) + } + if _, err := stream.Recv(); err == nil || grpc.Code(err) != codes.InvalidArgument { + t.Fatalf("%v.Recv() = _, %v, want _, error code: %s", stream, err, codes.InvalidArgument) + } + + // Test for streaming RPC send. + respParam[0].Size = proto.Int32(int32(smallSize)) + sreq.Payload = largePayload + stream, err = tc.FullDuplexCall(te2.ctx) + if err != nil { + t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) + } + if err := stream.Send(sreq); err == nil || grpc.Code(err) != codes.InvalidArgument { + t.Fatalf("%v.Send(%v) = %v, want _, error code: %s", stream, sreq, err, codes.InvalidArgument) + } + wg.Wait() + + // Case3: Client API set maxReqSize to 4096 (send), maxRespSize to 4096 (recv). Sc sets maxReqSize to 2048 (send), maxRespSize to 2048 (recv). + te3, ch3 := testServiceConfigSetup(t, e) + te3.maxClientReceiveMsgSize = 4096 + te3.maxClientSendMsgSize = 4096 + te3.startServer(&testServer{security: e.security}) + defer te3.tearDown() + wg.Add(1) + go func() { + defer wg.Done() + ch3 <- sc + }() + tc = testpb.NewTestServiceClient(te3.clientConn()) + + // Test for unary RPC recv. + req.Payload = smallPayload + req.ResponseSize = proto.Int32(int32(largeSize)) + + if _, err := tc.UnaryCall(context.Background(), req); err != nil { + t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want ", err) + } + + req.ResponseSize = proto.Int32(int32(extraLargeSize)) + if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.InvalidArgument { + t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.InvalidArgument) + } + + // Test for unary RPC send. + req.Payload = largePayload + req.ResponseSize = proto.Int32(int32(smallSize)) + if _, err := tc.UnaryCall(context.Background(), req); err != nil { + t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want ", err) + } + + req.Payload = extraLargePayload + if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.InvalidArgument { + t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.InvalidArgument) + } + + // Test for streaming RPC recv. + stream, err = tc.FullDuplexCall(te3.ctx) + if err != nil { + t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) + } + respParam[0].Size = proto.Int32(int32(largeSize)) + sreq.Payload = smallPayload + + if err := stream.Send(sreq); err != nil { + t.Fatalf("%v.Send(%v) = %v, want ", stream, sreq, err) + } + if _, err := stream.Recv(); err != nil { + t.Fatalf("%v.Recv() = _, %v, want ", stream, err) + } + + respParam[0].Size = proto.Int32(int32(extraLargeSize)) + + if err := stream.Send(sreq); err != nil { + t.Fatalf("%v.Send(%v) = %v, want ", stream, sreq, err) + } + if _, err := stream.Recv(); err == nil || grpc.Code(err) != codes.InvalidArgument { + t.Fatalf("%v.Recv() = _, %v, want _, error code: %s", stream, err, codes.InvalidArgument) + } + + // Test for streaming RPC send. + respParam[0].Size = proto.Int32(int32(smallSize)) + sreq.Payload = largePayload + stream, err = tc.FullDuplexCall(te3.ctx) + if err != nil { + t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) + } + if err := stream.Send(sreq); err != nil { + t.Fatalf("%v.Send(%v) = %v, want ", stream, err) + } + sreq.Payload = extraLargePayload + if err := stream.Send(sreq); err == nil || grpc.Code(err) != codes.InvalidArgument { + t.Fatalf("%v.Send(%v) = %v, want _, error code: %s", stream, sreq, err, codes.InvalidArgument) + } + wg.Wait() +} + +func TestMsgSizeDefaultAndApi(t *testing.T) { + defer leakCheck(t)() + for _, e := range listTestEnv() { + testMaxMsgSizeClientDefault(t, e) + testMaxMsgSizeClientApi(t, e) + testMaxMsgSizeServerApi(t, e) + } +} + +func testMaxMsgSizeClientDefault(t *testing.T, e env) { + te := newTest(t, e) + te.userAgent = testAppUA + // To avoid error on server side. + te.maxServerSendMsgSize = 5 * 1024 * 1024 + te.declareLogNoise( + "transport: http2Client.notifyError got notified that the client transport was broken EOF", + "grpc: addrConn.transportMonitor exits due to: grpc: the connection is closing", + "grpc: addrConn.resetTransport failed to create client transport: connection error", + "Failed to dial : context canceled; please retry.", + ) + te.startServer(&testServer{security: e.security}) + + defer te.tearDown() + tc := testpb.NewTestServiceClient(te.clientConn()) + + const smallSize = 1 + const largeSize = 4 * 1024 * 1024 + smallPayload, err := newPayload(testpb.PayloadType_COMPRESSABLE, smallSize) + if err != nil { + t.Fatal(err) + } + + largePayload, err := newPayload(testpb.PayloadType_COMPRESSABLE, largeSize) + if err != nil { + t.Fatal(err) + } + req := &testpb.SimpleRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + ResponseSize: proto.Int32(int32(largeSize)), + Payload: smallPayload, + } + // Test for unary RPC recv. + if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.InvalidArgument { + t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.InvalidArgument) + } + + // Test for unary RPC send. + req.Payload = largePayload + req.ResponseSize = proto.Int32(int32(smallSize)) + if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.InvalidArgument { + t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.InvalidArgument) + } + + respParam := []*testpb.ResponseParameters{ + { + Size: proto.Int32(int32(largeSize)), + }, + } + sreq := &testpb.StreamingOutputCallRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + ResponseParameters: respParam, + Payload: smallPayload, + } + + // Test for streaming RPC recv. + stream, err := tc.FullDuplexCall(te.ctx) + if err != nil { + t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) + } + if err := stream.Send(sreq); err != nil { + t.Fatalf("%v.Send(%v) = %v, want ", stream, sreq, err) + } + if _, err := stream.Recv(); err == nil || grpc.Code(err) != codes.InvalidArgument { + t.Fatalf("%v.Recv() = _, %v, want _, error code: %s", stream, err, codes.InvalidArgument) + } + + // Test for streaming RPC send. + respParam[0].Size = proto.Int32(int32(smallSize)) + sreq.Payload = largePayload + stream, err = tc.FullDuplexCall(te.ctx) + if err != nil { + t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) + } + if err := stream.Send(sreq); err == nil || grpc.Code(err) != codes.InvalidArgument { + t.Fatalf("%v.Send(%v) = %v, want _, error codes: %s", stream, sreq, err, codes.InvalidArgument) + } +} + +func testMaxMsgSizeClientApi(t *testing.T, e env) { + te := newTest(t, e) + te.userAgent = testAppUA + // To avoid error on server side. + te.maxServerSendMsgSize = 5 * 1024 * 1024 + te.maxClientReceiveMsgSize = 1024 + te.maxClientSendMsgSize = 1024 + te.declareLogNoise( + "transport: http2Client.notifyError got notified that the client transport was broken EOF", + "grpc: addrConn.transportMonitor exits due to: grpc: the connection is closing", + "grpc: addrConn.resetTransport failed to create client transport: connection error", + "Failed to dial : context canceled; please retry.", + ) + te.startServer(&testServer{security: e.security}) + + defer te.tearDown() + tc := testpb.NewTestServiceClient(te.clientConn()) + + const smallSize = 1 + const largeSize = 1024 + smallPayload, err := newPayload(testpb.PayloadType_COMPRESSABLE, smallSize) + if err != nil { + t.Fatal(err) + } + + largePayload, err := newPayload(testpb.PayloadType_COMPRESSABLE, largeSize) + if err != nil { + t.Fatal(err) + } + req := &testpb.SimpleRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + ResponseSize: proto.Int32(int32(largeSize)), + Payload: smallPayload, + } + // Test for unary RPC recv. + if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.InvalidArgument { + t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.InvalidArgument) + } + + // Test for unary RPC send. + req.Payload = largePayload + req.ResponseSize = proto.Int32(int32(smallSize)) + if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.InvalidArgument { + t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.InvalidArgument) + } + + respParam := []*testpb.ResponseParameters{ + { + Size: proto.Int32(int32(largeSize)), + }, + } + sreq := &testpb.StreamingOutputCallRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + ResponseParameters: respParam, + Payload: smallPayload, + } + + // Test for streaming RPC recv. + stream, err := tc.FullDuplexCall(te.ctx) + if err != nil { + t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) + } + if err := stream.Send(sreq); err != nil { + t.Fatalf("%v.Send(%v) = %v, want ", stream, sreq, err) + } + if _, err := stream.Recv(); err == nil || grpc.Code(err) != codes.InvalidArgument { + t.Fatalf("%v.Recv() = _, %v, want _, error code: %s", stream, err, codes.InvalidArgument) + } + + // Test for streaming RPC send. + respParam[0].Size = proto.Int32(int32(smallSize)) + sreq.Payload = largePayload + stream, err = tc.FullDuplexCall(te.ctx) + if err != nil { + t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) + } + if err := stream.Send(sreq); err == nil || grpc.Code(err) != codes.InvalidArgument { + t.Fatalf("%v.Send(%v) = %v, want _, error code: %s", stream, sreq, err, codes.InvalidArgument) + } +} + +func testMaxMsgSizeServerApi(t *testing.T, e env) { + te := newTest(t, e) + te.userAgent = testAppUA + te.maxServerReceiveMsgSize = 1024 + te.maxServerSendMsgSize = 1024 + te.declareLogNoise( + "transport: http2Client.notifyError got notified that the client transport was broken EOF", + "grpc: addrConn.transportMonitor exits due to: grpc: the connection is closing", + "grpc: addrConn.resetTransport failed to create client transport: connection error", + "Failed to dial : context canceled; please retry.", + ) + te.startServer(&testServer{security: e.security}) + + defer te.tearDown() + tc := testpb.NewTestServiceClient(te.clientConn()) + + const smallSize = 1 + const largeSize = 1024 + smallPayload, err := newPayload(testpb.PayloadType_COMPRESSABLE, smallSize) + if err != nil { + t.Fatal(err) + } + + largePayload, err := newPayload(testpb.PayloadType_COMPRESSABLE, largeSize) + if err != nil { + t.Fatal(err) + } + req := &testpb.SimpleRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + ResponseSize: proto.Int32(int32(largeSize)), + Payload: smallPayload, + } + // Test for unary RPC send. + if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.Unknown { + t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.InvalidArgument) + } + + // Test for unary RPC recv. + req.Payload = largePayload + req.ResponseSize = proto.Int32(int32(smallSize)) + if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.InvalidArgument { + t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.InvalidArgument) + } + + respParam := []*testpb.ResponseParameters{ + { + Size: proto.Int32(int32(largeSize)), + }, + } + sreq := &testpb.StreamingOutputCallRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + ResponseParameters: respParam, + Payload: smallPayload, + } + + // Test for streaming RPC send. + stream, err := tc.FullDuplexCall(te.ctx) + if err != nil { + t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) + } + if err := stream.Send(sreq); err != nil { + t.Fatalf("%v.Send(%v) = %v, want ", stream, sreq, err) + } + if _, err := stream.Recv(); err == nil || grpc.Code(err) != codes.InvalidArgument { + t.Fatalf("%v.Recv() = _, %v, want _, error code: %s", stream, err, codes.InvalidArgument) + } + + // Test for streaming RPC recv. + respParam[0].Size = proto.Int32(int32(smallSize)) + sreq.Payload = largePayload + stream, err = tc.FullDuplexCall(te.ctx) + if err != nil { + t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) + } + if err := stream.Send(sreq); err != nil { + t.Fatalf("%v.Send(%v) = %v, want ", stream, sreq, err) + } + if _, err := stream.Recv(); err == nil || grpc.Code(err) != codes.InvalidArgument { + t.Fatalf("%v.Recv() = _, %v, want _, error code: %s", stream, err, codes.InvalidArgument) } } @@ -1415,6 +2014,7 @@ func testLargeUnary(t *testing.T, e env) { } } +// Test backward-compatability API for setting msg size limit. func TestExceedMsgLimit(t *testing.T) { defer leakCheck(t)() for _, e := range listTestEnv() { @@ -1441,23 +2041,23 @@ func testExceedMsgLimit(t *testing.T, e env) { t.Fatal(err) } - // test on server side for unary RPC + // Test on server side for unary RPC. req := &testpb.SimpleRequest{ ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), ResponseSize: proto.Int32(smallSize), Payload: payload, } - if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.Internal { - t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.Internal) + if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.InvalidArgument { + t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.InvalidArgument) } - // test on client side for unary RPC + // Test on client side for unary RPC. req.ResponseSize = proto.Int32(int32(te.maxMsgSize) + 1) req.Payload = smallPayload - if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.Internal { - t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.Internal) + if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.InvalidArgument { + t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.InvalidArgument) } - // test on server side for streaming RPC + // Test on server side for streaming RPC. stream, err := tc.FullDuplexCall(te.ctx) if err != nil { t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) @@ -1481,11 +2081,11 @@ func testExceedMsgLimit(t *testing.T, e env) { if err := stream.Send(sreq); err != nil { t.Fatalf("%v.Send(%v) = %v, want ", stream, sreq, err) } - if _, err := stream.Recv(); err == nil || grpc.Code(err) != codes.Internal { - t.Fatalf("%v.Recv() = _, %v, want _, error code: %s", stream, err, codes.Internal) + if _, err := stream.Recv(); err == nil || grpc.Code(err) != codes.InvalidArgument { + t.Fatalf("%v.Recv() = _, %v, want _, error code: %s", stream, err, codes.InvalidArgument) } - // test on client side for streaming RPC + // Test on client side for streaming RPC. stream, err = tc.FullDuplexCall(te.ctx) if err != nil { t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) @@ -1495,8 +2095,8 @@ func testExceedMsgLimit(t *testing.T, e env) { if err := stream.Send(sreq); err != nil { t.Fatalf("%v.Send(%v) = %v, want ", stream, sreq, err) } - if _, err := stream.Recv(); err == nil || grpc.Code(err) != codes.Internal { - t.Fatalf("%v.Recv() = _, %v, want _, error code: %s", stream, err, codes.Internal) + if _, err := stream.Recv(); err == nil || grpc.Code(err) != codes.InvalidArgument { + t.Fatalf("%v.Recv() = _, %v, want _, error code: %s", stream, err, codes.InvalidArgument) } }