diff --git a/rpc_util.go b/rpc_util.go index 78c5f2ea..f7d0ea50 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -38,6 +38,7 @@ import ( "encoding/binary" "fmt" "io" + "math" "math/rand" "os" "time" @@ -139,42 +140,36 @@ type msgFixedHeader struct { // EOF is returned with nil msg and 0 pf if the entire stream is done. Other // non-nil error is returned if something is wrong on reading. func (p *parser) recvMsg() (pf payloadFormat, msg []byte, err error) { - const ( - headerSize = 5 - formatIndex = 1 - ) - - var hdr msgFixedHeader - var buf [headerSize]byte + var buf [5]byte // see msgFixedHeader if _, err := io.ReadFull(p.s, buf[:]); err != nil { return 0, nil, err } - hdr.T = payloadFormat(buf[formatIndex]) - hdr.Length = binary.BigEndian.Uint32(buf[formatIndex:]) + pf = payloadFormat(buf[0]) + length := binary.BigEndian.Uint32(buf[1:]) - if hdr.Length == 0 { - return hdr.T, nil, nil + if length == 0 { + return pf, nil, nil } - msg = make([]byte, int(hdr.Length)) + msg = make([]byte, int(length)) if _, err := io.ReadFull(p.s, msg); err != nil { if err == io.EOF { err = io.ErrUnexpectedEOF } return 0, nil, err } - return hdr.T, msg, nil + return pf, msg, nil } // encode serializes msg and prepends the message header. If msg is nil, it // generates the message header of 0 message length. func encode(c Codec, msg interface{}, pf payloadFormat) ([]byte, error) { var buf bytes.Buffer - // Write message fixed header. + // Write message into the fixed header. buf.WriteByte(uint8(pf)) var b []byte - var length uint32 + var length int if msg != nil { var err error // TODO(zhaoq): optimize to reduce memory alloc and copying. @@ -182,10 +177,13 @@ func encode(c Codec, msg interface{}, pf payloadFormat) ([]byte, error) { if err != nil { return nil, err } - length = uint32(len(b)) + length = len(b) + } + if length > math.MaxUint32 { + return nil, Errorf(codes.InvalidArgument, "grpc: message too large (%d bytes)", length) } var szHdr [4]byte - binary.BigEndian.PutUint32(szHdr[:], length) + binary.BigEndian.PutUint32(szHdr[:], uint32(length)) buf.Write(szHdr[:]) buf.Write(b) return buf.Bytes(), nil diff --git a/rpc_util_test.go b/rpc_util_test.go index 1a296841..2673cd05 100644 --- a/rpc_util_test.go +++ b/rpc_util_test.go @@ -47,6 +47,7 @@ import ( ) func TestSimpleParsing(t *testing.T) { + bigMsg := bytes.Repeat([]byte{'x'}, 1<<24) for _, test := range []struct { // input p []byte @@ -60,6 +61,8 @@ func TestSimpleParsing(t *testing.T) { {[]byte{0, 0, 0, 0, 1, 'a'}, nil, []byte{'a'}, compressionNone}, {[]byte{1, 0}, io.ErrUnexpectedEOF, nil, compressionNone}, {[]byte{0, 0, 0, 0, 10, 'a'}, io.ErrUnexpectedEOF, nil, compressionNone}, + // Check that messages with length >= 2^24 are parsed. + {append([]byte{0, 1, 0, 0, 0}, bigMsg...), nil, bigMsg, compressionNone}, } { buf := bytes.NewReader(test.p) parser := &parser{buf}