Reject over-sized requests on server

This commit is contained in:
iamqizhao
2016-07-26 16:44:49 -07:00
parent f78100723d
commit 8c908a8c1d
5 changed files with 97 additions and 23 deletions

View File

@ -36,6 +36,7 @@ package grpc
import ( import (
"bytes" "bytes"
"io" "io"
"math"
"time" "time"
"golang.org/x/net/context" "golang.org/x/net/context"
@ -57,7 +58,7 @@ func recvResponse(dopts dialOptions, t transport.ClientTransport, c *callInfo, s
} }
p := &parser{r: stream} p := &parser{r: stream}
for { for {
if err = recv(p, dopts.codec, stream, dopts.dc, reply); err != nil { if err = recv(p, dopts.codec, stream, dopts.dc, reply, math.MaxInt32); err != nil {
if err == io.EOF { if err == io.EOF {
break break
} }

View File

@ -308,7 +308,7 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, dc Decompressor) er
return nil return nil
} }
func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{}) error { func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{}, maxMsgSize int) error {
pf, d, err := p.recvMsg() pf, d, err := p.recvMsg()
if err != nil { if err != nil {
return err return err
@ -319,11 +319,16 @@ func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{
if pf == compressionMade { if pf == compressionMade {
d, err = dc.Do(bytes.NewReader(d)) d, err = dc.Do(bytes.NewReader(d))
if err != nil { if err != nil {
return transport.StreamErrorf(codes.Internal, "grpc: failed to decompress the received message %v", err) return Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err)
} }
} }
if len(d) > maxMsgSize {
// TODO: Revisit the error code. Currently keep it consistent with java
// implementation.
return Errorf(codes.Internal, "grpc: server received a message of %d bytes exceeding %d limit", len(d), maxMsgSize)
}
if err := c.Unmarshal(d, m); err != nil { if err := c.Unmarshal(d, m); err != nil {
return transport.StreamErrorf(codes.Internal, "grpc: failed to unmarshal the received message %v", err) return Errorf(codes.Internal, "grpc: failed to unmarshal the received message %v", err)
} }
return nil return nil
} }

View File

@ -105,14 +105,14 @@ type options struct {
codec Codec codec Codec
cp Compressor cp Compressor
dc Decompressor dc Decompressor
msgLimit int maxMsgSize int
unaryInt UnaryServerInterceptor unaryInt UnaryServerInterceptor
streamInt StreamServerInterceptor streamInt StreamServerInterceptor
maxConcurrentStreams uint32 maxConcurrentStreams uint32
useHandlerImpl bool // use http.Handler-based server useHandlerImpl bool // use http.Handler-based server
} }
var defaultMsgLimit = 1024 * 1024 * 4 // use 4MB as the default message size limit var defaultMaxMsgSize = 1024 * 1024 * 4 // use 4MB as the default message size limit
// A ServerOption sets options. // A ServerOption sets options.
type ServerOption func(*options) type ServerOption func(*options)
@ -124,23 +124,25 @@ func CustomCodec(codec Codec) ServerOption {
} }
} }
// RPCCompressor returns a ServerOption that sets a compressor for outbound message. // RPCCompressor returns a ServerOption that sets a compressor for outbound messages.
func RPCCompressor(cp Compressor) ServerOption { func RPCCompressor(cp Compressor) ServerOption {
return func(o *options) { return func(o *options) {
o.cp = cp o.cp = cp
} }
} }
// RPCDecompressor returns a ServerOption that sets a decompressor for inbound message. // RPCDecompressor returns a ServerOption that sets a decompressor for inbound messages.
func RPCDecompressor(dc Decompressor) ServerOption { func RPCDecompressor(dc Decompressor) ServerOption {
return func(o *options) { return func(o *options) {
o.dc = dc o.dc = dc
} }
} }
func MsgLimit(m int) ServerOption { // MaxMsgSize returns a ServerOption to set the max message size in bytes for inbound mesages.
// If this is not set, gRPC uses the default 4MB.
func MaxMsgSize(m int) ServerOption {
return func(o *options) { return func(o *options) {
o.msgLimit = m o.maxMsgSize = m
} }
} }
@ -186,7 +188,7 @@ func StreamInterceptor(i StreamServerInterceptor) ServerOption {
// started to accept requests yet. // started to accept requests yet.
func NewServer(opt ...ServerOption) *Server { func NewServer(opt ...ServerOption) *Server {
var opts options var opts options
opts.msgLimit = defaultMsgLimit opts.maxMsgSize = defaultMaxMsgSize
for _, o := range opt { for _, o := range opt {
o(&opts) o(&opts)
} }
@ -585,11 +587,11 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
return err return err
} }
} }
if len(req) > s.opts.msgLimit { if len(req) > s.opts.maxMsgSize {
// TODO: Revisit the error code. Currently keep it consistent with // TODO: Revisit the error code. Currently keep it consistent with
// java implementation. // java implementation.
statusCode = codes.Internal statusCode = codes.Internal
statusDesc = fmt.Sprintf("server received a message of %d bytes exceeding %d limit", len(req), s.opts.msgLimit) statusDesc = fmt.Sprintf("grpc: server received a message of %d bytes exceeding %d limit", len(req), s.opts.maxMsgSize)
} }
if err := s.opts.codec.Unmarshal(req, v); err != nil { if err := s.opts.codec.Unmarshal(req, v); err != nil {
return err return err
@ -656,6 +658,7 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
codec: s.opts.codec, codec: s.opts.codec,
cp: s.opts.cp, cp: s.opts.cp,
dc: s.opts.dc, dc: s.opts.dc,
maxMsgSize: s.opts.maxMsgSize,
trInfo: trInfo, trInfo: trInfo,
} }
if ss.cp != nil { if ss.cp != nil {

View File

@ -37,6 +37,7 @@ import (
"bytes" "bytes"
"errors" "errors"
"io" "io"
"math"
"sync" "sync"
"time" "time"
@ -291,7 +292,7 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {
} }
func (cs *clientStream) RecvMsg(m interface{}) (err error) { func (cs *clientStream) RecvMsg(m interface{}) (err error) {
err = recv(cs.p, cs.codec, cs.s, cs.dc, m) err = recv(cs.p, cs.codec, cs.s, cs.dc, m, math.MaxInt32)
defer func() { defer func() {
// err != nil indicates the termination of the stream. // err != nil indicates the termination of the stream.
if err != nil { if err != nil {
@ -310,7 +311,7 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) {
return return
} }
// Special handling for client streaming rpc. // Special handling for client streaming rpc.
err = recv(cs.p, cs.codec, cs.s, cs.dc, m) err = recv(cs.p, cs.codec, cs.s, cs.dc, m, math.MaxInt32)
cs.closeTransportStream(err) cs.closeTransportStream(err)
if err == nil { if err == nil {
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>")) return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
@ -411,6 +412,7 @@ type serverStream struct {
cp Compressor cp Compressor
dc Decompressor dc Decompressor
cbuf *bytes.Buffer cbuf *bytes.Buffer
maxMsgSize int
statusCode codes.Code statusCode codes.Code
statusDesc string statusDesc string
trInfo *traceInfo trInfo *traceInfo
@ -477,5 +479,5 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) {
ss.mu.Unlock() ss.mu.Unlock()
} }
}() }()
return recv(ss.p, ss.codec, ss.s, ss.dc, m) return recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxMsgSize)
} }

View File

@ -373,6 +373,7 @@ type test struct {
testServer testpb.TestServiceServer // nil means none testServer testpb.TestServiceServer // nil means none
healthServer *health.HealthServer // nil means disabled healthServer *health.HealthServer // nil means disabled
maxStream uint32 maxStream uint32
maxMsgSize int
userAgent string userAgent string
clientCompression bool clientCompression bool
serverCompression bool serverCompression bool
@ -423,6 +424,9 @@ func (te *test) startServer(ts testpb.TestServiceServer) {
e := te.e e := te.e
te.t.Logf("Running test in %s environment...", e.name) te.t.Logf("Running test in %s environment...", e.name)
sopts := []grpc.ServerOption{grpc.MaxConcurrentStreams(te.maxStream)} sopts := []grpc.ServerOption{grpc.MaxConcurrentStreams(te.maxStream)}
if te.maxMsgSize > 0 {
sopts = append(sopts, grpc.MaxMsgSize(te.maxMsgSize))
}
if te.serverCompression { if te.serverCompression {
sopts = append(sopts, sopts = append(sopts,
grpc.RPCCompressor(grpc.NewGZIPCompressor()), grpc.RPCCompressor(grpc.NewGZIPCompressor()),
@ -956,6 +960,65 @@ func testLargeUnary(t *testing.T, e env) {
} }
} }
func TestExceedMsgLimit(t *testing.T) {
defer leakCheck(t)()
for _, e := range listTestEnv() {
testExceedMsgLimit(t, e)
}
}
func testExceedMsgLimit(t *testing.T, e env) {
te := newTest(t, e)
te.maxMsgSize = 1024
te.startServer(&testServer{security: e.security})
defer te.tearDown()
tc := testpb.NewTestServiceClient(te.clientConn())
argSize := int32(te.maxMsgSize + 1)
const respSize = 1
payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, argSize)
if err != nil {
t.Fatal(err)
}
req := &testpb.SimpleRequest{
ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(),
ResponseSize: proto.Int32(respSize),
Payload: payload,
}
if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.Internal {
t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: ", err, codes.Internal)
}
stream, err := tc.FullDuplexCall(te.ctx)
if err != nil {
t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
}
respParam := []*testpb.ResponseParameters{
{
Size: proto.Int32(1),
},
}
spayload, err := newPayload(testpb.PayloadType_COMPRESSABLE, int32(te.maxMsgSize+1))
if err != nil {
t.Fatal(err)
}
sreq := &testpb.StreamingOutputCallRequest{
ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(),
ResponseParameters: respParam,
Payload: spayload,
}
if err := stream.Send(sreq); err != nil {
t.Fatalf("%v.Send(%v) = %v, want <nil>", stream, sreq, err)
}
if _, err := stream.Recv(); err == nil || grpc.Code(err) != codes.Internal {
t.Fatalf("%v.Recv() = _, %v, want _, error code: ", stream, err, codes.Internal)
}
}
func TestMetadataUnaryRPC(t *testing.T) { func TestMetadataUnaryRPC(t *testing.T) {
defer leakCheck(t)() defer leakCheck(t)()
for _, e := range listTestEnv() { for _, e := range listTestEnv() {