diff --git a/server.go b/server.go index a2b2b94d..7b05da6f 100644 --- a/server.go +++ b/server.go @@ -101,12 +101,15 @@ type options struct { codec Codec cp Compressor dc Decompressor + msgLimit int unaryInt UnaryServerInterceptor streamInt StreamServerInterceptor maxConcurrentStreams uint32 useHandlerImpl bool // use http.Handler-based server } +var defaultMsgLimit = 1024 * 1024 * 4 // use 4MB as the default message size limit + // A ServerOption sets options. type ServerOption func(*options) @@ -131,6 +134,12 @@ func RPCDecompressor(dc Decompressor) ServerOption { } } +func MsgLimit(m int) ServerOption { + return func(o *options) { + o.msgLimit = m + } +} + // MaxConcurrentStreams returns a ServerOption that will apply a limit on the number // of concurrent streams to each ServerTransport. func MaxConcurrentStreams(n uint32) ServerOption { @@ -173,6 +182,7 @@ func StreamInterceptor(i StreamServerInterceptor) ServerOption { // started to accept requests yet. func NewServer(opt ...ServerOption) *Server { var opts options + opts.msgLimit = defaultMsgLimit for _, o := range opt { o(&opts) } @@ -569,6 +579,12 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. return err } } + if len(req) > s.opts.msgLimit { + // TODO: Revisit the error code. Currently keep it consistent with + // java implementation. + statusCode = codes.Internal + statusDesc = fmt.Sprintf("server received a message of %d bytes exceeding %d limit", len(req), s.opts.msgLimit) + } if err := s.opts.codec.Unmarshal(req, v); err != nil { return err }