From 5db344a40acb427635d1a2da686decd1ad6b7835 Mon Sep 17 00:00:00 2001 From: Zhouyihai Ding Date: Tue, 31 Oct 2017 10:21:13 -0700 Subject: [PATCH] Introduce new Compressor/Decompressor API (#1428) --- call.go | 17 +++--- clientconn.go | 16 +++++- encoding/encoding.go | 57 ++++++++++++++++++ encoding/gzip/gzip.go | 93 ++++++++++++++++++++++++++++++ rpc_util.go | 61 +++++++++++++++----- server.go | 56 ++++++++++++------ stream.go | 27 ++++++--- test/end2end_test.go | 131 +++++++++++++++++++++++++++++++++++++----- 8 files changed, 396 insertions(+), 62 deletions(-) create mode 100644 encoding/encoding.go create mode 100644 encoding/gzip/gzip.go diff --git a/call.go b/call.go index 4bd673cd..368980d4 100644 --- a/call.go +++ b/call.go @@ -19,7 +19,6 @@ package grpc import ( - "bytes" "io" "time" @@ -27,6 +26,7 @@ import ( "golang.org/x/net/trace" "google.golang.org/grpc/balancer" "google.golang.org/grpc/codes" + "google.golang.org/grpc/encoding" "google.golang.org/grpc/peer" "google.golang.org/grpc/stats" "google.golang.org/grpc/status" @@ -62,7 +62,7 @@ func recvResponse(ctx context.Context, dopts dialOptions, t transport.ClientTran if c.maxReceiveMessageSize == nil { return Errorf(codes.Internal, "callInfo maxReceiveMessageSize field uninitialized(nil)") } - if err = recv(p, dopts.codec, stream, dopts.dc, reply, *c.maxReceiveMessageSize, inPayload); err != nil { + if err = recv(p, dopts.codec, stream, dopts.dc, reply, *c.maxReceiveMessageSize, inPayload, encoding.GetCompressor(c.compressorType)); err != nil { if err == io.EOF { break } @@ -89,18 +89,17 @@ func sendRequest(ctx context.Context, dopts dialOptions, compressor Compressor, } }() var ( - cbuf *bytes.Buffer outPayload *stats.OutPayload ) - if compressor != nil { - cbuf = new(bytes.Buffer) - } if dopts.copts.StatsHandler != nil { outPayload = &stats.OutPayload{ Client: true, } } - hdr, data, err := encode(dopts.codec, args, compressor, cbuf, outPayload) + if c.compressorType != "" && encoding.GetCompressor(c.compressorType) == nil { + return Errorf(codes.Internal, "grpc: Compressor is not installed for grpc-encoding %q", c.compressorType) + } + hdr, data, err := encode(dopts.codec, args, compressor, outPayload, encoding.GetCompressor(c.compressorType)) if err != nil { return err } @@ -223,7 +222,9 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli Host: cc.authority, Method: method, } - if cc.dopts.cp != nil { + if c.compressorType != "" { + callHdr.SendCompress = c.compressorType + } else if cc.dopts.cp != nil { callHdr.SendCompress = cc.dopts.cp.Type() } if c.creds != nil { diff --git a/clientconn.go b/clientconn.go index 5f5aac41..2c5d8ee8 100644 --- a/clientconn.go +++ b/clientconn.go @@ -104,6 +104,16 @@ const ( // DialOption configures how we set up the connection. type DialOption func(*dialOptions) +// UseCompressor returns a CallOption which sets the compressor used when sending the request. +// If WithCompressor is set, UseCompressor has higher priority. +// This API is EXPERIMENTAL. +func UseCompressor(name string) CallOption { + return beforeCall(func(c *callInfo) error { + c.compressorType = name + return nil + }) +} + // WithWriteBufferSize lets you set the size of write buffer, this determines how much data can be batched // before doing a write on the wire. func WithWriteBufferSize(s int) DialOption { @@ -156,7 +166,8 @@ func WithCodec(c Codec) DialOption { } // WithCompressor returns a DialOption which sets a CompressorGenerator for generating message -// compressor. +// compressor. It has lower priority than the compressor set by RegisterCompressor. +// This function is deprecated. func WithCompressor(cp Compressor) DialOption { return func(o *dialOptions) { o.cp = cp @@ -164,7 +175,8 @@ func WithCompressor(cp Compressor) DialOption { } // WithDecompressor returns a DialOption which sets a DecompressorGenerator for generating -// message decompressor. +// message decompressor. It has higher priority than the decompressor set by RegisterCompressor. +// This function is deprecated. func WithDecompressor(dc Decompressor) DialOption { return func(o *dialOptions) { o.dc = dc diff --git a/encoding/encoding.go b/encoding/encoding.go new file mode 100644 index 00000000..f6cc3d66 --- /dev/null +++ b/encoding/encoding.go @@ -0,0 +1,57 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package encoding defines the interface for the compressor and the functions +// to register and get the compossor. +// This package is EXPERIMENTAL. +package encoding + +import ( + "io" +) + +var registerCompressor = make(map[string]Compressor) + +// Compressor is used for compressing and decompressing when sending or receiving messages. +type Compressor interface { + // Compress writes the data written to wc to w after compressing it. If an error + // occurs while initializing the compressor, that error is returned instead. + Compress(w io.Writer) (io.WriteCloser, error) + // Decompress reads data from r, decompresses it, and provides the uncompressed data + // via the returned io.Reader. If an error occurs while initializing the decompressor, that error + // is returned instead. + Decompress(r io.Reader) (io.Reader, error) + // Name is the name of the compression codec and is used to set the content coding header. + Name() string +} + +// RegisterCompressor registers the compressor with gRPC by its name. It can be activated when +// sending an RPC via grpc.UseCompressor(). It will be automatically accessed when receiving a +// message based on the content coding header. Servers also use it to send a response with the +// same encoding as the request. +// +// NOTE: this function must only be called during initialization time (i.e. in an init() function). If +// multiple Compressors are registered with the same name, the one registered last will take effect. +func RegisterCompressor(c Compressor) { + registerCompressor[c.Name()] = c +} + +// GetCompressor returns Compressor for the given compressor name. +func GetCompressor(name string) Compressor { + return registerCompressor[name] +} diff --git a/encoding/gzip/gzip.go b/encoding/gzip/gzip.go new file mode 100644 index 00000000..fb4385eb --- /dev/null +++ b/encoding/gzip/gzip.go @@ -0,0 +1,93 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package gzip implements and registers the gzip compressor +// during the initialization. +// This package is EXPERIMENTAL. +package gzip + +import ( + "compress/gzip" + "io" + "io/ioutil" + "sync" + + "google.golang.org/grpc/encoding" +) + +func init() { + c := &compressor{} + c.poolCompressor.New = func() interface{} { + return &writer{Writer: gzip.NewWriter(ioutil.Discard), pool: &c.poolCompressor} + } + encoding.RegisterCompressor(c) +} + +type writer struct { + *gzip.Writer + pool *sync.Pool +} + +func (c *compressor) Compress(w io.Writer) (io.WriteCloser, error) { + z := c.poolCompressor.Get().(*writer) + z.Writer.Reset(w) + return z, nil +} + +func (z *writer) Close() error { + defer z.pool.Put(z) + return z.Writer.Close() +} + +type reader struct { + *gzip.Reader + pool *sync.Pool +} + +func (c *compressor) Decompress(r io.Reader) (io.Reader, error) { + z, inPool := c.poolDecompressor.Get().(*reader) + if !inPool { + newZ, err := gzip.NewReader(r) + if err != nil { + return nil, err + } + return &reader{Reader: newZ, pool: &c.poolDecompressor}, nil + } + if err := z.Reset(r); err != nil { + c.poolDecompressor.Put(z) + return nil, err + } + return z, nil +} + +func (z *reader) Read(p []byte) (n int, err error) { + n, err = z.Reader.Read(p) + if err == io.EOF { + z.pool.Put(z) + } + return n, err +} + +func (c *compressor) Name() string { + return "gzip" +} + +type compressor struct { + poolCompressor sync.Pool + poolDecompressor sync.Pool +} diff --git a/rpc_util.go b/rpc_util.go index 9c8d8819..eccf84de 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -31,6 +31,7 @@ import ( "golang.org/x/net/context" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" + "google.golang.org/grpc/encoding" "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" "google.golang.org/grpc/stats" @@ -122,6 +123,7 @@ func (d *gzipDecompressor) Type() string { // callInfo contains all related configuration and information about an RPC. type callInfo struct { + compressorType string failFast bool headerMD metadata.MD trailerMD metadata.MD @@ -294,13 +296,16 @@ func (p *parser) recvMsg(maxReceiveMessageSize int) (pf payloadFormat, msg []byt // encode serializes msg and returns a buffer of message header and a buffer of msg. // If msg is nil, it generates the message header and an empty msg buffer. -func encode(c Codec, msg interface{}, cp Compressor, cbuf *bytes.Buffer, outPayload *stats.OutPayload) ([]byte, []byte, error) { - var b []byte +// TODO(ddyihai): eliminate extra Compressor parameter. +func encode(c Codec, msg interface{}, cp Compressor, outPayload *stats.OutPayload, compressor encoding.Compressor) ([]byte, []byte, error) { + var ( + b []byte + cbuf *bytes.Buffer + ) const ( payloadLen = 1 sizeLen = 4 ) - if msg != nil { var err error b, err = c.Marshal(msg) @@ -313,24 +318,35 @@ func encode(c Codec, msg interface{}, cp Compressor, cbuf *bytes.Buffer, outPayl outPayload.Data = b outPayload.Length = len(b) } - if cp != nil { - if err := cp.Do(cbuf, b); err != nil { - return nil, nil, Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error()) + if compressor != nil || cp != nil { + cbuf = new(bytes.Buffer) + // Has compressor, check Compressor is set by UseCompressor first. + if compressor != nil { + z, _ := compressor.Compress(cbuf) + if _, err := z.Write(b); err != nil { + return nil, nil, Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error()) + } + z.Close() + } else { + // If Compressor is not set by UseCompressor, use default Compressor + if err := cp.Do(cbuf, b); err != nil { + return nil, nil, Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error()) + } } b = cbuf.Bytes() } } - if uint(len(b)) > math.MaxUint32 { return nil, nil, Errorf(codes.ResourceExhausted, "grpc: message too large (%d bytes)", len(b)) } bufHeader := make([]byte, payloadLen+sizeLen) - if cp == nil { - bufHeader[0] = byte(compressionNone) - } else { + if compressor != nil || cp != nil { bufHeader[0] = byte(compressionMade) + } else { + bufHeader[0] = byte(compressionNone) } + // Write length of b into buf binary.BigEndian.PutUint32(bufHeader[payloadLen:], uint32(len(b))) if outPayload != nil { @@ -343,7 +359,7 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, dc Decompressor) er switch pf { case compressionNone: case compressionMade: - if dc == nil || recvCompress != dc.Type() { + if (dc == nil || recvCompress != dc.Type()) && encoding.GetCompressor(recvCompress) == nil { return Errorf(codes.Unimplemented, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress) } default: @@ -352,7 +368,9 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, dc Decompressor) er return nil } -func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{}, maxReceiveMessageSize int, inPayload *stats.InPayload) error { +// TODO(ddyihai): eliminate extra Compressor parameter. +func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{}, maxReceiveMessageSize int, + inPayload *stats.InPayload, compressor encoding.Compressor) error { pf, d, err := p.recvMsg(maxReceiveMessageSize) if err != nil { return err @@ -364,9 +382,22 @@ func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{ return err } if pf == compressionMade { - d, err = dc.Do(bytes.NewReader(d)) - if err != nil { - return Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err) + // To match legacy behavior, if the decompressor is set by WithDecompressor or RPCDecompressor, + // use this decompressor as the default. + if dc != nil { + d, err = dc.Do(bytes.NewReader(d)) + if err != nil { + return Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err) + } + } else { + dcReader, err := compressor.Decompress(bytes.NewReader(d)) + if err != nil { + return Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err) + } + d, err = ioutil.ReadAll(dcReader) + if err != nil { + return Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err) + } } } if len(d) > maxReceiveMessageSize { diff --git a/server.go b/server.go index def301a1..2c26db8c 100644 --- a/server.go +++ b/server.go @@ -32,11 +32,14 @@ import ( "sync" "time" + "io/ioutil" + "golang.org/x/net/context" "golang.org/x/net/http2" "golang.org/x/net/trace" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" + "google.golang.org/grpc/encoding" "google.golang.org/grpc/grpclog" "google.golang.org/grpc/internal" "google.golang.org/grpc/keepalive" @@ -187,6 +190,8 @@ func CustomCodec(codec Codec) ServerOption { } // RPCCompressor returns a ServerOption that sets a compressor for outbound messages. +// It has lower priority than the compressor set by RegisterCompressor. +// This function is deprecated. func RPCCompressor(cp Compressor) ServerOption { return func(o *options) { o.cp = cp @@ -194,6 +199,8 @@ func RPCCompressor(cp Compressor) ServerOption { } // RPCDecompressor returns a ServerOption that sets a decompressor for inbound messages. +// It has higher priority than the decompressor set by RegisterCompressor. +// This function is deprecated. func RPCDecompressor(dc Decompressor) ServerOption { return func(o *options) { o.dc = dc @@ -701,16 +708,18 @@ func (s *Server) removeConn(c io.Closer) { func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Stream, msg interface{}, cp Compressor, opts *transport.Options) error { var ( - cbuf *bytes.Buffer outPayload *stats.OutPayload ) - if cp != nil { - cbuf = new(bytes.Buffer) - } if s.opts.statsHandler != nil { outPayload = &stats.OutPayload{} } - hdr, data, err := encode(s.opts.codec, msg, cp, cbuf, outPayload) + if stream.RecvCompress() != "" { + // Server receives compressor, check compressor set by register and default. + if encoding.GetCompressor(stream.RecvCompress()) == nil && (cp == nil || cp != nil && cp.Type() != stream.RecvCompress()) { + return Errorf(codes.Internal, "grpc: Compressor is not installed for grpc-encoding %q", stream.RecvCompress()) + } + } + hdr, data, err := encode(s.opts.codec, msg, cp, outPayload, encoding.GetCompressor(stream.RecvCompress())) if err != nil { grpclog.Errorln("grpc: server failed to encode response: ", err) return err @@ -754,7 +763,9 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } }() } - if s.opts.cp != nil { + if stream.RecvCompress() != "" { + stream.SetSendCompress(stream.RecvCompress()) + } else if s.opts.cp != nil { // NOTE: this needs to be ahead of all handling, https://github.com/grpc/grpc-go/issues/686. stream.SetSendCompress(s.opts.cp.Type()) } @@ -786,7 +797,6 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } return err } - if err := checkRecvPayload(pf, stream.RecvCompress(), s.opts.dc); err != nil { if st, ok := status.FromError(err); ok { if e := t.WriteStatus(stream, st); e != nil { @@ -812,9 +822,18 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } if pf == compressionMade { var err error - req, err = s.opts.dc.Do(bytes.NewReader(req)) - if err != nil { - return Errorf(codes.Internal, err.Error()) + if s.opts.dc != nil { + req, err = s.opts.dc.Do(bytes.NewReader(req)) + if err != nil { + return Errorf(codes.Internal, err.Error()) + } + } else { + dcReader := encoding.GetCompressor(stream.RecvCompress()) + tmp, _ := dcReader.Decompress(bytes.NewReader(req)) + req, err = ioutil.ReadAll(tmp) + if err != nil { + return Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err) + } } } if len(req) > s.opts.maxReceiveMessageSize { @@ -909,16 +928,19 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp sh.HandleRPC(stream.Context(), end) }() } - if s.opts.cp != nil { + if stream.RecvCompress() != "" { + stream.SetSendCompress(stream.RecvCompress()) + } else if s.opts.cp != nil { stream.SetSendCompress(s.opts.cp.Type()) } ss := &serverStream{ - t: t, - s: stream, - p: &parser{r: stream}, - codec: s.opts.codec, - cp: s.opts.cp, - dc: s.opts.dc, + t: t, + s: stream, + p: &parser{r: stream}, + codec: s.opts.codec, + cpType: stream.RecvCompress(), + cp: s.opts.cp, + dc: s.opts.dc, maxReceiveMessageSize: s.opts.maxReceiveMessageSize, maxSendMessageSize: s.opts.maxSendMessageSize, trInfo: trInfo, diff --git a/stream.go b/stream.go index b58f7f8d..a659f14e 100644 --- a/stream.go +++ b/stream.go @@ -19,7 +19,6 @@ package grpc import ( - "bytes" "errors" "io" "sync" @@ -29,6 +28,7 @@ import ( "golang.org/x/net/trace" "google.golang.org/grpc/balancer" "google.golang.org/grpc/codes" + "google.golang.org/grpc/encoding" "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" "google.golang.org/grpc/stats" @@ -151,7 +151,9 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth // time soon, so we ask the transport to flush the header. Flush: desc.ClientStreams, } - if cc.dopts.cp != nil { + if c.compressorType != "" { + callHdr.SendCompress = c.compressorType + } else if cc.dopts.cp != nil { callHdr.SendCompress = cc.dopts.cp.Type() } if c.creds != nil { @@ -242,6 +244,7 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth c: c, desc: desc, codec: cc.dopts.codec, + cpType: c.compressorType, cp: cc.dopts.cp, dc: cc.dopts.dc, cancel: cancel, @@ -292,6 +295,7 @@ type clientStream struct { p *parser desc *StreamDesc codec Codec + cpType string cp Compressor dc Decompressor cancel context.CancelFunc @@ -369,7 +373,10 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) { Client: true, } } - hdr, data, err := encode(cs.codec, m, cs.cp, bytes.NewBuffer([]byte{}), outPayload) + if cs.cpType != "" && encoding.GetCompressor(cs.cpType) == nil { + return Errorf(codes.Internal, "grpc: Compressor is not installed for grpc-encoding %q", cs.cpType) + } + hdr, data, err := encode(cs.codec, m, cs.cp, outPayload, encoding.GetCompressor(cs.cpType)) if err != nil { return err } @@ -397,7 +404,7 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) { if cs.c.maxReceiveMessageSize == nil { return Errorf(codes.Internal, "callInfo maxReceiveMessageSize field uninitialized(nil)") } - err = recv(cs.p, cs.codec, cs.s, cs.dc, m, *cs.c.maxReceiveMessageSize, inPayload) + err = recv(cs.p, cs.codec, cs.s, cs.dc, m, *cs.c.maxReceiveMessageSize, inPayload, encoding.GetCompressor(cs.cpType)) defer func() { // err != nil indicates the termination of the stream. if err != nil { @@ -423,7 +430,7 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) { if cs.c.maxReceiveMessageSize == nil { return Errorf(codes.Internal, "callInfo maxReceiveMessageSize field uninitialized(nil)") } - err = recv(cs.p, cs.codec, cs.s, cs.dc, m, *cs.c.maxReceiveMessageSize, nil) + err = recv(cs.p, cs.codec, cs.s, cs.dc, m, *cs.c.maxReceiveMessageSize, nil, encoding.GetCompressor(cs.cpType)) cs.closeTransportStream(err) if err == nil { return toRPCErr(errors.New("grpc: client streaming protocol violation: get , want ")) @@ -552,6 +559,7 @@ type serverStream struct { s *transport.Stream p *parser codec Codec + cpType string cp Compressor dc Decompressor maxReceiveMessageSize int @@ -609,7 +617,12 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) { if ss.statsHandler != nil { outPayload = &stats.OutPayload{} } - hdr, data, err := encode(ss.codec, m, ss.cp, bytes.NewBuffer([]byte{}), outPayload) + if ss.cpType != "" { + if encoding.GetCompressor(ss.cpType) == nil && (ss.cp == nil || ss.cp != nil && ss.cp.Type() != ss.cpType) { + return Errorf(codes.Internal, "grpc: Compressor is not installed for grpc-encoding %q", ss.cpType) + } + } + hdr, data, err := encode(ss.codec, m, ss.cp, outPayload, encoding.GetCompressor(ss.cpType)) if err != nil { return err } @@ -649,7 +662,7 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) { if ss.statsHandler != nil { inPayload = &stats.InPayload{} } - if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxReceiveMessageSize, inPayload); err != nil { + if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxReceiveMessageSize, inPayload, encoding.GetCompressor(ss.cpType)); err != nil { if err == io.EOF { return err } diff --git a/test/end2end_test.go b/test/end2end_test.go index a5e81a99..1518dcab 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -51,6 +51,7 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials" + _ "google.golang.org/grpc/encoding/gzip" _ "google.golang.org/grpc/grpclog/glogger" "google.golang.org/grpc/health" healthpb "google.golang.org/grpc/health/grpc_health_v1" @@ -437,18 +438,24 @@ type test struct { cancel context.CancelFunc // Configurable knobs, after newTest returns: - testServer testpb.TestServiceServer // nil means none - healthServer *health.Server // nil means disabled - maxStream uint32 - tapHandle tap.ServerInHandle - maxMsgSize *int - maxClientReceiveMsgSize *int - maxClientSendMsgSize *int - maxServerReceiveMsgSize *int - maxServerSendMsgSize *int - userAgent string - clientCompression bool - serverCompression bool + testServer testpb.TestServiceServer // nil means none + healthServer *health.Server // nil means disabled + maxStream uint32 + tapHandle tap.ServerInHandle + maxMsgSize *int + maxClientReceiveMsgSize *int + maxClientSendMsgSize *int + maxServerReceiveMsgSize *int + maxServerSendMsgSize *int + userAgent string + // clientCompression and serverCompression are set to test the deprecated API + // WithCompressor and WithDecompressor. + clientCompression bool + serverCompression bool + // clientUseCompression is set to test the new compressor registration API UseCompressor. + clientUseCompression bool + // clientNopCompression is set to create a compressor whose type is not supported. + clientNopCompression bool unaryClientInt grpc.UnaryClientInterceptor streamClientInt grpc.StreamClientInterceptor unaryServerInt grpc.UnaryServerInterceptor @@ -594,6 +601,32 @@ func (te *test) startServer(ts testpb.TestServiceServer) { te.srvAddr = addr } +type nopCompressor struct { + grpc.Compressor +} + +// NewNopCompressor creates a compressor to test the case that type is not supported. +func NewNopCompressor() grpc.Compressor { + return &nopCompressor{grpc.NewGZIPCompressor()} +} + +func (c *nopCompressor) Type() string { + return "nop" +} + +type nopDecompressor struct { + grpc.Decompressor +} + +// NewNopDecompressor creates a decompressor to test the case that type is not supported. +func NewNopDecompressor() grpc.Decompressor { + return &nopDecompressor{grpc.NewGZIPDecompressor()} +} + +func (d *nopDecompressor) Type() string { + return "nop" +} + func (te *test) clientConn() *grpc.ClientConn { if te.cc != nil { return te.cc @@ -613,6 +646,15 @@ func (te *test) clientConn() *grpc.ClientConn { grpc.WithDecompressor(grpc.NewGZIPDecompressor()), ) } + if te.clientUseCompression { + opts = append(opts, grpc.WithDefaultCallOptions(grpc.UseCompressor("gzip"))) + } + if te.clientNopCompression { + opts = append(opts, + grpc.WithCompressor(NewNopCompressor()), + grpc.WithDecompressor(NewNopDecompressor()), + ) + } if te.unaryClientInt != nil { opts = append(opts, grpc.WithUnaryInterceptor(te.unaryClientInt)) } @@ -3749,7 +3791,8 @@ func TestCompressServerHasNoSupport(t *testing.T) { func testCompressServerHasNoSupport(t *testing.T, e env) { te := newTest(t, e) te.serverCompression = false - te.clientCompression = true + te.clientCompression = false + te.clientNopCompression = true te.startServer(&testServer{security: e.security}) defer te.tearDown() tc := testpb.NewTestServiceClient(te.clientConn()) @@ -5572,3 +5615,65 @@ func TestMethodFromServerStream(t *testing.T) { t.Fatalf("Invoke with method %q, got %q, %v, want %q, true", testMethod, method, ok, testMethod) } } + +func TestCompressorRegister(t *testing.T) { + defer leakcheck.Check(t) + for _, e := range listTestEnv() { + testCompressorRegister(t, e) + } +} + +func testCompressorRegister(t *testing.T, e env) { + te := newTest(t, e) + te.clientCompression = false + te.serverCompression = false + te.clientUseCompression = true + + te.startServer(&testServer{security: e.security}) + defer te.tearDown() + tc := testpb.NewTestServiceClient(te.clientConn()) + + // Unary call + const argSize = 271828 + const respSize = 314159 + payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, argSize) + if err != nil { + t.Fatal(err) + } + req := &testpb.SimpleRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE, + ResponseSize: respSize, + Payload: payload, + } + ctx := metadata.NewOutgoingContext(context.Background(), metadata.Pairs("something", "something")) + if _, err := tc.UnaryCall(ctx, req); err != nil { + t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, ", err) + } + // Streaming RPC + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + stream, err := tc.FullDuplexCall(ctx) + if err != nil { + t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) + } + respParam := []*testpb.ResponseParameters{ + { + Size: 31415, + }, + } + payload, err = newPayload(testpb.PayloadType_COMPRESSABLE, int32(31415)) + if err != nil { + t.Fatal(err) + } + sreq := &testpb.StreamingOutputCallRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE, + ResponseParameters: respParam, + Payload: payload, + } + if err := stream.Send(sreq); err != nil { + t.Fatalf("%v.Send(%v) = %v, want ", stream, sreq, err) + } + if _, err := stream.Recv(); err != nil { + t.Fatalf("%v.Recv() = %v, want ", stream, err) + } +}