diff --git a/call_test.go b/call_test.go index 48134c4c..c867f235 100644 --- a/call_test.go +++ b/call_test.go @@ -104,7 +104,7 @@ func (h *testStreamHandler) handleStream(t *testing.T, s *transport.Stream) { } } // send a response back to end the stream. - data, err := encode(testCodec{}, &expectedResponse, nil, nil, nil) + data, err := encode(testCodec{}, &expectedResponse, nil) if err != nil { t.Errorf("Failed to encode the response: %v", err) return diff --git a/encoding/proto/proto.go b/encoding/proto/proto.go index 66b97a6f..9ca00911 100644 --- a/encoding/proto/proto.go +++ b/encoding/proto/proto.go @@ -65,6 +65,9 @@ func marshal(v interface{}, cb *cachedProtoBuffer) ([]byte, error) { } func (codec) Marshal(v interface{}) ([]byte, error) { + if v == nil { + return nil, nil + } if pm, ok := v.(proto.Marshaler); ok { // object can marshal itself, no need for buffer return pm.Marshal() diff --git a/interop/test_utils.go b/interop/test_utils.go index cbc7756d..38a20990 100644 --- a/interop/test_utils.go +++ b/interop/test_utils.go @@ -573,9 +573,8 @@ func DoUnimplementedService(tc testpb.UnimplementedServiceClient) { // DoUnimplementedMethod attempts to call an unimplemented method. func DoUnimplementedMethod(cc *grpc.ClientConn) { - var req, reply proto.Message - if err := cc.Invoke(context.Background(), "/grpc.testing.TestService/UnimplementedCall", req, reply); err == nil || status.Code(err) != codes.Unimplemented { - grpclog.Fatalf("ClientConn.Invoke(_, _, _, _, _) = %v, want error code %s", err, codes.Unimplemented) + if err := grpc.Invoke(context.Background(), "/grpc.testing.TestService/UnimplementedCall", &testpb.Empty{}, &testpb.Empty{}, cc); err == nil || status.Code(err) != codes.Unimplemented { + grpclog.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want error code %s", err, codes.Unimplemented) } } diff --git a/rpc_util.go b/rpc_util.go index 90d1b849..407f69df 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -435,45 +435,49 @@ func recvMsg(s *transport.Stream, maxRecvMsgSize int) (bool, []byte, error) { } // encode serializes msg and returns a buffer of msg. -// If msg is nil, it generates an empty buffer. -// TODO(ddyihai): eliminate extra Compressor parameter. -func encode(c baseCodec, msg interface{}, cp Compressor, outPayload *stats.OutPayload, compressor encoding.Compressor) ([]byte, error) { - var ( - b []byte - cbuf *bytes.Buffer - ) - if msg != nil { - var err error - b, err = c.Marshal(msg) - if err != nil { - return nil, status.Errorf(codes.Internal, "grpc: error while marshaling: %v", err.Error()) +func encode(c baseCodec, msg interface{}, outPayload *stats.OutPayload) ([]byte, error) { + b, err := c.Marshal(msg) + if err != nil { + return nil, status.Errorf(codes.Internal, "grpc: error while marshaling: %v", err.Error()) + } + if b == nil { + // If there was no error while marshalling, yet payload was nil, + // we update it to an empty slice, since a nil payload leads to + // an empty data frame(no gRPC message header is added). + b = []byte{} + } + if outPayload != nil { + outPayload.Payload = msg + // TODO truncate large payload. + outPayload.Data = b + outPayload.Length = len(b) + } + return b, nil +} + +// compress the message if there is a compressor registered. +// TODO(mmukhi, dfawley): eliminate extra Compressor parameter. +func compress(b []byte, cp Compressor, compressor encoding.Compressor, outPayload *stats.OutPayload) ([]byte, bool, error) { + if len(b) <= 0 || (compressor == nil && cp == nil) { + return b, false, nil + } + cbuf := new(bytes.Buffer) + // Has compressor, check Compressor is set by UseCompressor first. + if compressor != nil { + z, _ := compressor.Compress(cbuf) + if _, err := z.Write(b); err != nil { + return nil, false, status.Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error()) } - if outPayload != nil { - outPayload.Payload = msg - // TODO truncate large payload. - outPayload.Data = b - outPayload.Length = len(b) - } - if compressor != nil || cp != nil { - cbuf = new(bytes.Buffer) - // Has compressor, check Compressor is set by UseCompressor first. - if compressor != nil { - z, _ := compressor.Compress(cbuf) - if _, err := z.Write(b); err != nil { - return nil, status.Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error()) - } - z.Close() - } else { - // If Compressor is not set by UseCompressor, use default Compressor - if err := cp.Do(cbuf, b); err != nil { - return nil, status.Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error()) - } - } - b = cbuf.Bytes() + z.Close() + } else { + // If Compressor is not set by UseCompressor, use default Compressor + if err := cp.Do(cbuf, b); err != nil { + return nil, false, status.Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error()) } } + b = cbuf.Bytes() if uint(len(b)) > math.MaxUint32 { - return nil, status.Errorf(codes.ResourceExhausted, "grpc: message too large (%d bytes)", len(b)) + return nil, false, status.Errorf(codes.ResourceExhausted, "grpc: message too large (%d bytes)", len(b)) } if outPayload != nil { @@ -481,7 +485,7 @@ func encode(c baseCodec, msg interface{}, cp Compressor, outPayload *stats.OutPa // before it's put on wire. outPayload.WireLength = 5 + len(b) } - return b, nil + return b, true, nil } func checkRecvPayload(recvCompress string, haveCompressor bool) *status.Status { diff --git a/rpc_util_test.go b/rpc_util_test.go index 2cf2b43a..004c5386 100644 --- a/rpc_util_test.go +++ b/rpc_util_test.go @@ -48,16 +48,15 @@ func TestEncode(t *testing.T) { for _, test := range []struct { // input msg proto.Message - cp Compressor // outputs data []byte err error }{ - {nil, nil, []byte{}, nil}, + {&perfpb.Buffer{}, []byte{}, nil}, } { - data, err := encode(encoding.GetCodec(protoenc.Name), test.msg, nil, nil, nil) + data, err := encode(encoding.GetCodec(protoenc.Name), test.msg, nil) if err != test.err || !bytes.Equal(data, test.data) { - t.Fatalf("encode(_, _, %v, _) = %v, %v\nwant %v, %v", test.cp, data, err, test.data, test.err) + t.Fatalf("encode(_, _, _, _) = %v, %v\nwant %v, %v", data, err, test.data, test.err) } } } @@ -156,7 +155,7 @@ func TestParseDialTarget(t *testing.T) { func bmEncode(b *testing.B, mSize int) { cdc := encoding.GetCodec(protoenc.Name) msg := &perfpb.Buffer{Body: make([]byte, mSize)} - encodeData, _ := encode(cdc, msg, nil, nil, nil) + encodeData, _ := encode(cdc, msg, nil) // 5 bytes of gRPC-specific message header // is added to the message before it is written // to the wire. @@ -164,7 +163,7 @@ func bmEncode(b *testing.B, mSize int) { b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { - encode(cdc, msg, nil, nil, nil) + encode(cdc, msg, nil) } b.SetBytes(encodedSz) } diff --git a/server.go b/server.go index 9d37a210..7ade17bd 100644 --- a/server.go +++ b/server.go @@ -833,7 +833,7 @@ func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Str if s.opts.statsHandler != nil { outPayload = &stats.OutPayload{} } - data, err := encode(s.getCodec(stream.ContentSubtype()), msg, cp, outPayload, comp) + data, err := encode(s.getCodec(stream.ContentSubtype()), msg, outPayload) if err != nil { grpclog.Errorln("grpc: server failed to encode response: ", err) return err @@ -841,7 +841,12 @@ func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Str if len(data) > s.opts.maxSendMessageSize { return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(data), s.opts.maxSendMessageSize) } - opts.IsCompressed = cp != nil || comp != nil + data, isCompressed, err := compress(data, cp, comp, outPayload) + opts.IsCompressed = isCompressed + if err != nil { + grpclog.Errorln("grpc: server failed to compress response: ", err) + return err + } err = t.Write(stream, data, opts) if err == nil && outPayload != nil { outPayload.SentTime = time.Now() diff --git a/stream.go b/stream.go index 1b5dd495..ade358f7 100644 --- a/stream.go +++ b/stream.go @@ -470,19 +470,23 @@ func (a *csAttempt) sendMsg(m interface{}) (err error) { Client: true, } } - data, err := encode(cs.codec, m, cs.cp, outPayload, cs.comp) + data, err := encode(cs.codec, m, outPayload) if err != nil { return err } if len(data) > *cs.c.maxSendMessageSize { return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(data), *cs.c.maxSendMessageSize) } + data, isCompressed, err := compress(data, cs.cp, cs.comp, outPayload) + if err != nil { + return err + } if !cs.desc.ClientStreams { cs.sentLast = true } opts := &transport.Options{ Last: !cs.desc.ClientStreams, - IsCompressed: cs.cp != nil || cs.comp != nil, + IsCompressed: isCompressed, } err = a.t.Write(a.s, data, opts) if err == nil { @@ -701,16 +705,20 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) { if ss.statsHandler != nil { outPayload = &stats.OutPayload{} } - data, err := encode(ss.codec, m, ss.cp, outPayload, ss.comp) + data, err := encode(ss.codec, m, outPayload) if err != nil { return err } if len(data) > ss.maxSendMessageSize { return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(data), ss.maxSendMessageSize) } + data, isCompressed, err := compress(data, ss.cp, ss.comp, outPayload) + if err != nil { + return err + } opts := &transport.Options{ Last: false, - IsCompressed: ss.cp != nil || ss.comp != nil, + IsCompressed: isCompressed, } if err := ss.t.Write(ss.s, data, opts); err != nil { return toRPCErr(err) diff --git a/test/end2end_test.go b/test/end2end_test.go index ace57cb7..a452c44b 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -5923,7 +5923,7 @@ func TestMethodFromServerStream(t *testing.T) { te.startServer(nil) defer te.tearDown() - _ = te.clientConn().Invoke(context.Background(), testMethod, nil, nil) + _ = te.clientConn().Invoke(context.Background(), testMethod, &testpb.Empty{}, &testpb.Empty{}) if !ok || method != testMethod { t.Fatalf("Invoke with method %q, got %q, %v, want %q, true", testMethod, method, ok, testMethod) }