diff --git a/call_test.go b/call_test.go index 2bcea807..38ffc31d 100644 --- a/call_test.go +++ b/call_test.go @@ -105,12 +105,13 @@ func (h *testStreamHandler) handleStream(t *testing.T, s *transport.Stream) { } } // send a response back to end the stream. - hdr, data, err := encode(testCodec{}, &expectedResponse, nil, nil, nil) + data, err := encode(testCodec{}, &expectedResponse) if err != nil { t.Errorf("Failed to encode the response: %v", err) return } - h.t.Write(s, hdr, data, &transport.Options{}) + hdr, payload := msgHeader(data, nil) + h.t.Write(s, hdr, payload, &transport.Options{}) h.t.WriteStatus(s, status.New(codes.OK, "")) } diff --git a/rpc_util.go b/rpc_util.go index c783ee50..6b1e8cfd 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -419,8 +419,8 @@ func (o CustomCodecCallOption) after(c *callInfo) {} type payloadFormat uint8 const ( - compressionNone payloadFormat = iota // no compression - compressionMade + compressionNone payloadFormat = 0 // no compression + compressionMade payloadFormat = 1 // compressed ) // parser reads complete gRPC messages from the underlying reader. @@ -477,65 +477,82 @@ func (p *parser) recvMsg(maxReceiveMessageSize int) (pf payloadFormat, msg []byt return pf, msg, nil } -// encode serializes msg and returns a buffer of message header and a buffer of msg. -// If msg is nil, it generates the message header and an empty msg buffer. -// TODO(ddyihai): eliminate extra Compressor parameter. -func encode(c baseCodec, msg interface{}, cp Compressor, outPayload *stats.OutPayload, compressor encoding.Compressor) ([]byte, []byte, error) { - var ( - b []byte - cbuf *bytes.Buffer - ) - const ( - payloadLen = 1 - sizeLen = 4 - ) - if msg != nil { - var err error - b, err = c.Marshal(msg) - if err != nil { - return nil, nil, status.Errorf(codes.Internal, "grpc: error while marshaling: %v", err.Error()) - } - if outPayload != nil { - outPayload.Payload = msg - // TODO truncate large payload. - outPayload.Data = b - outPayload.Length = len(b) - } - if compressor != nil || cp != nil { - cbuf = new(bytes.Buffer) - // Has compressor, check Compressor is set by UseCompressor first. - if compressor != nil { - z, _ := compressor.Compress(cbuf) - if _, err := z.Write(b); err != nil { - return nil, nil, status.Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error()) - } - z.Close() - } else { - // If Compressor is not set by UseCompressor, use default Compressor - if err := cp.Do(cbuf, b); err != nil { - return nil, nil, status.Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error()) - } - } - b = cbuf.Bytes() - } +// encode serializes msg and returns a buffer containing the message, or an +// error if it is too large to be transmitted by grpc. If msg is nil, it +// generates an empty message. +func encode(c baseCodec, msg interface{}) ([]byte, error) { + if msg == nil { // NOTE: typed nils will not be caught by this check + return nil, nil + } + b, err := c.Marshal(msg) + if err != nil { + return nil, status.Errorf(codes.Internal, "grpc: error while marshaling: %v", err.Error()) } if uint(len(b)) > math.MaxUint32 { - return nil, nil, status.Errorf(codes.ResourceExhausted, "grpc: message too large (%d bytes)", len(b)) + return nil, status.Errorf(codes.ResourceExhausted, "grpc: message too large (%d bytes)", len(b)) } + return b, nil +} - bufHeader := make([]byte, payloadLen+sizeLen) - if compressor != nil || cp != nil { - bufHeader[0] = byte(compressionMade) +// compress returns the input bytes compressed by compressor or cp. If both +// compressors are nil, returns nil. +// +// TODO(dfawley): eliminate cp parameter by wrapping Compressor in an encoding.Compressor. +func compress(in []byte, cp Compressor, compressor encoding.Compressor) ([]byte, error) { + if compressor == nil && cp == nil { + return nil, nil + } + wrapErr := func(err error) error { + return status.Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error()) + } + cbuf := &bytes.Buffer{} + if compressor != nil { + z, _ := compressor.Compress(cbuf) + if _, err := z.Write(in); err != nil { + return nil, wrapErr(err) + } + if err := z.Close(); err != nil { + return nil, wrapErr(err) + } } else { - bufHeader[0] = byte(compressionNone) + if err := cp.Do(cbuf, in); err != nil { + return nil, wrapErr(err) + } + } + return cbuf.Bytes(), nil +} + +const ( + payloadLen = 1 + sizeLen = 4 + headerLen = payloadLen + sizeLen +) + +// msgHeader returns a 5-byte header for the message being transmitted and the +// payload, which is compData if non-nil or data otherwise. +func msgHeader(data, compData []byte) (hdr []byte, payload []byte) { + hdr = make([]byte, headerLen) + if compData != nil { + hdr[0] = byte(compressionMade) + data = compData + } else { + hdr[0] = byte(compressionNone) } - // Write length of b into buf - binary.BigEndian.PutUint32(bufHeader[payloadLen:], uint32(len(b))) - if outPayload != nil { - outPayload.WireLength = payloadLen + sizeLen + len(b) + // Write length of payload into buf + binary.BigEndian.PutUint32(hdr[payloadLen:], uint32(len(data))) + return hdr, data +} + +func outPayload(client bool, msg interface{}, data, payload []byte, t time.Time) *stats.OutPayload { + return &stats.OutPayload{ + Client: client, + Payload: msg, + Data: data, + Length: len(data), + WireLength: len(payload) + headerLen, + SentTime: t, } - return bufHeader, b, nil } func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool) *status.Status { diff --git a/rpc_util_test.go b/rpc_util_test.go index 770e850c..f28cff23 100644 --- a/rpc_util_test.go +++ b/rpc_util_test.go @@ -105,23 +105,25 @@ func TestEncode(t *testing.T) { for _, test := range []struct { // input msg proto.Message - cp Compressor // outputs hdr []byte data []byte err error }{ - {nil, nil, []byte{0, 0, 0, 0, 0}, []byte{}, nil}, + {nil, []byte{0, 0, 0, 0, 0}, []byte{}, nil}, } { - hdr, data, err := encode(encoding.GetCodec(protoenc.Name), test.msg, nil, nil, nil) - if err != test.err || !bytes.Equal(hdr, test.hdr) || !bytes.Equal(data, test.data) { - t.Fatalf("encode(_, _, %v, _) = %v, %v, %v\nwant %v, %v, %v", test.cp, hdr, data, err, test.hdr, test.data, test.err) + data, err := encode(encoding.GetCodec(protoenc.Name), test.msg) + if err != test.err || !bytes.Equal(data, test.data) { + t.Errorf("encode(_, %v) = %v, %v; want %v, %v", test.msg, data, err, test.data, test.err) + continue + } + if hdr, _ := msgHeader(data, nil); !bytes.Equal(hdr, test.hdr) { + t.Errorf("msgHeader(%v, false) = %v; want %v", data, hdr, test.hdr) } } } func TestCompress(t *testing.T) { - bestCompressor, err := NewGZIPCompressorWithLevel(gzip.BestCompression) if err != nil { t.Fatalf("Could not initialize gzip compressor with best compression.") @@ -214,12 +216,12 @@ func TestParseDialTarget(t *testing.T) { func bmEncode(b *testing.B, mSize int) { cdc := encoding.GetCodec(protoenc.Name) msg := &perfpb.Buffer{Body: make([]byte, mSize)} - encodeHdr, encodeData, _ := encode(cdc, msg, nil, nil, nil) - encodedSz := int64(len(encodeHdr) + len(encodeData)) + encodeData, _ := encode(cdc, msg) + encodedSz := int64(len(encodeData)) b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { - encode(cdc, msg, nil, nil, nil) + encode(cdc, msg) } b.SetBytes(encodedSz) } diff --git a/server.go b/server.go index 4969331c..c76bb535 100644 --- a/server.go +++ b/server.go @@ -827,24 +827,24 @@ func (s *Server) incrCallsFailed() { } func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Stream, msg interface{}, cp Compressor, opts *transport.Options, comp encoding.Compressor) error { - var ( - outPayload *stats.OutPayload - ) - if s.opts.statsHandler != nil { - outPayload = &stats.OutPayload{} - } - hdr, data, err := encode(s.getCodec(stream.ContentSubtype()), msg, cp, outPayload, comp) + data, err := encode(s.getCodec(stream.ContentSubtype()), msg) if err != nil { grpclog.Errorln("grpc: server failed to encode response: ", err) return err } - if len(data) > s.opts.maxSendMessageSize { - return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(data), s.opts.maxSendMessageSize) + compData, err := compress(data, cp, comp) + if err != nil { + grpclog.Errorln("grpc: server failed to compress response: ", err) + return err } - err = t.Write(stream, hdr, data, opts) - if err == nil && outPayload != nil { - outPayload.SentTime = time.Now() - s.opts.statsHandler.HandleRPC(stream.Context(), outPayload) + hdr, payload := msgHeader(data, compData) + // TODO(dfawley): should we be checking len(data) instead? + if len(payload) > s.opts.maxSendMessageSize { + return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(payload), s.opts.maxSendMessageSize) + } + err = t.Write(stream, hdr, payload, opts) + if err == nil && s.opts.statsHandler != nil { + s.opts.statsHandler.HandleRPC(stream.Context(), outPayload(false, msg, data, payload, time.Now())) } return err } diff --git a/stream.go b/stream.go index 11d8f67c..bb8c75a8 100644 --- a/stream.go +++ b/stream.go @@ -476,27 +476,27 @@ func (a *csAttempt) sendMsg(m interface{}) (err error) { } a.mu.Unlock() } - var outPayload *stats.OutPayload - if a.statsHandler != nil { - outPayload = &stats.OutPayload{ - Client: true, - } - } - hdr, data, err := encode(cs.codec, m, cs.cp, outPayload, cs.comp) + data, err := encode(cs.codec, m) if err != nil { return err } - if len(data) > *cs.c.maxSendMessageSize { - return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(data), *cs.c.maxSendMessageSize) + compData, err := compress(data, cs.cp, cs.comp) + if err != nil { + return err } + hdr, payload := msgHeader(data, compData) + // TODO(dfawley): should we be checking len(data) instead? + if len(payload) > *cs.c.maxSendMessageSize { + return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(payload), *cs.c.maxSendMessageSize) + } + if !cs.desc.ClientStreams { cs.sentLast = true } - err = a.t.Write(a.s, hdr, data, &transport.Options{Last: !cs.desc.ClientStreams}) + err = a.t.Write(a.s, hdr, payload, &transport.Options{Last: !cs.desc.ClientStreams}) if err == nil { - if outPayload != nil { - outPayload.SentTime = time.Now() - a.statsHandler.HandleRPC(a.ctx, outPayload) + if a.statsHandler != nil { + a.statsHandler.HandleRPC(a.ctx, outPayload(true, m, data, payload, time.Now())) } if channelz.IsOn() { a.t.IncrMsgSent() @@ -706,23 +706,24 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) { ss.t.IncrMsgSent() } }() - var outPayload *stats.OutPayload - if ss.statsHandler != nil { - outPayload = &stats.OutPayload{} - } - hdr, data, err := encode(ss.codec, m, ss.cp, outPayload, ss.comp) + data, err := encode(ss.codec, m) if err != nil { return err } - if len(data) > ss.maxSendMessageSize { - return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(data), ss.maxSendMessageSize) + compData, err := compress(data, ss.cp, ss.comp) + if err != nil { + return err } - if err := ss.t.Write(ss.s, hdr, data, &transport.Options{Last: false}); err != nil { + hdr, payload := msgHeader(data, compData) + // TODO(dfawley): should we be checking len(data) instead? + if len(payload) > ss.maxSendMessageSize { + return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(payload), ss.maxSendMessageSize) + } + if err := ss.t.Write(ss.s, hdr, payload, &transport.Options{Last: false}); err != nil { return toRPCErr(err) } - if outPayload != nil { - outPayload.SentTime = time.Now() - ss.statsHandler.HandleRPC(ss.s.Context(), outPayload) + if ss.statsHandler != nil { + ss.statsHandler.HandleRPC(ss.s.Context(), outPayload(false, m, data, payload, time.Now())) } return nil }