Merge pull request #780 from iamqizhao/master
Limit the max request size to be accepted
This commit is contained in:
3
call.go
3
call.go
@ -36,6 +36,7 @@ package grpc
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"io"
|
"io"
|
||||||
|
"math"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/net/context"
|
"golang.org/x/net/context"
|
||||||
@ -64,7 +65,7 @@ func recvResponse(dopts dialOptions, t transport.ClientTransport, c *callInfo, s
|
|||||||
}
|
}
|
||||||
p := &parser{r: stream}
|
p := &parser{r: stream}
|
||||||
for {
|
for {
|
||||||
if err = recv(p, dopts.codec, stream, dopts.dc, reply); err != nil {
|
if err = recv(p, dopts.codec, stream, dopts.dc, reply, math.MaxInt32); err != nil {
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
@ -81,7 +81,7 @@ type testStreamHandler struct {
|
|||||||
func (h *testStreamHandler) handleStream(t *testing.T, s *transport.Stream) {
|
func (h *testStreamHandler) handleStream(t *testing.T, s *transport.Stream) {
|
||||||
p := &parser{r: s}
|
p := &parser{r: s}
|
||||||
for {
|
for {
|
||||||
pf, req, err := p.recvMsg()
|
pf, req, err := p.recvMsg(math.MaxInt32)
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
18
rpc_util.go
18
rpc_util.go
@ -227,7 +227,7 @@ type parser struct {
|
|||||||
// No other error values or types must be returned, which also means
|
// No other error values or types must be returned, which also means
|
||||||
// that the underlying io.Reader must not return an incompatible
|
// that the underlying io.Reader must not return an incompatible
|
||||||
// error.
|
// error.
|
||||||
func (p *parser) recvMsg() (pf payloadFormat, msg []byte, err error) {
|
func (p *parser) recvMsg(maxMsgSize int) (pf payloadFormat, msg []byte, err error) {
|
||||||
if _, err := io.ReadFull(p.r, p.header[:]); err != nil {
|
if _, err := io.ReadFull(p.r, p.header[:]); err != nil {
|
||||||
return 0, nil, err
|
return 0, nil, err
|
||||||
}
|
}
|
||||||
@ -238,6 +238,9 @@ func (p *parser) recvMsg() (pf payloadFormat, msg []byte, err error) {
|
|||||||
if length == 0 {
|
if length == 0 {
|
||||||
return pf, nil, nil
|
return pf, nil, nil
|
||||||
}
|
}
|
||||||
|
if length > uint32(maxMsgSize) {
|
||||||
|
return 0, nil, Errorf(codes.Internal, "grpc: received message length %d exceeding the max size %d", length, maxMsgSize)
|
||||||
|
}
|
||||||
// TODO(bradfitz,zhaoq): garbage. reuse buffer after proto decoding instead
|
// TODO(bradfitz,zhaoq): garbage. reuse buffer after proto decoding instead
|
||||||
// of making it for each message:
|
// of making it for each message:
|
||||||
msg = make([]byte, int(length))
|
msg = make([]byte, int(length))
|
||||||
@ -308,8 +311,8 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, dc Decompressor) er
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{}) error {
|
func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{}, maxMsgSize int) error {
|
||||||
pf, d, err := p.recvMsg()
|
pf, d, err := p.recvMsg(maxMsgSize)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -319,11 +322,16 @@ func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{
|
|||||||
if pf == compressionMade {
|
if pf == compressionMade {
|
||||||
d, err = dc.Do(bytes.NewReader(d))
|
d, err = dc.Do(bytes.NewReader(d))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return transport.StreamErrorf(codes.Internal, "grpc: failed to decompress the received message %v", err)
|
return Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if len(d) > maxMsgSize {
|
||||||
|
// TODO: Revisit the error code. Currently keep it consistent with java
|
||||||
|
// implementation.
|
||||||
|
return Errorf(codes.Internal, "grpc: received a message of %d bytes exceeding %d limit", len(d), maxMsgSize)
|
||||||
|
}
|
||||||
if err := c.Unmarshal(d, m); err != nil {
|
if err := c.Unmarshal(d, m); err != nil {
|
||||||
return transport.StreamErrorf(codes.Internal, "grpc: failed to unmarshal the received message %v", err)
|
return Errorf(codes.Internal, "grpc: failed to unmarshal the received message %v", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -36,6 +36,7 @@ package grpc
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"io"
|
"io"
|
||||||
|
"math"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@ -66,9 +67,9 @@ func TestSimpleParsing(t *testing.T) {
|
|||||||
} {
|
} {
|
||||||
buf := bytes.NewReader(test.p)
|
buf := bytes.NewReader(test.p)
|
||||||
parser := &parser{r: buf}
|
parser := &parser{r: buf}
|
||||||
pt, b, err := parser.recvMsg()
|
pt, b, err := parser.recvMsg(math.MaxInt32)
|
||||||
if err != test.err || !bytes.Equal(b, test.b) || pt != test.pt {
|
if err != test.err || !bytes.Equal(b, test.b) || pt != test.pt {
|
||||||
t.Fatalf("parser{%v}.recvMsg() = %v, %v, %v\nwant %v, %v, %v", test.p, pt, b, err, test.pt, test.b, test.err)
|
t.Fatalf("parser{%v}.recvMsg(_) = %v, %v, %v\nwant %v, %v, %v", test.p, pt, b, err, test.pt, test.b, test.err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -88,16 +89,16 @@ func TestMultipleParsing(t *testing.T) {
|
|||||||
{compressionNone, []byte("d")},
|
{compressionNone, []byte("d")},
|
||||||
}
|
}
|
||||||
for i, want := range wantRecvs {
|
for i, want := range wantRecvs {
|
||||||
pt, data, err := parser.recvMsg()
|
pt, data, err := parser.recvMsg(math.MaxInt32)
|
||||||
if err != nil || pt != want.pt || !reflect.DeepEqual(data, want.data) {
|
if err != nil || pt != want.pt || !reflect.DeepEqual(data, want.data) {
|
||||||
t.Fatalf("after %d calls, parser{%v}.recvMsg() = %v, %v, %v\nwant %v, %v, <nil>",
|
t.Fatalf("after %d calls, parser{%v}.recvMsg(_) = %v, %v, %v\nwant %v, %v, <nil>",
|
||||||
i, p, pt, data, err, want.pt, want.data)
|
i, p, pt, data, err, want.pt, want.data)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pt, data, err := parser.recvMsg()
|
pt, data, err := parser.recvMsg(math.MaxInt32)
|
||||||
if err != io.EOF {
|
if err != io.EOF {
|
||||||
t.Fatalf("after %d recvMsgs calls, parser{%v}.recvMsg() = %v, %v, %v\nwant _, _, %v",
|
t.Fatalf("after %d recvMsgs calls, parser{%v}.recvMsg(_) = %v, %v, %v\nwant _, _, %v",
|
||||||
len(wantRecvs), p, pt, data, err, io.EOF)
|
len(wantRecvs), p, pt, data, err, io.EOF)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
29
server.go
29
server.go
@ -105,12 +105,15 @@ type options struct {
|
|||||||
codec Codec
|
codec Codec
|
||||||
cp Compressor
|
cp Compressor
|
||||||
dc Decompressor
|
dc Decompressor
|
||||||
|
maxMsgSize int
|
||||||
unaryInt UnaryServerInterceptor
|
unaryInt UnaryServerInterceptor
|
||||||
streamInt StreamServerInterceptor
|
streamInt StreamServerInterceptor
|
||||||
maxConcurrentStreams uint32
|
maxConcurrentStreams uint32
|
||||||
useHandlerImpl bool // use http.Handler-based server
|
useHandlerImpl bool // use http.Handler-based server
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var defaultMaxMsgSize = 1024 * 1024 * 4 // use 4MB as the default message size limit
|
||||||
|
|
||||||
// A ServerOption sets options.
|
// A ServerOption sets options.
|
||||||
type ServerOption func(*options)
|
type ServerOption func(*options)
|
||||||
|
|
||||||
@ -121,20 +124,28 @@ func CustomCodec(codec Codec) ServerOption {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// RPCCompressor returns a ServerOption that sets a compressor for outbound message.
|
// RPCCompressor returns a ServerOption that sets a compressor for outbound messages.
|
||||||
func RPCCompressor(cp Compressor) ServerOption {
|
func RPCCompressor(cp Compressor) ServerOption {
|
||||||
return func(o *options) {
|
return func(o *options) {
|
||||||
o.cp = cp
|
o.cp = cp
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// RPCDecompressor returns a ServerOption that sets a decompressor for inbound message.
|
// RPCDecompressor returns a ServerOption that sets a decompressor for inbound messages.
|
||||||
func RPCDecompressor(dc Decompressor) ServerOption {
|
func RPCDecompressor(dc Decompressor) ServerOption {
|
||||||
return func(o *options) {
|
return func(o *options) {
|
||||||
o.dc = dc
|
o.dc = dc
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MaxMsgSize returns a ServerOption to set the max message size in bytes for inbound mesages.
|
||||||
|
// If this is not set, gRPC uses the default 4MB.
|
||||||
|
func MaxMsgSize(m int) ServerOption {
|
||||||
|
return func(o *options) {
|
||||||
|
o.maxMsgSize = m
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// MaxConcurrentStreams returns a ServerOption that will apply a limit on the number
|
// MaxConcurrentStreams returns a ServerOption that will apply a limit on the number
|
||||||
// of concurrent streams to each ServerTransport.
|
// of concurrent streams to each ServerTransport.
|
||||||
func MaxConcurrentStreams(n uint32) ServerOption {
|
func MaxConcurrentStreams(n uint32) ServerOption {
|
||||||
@ -177,6 +188,7 @@ func StreamInterceptor(i StreamServerInterceptor) ServerOption {
|
|||||||
// started to accept requests yet.
|
// started to accept requests yet.
|
||||||
func NewServer(opt ...ServerOption) *Server {
|
func NewServer(opt ...ServerOption) *Server {
|
||||||
var opts options
|
var opts options
|
||||||
|
opts.maxMsgSize = defaultMaxMsgSize
|
||||||
for _, o := range opt {
|
for _, o := range opt {
|
||||||
o(&opts)
|
o(&opts)
|
||||||
}
|
}
|
||||||
@ -526,7 +538,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
|
|||||||
}
|
}
|
||||||
p := &parser{r: stream}
|
p := &parser{r: stream}
|
||||||
for {
|
for {
|
||||||
pf, req, err := p.recvMsg()
|
pf, req, err := p.recvMsg(s.opts.maxMsgSize)
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
// The entire stream is done (for unary RPC only).
|
// The entire stream is done (for unary RPC only).
|
||||||
return err
|
return err
|
||||||
@ -536,6 +548,10 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
|
|||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
switch err := err.(type) {
|
switch err := err.(type) {
|
||||||
|
case *rpcError:
|
||||||
|
if err := t.WriteStatus(stream, err.code, err.desc); err != nil {
|
||||||
|
grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", err)
|
||||||
|
}
|
||||||
case transport.ConnectionError:
|
case transport.ConnectionError:
|
||||||
// Nothing to do here.
|
// Nothing to do here.
|
||||||
case transport.StreamError:
|
case transport.StreamError:
|
||||||
@ -575,6 +591,12 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if len(req) > s.opts.maxMsgSize {
|
||||||
|
// TODO: Revisit the error code. Currently keep it consistent with
|
||||||
|
// java implementation.
|
||||||
|
statusCode = codes.Internal
|
||||||
|
statusDesc = fmt.Sprintf("grpc: server received a message of %d bytes exceeding %d limit", len(req), s.opts.maxMsgSize)
|
||||||
|
}
|
||||||
if err := s.opts.codec.Unmarshal(req, v); err != nil {
|
if err := s.opts.codec.Unmarshal(req, v); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -640,6 +662,7 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
|
|||||||
codec: s.opts.codec,
|
codec: s.opts.codec,
|
||||||
cp: s.opts.cp,
|
cp: s.opts.cp,
|
||||||
dc: s.opts.dc,
|
dc: s.opts.dc,
|
||||||
|
maxMsgSize: s.opts.maxMsgSize,
|
||||||
trInfo: trInfo,
|
trInfo: trInfo,
|
||||||
}
|
}
|
||||||
if ss.cp != nil {
|
if ss.cp != nil {
|
||||||
|
@ -37,6 +37,7 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
|
"math"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -291,7 +292,7 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (cs *clientStream) RecvMsg(m interface{}) (err error) {
|
func (cs *clientStream) RecvMsg(m interface{}) (err error) {
|
||||||
err = recv(cs.p, cs.codec, cs.s, cs.dc, m)
|
err = recv(cs.p, cs.codec, cs.s, cs.dc, m, math.MaxInt32)
|
||||||
defer func() {
|
defer func() {
|
||||||
// err != nil indicates the termination of the stream.
|
// err != nil indicates the termination of the stream.
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -310,7 +311,7 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Special handling for client streaming rpc.
|
// Special handling for client streaming rpc.
|
||||||
err = recv(cs.p, cs.codec, cs.s, cs.dc, m)
|
err = recv(cs.p, cs.codec, cs.s, cs.dc, m, math.MaxInt32)
|
||||||
cs.closeTransportStream(err)
|
cs.closeTransportStream(err)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
|
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
|
||||||
@ -411,6 +412,7 @@ type serverStream struct {
|
|||||||
cp Compressor
|
cp Compressor
|
||||||
dc Decompressor
|
dc Decompressor
|
||||||
cbuf *bytes.Buffer
|
cbuf *bytes.Buffer
|
||||||
|
maxMsgSize int
|
||||||
statusCode codes.Code
|
statusCode codes.Code
|
||||||
statusDesc string
|
statusDesc string
|
||||||
trInfo *traceInfo
|
trInfo *traceInfo
|
||||||
@ -477,5 +479,5 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) {
|
|||||||
ss.mu.Unlock()
|
ss.mu.Unlock()
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
return recv(ss.p, ss.codec, ss.s, ss.dc, m)
|
return recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxMsgSize)
|
||||||
}
|
}
|
||||||
|
@ -363,6 +363,7 @@ type test struct {
|
|||||||
testServer testpb.TestServiceServer // nil means none
|
testServer testpb.TestServiceServer // nil means none
|
||||||
healthServer *health.Server // nil means disabled
|
healthServer *health.Server // nil means disabled
|
||||||
maxStream uint32
|
maxStream uint32
|
||||||
|
maxMsgSize int
|
||||||
userAgent string
|
userAgent string
|
||||||
clientCompression bool
|
clientCompression bool
|
||||||
serverCompression bool
|
serverCompression bool
|
||||||
@ -413,6 +414,9 @@ func (te *test) startServer(ts testpb.TestServiceServer) {
|
|||||||
e := te.e
|
e := te.e
|
||||||
te.t.Logf("Running test in %s environment...", e.name)
|
te.t.Logf("Running test in %s environment...", e.name)
|
||||||
sopts := []grpc.ServerOption{grpc.MaxConcurrentStreams(te.maxStream)}
|
sopts := []grpc.ServerOption{grpc.MaxConcurrentStreams(te.maxStream)}
|
||||||
|
if te.maxMsgSize > 0 {
|
||||||
|
sopts = append(sopts, grpc.MaxMsgSize(te.maxMsgSize))
|
||||||
|
}
|
||||||
if te.serverCompression {
|
if te.serverCompression {
|
||||||
sopts = append(sopts,
|
sopts = append(sopts,
|
||||||
grpc.RPCCompressor(grpc.NewGZIPCompressor()),
|
grpc.RPCCompressor(grpc.NewGZIPCompressor()),
|
||||||
@ -1068,6 +1072,65 @@ func testLargeUnary(t *testing.T, e env) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestExceedMsgLimit(t *testing.T) {
|
||||||
|
defer leakCheck(t)()
|
||||||
|
for _, e := range listTestEnv() {
|
||||||
|
testExceedMsgLimit(t, e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testExceedMsgLimit(t *testing.T, e env) {
|
||||||
|
te := newTest(t, e)
|
||||||
|
te.maxMsgSize = 1024
|
||||||
|
te.startServer(&testServer{security: e.security})
|
||||||
|
defer te.tearDown()
|
||||||
|
tc := testpb.NewTestServiceClient(te.clientConn())
|
||||||
|
|
||||||
|
argSize := int32(te.maxMsgSize + 1)
|
||||||
|
const respSize = 1
|
||||||
|
|
||||||
|
payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, argSize)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := &testpb.SimpleRequest{
|
||||||
|
ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(),
|
||||||
|
ResponseSize: proto.Int32(respSize),
|
||||||
|
Payload: payload,
|
||||||
|
}
|
||||||
|
if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.Internal {
|
||||||
|
t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %d", err, codes.Internal)
|
||||||
|
}
|
||||||
|
|
||||||
|
stream, err := tc.FullDuplexCall(te.ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
|
||||||
|
}
|
||||||
|
respParam := []*testpb.ResponseParameters{
|
||||||
|
{
|
||||||
|
Size: proto.Int32(1),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
spayload, err := newPayload(testpb.PayloadType_COMPRESSABLE, int32(te.maxMsgSize+1))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sreq := &testpb.StreamingOutputCallRequest{
|
||||||
|
ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(),
|
||||||
|
ResponseParameters: respParam,
|
||||||
|
Payload: spayload,
|
||||||
|
}
|
||||||
|
if err := stream.Send(sreq); err != nil {
|
||||||
|
t.Fatalf("%v.Send(%v) = %v, want <nil>", stream, sreq, err)
|
||||||
|
}
|
||||||
|
if _, err := stream.Recv(); err == nil || grpc.Code(err) != codes.Internal {
|
||||||
|
t.Fatalf("%v.Recv() = _, %v, want _, error code: %d", stream, err, codes.Internal)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestMetadataUnaryRPC(t *testing.T) {
|
func TestMetadataUnaryRPC(t *testing.T) {
|
||||||
defer leakCheck(t)()
|
defer leakCheck(t)()
|
||||||
for _, e := range listTestEnv() {
|
for _, e := range listTestEnv() {
|
||||||
|
Reference in New Issue
Block a user