diff --git a/Documentation/compression.md b/Documentation/compression.md new file mode 100644 index 00000000..204f880a --- /dev/null +++ b/Documentation/compression.md @@ -0,0 +1,80 @@ +# Compression + +The preferred method for configuring message compression on both clients and +servers is to use +[`encoding.RegisterCompressor`](https://godoc.org/google.golang.org/grpc/encoding#RegisterCompressor) +to register an implementation of a compression algorithm. See +`grpc/encoding/gzip/gzip.go` for an example of how to implement one. + +Once a compressor has been registered on the client-side, RPCs may be sent using +it via the +[`UseCompressor`](https://godoc.org/google.golang.org/grpc#UseCompressor) +`CallOption`. Remember that `CallOption`s may be turned into defaults for all +calls from a `ClientConn` by using the +[`WithDefaultCallOptions`](https://godoc.org/google.golang.org/grpc#WithDefaultCallOptions) +`DialOption`. If `UseCompressor` is used and the corresponding compressor has +not been installed, an `Internal` error will be returned to the application +before the RPC is sent. + +Server-side, registered compressors will be used automatically to decode request +messages and encode the responses. Servers currently always respond using the +same compression method specified by the client. If the corresponding +compressor has not been registered, an `Unimplemented` status will be returned +to the client. + +## Deprecated API + +There is a deprecated API for setting compression as well. It is not +recommended for use. However, if you were previously using it, the following +section may be helpful in understanding how it works in combination with the new +API. + +### Client-Side + +There are two legacy functions and one new function to configure compression: + +```go +func WithCompressor(grpc.Compressor) DialOption {} +func WithDecompressor(grpc.Decompressor) DialOption {} +func UseCompressor(name) CallOption {} +``` + +For outgoing requests, the following rules are applied in order: +1. If `UseCompressor` is used, messages will be compressed using the compressor + named. + * If the compressor named is not registered, an Internal error is returned + back to the client before sending the RPC. + * If UseCompressor("identity"), no compressor will be used, but "identity" + will be sent in the header to the server. +1. If `WithCompressor` is used, messages will be compressed using that + compressor implementation. +1. Otherwise, outbound messages will be uncompressed. + +For incoming responses, the following rules are applied in order: +1. If `WithDecompressor` is used and it matches the message's encoding, it will + be used. +1. If a registered compressor matches the response's encoding, it will be used. +1. Otherwise, the stream will be closed and an `Unimplemented` status error will + be returned to the application. + +### Server-Side + +There are two legacy functions to configure compression: +```go +func RPCCompressor(grpc.Compressor) ServerOption {} +func RPCDecompressor(grpc.Decompressor) ServerOption {} +``` + +For incoming requests, the following rules are applied in order: +1. If `RPCDecompressor` is used and that decompressor matches the request's + encoding: it will be used. +1. If a registered compressor matches the request's encoding, it will be used. +1. Otherwise, an `Unimplemented` status will be returned to the client. + +For outgoing responses, the following rules are applied in order: +1. If `RPCCompressor` is used, that compressor will be used to compress all + response messages. +1. If compression was used for the incoming request and a registered compressor + supports it, that same compression method will be used for the outgoing + response. +1. Otherwise, no compression will be used for the outgoing response. diff --git a/call.go b/call.go index 744514c3..0854f84b 100644 --- a/call.go +++ b/call.go @@ -61,7 +61,17 @@ func recvResponse(ctx context.Context, dopts dialOptions, t transport.ClientTran if c.maxReceiveMessageSize == nil { return Errorf(codes.Internal, "callInfo maxReceiveMessageSize field uninitialized(nil)") } - if err = recv(p, dopts.codec, stream, dopts.dc, reply, *c.maxReceiveMessageSize, inPayload, encoding.GetCompressor(c.compressorType)); err != nil { + + // Set dc if it exists and matches the message compression type used, + // otherwise set comp if a registered compressor exists for it. + var comp encoding.Compressor + var dc Decompressor + if rc := stream.RecvCompress(); dopts.dc != nil && dopts.dc.Type() == rc { + dc = dopts.dc + } else if rc != "" && rc != encoding.Identity { + comp = encoding.GetCompressor(rc) + } + if err = recv(p, dopts.codec, stream, dc, reply, *c.maxReceiveMessageSize, inPayload, comp); err != nil { if err == io.EOF { break } @@ -95,10 +105,18 @@ func sendRequest(ctx context.Context, dopts dialOptions, compressor Compressor, Client: true, } } - if c.compressorType != "" && encoding.GetCompressor(c.compressorType) == nil { - return Errorf(codes.Internal, "grpc: Compressor is not installed for grpc-encoding %q", c.compressorType) + // Set comp and clear compressor if a registered compressor matches the type + // specified via UseCompressor. (And error if a matching compressor is not + // registered.) + var comp encoding.Compressor + if ct := c.compressorType; ct != "" && ct != encoding.Identity { + compressor = nil // Disable the legacy compressor. + comp = encoding.GetCompressor(ct) + if comp == nil { + return Errorf(codes.Internal, "grpc: Compressor is not installed for grpc-encoding %q", ct) + } } - hdr, data, err := encode(dopts.codec, args, compressor, outPayload, encoding.GetCompressor(c.compressorType)) + hdr, data, err := encode(dopts.codec, args, compressor, outPayload, comp) if err != nil { return err } @@ -211,9 +229,6 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli Host: cc.authority, Method: method, } - if cc.dopts.cp != nil { - callHdr.SendCompress = cc.dopts.cp.Type() - } if c.creds != nil { callHdr.Creds = c.creds } diff --git a/clientconn.go b/clientconn.go index e4e72368..94e3bcf9 100644 --- a/clientconn.go +++ b/clientconn.go @@ -107,16 +107,6 @@ const ( // DialOption configures how we set up the connection. type DialOption func(*dialOptions) -// UseCompressor returns a CallOption which sets the compressor used when sending the request. -// If WithCompressor is set, UseCompressor has higher priority. -// This API is EXPERIMENTAL. -func UseCompressor(name string) CallOption { - return beforeCall(func(c *callInfo) error { - c.compressorType = name - return nil - }) -} - // WithWriteBufferSize lets you set the size of write buffer, this determines how much data can be batched // before doing a write on the wire. func WithWriteBufferSize(s int) DialOption { @@ -168,18 +158,26 @@ func WithCodec(c Codec) DialOption { } } -// WithCompressor returns a DialOption which sets a CompressorGenerator for generating message -// compressor. It has lower priority than the compressor set by RegisterCompressor. -// This function is deprecated. +// WithCompressor returns a DialOption which sets a Compressor to use for +// message compression. It has lower priority than the compressor set by +// the UseCompressor CallOption. +// +// Deprecated: use UseCompressor instead. func WithCompressor(cp Compressor) DialOption { return func(o *dialOptions) { o.cp = cp } } -// WithDecompressor returns a DialOption which sets a DecompressorGenerator for generating -// message decompressor. It has higher priority than the decompressor set by RegisterCompressor. -// This function is deprecated. +// WithDecompressor returns a DialOption which sets a Decompressor to use for +// incoming message decompression. If incoming response messages are encoded +// using the decompressor's Type(), it will be used. Otherwise, the message +// encoding will be used to look up the compressor registered via +// encoding.RegisterCompressor, which will then be used to decompress the +// message. If no compressor is registered for the encoding, an Unimplemented +// status error will be returned. +// +// Deprecated: use encoding.RegisterCompressor instead. func WithDecompressor(dc Decompressor) DialOption { return func(o *dialOptions) { o.dc = dc diff --git a/encoding/encoding.go b/encoding/encoding.go index f6cc3d66..47d10b07 100644 --- a/encoding/encoding.go +++ b/encoding/encoding.go @@ -55,3 +55,7 @@ func RegisterCompressor(c Compressor) { func GetCompressor(name string) Compressor { return registerCompressor[name] } + +// Identity specifies the optional encoding for uncompressed streams. +// It is intended for grpc internal use only. +const Identity = "identity" diff --git a/rpc_util.go b/rpc_util.go index 7c39ed15..eae2264b 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -236,6 +236,18 @@ func PerRPCCredentials(creds credentials.PerRPCCredentials) CallOption { }) } +// UseCompressor returns a CallOption which sets the compressor used when +// sending the request. If WithCompressor is also set, UseCompressor has +// higher priority. +// +// This API is EXPERIMENTAL. +func UseCompressor(name string) CallOption { + return beforeCall(func(c *callInfo) error { + c.compressorType = name + return nil + }) +} + // The format of the payload: compressed or not? type payloadFormat uint8 @@ -359,22 +371,26 @@ func encode(c Codec, msg interface{}, cp Compressor, outPayload *stats.OutPayloa return bufHeader, b, nil } -func checkRecvPayload(pf payloadFormat, recvCompress string, dc Decompressor) error { +func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool) *status.Status { switch pf { case compressionNone: case compressionMade: - if (dc == nil || recvCompress != dc.Type()) && encoding.GetCompressor(recvCompress) == nil { - return Errorf(codes.Unimplemented, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress) + if recvCompress == "" || recvCompress == encoding.Identity { + return status.New(codes.Internal, "grpc: compressed flag set with identity or empty encoding") + } + if !haveCompressor { + return status.Newf(codes.Unimplemented, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress) } default: - return Errorf(codes.Internal, "grpc: received unexpected payload format %d", pf) + return status.Newf(codes.Internal, "grpc: received unexpected payload format %d", pf) } return nil } -// TODO(ddyihai): eliminate extra Compressor parameter. -func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{}, maxReceiveMessageSize int, - inPayload *stats.InPayload, compressor encoding.Compressor) error { +// For the two compressor parameters, both should not be set, but if they are, +// dc takes precedence over compressor. +// TODO(dfawley): wrap the old compressor/decompressor using the new API? +func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{}, maxReceiveMessageSize int, inPayload *stats.InPayload, compressor encoding.Compressor) error { pf, d, err := p.recvMsg(maxReceiveMessageSize) if err != nil { return err @@ -382,9 +398,11 @@ func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{ if inPayload != nil { inPayload.WireLength = len(d) } - if err := checkRecvPayload(pf, s.RecvCompress(), dc); err != nil { - return err + + if st := checkRecvPayload(pf, s.RecvCompress(), compressor != nil || dc != nil); st != nil { + return st.Err() } + if pf == compressionMade { // To match legacy behavior, if the decompressor is set by WithDecompressor or RPCDecompressor, // use this decompressor as the default. diff --git a/server.go b/server.go index e2d11ce9..e9737fc4 100644 --- a/server.go +++ b/server.go @@ -191,18 +191,24 @@ func CustomCodec(codec Codec) ServerOption { } } -// RPCCompressor returns a ServerOption that sets a compressor for outbound messages. -// It has lower priority than the compressor set by RegisterCompressor. -// This function is deprecated. +// RPCCompressor returns a ServerOption that sets a compressor for outbound +// messages. For backward compatibility, all outbound messages will be sent +// using this compressor, regardless of incoming message compression. By +// default, server messages will be sent using the same compressor with which +// request messages were sent. +// +// Deprecated: use encoding.RegisterCompressor instead. func RPCCompressor(cp Compressor) ServerOption { return func(o *options) { o.cp = cp } } -// RPCDecompressor returns a ServerOption that sets a decompressor for inbound messages. -// It has higher priority than the decompressor set by RegisterCompressor. -// This function is deprecated. +// RPCDecompressor returns a ServerOption that sets a decompressor for inbound +// messages. It has higher priority than decompressors registered via +// encoding.RegisterCompressor. +// +// Deprecated: use encoding.RegisterCompressor instead. func RPCDecompressor(dc Decompressor) ServerOption { return func(o *options) { o.dc = dc @@ -725,20 +731,14 @@ func (s *Server) removeConn(c io.Closer) { } } -func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Stream, msg interface{}, cp Compressor, opts *transport.Options) error { +func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Stream, msg interface{}, cp Compressor, opts *transport.Options, comp encoding.Compressor) error { var ( outPayload *stats.OutPayload ) if s.opts.statsHandler != nil { outPayload = &stats.OutPayload{} } - if stream.RecvCompress() != "" { - // Server receives compressor, check compressor set by register and default. - if encoding.GetCompressor(stream.RecvCompress()) == nil && (cp == nil || cp != nil && cp.Type() != stream.RecvCompress()) { - return Errorf(codes.Internal, "grpc: Compressor is not installed for grpc-encoding %q", stream.RecvCompress()) - } - } - hdr, data, err := encode(s.opts.codec, msg, cp, outPayload, encoding.GetCompressor(stream.RecvCompress())) + hdr, data, err := encode(s.opts.codec, msg, cp, outPayload, comp) if err != nil { grpclog.Errorln("grpc: server failed to encode response: ", err) return err @@ -782,12 +782,43 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } }() } - if stream.RecvCompress() != "" { - stream.SetSendCompress(stream.RecvCompress()) - } else if s.opts.cp != nil { - // NOTE: this needs to be ahead of all handling, https://github.com/grpc/grpc-go/issues/686. - stream.SetSendCompress(s.opts.cp.Type()) + + // comp and cp are used for compression. decomp and dc are used for + // decompression. If comp and decomp are both set, they are the same; + // however they are kept separate to ensure that at most one of the + // compressor/decompressor variable pairs are set for use later. + var comp, decomp encoding.Compressor + var cp Compressor + var dc Decompressor + + // If dc is set and matches the stream's compression, use it. Otherwise, try + // to find a matching registered compressor for decomp. + if rc := stream.RecvCompress(); s.opts.dc != nil && s.opts.dc.Type() == rc { + dc = s.opts.dc + } else if rc != "" && rc != encoding.Identity { + decomp = encoding.GetCompressor(rc) + if decomp == nil { + st := status.Newf(codes.Unimplemented, "grpc: Decompressor is not installed for grpc-encoding %q", rc) + t.WriteStatus(stream, st) + return st.Err() + } } + + // If cp is set, use it. Otherwise, attempt to compress the response using + // the incoming message compression method. + // + // NOTE: this needs to be ahead of all handling, https://github.com/grpc/grpc-go/issues/686. + if s.opts.cp != nil { + cp = s.opts.cp + stream.SetSendCompress(cp.Type()) + } else if rc := stream.RecvCompress(); rc != "" && rc != encoding.Identity { + // Legacy compressor not specified; attempt to respond with same encoding. + comp = encoding.GetCompressor(rc) + if comp != nil { + stream.SetSendCompress(rc) + } + } + p := &parser{r: stream} pf, req, err := p.recvMsg(s.opts.maxReceiveMessageSize) if err == io.EOF { @@ -816,18 +847,11 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } return err } - if err := checkRecvPayload(pf, stream.RecvCompress(), s.opts.dc); err != nil { - if st, ok := status.FromError(err); ok { - if e := t.WriteStatus(stream, st); e != nil { - grpclog.Warningf("grpc: Server.processUnaryRPC failed to write status %v", e) - } - return err - } - if e := t.WriteStatus(stream, status.New(codes.Internal, err.Error())); e != nil { + if st := checkRecvPayload(pf, stream.RecvCompress(), dc != nil || decomp != nil); st != nil { + if e := t.WriteStatus(stream, st); e != nil { grpclog.Warningf("grpc: Server.processUnaryRPC failed to write status %v", e) } - - // TODO checkRecvPayload always return RPC error. Add a return here if necessary. + return st.Err() } var inPayload *stats.InPayload if sh != nil { @@ -841,14 +865,13 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } if pf == compressionMade { var err error - if s.opts.dc != nil { - req, err = s.opts.dc.Do(bytes.NewReader(req)) + if dc != nil { + req, err = dc.Do(bytes.NewReader(req)) if err != nil { return Errorf(codes.Internal, err.Error()) } } else { - dcReader := encoding.GetCompressor(stream.RecvCompress()) - tmp, _ := dcReader.Decompress(bytes.NewReader(req)) + tmp, _ := decomp.Decompress(bytes.NewReader(req)) req, err = ioutil.ReadAll(tmp) if err != nil { return Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err) @@ -898,7 +921,8 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. Last: true, Delay: false, } - if err := s.sendResponse(t, stream, reply, s.opts.cp, opts); err != nil { + + if err := s.sendResponse(t, stream, reply, cp, opts, comp); err != nil { if err == io.EOF { // The entire stream is done (for unary RPC only). return err @@ -947,24 +971,45 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp sh.HandleRPC(stream.Context(), end) }() } - if stream.RecvCompress() != "" { - stream.SetSendCompress(stream.RecvCompress()) - } else if s.opts.cp != nil { - stream.SetSendCompress(s.opts.cp.Type()) - } ss := &serverStream{ - t: t, - s: stream, - p: &parser{r: stream}, - codec: s.opts.codec, - cpType: stream.RecvCompress(), - cp: s.opts.cp, - dc: s.opts.dc, + t: t, + s: stream, + p: &parser{r: stream}, + codec: s.opts.codec, maxReceiveMessageSize: s.opts.maxReceiveMessageSize, maxSendMessageSize: s.opts.maxSendMessageSize, trInfo: trInfo, statsHandler: sh, } + + // If dc is set and matches the stream's compression, use it. Otherwise, try + // to find a matching registered compressor for decomp. + if rc := stream.RecvCompress(); s.opts.dc != nil && s.opts.dc.Type() == rc { + ss.dc = s.opts.dc + } else if rc != "" && rc != encoding.Identity { + ss.decomp = encoding.GetCompressor(rc) + if ss.decomp == nil { + st := status.Newf(codes.Unimplemented, "grpc: Decompressor is not installed for grpc-encoding %q", rc) + t.WriteStatus(ss.s, st) + return st.Err() + } + } + + // If cp is set, use it. Otherwise, attempt to compress the response using + // the incoming message compression method. + // + // NOTE: this needs to be ahead of all handling, https://github.com/grpc/grpc-go/issues/686. + if s.opts.cp != nil { + ss.cp = s.opts.cp + stream.SetSendCompress(s.opts.cp.Type()) + } else if rc := stream.RecvCompress(); rc != "" && rc != encoding.Identity { + // Legacy compressor not specified; attempt to respond with same encoding. + ss.comp = encoding.GetCompressor(rc) + if ss.comp != nil { + stream.SetSendCompress(rc) + } + } + if trInfo != nil { trInfo.tr.LazyLog(&trInfo.firstLine, false) defer func() { diff --git a/service_config_test.go b/service_config_test.go index 737cc2c4..7e985457 100644 --- a/service_config_test.go +++ b/service_config_test.go @@ -83,7 +83,7 @@ func TestParseLoadBalancer(t *testing.T) { } } -func TestPraseWaitForReady(t *testing.T) { +func TestParseWaitForReady(t *testing.T) { testcases := []struct { scjs string wantSC ServiceConfig diff --git a/stream.go b/stream.go index 44547b79..a0e5a84b 100644 --- a/stream.go +++ b/stream.go @@ -151,10 +151,24 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth // time soon, so we ask the transport to flush the header. Flush: desc.ClientStreams, } - if c.compressorType != "" { - callHdr.SendCompress = c.compressorType + + // Set our outgoing compression according to the UseCompressor CallOption, if + // set. In that case, also find the compressor from the encoding package. + // Otherwise, use the compressor configured by the WithCompressor DialOption, + // if set. + var cp Compressor + var comp encoding.Compressor + if ct := c.compressorType; ct != "" { + callHdr.SendCompress = ct + if ct != encoding.Identity { + comp = encoding.GetCompressor(ct) + if comp == nil { + return nil, Errorf(codes.Internal, "grpc: Compressor is not installed for requested grpc-encoding %q", ct) + } + } } else if cc.dopts.cp != nil { callHdr.SendCompress = cc.dopts.cp.Type() + cp = cc.dopts.cp } if c.creds != nil { callHdr.Creds = c.creds @@ -241,9 +255,9 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth c: c, desc: desc, codec: cc.dopts.codec, - cpType: c.compressorType, - cp: cc.dopts.cp, + cp: cp, dc: cc.dopts.dc, + comp: comp, cancel: cancel, done: done, @@ -285,16 +299,20 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth // clientStream implements a client side Stream. type clientStream struct { - opts []CallOption - c *callInfo - t transport.ClientTransport - s *transport.Stream - p *parser - desc *StreamDesc - codec Codec - cpType string - cp Compressor - dc Decompressor + opts []CallOption + c *callInfo + t transport.ClientTransport + s *transport.Stream + p *parser + desc *StreamDesc + + codec Codec + cp Compressor + dc Decompressor + comp encoding.Compressor + decomp encoding.Compressor + decompSet bool + cancel context.CancelFunc tracing bool // set to EnableTracing when the clientStream is created. @@ -370,10 +388,7 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) { Client: true, } } - if cs.cpType != "" && encoding.GetCompressor(cs.cpType) == nil { - return Errorf(codes.Internal, "grpc: Compressor is not installed for grpc-encoding %q", cs.cpType) - } - hdr, data, err := encode(cs.codec, m, cs.cp, outPayload, encoding.GetCompressor(cs.cpType)) + hdr, data, err := encode(cs.codec, m, cs.cp, outPayload, cs.comp) if err != nil { return err } @@ -401,7 +416,23 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) { if cs.c.maxReceiveMessageSize == nil { return Errorf(codes.Internal, "callInfo maxReceiveMessageSize field uninitialized(nil)") } - err = recv(cs.p, cs.codec, cs.s, cs.dc, m, *cs.c.maxReceiveMessageSize, inPayload, encoding.GetCompressor(cs.cpType)) + if !cs.decompSet { + // Block until we receive headers containing received message encoding. + if ct := cs.s.RecvCompress(); ct != "" && ct != encoding.Identity { + if cs.dc == nil || cs.dc.Type() != ct { + // No configured decompressor, or it does not match the incoming + // message encoding; attempt to find a registered compressor that does. + cs.dc = nil + cs.decomp = encoding.GetCompressor(ct) + } + } else { + // No compression is used; disable our decompressor. + cs.dc = nil + } + // Only initialize this state once per stream. + cs.decompSet = true + } + err = recv(cs.p, cs.codec, cs.s, cs.dc, m, *cs.c.maxReceiveMessageSize, inPayload, cs.decomp) defer func() { // err != nil indicates the termination of the stream. if err != nil { @@ -427,7 +458,7 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) { if cs.c.maxReceiveMessageSize == nil { return Errorf(codes.Internal, "callInfo maxReceiveMessageSize field uninitialized(nil)") } - err = recv(cs.p, cs.codec, cs.s, cs.dc, m, *cs.c.maxReceiveMessageSize, nil, encoding.GetCompressor(cs.cpType)) + err = recv(cs.p, cs.codec, cs.s, cs.dc, m, *cs.c.maxReceiveMessageSize, nil, cs.decomp) cs.closeTransportStream(err) if err == nil { return toRPCErr(errors.New("grpc: client streaming protocol violation: get , want ")) @@ -552,13 +583,16 @@ type ServerStream interface { // serverStream implements a server side Stream. type serverStream struct { - t transport.ServerTransport - s *transport.Stream - p *parser - codec Codec - cpType string - cp Compressor - dc Decompressor + t transport.ServerTransport + s *transport.Stream + p *parser + codec Codec + + cp Compressor + dc Decompressor + comp encoding.Compressor + decomp encoding.Compressor + maxReceiveMessageSize int maxSendMessageSize int trInfo *traceInfo @@ -614,12 +648,7 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) { if ss.statsHandler != nil { outPayload = &stats.OutPayload{} } - if ss.cpType != "" { - if encoding.GetCompressor(ss.cpType) == nil && (ss.cp == nil || ss.cp != nil && ss.cp.Type() != ss.cpType) { - return Errorf(codes.Internal, "grpc: Compressor is not installed for grpc-encoding %q", ss.cpType) - } - } - hdr, data, err := encode(ss.codec, m, ss.cp, outPayload, encoding.GetCompressor(ss.cpType)) + hdr, data, err := encode(ss.codec, m, ss.cp, outPayload, ss.comp) if err != nil { return err } @@ -659,7 +688,7 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) { if ss.statsHandler != nil { inPayload = &stats.InPayload{} } - if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxReceiveMessageSize, inPayload, encoding.GetCompressor(ss.cpType)); err != nil { + if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxReceiveMessageSize, inPayload, ss.decomp); err != nil { if err == io.EOF { return err } diff --git a/test/end2end_test.go b/test/end2end_test.go index e7a1f1bd..7a455e6d 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -3886,9 +3886,68 @@ func testCompressOK(t *testing.T, e env) { if err := stream.Send(sreq); err != nil { t.Fatalf("%v.Send(%v) = %v, want ", stream, sreq, err) } + stream.CloseSend() if _, err := stream.Recv(); err != nil { t.Fatalf("%v.Recv() = %v, want ", stream, err) } + if _, err := stream.Recv(); err != io.EOF { + t.Fatalf("%v.Recv() = %v, want io.EOF", stream, err) + } +} + +func TestIdentityEncoding(t *testing.T) { + defer leakcheck.Check(t) + for _, e := range listTestEnv() { + testIdentityEncoding(t, e) + } +} + +func testIdentityEncoding(t *testing.T, e env) { + te := newTest(t, e) + te.startServer(&testServer{security: e.security}) + defer te.tearDown() + tc := testpb.NewTestServiceClient(te.clientConn()) + + // Unary call + payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, 5) + if err != nil { + t.Fatal(err) + } + req := &testpb.SimpleRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE, + ResponseSize: 10, + Payload: payload, + } + ctx := metadata.NewOutgoingContext(context.Background(), metadata.Pairs("something", "something")) + if _, err := tc.UnaryCall(ctx, req); err != nil { + t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, ", err) + } + // Streaming RPC + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + stream, err := tc.FullDuplexCall(ctx, grpc.UseCompressor("identity")) + if err != nil { + t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) + } + payload, err = newPayload(testpb.PayloadType_COMPRESSABLE, int32(31415)) + if err != nil { + t.Fatal(err) + } + sreq := &testpb.StreamingOutputCallRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE, + ResponseParameters: []*testpb.ResponseParameters{{Size: 10}}, + Payload: payload, + } + if err := stream.Send(sreq); err != nil { + t.Fatalf("%v.Send(%v) = %v, want ", stream, sreq, err) + } + stream.CloseSend() + if _, err := stream.Recv(); err != nil { + t.Fatalf("%v.Recv() = %v, want ", stream, err) + } + if _, err := stream.Recv(); err != io.EOF { + t.Fatalf("%v.Recv() = %v, want io.EOF", stream, err) + } } func TestUnaryClientInterceptor(t *testing.T) { diff --git a/transport/transport.go b/transport/transport.go index d48e0611..b7a5dbe4 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -252,6 +252,9 @@ type Stream struct { // RecvCompress returns the compression algorithm applied to the inbound // message. It is empty string if there is no compression applied. func (s *Stream) RecvCompress() string { + if s.headerChan != nil { + <-s.headerChan + } return s.recvCompress } @@ -528,10 +531,6 @@ type CallHdr struct { // Method specifies the operation to perform. Method string - // RecvCompress specifies the compression algorithm applied on - // inbound messages. - RecvCompress string - // SendCompress specifies the compression algorithm applied on // outbound message. SendCompress string diff --git a/vet.sh b/vet.sh index e0e202f0..02d4bae3 100755 --- a/vet.sh +++ b/vet.sh @@ -76,5 +76,10 @@ if [[ "$check_proto" = "true" ]]; then fi # TODO(menghanl): fix errors in transport_test. -staticcheck -ignore google.golang.org/grpc/transport/transport_test.go:SA2002 ./... +staticcheck -ignore ' +google.golang.org/grpc/transport/transport_test.go:SA2002 +google.golang.org/grpc/benchmark/benchmain/main.go:SA1019 +google.golang.org/grpc/stats/stats_test.go:SA1019 +google.golang.org/grpc/test/end2end_test.go:SA1019 +' ./... misspell -error .