From af8888dc8d774997f0173bee508daa355dd40e18 Mon Sep 17 00:00:00 2001 From: iamqizhao Date: Fri, 29 Jan 2016 14:38:20 -0800 Subject: [PATCH] remove Compressor/DecompressorGenerator --- call.go | 12 ++++-------- clientconn.go | 12 ++++++------ rpc_util.go | 12 +----------- server.go | 40 ++++++++++++++++------------------------ stream.go | 26 +++++++++++--------------- test/end2end_test.go | 16 ++++++++-------- 6 files changed, 46 insertions(+), 72 deletions(-) diff --git a/call.go b/call.go index f29396ae..d4ae68be 100644 --- a/call.go +++ b/call.go @@ -57,7 +57,7 @@ func recvResponse(dopts dialOptions, t transport.ClientTransport, c *callInfo, s } p := &parser{s: stream} for { - if err = recv(p, dopts.codec, stream, dopts.dg, reply); err != nil { + if err = recv(p, dopts.codec, stream, dopts.dc, reply); err != nil { if err == io.EOF { break } @@ -133,11 +133,7 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli } var ( lastErr error // record the error that happened - cp Compressor ) - if cc.dopts.cg != nil { - cp = cc.dopts.cg() - } for { var ( err error @@ -152,8 +148,8 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli Host: cc.authority, Method: method, } - if cp != nil { - callHdr.SendCompress = cp.Type() + if cc.dopts.cp != nil { + callHdr.SendCompress = cc.dopts.cp.Type() } t, err = cc.dopts.picker.Pick(ctx) if err != nil { @@ -166,7 +162,7 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli if c.traceInfo.tr != nil { c.traceInfo.tr.LazyLog(&payload{sent: true, msg: args}, true) } - stream, err = sendRequest(ctx, cc.dopts.codec, cp, callHdr, t, args, topts) + stream, err = sendRequest(ctx, cc.dopts.codec, cc.dopts.cp, callHdr, t, args, topts) if err != nil { if _, ok := err.(transport.ConnectionError); ok { lastErr = err diff --git a/clientconn.go b/clientconn.go index 038ed884..28e74da8 100644 --- a/clientconn.go +++ b/clientconn.go @@ -73,8 +73,8 @@ var ( // values passed to Dial. type dialOptions struct { codec Codec - cg CompressorGenerator - dg DecompressorGenerator + cp Compressor + dc Decompressor picker Picker block bool insecure bool @@ -93,17 +93,17 @@ func WithCodec(c Codec) DialOption { // WithCompressor returns a DialOption which sets a CompressorGenerator for generating message // compressor. -func WithCompressor(f CompressorGenerator) DialOption { +func WithCompressor(cp Compressor) DialOption { return func(o *dialOptions) { - o.cg = f + o.cp = cp } } // WithDecompressor returns a DialOption which sets a DecompressorGenerator for generating // message decompressor. -func WithDecompressor(f DecompressorGenerator) DialOption { +func WithDecompressor(dc Decompressor) DialOption { return func(o *dialOptions) { - o.dg = f + o.dc = dc } } diff --git a/rpc_util.go b/rpc_util.go index 427b49e0..e98ddbcd 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -135,12 +135,6 @@ func (d *gzipDecompressor) Type() string { return "gzip" } -// CompressorGenerator defines the function generating a Compressor. -type CompressorGenerator func() Compressor - -// DecompressorGenerator defines the function generating a Decompressor. -type DecompressorGenerator func() Decompressor - // callInfo contains all related configuration and information about an RPC. type callInfo struct { failFast bool @@ -290,15 +284,11 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, dc Decompressor) er return nil } -func recv(p *parser, c Codec, s *transport.Stream, dg DecompressorGenerator, m interface{}) error { +func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{}) error { pf, d, err := p.recvMsg() if err != nil { return err } - var dc Decompressor - if pf == compressionMade && dg != nil { - dc = dg() - } if err := checkRecvPayload(pf, s.RecvCompress(), dc); err != nil { return err } diff --git a/server.go b/server.go index dd864275..904a66a7 100644 --- a/server.go +++ b/server.go @@ -93,8 +93,8 @@ type Server struct { type options struct { creds credentials.Credentials codec Codec - cg CompressorGenerator - dg DecompressorGenerator + cp Compressor + dc Decompressor maxConcurrentStreams uint32 } @@ -108,15 +108,15 @@ func CustomCodec(codec Codec) ServerOption { } } -func CompressON(f CompressorGenerator) ServerOption { +func RPCCompressor(cp Compressor) ServerOption { return func(o *options) { - o.cg = f + o.cp = cp } } -func DecompressON(f DecompressorGenerator) ServerOption { +func RPCDecompressor(dc Decompressor) ServerOption { return func(o *options) { - o.dg = f + o.dc = dc } } @@ -362,11 +362,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. return err } - var dc Decompressor - if pf == compressionMade && s.opts.dg != nil { - dc = s.opts.dg() - } - if err := checkRecvPayload(pf, stream.RecvCompress(), dc); err != nil { + if err := checkRecvPayload(pf, stream.RecvCompress(), s.opts.dc); err != nil { switch err := err.(type) { case transport.StreamError: if err := t.WriteStatus(stream, err.Code, err.Desc); err != nil { @@ -385,7 +381,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. df := func(v interface{}) error { if pf == compressionMade { var err error - req, err = dc.Do(bytes.NewReader(req)) + req, err = s.opts.dc.Do(bytes.NewReader(req)) if err != nil { if err := t.WriteStatus(stream, codes.Internal, err.Error()); err != nil { grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", err) @@ -427,12 +423,10 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. Last: true, Delay: false, } - var cp Compressor - if s.opts.cg != nil { - cp = s.opts.cg() - stream.SetSendCompress(cp.Type()) + if s.opts.cp != nil { + stream.SetSendCompress(s.opts.cp.Type()) } - if err := s.sendResponse(t, stream, reply, cp, opts); err != nil { + if err := s.sendResponse(t, stream, reply, s.opts.cp, opts); err != nil { switch err := err.(type) { case transport.ConnectionError: // Nothing to do here. @@ -453,21 +447,19 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, sd *StreamDesc, trInfo *traceInfo) (err error) { - var cp Compressor - if s.opts.cg != nil { - cp = s.opts.cg() - stream.SetSendCompress(cp.Type()) + if s.opts.cp != nil { + stream.SetSendCompress(s.opts.cp.Type()) } ss := &serverStream{ t: t, s: stream, p: &parser{s: stream}, codec: s.opts.codec, - cp: cp, - dg: s.opts.dg, + cp: s.opts.cp, + dc: s.opts.dc, trInfo: trInfo, } - if cp != nil { + if ss.cp != nil { ss.cbuf = new(bytes.Buffer) } if trInfo != nil { diff --git a/stream.go b/stream.go index e649c4c1..4974d8a8 100644 --- a/stream.go +++ b/stream.go @@ -105,28 +105,24 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth if err != nil { return nil, toRPCErr(err) } - var cp Compressor - if cc.dopts.cg != nil { - cp = cc.dopts.cg() - } // TODO(zhaoq): CallOption is omitted. Add support when it is needed. callHdr := &transport.CallHdr{ Host: cc.authority, Method: method, Flush: desc.ServerStreams&&desc.ClientStreams, } - if cp != nil { - callHdr.SendCompress = cp.Type() + if cc.dopts.cp != nil { + callHdr.SendCompress = cc.dopts.cp.Type() } cs := &clientStream{ desc: desc, codec: cc.dopts.codec, - cp: cp, - dg: cc.dopts.dg, + cp: cc.dopts.cp, + dc: cc.dopts.dc, tracing: EnableTracing, } - if cp != nil { - callHdr.SendCompress = cp.Type() + if cc.dopts.cp != nil { + callHdr.SendCompress = cc.dopts.cp.Type() cs.cbuf = new(bytes.Buffer) } if cs.tracing { @@ -170,7 +166,7 @@ type clientStream struct { codec Codec cp Compressor cbuf *bytes.Buffer - dg DecompressorGenerator + dc Decompressor tracing bool // set to EnableTracing when the clientStream is created. @@ -229,7 +225,7 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) { } func (cs *clientStream) RecvMsg(m interface{}) (err error) { - err = recv(cs.p, cs.codec, cs.s, cs.dg, m) + err = recv(cs.p, cs.codec, cs.s, cs.dc, m) defer func() { // err != nil indicates the termination of the stream. if err != nil { @@ -248,7 +244,7 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) { return } // Special handling for client streaming rpc. - err = recv(cs.p, cs.codec, cs.s, cs.dg, m) + err = recv(cs.p, cs.codec, cs.s, cs.dc, m) cs.closeTransportStream(err) if err == nil { return toRPCErr(errors.New("grpc: client streaming protocol violation: get , want ")) @@ -334,7 +330,7 @@ type serverStream struct { p *parser codec Codec cp Compressor - dg DecompressorGenerator + dc Decompressor cbuf *bytes.Buffer statusCode codes.Code statusDesc string @@ -402,5 +398,5 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) { ss.mu.Unlock() } }() - return recv(ss.p, ss.codec, ss.s, ss.dg, m) + return recv(ss.p, ss.codec, ss.s, ss.dc, m) } diff --git a/test/end2end_test.go b/test/end2end_test.go index 82f83735..d0e4ea2c 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -327,8 +327,8 @@ func listTestEnv() []env { return []env{{"tcp", nil, ""}, {"tcp", nil, "tls"}, {"unix", unixDialer, ""}, {"unix", unixDialer, "tls"}} } -func serverSetUp(t *testing.T, servON bool, hs *health.HealthServer, maxStream uint32, cg grpc.CompressorGenerator, dg grpc.DecompressorGenerator, e env) (s *grpc.Server, addr string) { - sopts := []grpc.ServerOption{grpc.MaxConcurrentStreams(maxStream), grpc.CompressON(cg), grpc.DecompressON(dg)} +func serverSetUp(t *testing.T, servON bool, hs *health.HealthServer, maxStream uint32, cp grpc.Compressor, dc grpc.Decompressor, e env) (s *grpc.Server, addr string) { + sopts := []grpc.ServerOption{grpc.MaxConcurrentStreams(maxStream), grpc.RPCCompressor(cp), grpc.RPCDecompressor(dc)} la := ":0" switch e.network { case "unix": @@ -367,16 +367,16 @@ func serverSetUp(t *testing.T, servON bool, hs *health.HealthServer, maxStream u return } -func clientSetUp(t *testing.T, addr string, cg grpc.CompressorGenerator, dg grpc.DecompressorGenerator, ua string, e env) (cc *grpc.ClientConn) { +func clientSetUp(t *testing.T, addr string, cp grpc.Compressor, dc grpc.Decompressor, ua string, e env) (cc *grpc.ClientConn) { var derr error if e.security == "tls" { creds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "x.test.youtube.com") if err != nil { t.Fatalf("Failed to create credentials %v", err) } - cc, derr = grpc.Dial(addr, grpc.WithTransportCredentials(creds), grpc.WithDialer(e.dialer), grpc.WithUserAgent(ua), grpc.WithCompressor(cg), grpc.WithDecompressor(dg)) + cc, derr = grpc.Dial(addr, grpc.WithTransportCredentials(creds), grpc.WithDialer(e.dialer), grpc.WithUserAgent(ua), grpc.WithCompressor(cp), grpc.WithDecompressor(dc)) } else { - cc, derr = grpc.Dial(addr, grpc.WithDialer(e.dialer), grpc.WithInsecure(), grpc.WithUserAgent(ua), grpc.WithCompressor(cg), grpc.WithDecompressor(dg)) + cc, derr = grpc.Dial(addr, grpc.WithDialer(e.dialer), grpc.WithInsecure(), grpc.WithUserAgent(ua), grpc.WithCompressor(cp), grpc.WithDecompressor(dc)) } if derr != nil { t.Fatalf("Dial(%q) = %v", addr, derr) @@ -1151,7 +1151,7 @@ func TestCompressServerHasNoSupport(t *testing.T) { func testCompressServerHasNoSupport(t *testing.T, e env) { s, addr := serverSetUp(t, true, nil, math.MaxUint32, nil, nil, e) - cc := clientSetUp(t, addr, grpc.NewGZIPCompressor, nil, "", e) + cc := clientSetUp(t, addr, grpc.NewGZIPCompressor(), nil, "", e) // Unary call tc := testpb.NewTestServiceClient(cc) defer tearDown(s, cc) @@ -1203,8 +1203,8 @@ func TestCompressOK(t *testing.T) { } func testCompressOK(t *testing.T, e env) { - s, addr := serverSetUp(t, true, nil, math.MaxUint32, grpc.NewGZIPCompressor, grpc.NewGZIPDecompressor, e) - cc := clientSetUp(t, addr, grpc.NewGZIPCompressor, grpc.NewGZIPDecompressor, "", e) + s, addr := serverSetUp(t, true, nil, math.MaxUint32, grpc.NewGZIPCompressor(), grpc.NewGZIPDecompressor(), e) + cc := clientSetUp(t, addr, grpc.NewGZIPCompressor(), grpc.NewGZIPDecompressor(), "", e) // Unary call tc := testpb.NewTestServiceClient(cc) defer tearDown(s, cc)