limit max msg size on server

This commit is contained in:
iamqizhao
2016-07-26 14:04:19 -07:00
parent afdabe5738
commit b071b83367

View File

@ -101,12 +101,15 @@ type options struct {
codec Codec codec Codec
cp Compressor cp Compressor
dc Decompressor dc Decompressor
msgLimit int
unaryInt UnaryServerInterceptor unaryInt UnaryServerInterceptor
streamInt StreamServerInterceptor streamInt StreamServerInterceptor
maxConcurrentStreams uint32 maxConcurrentStreams uint32
useHandlerImpl bool // use http.Handler-based server useHandlerImpl bool // use http.Handler-based server
} }
var defaultMsgLimit = 1024 * 1024 * 4 // use 4MB as the default message size limit
// A ServerOption sets options. // A ServerOption sets options.
type ServerOption func(*options) type ServerOption func(*options)
@ -131,6 +134,12 @@ func RPCDecompressor(dc Decompressor) ServerOption {
} }
} }
func MsgLimit(m int) ServerOption {
return func(o *options) {
o.msgLimit = m
}
}
// MaxConcurrentStreams returns a ServerOption that will apply a limit on the number // MaxConcurrentStreams returns a ServerOption that will apply a limit on the number
// of concurrent streams to each ServerTransport. // of concurrent streams to each ServerTransport.
func MaxConcurrentStreams(n uint32) ServerOption { func MaxConcurrentStreams(n uint32) ServerOption {
@ -173,6 +182,7 @@ func StreamInterceptor(i StreamServerInterceptor) ServerOption {
// started to accept requests yet. // started to accept requests yet.
func NewServer(opt ...ServerOption) *Server { func NewServer(opt ...ServerOption) *Server {
var opts options var opts options
opts.msgLimit = defaultMsgLimit
for _, o := range opt { for _, o := range opt {
o(&opts) o(&opts)
} }
@ -569,6 +579,12 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
return err return err
} }
} }
if len(req) > s.opts.msgLimit {
// TODO: Revisit the error code. Currently keep it consistent with
// java implementation.
statusCode = codes.Internal
statusDesc = fmt.Sprintf("server received a message of %d bytes exceeding %d limit", len(req), s.opts.msgLimit)
}
if err := s.opts.codec.Unmarshal(req, v); err != nil { if err := s.opts.codec.Unmarshal(req, v); err != nil {
return err return err
} }