Reject over-sized requests on server

This commit is contained in:
iamqizhao
2016-07-26 16:44:49 -07:00
parent f78100723d
commit 8c908a8c1d
5 changed files with 97 additions and 23 deletions

View File

@ -105,14 +105,14 @@ type options struct {
codec Codec
cp Compressor
dc Decompressor
msgLimit int
maxMsgSize int
unaryInt UnaryServerInterceptor
streamInt StreamServerInterceptor
maxConcurrentStreams uint32
useHandlerImpl bool // use http.Handler-based server
}
var defaultMsgLimit = 1024 * 1024 * 4 // use 4MB as the default message size limit
var defaultMaxMsgSize = 1024 * 1024 * 4 // use 4MB as the default message size limit
// A ServerOption sets options.
type ServerOption func(*options)
@ -124,23 +124,25 @@ func CustomCodec(codec Codec) ServerOption {
}
}
// RPCCompressor returns a ServerOption that sets a compressor for outbound message.
// RPCCompressor returns a ServerOption that sets a compressor for outbound messages.
func RPCCompressor(cp Compressor) ServerOption {
return func(o *options) {
o.cp = cp
}
}
// RPCDecompressor returns a ServerOption that sets a decompressor for inbound message.
// RPCDecompressor returns a ServerOption that sets a decompressor for inbound messages.
func RPCDecompressor(dc Decompressor) ServerOption {
return func(o *options) {
o.dc = dc
}
}
func MsgLimit(m int) ServerOption {
// MaxMsgSize returns a ServerOption to set the max message size in bytes for inbound mesages.
// If this is not set, gRPC uses the default 4MB.
func MaxMsgSize(m int) ServerOption {
return func(o *options) {
o.msgLimit = m
o.maxMsgSize = m
}
}
@ -186,7 +188,7 @@ func StreamInterceptor(i StreamServerInterceptor) ServerOption {
// started to accept requests yet.
func NewServer(opt ...ServerOption) *Server {
var opts options
opts.msgLimit = defaultMsgLimit
opts.maxMsgSize = defaultMaxMsgSize
for _, o := range opt {
o(&opts)
}
@ -585,11 +587,11 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
return err
}
}
if len(req) > s.opts.msgLimit {
if len(req) > s.opts.maxMsgSize {
// TODO: Revisit the error code. Currently keep it consistent with
// java implementation.
statusCode = codes.Internal
statusDesc = fmt.Sprintf("server received a message of %d bytes exceeding %d limit", len(req), s.opts.msgLimit)
statusDesc = fmt.Sprintf("grpc: server received a message of %d bytes exceeding %d limit", len(req), s.opts.maxMsgSize)
}
if err := s.opts.codec.Unmarshal(req, v); err != nil {
return err
@ -650,13 +652,14 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
stream.SetSendCompress(s.opts.cp.Type())
}
ss := &serverStream{
t: t,
s: stream,
p: &parser{r: stream},
codec: s.opts.codec,
cp: s.opts.cp,
dc: s.opts.dc,
trInfo: trInfo,
t: t,
s: stream,
p: &parser{r: stream},
codec: s.opts.codec,
cp: s.opts.cp,
dc: s.opts.dc,
maxMsgSize: s.opts.maxMsgSize,
trInfo: trInfo,
}
if ss.cp != nil {
ss.cbuf = new(bytes.Buffer)