Add proper support for 'identity' encoding type (#1664)

This commit is contained in:
dfawley
2017-11-17 09:24:54 -08:00
committed by GitHub
parent c1fc29613d
commit 816fa5b06f
11 changed files with 370 additions and 118 deletions

View File

@ -0,0 +1,80 @@
# Compression
The preferred method for configuring message compression on both clients and
servers is to use
[`encoding.RegisterCompressor`](https://godoc.org/google.golang.org/grpc/encoding#RegisterCompressor)
to register an implementation of a compression algorithm. See
`grpc/encoding/gzip/gzip.go` for an example of how to implement one.
Once a compressor has been registered on the client-side, RPCs may be sent using
it via the
[`UseCompressor`](https://godoc.org/google.golang.org/grpc#UseCompressor)
`CallOption`. Remember that `CallOption`s may be turned into defaults for all
calls from a `ClientConn` by using the
[`WithDefaultCallOptions`](https://godoc.org/google.golang.org/grpc#WithDefaultCallOptions)
`DialOption`. If `UseCompressor` is used and the corresponding compressor has
not been installed, an `Internal` error will be returned to the application
before the RPC is sent.
Server-side, registered compressors will be used automatically to decode request
messages and encode the responses. Servers currently always respond using the
same compression method specified by the client. If the corresponding
compressor has not been registered, an `Unimplemented` status will be returned
to the client.
## Deprecated API
There is a deprecated API for setting compression as well. It is not
recommended for use. However, if you were previously using it, the following
section may be helpful in understanding how it works in combination with the new
API.
### Client-Side
There are two legacy functions and one new function to configure compression:
```go
func WithCompressor(grpc.Compressor) DialOption {}
func WithDecompressor(grpc.Decompressor) DialOption {}
func UseCompressor(name) CallOption {}
```
For outgoing requests, the following rules are applied in order:
1. If `UseCompressor` is used, messages will be compressed using the compressor
named.
* If the compressor named is not registered, an Internal error is returned
back to the client before sending the RPC.
* If UseCompressor("identity"), no compressor will be used, but "identity"
will be sent in the header to the server.
1. If `WithCompressor` is used, messages will be compressed using that
compressor implementation.
1. Otherwise, outbound messages will be uncompressed.
For incoming responses, the following rules are applied in order:
1. If `WithDecompressor` is used and it matches the message's encoding, it will
be used.
1. If a registered compressor matches the response's encoding, it will be used.
1. Otherwise, the stream will be closed and an `Unimplemented` status error will
be returned to the application.
### Server-Side
There are two legacy functions to configure compression:
```go
func RPCCompressor(grpc.Compressor) ServerOption {}
func RPCDecompressor(grpc.Decompressor) ServerOption {}
```
For incoming requests, the following rules are applied in order:
1. If `RPCDecompressor` is used and that decompressor matches the request's
encoding: it will be used.
1. If a registered compressor matches the request's encoding, it will be used.
1. Otherwise, an `Unimplemented` status will be returned to the client.
For outgoing responses, the following rules are applied in order:
1. If `RPCCompressor` is used, that compressor will be used to compress all
response messages.
1. If compression was used for the incoming request and a registered compressor
supports it, that same compression method will be used for the outgoing
response.
1. Otherwise, no compression will be used for the outgoing response.

29
call.go
View File

@ -61,7 +61,17 @@ func recvResponse(ctx context.Context, dopts dialOptions, t transport.ClientTran
if c.maxReceiveMessageSize == nil { if c.maxReceiveMessageSize == nil {
return Errorf(codes.Internal, "callInfo maxReceiveMessageSize field uninitialized(nil)") return Errorf(codes.Internal, "callInfo maxReceiveMessageSize field uninitialized(nil)")
} }
if err = recv(p, dopts.codec, stream, dopts.dc, reply, *c.maxReceiveMessageSize, inPayload, encoding.GetCompressor(c.compressorType)); err != nil {
// Set dc if it exists and matches the message compression type used,
// otherwise set comp if a registered compressor exists for it.
var comp encoding.Compressor
var dc Decompressor
if rc := stream.RecvCompress(); dopts.dc != nil && dopts.dc.Type() == rc {
dc = dopts.dc
} else if rc != "" && rc != encoding.Identity {
comp = encoding.GetCompressor(rc)
}
if err = recv(p, dopts.codec, stream, dc, reply, *c.maxReceiveMessageSize, inPayload, comp); err != nil {
if err == io.EOF { if err == io.EOF {
break break
} }
@ -95,10 +105,18 @@ func sendRequest(ctx context.Context, dopts dialOptions, compressor Compressor,
Client: true, Client: true,
} }
} }
if c.compressorType != "" && encoding.GetCompressor(c.compressorType) == nil { // Set comp and clear compressor if a registered compressor matches the type
return Errorf(codes.Internal, "grpc: Compressor is not installed for grpc-encoding %q", c.compressorType) // specified via UseCompressor. (And error if a matching compressor is not
// registered.)
var comp encoding.Compressor
if ct := c.compressorType; ct != "" && ct != encoding.Identity {
compressor = nil // Disable the legacy compressor.
comp = encoding.GetCompressor(ct)
if comp == nil {
return Errorf(codes.Internal, "grpc: Compressor is not installed for grpc-encoding %q", ct)
}
} }
hdr, data, err := encode(dopts.codec, args, compressor, outPayload, encoding.GetCompressor(c.compressorType)) hdr, data, err := encode(dopts.codec, args, compressor, outPayload, comp)
if err != nil { if err != nil {
return err return err
} }
@ -211,9 +229,6 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
Host: cc.authority, Host: cc.authority,
Method: method, Method: method,
} }
if cc.dopts.cp != nil {
callHdr.SendCompress = cc.dopts.cp.Type()
}
if c.creds != nil { if c.creds != nil {
callHdr.Creds = c.creds callHdr.Creds = c.creds
} }

View File

@ -107,16 +107,6 @@ const (
// DialOption configures how we set up the connection. // DialOption configures how we set up the connection.
type DialOption func(*dialOptions) type DialOption func(*dialOptions)
// UseCompressor returns a CallOption which sets the compressor used when sending the request.
// If WithCompressor is set, UseCompressor has higher priority.
// This API is EXPERIMENTAL.
func UseCompressor(name string) CallOption {
return beforeCall(func(c *callInfo) error {
c.compressorType = name
return nil
})
}
// WithWriteBufferSize lets you set the size of write buffer, this determines how much data can be batched // WithWriteBufferSize lets you set the size of write buffer, this determines how much data can be batched
// before doing a write on the wire. // before doing a write on the wire.
func WithWriteBufferSize(s int) DialOption { func WithWriteBufferSize(s int) DialOption {
@ -168,18 +158,26 @@ func WithCodec(c Codec) DialOption {
} }
} }
// WithCompressor returns a DialOption which sets a CompressorGenerator for generating message // WithCompressor returns a DialOption which sets a Compressor to use for
// compressor. It has lower priority than the compressor set by RegisterCompressor. // message compression. It has lower priority than the compressor set by
// This function is deprecated. // the UseCompressor CallOption.
//
// Deprecated: use UseCompressor instead.
func WithCompressor(cp Compressor) DialOption { func WithCompressor(cp Compressor) DialOption {
return func(o *dialOptions) { return func(o *dialOptions) {
o.cp = cp o.cp = cp
} }
} }
// WithDecompressor returns a DialOption which sets a DecompressorGenerator for generating // WithDecompressor returns a DialOption which sets a Decompressor to use for
// message decompressor. It has higher priority than the decompressor set by RegisterCompressor. // incoming message decompression. If incoming response messages are encoded
// This function is deprecated. // using the decompressor's Type(), it will be used. Otherwise, the message
// encoding will be used to look up the compressor registered via
// encoding.RegisterCompressor, which will then be used to decompress the
// message. If no compressor is registered for the encoding, an Unimplemented
// status error will be returned.
//
// Deprecated: use encoding.RegisterCompressor instead.
func WithDecompressor(dc Decompressor) DialOption { func WithDecompressor(dc Decompressor) DialOption {
return func(o *dialOptions) { return func(o *dialOptions) {
o.dc = dc o.dc = dc

View File

@ -55,3 +55,7 @@ func RegisterCompressor(c Compressor) {
func GetCompressor(name string) Compressor { func GetCompressor(name string) Compressor {
return registerCompressor[name] return registerCompressor[name]
} }
// Identity specifies the optional encoding for uncompressed streams.
// It is intended for grpc internal use only.
const Identity = "identity"

View File

@ -236,6 +236,18 @@ func PerRPCCredentials(creds credentials.PerRPCCredentials) CallOption {
}) })
} }
// UseCompressor returns a CallOption which sets the compressor used when
// sending the request. If WithCompressor is also set, UseCompressor has
// higher priority.
//
// This API is EXPERIMENTAL.
func UseCompressor(name string) CallOption {
return beforeCall(func(c *callInfo) error {
c.compressorType = name
return nil
})
}
// The format of the payload: compressed or not? // The format of the payload: compressed or not?
type payloadFormat uint8 type payloadFormat uint8
@ -359,22 +371,26 @@ func encode(c Codec, msg interface{}, cp Compressor, outPayload *stats.OutPayloa
return bufHeader, b, nil return bufHeader, b, nil
} }
func checkRecvPayload(pf payloadFormat, recvCompress string, dc Decompressor) error { func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool) *status.Status {
switch pf { switch pf {
case compressionNone: case compressionNone:
case compressionMade: case compressionMade:
if (dc == nil || recvCompress != dc.Type()) && encoding.GetCompressor(recvCompress) == nil { if recvCompress == "" || recvCompress == encoding.Identity {
return Errorf(codes.Unimplemented, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress) return status.New(codes.Internal, "grpc: compressed flag set with identity or empty encoding")
}
if !haveCompressor {
return status.Newf(codes.Unimplemented, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress)
} }
default: default:
return Errorf(codes.Internal, "grpc: received unexpected payload format %d", pf) return status.Newf(codes.Internal, "grpc: received unexpected payload format %d", pf)
} }
return nil return nil
} }
// TODO(ddyihai): eliminate extra Compressor parameter. // For the two compressor parameters, both should not be set, but if they are,
func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{}, maxReceiveMessageSize int, // dc takes precedence over compressor.
inPayload *stats.InPayload, compressor encoding.Compressor) error { // TODO(dfawley): wrap the old compressor/decompressor using the new API?
func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{}, maxReceiveMessageSize int, inPayload *stats.InPayload, compressor encoding.Compressor) error {
pf, d, err := p.recvMsg(maxReceiveMessageSize) pf, d, err := p.recvMsg(maxReceiveMessageSize)
if err != nil { if err != nil {
return err return err
@ -382,9 +398,11 @@ func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{
if inPayload != nil { if inPayload != nil {
inPayload.WireLength = len(d) inPayload.WireLength = len(d)
} }
if err := checkRecvPayload(pf, s.RecvCompress(), dc); err != nil {
return err if st := checkRecvPayload(pf, s.RecvCompress(), compressor != nil || dc != nil); st != nil {
return st.Err()
} }
if pf == compressionMade { if pf == compressionMade {
// To match legacy behavior, if the decompressor is set by WithDecompressor or RPCDecompressor, // To match legacy behavior, if the decompressor is set by WithDecompressor or RPCDecompressor,
// use this decompressor as the default. // use this decompressor as the default.

137
server.go
View File

@ -191,18 +191,24 @@ func CustomCodec(codec Codec) ServerOption {
} }
} }
// RPCCompressor returns a ServerOption that sets a compressor for outbound messages. // RPCCompressor returns a ServerOption that sets a compressor for outbound
// It has lower priority than the compressor set by RegisterCompressor. // messages. For backward compatibility, all outbound messages will be sent
// This function is deprecated. // using this compressor, regardless of incoming message compression. By
// default, server messages will be sent using the same compressor with which
// request messages were sent.
//
// Deprecated: use encoding.RegisterCompressor instead.
func RPCCompressor(cp Compressor) ServerOption { func RPCCompressor(cp Compressor) ServerOption {
return func(o *options) { return func(o *options) {
o.cp = cp o.cp = cp
} }
} }
// RPCDecompressor returns a ServerOption that sets a decompressor for inbound messages. // RPCDecompressor returns a ServerOption that sets a decompressor for inbound
// It has higher priority than the decompressor set by RegisterCompressor. // messages. It has higher priority than decompressors registered via
// This function is deprecated. // encoding.RegisterCompressor.
//
// Deprecated: use encoding.RegisterCompressor instead.
func RPCDecompressor(dc Decompressor) ServerOption { func RPCDecompressor(dc Decompressor) ServerOption {
return func(o *options) { return func(o *options) {
o.dc = dc o.dc = dc
@ -725,20 +731,14 @@ func (s *Server) removeConn(c io.Closer) {
} }
} }
func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Stream, msg interface{}, cp Compressor, opts *transport.Options) error { func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Stream, msg interface{}, cp Compressor, opts *transport.Options, comp encoding.Compressor) error {
var ( var (
outPayload *stats.OutPayload outPayload *stats.OutPayload
) )
if s.opts.statsHandler != nil { if s.opts.statsHandler != nil {
outPayload = &stats.OutPayload{} outPayload = &stats.OutPayload{}
} }
if stream.RecvCompress() != "" { hdr, data, err := encode(s.opts.codec, msg, cp, outPayload, comp)
// Server receives compressor, check compressor set by register and default.
if encoding.GetCompressor(stream.RecvCompress()) == nil && (cp == nil || cp != nil && cp.Type() != stream.RecvCompress()) {
return Errorf(codes.Internal, "grpc: Compressor is not installed for grpc-encoding %q", stream.RecvCompress())
}
}
hdr, data, err := encode(s.opts.codec, msg, cp, outPayload, encoding.GetCompressor(stream.RecvCompress()))
if err != nil { if err != nil {
grpclog.Errorln("grpc: server failed to encode response: ", err) grpclog.Errorln("grpc: server failed to encode response: ", err)
return err return err
@ -782,12 +782,43 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
} }
}() }()
} }
if stream.RecvCompress() != "" {
stream.SetSendCompress(stream.RecvCompress()) // comp and cp are used for compression. decomp and dc are used for
} else if s.opts.cp != nil { // decompression. If comp and decomp are both set, they are the same;
// NOTE: this needs to be ahead of all handling, https://github.com/grpc/grpc-go/issues/686. // however they are kept separate to ensure that at most one of the
stream.SetSendCompress(s.opts.cp.Type()) // compressor/decompressor variable pairs are set for use later.
var comp, decomp encoding.Compressor
var cp Compressor
var dc Decompressor
// If dc is set and matches the stream's compression, use it. Otherwise, try
// to find a matching registered compressor for decomp.
if rc := stream.RecvCompress(); s.opts.dc != nil && s.opts.dc.Type() == rc {
dc = s.opts.dc
} else if rc != "" && rc != encoding.Identity {
decomp = encoding.GetCompressor(rc)
if decomp == nil {
st := status.Newf(codes.Unimplemented, "grpc: Decompressor is not installed for grpc-encoding %q", rc)
t.WriteStatus(stream, st)
return st.Err()
}
} }
// If cp is set, use it. Otherwise, attempt to compress the response using
// the incoming message compression method.
//
// NOTE: this needs to be ahead of all handling, https://github.com/grpc/grpc-go/issues/686.
if s.opts.cp != nil {
cp = s.opts.cp
stream.SetSendCompress(cp.Type())
} else if rc := stream.RecvCompress(); rc != "" && rc != encoding.Identity {
// Legacy compressor not specified; attempt to respond with same encoding.
comp = encoding.GetCompressor(rc)
if comp != nil {
stream.SetSendCompress(rc)
}
}
p := &parser{r: stream} p := &parser{r: stream}
pf, req, err := p.recvMsg(s.opts.maxReceiveMessageSize) pf, req, err := p.recvMsg(s.opts.maxReceiveMessageSize)
if err == io.EOF { if err == io.EOF {
@ -816,18 +847,11 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
} }
return err return err
} }
if err := checkRecvPayload(pf, stream.RecvCompress(), s.opts.dc); err != nil { if st := checkRecvPayload(pf, stream.RecvCompress(), dc != nil || decomp != nil); st != nil {
if st, ok := status.FromError(err); ok { if e := t.WriteStatus(stream, st); e != nil {
if e := t.WriteStatus(stream, st); e != nil {
grpclog.Warningf("grpc: Server.processUnaryRPC failed to write status %v", e)
}
return err
}
if e := t.WriteStatus(stream, status.New(codes.Internal, err.Error())); e != nil {
grpclog.Warningf("grpc: Server.processUnaryRPC failed to write status %v", e) grpclog.Warningf("grpc: Server.processUnaryRPC failed to write status %v", e)
} }
return st.Err()
// TODO checkRecvPayload always return RPC error. Add a return here if necessary.
} }
var inPayload *stats.InPayload var inPayload *stats.InPayload
if sh != nil { if sh != nil {
@ -841,14 +865,13 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
} }
if pf == compressionMade { if pf == compressionMade {
var err error var err error
if s.opts.dc != nil { if dc != nil {
req, err = s.opts.dc.Do(bytes.NewReader(req)) req, err = dc.Do(bytes.NewReader(req))
if err != nil { if err != nil {
return Errorf(codes.Internal, err.Error()) return Errorf(codes.Internal, err.Error())
} }
} else { } else {
dcReader := encoding.GetCompressor(stream.RecvCompress()) tmp, _ := decomp.Decompress(bytes.NewReader(req))
tmp, _ := dcReader.Decompress(bytes.NewReader(req))
req, err = ioutil.ReadAll(tmp) req, err = ioutil.ReadAll(tmp)
if err != nil { if err != nil {
return Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err) return Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err)
@ -898,7 +921,8 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
Last: true, Last: true,
Delay: false, Delay: false,
} }
if err := s.sendResponse(t, stream, reply, s.opts.cp, opts); err != nil {
if err := s.sendResponse(t, stream, reply, cp, opts, comp); err != nil {
if err == io.EOF { if err == io.EOF {
// The entire stream is done (for unary RPC only). // The entire stream is done (for unary RPC only).
return err return err
@ -947,24 +971,45 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
sh.HandleRPC(stream.Context(), end) sh.HandleRPC(stream.Context(), end)
}() }()
} }
if stream.RecvCompress() != "" {
stream.SetSendCompress(stream.RecvCompress())
} else if s.opts.cp != nil {
stream.SetSendCompress(s.opts.cp.Type())
}
ss := &serverStream{ ss := &serverStream{
t: t, t: t,
s: stream, s: stream,
p: &parser{r: stream}, p: &parser{r: stream},
codec: s.opts.codec, codec: s.opts.codec,
cpType: stream.RecvCompress(),
cp: s.opts.cp,
dc: s.opts.dc,
maxReceiveMessageSize: s.opts.maxReceiveMessageSize, maxReceiveMessageSize: s.opts.maxReceiveMessageSize,
maxSendMessageSize: s.opts.maxSendMessageSize, maxSendMessageSize: s.opts.maxSendMessageSize,
trInfo: trInfo, trInfo: trInfo,
statsHandler: sh, statsHandler: sh,
} }
// If dc is set and matches the stream's compression, use it. Otherwise, try
// to find a matching registered compressor for decomp.
if rc := stream.RecvCompress(); s.opts.dc != nil && s.opts.dc.Type() == rc {
ss.dc = s.opts.dc
} else if rc != "" && rc != encoding.Identity {
ss.decomp = encoding.GetCompressor(rc)
if ss.decomp == nil {
st := status.Newf(codes.Unimplemented, "grpc: Decompressor is not installed for grpc-encoding %q", rc)
t.WriteStatus(ss.s, st)
return st.Err()
}
}
// If cp is set, use it. Otherwise, attempt to compress the response using
// the incoming message compression method.
//
// NOTE: this needs to be ahead of all handling, https://github.com/grpc/grpc-go/issues/686.
if s.opts.cp != nil {
ss.cp = s.opts.cp
stream.SetSendCompress(s.opts.cp.Type())
} else if rc := stream.RecvCompress(); rc != "" && rc != encoding.Identity {
// Legacy compressor not specified; attempt to respond with same encoding.
ss.comp = encoding.GetCompressor(rc)
if ss.comp != nil {
stream.SetSendCompress(rc)
}
}
if trInfo != nil { if trInfo != nil {
trInfo.tr.LazyLog(&trInfo.firstLine, false) trInfo.tr.LazyLog(&trInfo.firstLine, false)
defer func() { defer func() {

View File

@ -83,7 +83,7 @@ func TestParseLoadBalancer(t *testing.T) {
} }
} }
func TestPraseWaitForReady(t *testing.T) { func TestParseWaitForReady(t *testing.T) {
testcases := []struct { testcases := []struct {
scjs string scjs string
wantSC ServiceConfig wantSC ServiceConfig

View File

@ -151,10 +151,24 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
// time soon, so we ask the transport to flush the header. // time soon, so we ask the transport to flush the header.
Flush: desc.ClientStreams, Flush: desc.ClientStreams,
} }
if c.compressorType != "" {
callHdr.SendCompress = c.compressorType // Set our outgoing compression according to the UseCompressor CallOption, if
// set. In that case, also find the compressor from the encoding package.
// Otherwise, use the compressor configured by the WithCompressor DialOption,
// if set.
var cp Compressor
var comp encoding.Compressor
if ct := c.compressorType; ct != "" {
callHdr.SendCompress = ct
if ct != encoding.Identity {
comp = encoding.GetCompressor(ct)
if comp == nil {
return nil, Errorf(codes.Internal, "grpc: Compressor is not installed for requested grpc-encoding %q", ct)
}
}
} else if cc.dopts.cp != nil { } else if cc.dopts.cp != nil {
callHdr.SendCompress = cc.dopts.cp.Type() callHdr.SendCompress = cc.dopts.cp.Type()
cp = cc.dopts.cp
} }
if c.creds != nil { if c.creds != nil {
callHdr.Creds = c.creds callHdr.Creds = c.creds
@ -241,9 +255,9 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
c: c, c: c,
desc: desc, desc: desc,
codec: cc.dopts.codec, codec: cc.dopts.codec,
cpType: c.compressorType, cp: cp,
cp: cc.dopts.cp,
dc: cc.dopts.dc, dc: cc.dopts.dc,
comp: comp,
cancel: cancel, cancel: cancel,
done: done, done: done,
@ -285,16 +299,20 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
// clientStream implements a client side Stream. // clientStream implements a client side Stream.
type clientStream struct { type clientStream struct {
opts []CallOption opts []CallOption
c *callInfo c *callInfo
t transport.ClientTransport t transport.ClientTransport
s *transport.Stream s *transport.Stream
p *parser p *parser
desc *StreamDesc desc *StreamDesc
codec Codec
cpType string codec Codec
cp Compressor cp Compressor
dc Decompressor dc Decompressor
comp encoding.Compressor
decomp encoding.Compressor
decompSet bool
cancel context.CancelFunc cancel context.CancelFunc
tracing bool // set to EnableTracing when the clientStream is created. tracing bool // set to EnableTracing when the clientStream is created.
@ -370,10 +388,7 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {
Client: true, Client: true,
} }
} }
if cs.cpType != "" && encoding.GetCompressor(cs.cpType) == nil { hdr, data, err := encode(cs.codec, m, cs.cp, outPayload, cs.comp)
return Errorf(codes.Internal, "grpc: Compressor is not installed for grpc-encoding %q", cs.cpType)
}
hdr, data, err := encode(cs.codec, m, cs.cp, outPayload, encoding.GetCompressor(cs.cpType))
if err != nil { if err != nil {
return err return err
} }
@ -401,7 +416,23 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) {
if cs.c.maxReceiveMessageSize == nil { if cs.c.maxReceiveMessageSize == nil {
return Errorf(codes.Internal, "callInfo maxReceiveMessageSize field uninitialized(nil)") return Errorf(codes.Internal, "callInfo maxReceiveMessageSize field uninitialized(nil)")
} }
err = recv(cs.p, cs.codec, cs.s, cs.dc, m, *cs.c.maxReceiveMessageSize, inPayload, encoding.GetCompressor(cs.cpType)) if !cs.decompSet {
// Block until we receive headers containing received message encoding.
if ct := cs.s.RecvCompress(); ct != "" && ct != encoding.Identity {
if cs.dc == nil || cs.dc.Type() != ct {
// No configured decompressor, or it does not match the incoming
// message encoding; attempt to find a registered compressor that does.
cs.dc = nil
cs.decomp = encoding.GetCompressor(ct)
}
} else {
// No compression is used; disable our decompressor.
cs.dc = nil
}
// Only initialize this state once per stream.
cs.decompSet = true
}
err = recv(cs.p, cs.codec, cs.s, cs.dc, m, *cs.c.maxReceiveMessageSize, inPayload, cs.decomp)
defer func() { defer func() {
// err != nil indicates the termination of the stream. // err != nil indicates the termination of the stream.
if err != nil { if err != nil {
@ -427,7 +458,7 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) {
if cs.c.maxReceiveMessageSize == nil { if cs.c.maxReceiveMessageSize == nil {
return Errorf(codes.Internal, "callInfo maxReceiveMessageSize field uninitialized(nil)") return Errorf(codes.Internal, "callInfo maxReceiveMessageSize field uninitialized(nil)")
} }
err = recv(cs.p, cs.codec, cs.s, cs.dc, m, *cs.c.maxReceiveMessageSize, nil, encoding.GetCompressor(cs.cpType)) err = recv(cs.p, cs.codec, cs.s, cs.dc, m, *cs.c.maxReceiveMessageSize, nil, cs.decomp)
cs.closeTransportStream(err) cs.closeTransportStream(err)
if err == nil { if err == nil {
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>")) return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
@ -552,13 +583,16 @@ type ServerStream interface {
// serverStream implements a server side Stream. // serverStream implements a server side Stream.
type serverStream struct { type serverStream struct {
t transport.ServerTransport t transport.ServerTransport
s *transport.Stream s *transport.Stream
p *parser p *parser
codec Codec codec Codec
cpType string
cp Compressor cp Compressor
dc Decompressor dc Decompressor
comp encoding.Compressor
decomp encoding.Compressor
maxReceiveMessageSize int maxReceiveMessageSize int
maxSendMessageSize int maxSendMessageSize int
trInfo *traceInfo trInfo *traceInfo
@ -614,12 +648,7 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) {
if ss.statsHandler != nil { if ss.statsHandler != nil {
outPayload = &stats.OutPayload{} outPayload = &stats.OutPayload{}
} }
if ss.cpType != "" { hdr, data, err := encode(ss.codec, m, ss.cp, outPayload, ss.comp)
if encoding.GetCompressor(ss.cpType) == nil && (ss.cp == nil || ss.cp != nil && ss.cp.Type() != ss.cpType) {
return Errorf(codes.Internal, "grpc: Compressor is not installed for grpc-encoding %q", ss.cpType)
}
}
hdr, data, err := encode(ss.codec, m, ss.cp, outPayload, encoding.GetCompressor(ss.cpType))
if err != nil { if err != nil {
return err return err
} }
@ -659,7 +688,7 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) {
if ss.statsHandler != nil { if ss.statsHandler != nil {
inPayload = &stats.InPayload{} inPayload = &stats.InPayload{}
} }
if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxReceiveMessageSize, inPayload, encoding.GetCompressor(ss.cpType)); err != nil { if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxReceiveMessageSize, inPayload, ss.decomp); err != nil {
if err == io.EOF { if err == io.EOF {
return err return err
} }

View File

@ -3886,9 +3886,68 @@ func testCompressOK(t *testing.T, e env) {
if err := stream.Send(sreq); err != nil { if err := stream.Send(sreq); err != nil {
t.Fatalf("%v.Send(%v) = %v, want <nil>", stream, sreq, err) t.Fatalf("%v.Send(%v) = %v, want <nil>", stream, sreq, err)
} }
stream.CloseSend()
if _, err := stream.Recv(); err != nil { if _, err := stream.Recv(); err != nil {
t.Fatalf("%v.Recv() = %v, want <nil>", stream, err) t.Fatalf("%v.Recv() = %v, want <nil>", stream, err)
} }
if _, err := stream.Recv(); err != io.EOF {
t.Fatalf("%v.Recv() = %v, want io.EOF", stream, err)
}
}
func TestIdentityEncoding(t *testing.T) {
defer leakcheck.Check(t)
for _, e := range listTestEnv() {
testIdentityEncoding(t, e)
}
}
func testIdentityEncoding(t *testing.T, e env) {
te := newTest(t, e)
te.startServer(&testServer{security: e.security})
defer te.tearDown()
tc := testpb.NewTestServiceClient(te.clientConn())
// Unary call
payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, 5)
if err != nil {
t.Fatal(err)
}
req := &testpb.SimpleRequest{
ResponseType: testpb.PayloadType_COMPRESSABLE,
ResponseSize: 10,
Payload: payload,
}
ctx := metadata.NewOutgoingContext(context.Background(), metadata.Pairs("something", "something"))
if _, err := tc.UnaryCall(ctx, req); err != nil {
t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, <nil>", err)
}
// Streaming RPC
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
stream, err := tc.FullDuplexCall(ctx, grpc.UseCompressor("identity"))
if err != nil {
t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
}
payload, err = newPayload(testpb.PayloadType_COMPRESSABLE, int32(31415))
if err != nil {
t.Fatal(err)
}
sreq := &testpb.StreamingOutputCallRequest{
ResponseType: testpb.PayloadType_COMPRESSABLE,
ResponseParameters: []*testpb.ResponseParameters{{Size: 10}},
Payload: payload,
}
if err := stream.Send(sreq); err != nil {
t.Fatalf("%v.Send(%v) = %v, want <nil>", stream, sreq, err)
}
stream.CloseSend()
if _, err := stream.Recv(); err != nil {
t.Fatalf("%v.Recv() = %v, want <nil>", stream, err)
}
if _, err := stream.Recv(); err != io.EOF {
t.Fatalf("%v.Recv() = %v, want io.EOF", stream, err)
}
} }
func TestUnaryClientInterceptor(t *testing.T) { func TestUnaryClientInterceptor(t *testing.T) {

View File

@ -252,6 +252,9 @@ type Stream struct {
// RecvCompress returns the compression algorithm applied to the inbound // RecvCompress returns the compression algorithm applied to the inbound
// message. It is empty string if there is no compression applied. // message. It is empty string if there is no compression applied.
func (s *Stream) RecvCompress() string { func (s *Stream) RecvCompress() string {
if s.headerChan != nil {
<-s.headerChan
}
return s.recvCompress return s.recvCompress
} }
@ -528,10 +531,6 @@ type CallHdr struct {
// Method specifies the operation to perform. // Method specifies the operation to perform.
Method string Method string
// RecvCompress specifies the compression algorithm applied on
// inbound messages.
RecvCompress string
// SendCompress specifies the compression algorithm applied on // SendCompress specifies the compression algorithm applied on
// outbound message. // outbound message.
SendCompress string SendCompress string

7
vet.sh
View File

@ -76,5 +76,10 @@ if [[ "$check_proto" = "true" ]]; then
fi fi
# TODO(menghanl): fix errors in transport_test. # TODO(menghanl): fix errors in transport_test.
staticcheck -ignore google.golang.org/grpc/transport/transport_test.go:SA2002 ./... staticcheck -ignore '
google.golang.org/grpc/transport/transport_test.go:SA2002
google.golang.org/grpc/benchmark/benchmain/main.go:SA1019
google.golang.org/grpc/stats/stats_test.go:SA1019
google.golang.org/grpc/test/end2end_test.go:SA1019
' ./...
misspell -error . misspell -error .