diff --git a/call.go b/call.go index e4e7771f..a8b6dcfd 100644 --- a/call.go +++ b/call.go @@ -36,6 +36,7 @@ package grpc import ( "bytes" "io" + "math" "time" "golang.org/x/net/context" @@ -64,7 +65,7 @@ func recvResponse(dopts dialOptions, t transport.ClientTransport, c *callInfo, s } p := &parser{r: stream} for { - if err = recv(p, dopts.codec, stream, dopts.dc, reply); err != nil { + if err = recv(p, dopts.codec, stream, dopts.dc, reply, math.MaxInt32); err != nil { if err == io.EOF { break } diff --git a/call_test.go b/call_test.go index 49349858..97eb9c00 100644 --- a/call_test.go +++ b/call_test.go @@ -81,7 +81,7 @@ type testStreamHandler struct { func (h *testStreamHandler) handleStream(t *testing.T, s *transport.Stream) { p := &parser{r: s} for { - pf, req, err := p.recvMsg() + pf, req, err := p.recvMsg(math.MaxInt32) if err == io.EOF { break } diff --git a/rpc_util.go b/rpc_util.go index d6287175..35ac9cc7 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -227,7 +227,7 @@ type parser struct { // No other error values or types must be returned, which also means // that the underlying io.Reader must not return an incompatible // error. -func (p *parser) recvMsg() (pf payloadFormat, msg []byte, err error) { +func (p *parser) recvMsg(maxMsgSize int) (pf payloadFormat, msg []byte, err error) { if _, err := io.ReadFull(p.r, p.header[:]); err != nil { return 0, nil, err } @@ -238,6 +238,9 @@ func (p *parser) recvMsg() (pf payloadFormat, msg []byte, err error) { if length == 0 { return pf, nil, nil } + if length > uint32(maxMsgSize) { + return 0, nil, Errorf(codes.Internal, "grpc: received message length %d exceeding the max size %d", length, maxMsgSize) + } // TODO(bradfitz,zhaoq): garbage. reuse buffer after proto decoding instead // of making it for each message: msg = make([]byte, int(length)) @@ -308,8 +311,8 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, dc Decompressor) er return nil } -func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{}) error { - pf, d, err := p.recvMsg() +func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{}, maxMsgSize int) error { + pf, d, err := p.recvMsg(maxMsgSize) if err != nil { return err } @@ -319,11 +322,16 @@ func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{ if pf == compressionMade { d, err = dc.Do(bytes.NewReader(d)) if err != nil { - return transport.StreamErrorf(codes.Internal, "grpc: failed to decompress the received message %v", err) + return Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err) } } + if len(d) > maxMsgSize { + // TODO: Revisit the error code. Currently keep it consistent with java + // implementation. + return Errorf(codes.Internal, "grpc: received a message of %d bytes exceeding %d limit", len(d), maxMsgSize) + } if err := c.Unmarshal(d, m); err != nil { - return transport.StreamErrorf(codes.Internal, "grpc: failed to unmarshal the received message %v", err) + return Errorf(codes.Internal, "grpc: failed to unmarshal the received message %v", err) } return nil } diff --git a/rpc_util_test.go b/rpc_util_test.go index 830546c6..8a813c62 100644 --- a/rpc_util_test.go +++ b/rpc_util_test.go @@ -36,6 +36,7 @@ package grpc import ( "bytes" "io" + "math" "reflect" "testing" @@ -66,9 +67,9 @@ func TestSimpleParsing(t *testing.T) { } { buf := bytes.NewReader(test.p) parser := &parser{r: buf} - pt, b, err := parser.recvMsg() + pt, b, err := parser.recvMsg(math.MaxInt32) if err != test.err || !bytes.Equal(b, test.b) || pt != test.pt { - t.Fatalf("parser{%v}.recvMsg() = %v, %v, %v\nwant %v, %v, %v", test.p, pt, b, err, test.pt, test.b, test.err) + t.Fatalf("parser{%v}.recvMsg(_) = %v, %v, %v\nwant %v, %v, %v", test.p, pt, b, err, test.pt, test.b, test.err) } } } @@ -88,16 +89,16 @@ func TestMultipleParsing(t *testing.T) { {compressionNone, []byte("d")}, } for i, want := range wantRecvs { - pt, data, err := parser.recvMsg() + pt, data, err := parser.recvMsg(math.MaxInt32) if err != nil || pt != want.pt || !reflect.DeepEqual(data, want.data) { - t.Fatalf("after %d calls, parser{%v}.recvMsg() = %v, %v, %v\nwant %v, %v, ", + t.Fatalf("after %d calls, parser{%v}.recvMsg(_) = %v, %v, %v\nwant %v, %v, ", i, p, pt, data, err, want.pt, want.data) } } - pt, data, err := parser.recvMsg() + pt, data, err := parser.recvMsg(math.MaxInt32) if err != io.EOF { - t.Fatalf("after %d recvMsgs calls, parser{%v}.recvMsg() = %v, %v, %v\nwant _, _, %v", + t.Fatalf("after %d recvMsgs calls, parser{%v}.recvMsg(_) = %v, %v, %v\nwant _, _, %v", len(wantRecvs), p, pt, data, err, io.EOF) } } diff --git a/server.go b/server.go index 9e239a46..1ed8aac9 100644 --- a/server.go +++ b/server.go @@ -105,12 +105,15 @@ type options struct { codec Codec cp Compressor dc Decompressor + maxMsgSize int unaryInt UnaryServerInterceptor streamInt StreamServerInterceptor maxConcurrentStreams uint32 useHandlerImpl bool // use http.Handler-based server } +var defaultMaxMsgSize = 1024 * 1024 * 4 // use 4MB as the default message size limit + // A ServerOption sets options. type ServerOption func(*options) @@ -121,20 +124,28 @@ func CustomCodec(codec Codec) ServerOption { } } -// RPCCompressor returns a ServerOption that sets a compressor for outbound message. +// RPCCompressor returns a ServerOption that sets a compressor for outbound messages. func RPCCompressor(cp Compressor) ServerOption { return func(o *options) { o.cp = cp } } -// RPCDecompressor returns a ServerOption that sets a decompressor for inbound message. +// RPCDecompressor returns a ServerOption that sets a decompressor for inbound messages. func RPCDecompressor(dc Decompressor) ServerOption { return func(o *options) { o.dc = dc } } +// MaxMsgSize returns a ServerOption to set the max message size in bytes for inbound mesages. +// If this is not set, gRPC uses the default 4MB. +func MaxMsgSize(m int) ServerOption { + return func(o *options) { + o.maxMsgSize = m + } +} + // MaxConcurrentStreams returns a ServerOption that will apply a limit on the number // of concurrent streams to each ServerTransport. func MaxConcurrentStreams(n uint32) ServerOption { @@ -177,6 +188,7 @@ func StreamInterceptor(i StreamServerInterceptor) ServerOption { // started to accept requests yet. func NewServer(opt ...ServerOption) *Server { var opts options + opts.maxMsgSize = defaultMaxMsgSize for _, o := range opt { o(&opts) } @@ -526,7 +538,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } p := &parser{r: stream} for { - pf, req, err := p.recvMsg() + pf, req, err := p.recvMsg(s.opts.maxMsgSize) if err == io.EOF { // The entire stream is done (for unary RPC only). return err @@ -536,6 +548,10 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } if err != nil { switch err := err.(type) { + case *rpcError: + if err := t.WriteStatus(stream, err.code, err.desc); err != nil { + grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", err) + } case transport.ConnectionError: // Nothing to do here. case transport.StreamError: @@ -575,6 +591,12 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. return err } } + if len(req) > s.opts.maxMsgSize { + // TODO: Revisit the error code. Currently keep it consistent with + // java implementation. + statusCode = codes.Internal + statusDesc = fmt.Sprintf("grpc: server received a message of %d bytes exceeding %d limit", len(req), s.opts.maxMsgSize) + } if err := s.opts.codec.Unmarshal(req, v); err != nil { return err } @@ -634,13 +656,14 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp stream.SetSendCompress(s.opts.cp.Type()) } ss := &serverStream{ - t: t, - s: stream, - p: &parser{r: stream}, - codec: s.opts.codec, - cp: s.opts.cp, - dc: s.opts.dc, - trInfo: trInfo, + t: t, + s: stream, + p: &parser{r: stream}, + codec: s.opts.codec, + cp: s.opts.cp, + dc: s.opts.dc, + maxMsgSize: s.opts.maxMsgSize, + trInfo: trInfo, } if ss.cp != nil { ss.cbuf = new(bytes.Buffer) diff --git a/stream.go b/stream.go index 008ad1e2..66bfad81 100644 --- a/stream.go +++ b/stream.go @@ -37,6 +37,7 @@ import ( "bytes" "errors" "io" + "math" "sync" "time" @@ -291,7 +292,7 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) { } func (cs *clientStream) RecvMsg(m interface{}) (err error) { - err = recv(cs.p, cs.codec, cs.s, cs.dc, m) + err = recv(cs.p, cs.codec, cs.s, cs.dc, m, math.MaxInt32) defer func() { // err != nil indicates the termination of the stream. if err != nil { @@ -310,7 +311,7 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) { return } // Special handling for client streaming rpc. - err = recv(cs.p, cs.codec, cs.s, cs.dc, m) + err = recv(cs.p, cs.codec, cs.s, cs.dc, m, math.MaxInt32) cs.closeTransportStream(err) if err == nil { return toRPCErr(errors.New("grpc: client streaming protocol violation: get , want ")) @@ -411,6 +412,7 @@ type serverStream struct { cp Compressor dc Decompressor cbuf *bytes.Buffer + maxMsgSize int statusCode codes.Code statusDesc string trInfo *traceInfo @@ -477,5 +479,5 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) { ss.mu.Unlock() } }() - return recv(ss.p, ss.codec, ss.s, ss.dc, m) + return recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxMsgSize) } diff --git a/test/end2end_test.go b/test/end2end_test.go index bc03000c..3b38e00a 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -363,6 +363,7 @@ type test struct { testServer testpb.TestServiceServer // nil means none healthServer *health.Server // nil means disabled maxStream uint32 + maxMsgSize int userAgent string clientCompression bool serverCompression bool @@ -413,6 +414,9 @@ func (te *test) startServer(ts testpb.TestServiceServer) { e := te.e te.t.Logf("Running test in %s environment...", e.name) sopts := []grpc.ServerOption{grpc.MaxConcurrentStreams(te.maxStream)} + if te.maxMsgSize > 0 { + sopts = append(sopts, grpc.MaxMsgSize(te.maxMsgSize)) + } if te.serverCompression { sopts = append(sopts, grpc.RPCCompressor(grpc.NewGZIPCompressor()), @@ -1068,6 +1072,65 @@ func testLargeUnary(t *testing.T, e env) { } } +func TestExceedMsgLimit(t *testing.T) { + defer leakCheck(t)() + for _, e := range listTestEnv() { + testExceedMsgLimit(t, e) + } +} + +func testExceedMsgLimit(t *testing.T, e env) { + te := newTest(t, e) + te.maxMsgSize = 1024 + te.startServer(&testServer{security: e.security}) + defer te.tearDown() + tc := testpb.NewTestServiceClient(te.clientConn()) + + argSize := int32(te.maxMsgSize + 1) + const respSize = 1 + + payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, argSize) + if err != nil { + t.Fatal(err) + } + + req := &testpb.SimpleRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + ResponseSize: proto.Int32(respSize), + Payload: payload, + } + if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.Internal { + t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %d", err, codes.Internal) + } + + stream, err := tc.FullDuplexCall(te.ctx) + if err != nil { + t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) + } + respParam := []*testpb.ResponseParameters{ + { + Size: proto.Int32(1), + }, + } + + spayload, err := newPayload(testpb.PayloadType_COMPRESSABLE, int32(te.maxMsgSize+1)) + if err != nil { + t.Fatal(err) + } + + sreq := &testpb.StreamingOutputCallRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + ResponseParameters: respParam, + Payload: spayload, + } + if err := stream.Send(sreq); err != nil { + t.Fatalf("%v.Send(%v) = %v, want ", stream, sreq, err) + } + if _, err := stream.Recv(); err == nil || grpc.Code(err) != codes.Internal { + t.Fatalf("%v.Recv() = _, %v, want _, error code: %d", stream, err, codes.Internal) + } +} + func TestMetadataUnaryRPC(t *testing.T) { defer leakCheck(t)() for _, e := range listTestEnv() {