remove Compressor/DecompressorGenerator

This commit is contained in:
iamqizhao
2016-01-29 14:38:20 -08:00
parent 3f30c980d6
commit af8888dc8d
6 changed files with 46 additions and 72 deletions

12
call.go
View File

@ -57,7 +57,7 @@ func recvResponse(dopts dialOptions, t transport.ClientTransport, c *callInfo, s
} }
p := &parser{s: stream} p := &parser{s: stream}
for { for {
if err = recv(p, dopts.codec, stream, dopts.dg, reply); err != nil { if err = recv(p, dopts.codec, stream, dopts.dc, reply); err != nil {
if err == io.EOF { if err == io.EOF {
break break
} }
@ -133,11 +133,7 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
} }
var ( var (
lastErr error // record the error that happened lastErr error // record the error that happened
cp Compressor
) )
if cc.dopts.cg != nil {
cp = cc.dopts.cg()
}
for { for {
var ( var (
err error err error
@ -152,8 +148,8 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
Host: cc.authority, Host: cc.authority,
Method: method, Method: method,
} }
if cp != nil { if cc.dopts.cp != nil {
callHdr.SendCompress = cp.Type() callHdr.SendCompress = cc.dopts.cp.Type()
} }
t, err = cc.dopts.picker.Pick(ctx) t, err = cc.dopts.picker.Pick(ctx)
if err != nil { if err != nil {
@ -166,7 +162,7 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
if c.traceInfo.tr != nil { if c.traceInfo.tr != nil {
c.traceInfo.tr.LazyLog(&payload{sent: true, msg: args}, true) c.traceInfo.tr.LazyLog(&payload{sent: true, msg: args}, true)
} }
stream, err = sendRequest(ctx, cc.dopts.codec, cp, callHdr, t, args, topts) stream, err = sendRequest(ctx, cc.dopts.codec, cc.dopts.cp, callHdr, t, args, topts)
if err != nil { if err != nil {
if _, ok := err.(transport.ConnectionError); ok { if _, ok := err.(transport.ConnectionError); ok {
lastErr = err lastErr = err

View File

@ -73,8 +73,8 @@ var (
// values passed to Dial. // values passed to Dial.
type dialOptions struct { type dialOptions struct {
codec Codec codec Codec
cg CompressorGenerator cp Compressor
dg DecompressorGenerator dc Decompressor
picker Picker picker Picker
block bool block bool
insecure bool insecure bool
@ -93,17 +93,17 @@ func WithCodec(c Codec) DialOption {
// WithCompressor returns a DialOption which sets a CompressorGenerator for generating message // WithCompressor returns a DialOption which sets a CompressorGenerator for generating message
// compressor. // compressor.
func WithCompressor(f CompressorGenerator) DialOption { func WithCompressor(cp Compressor) DialOption {
return func(o *dialOptions) { return func(o *dialOptions) {
o.cg = f o.cp = cp
} }
} }
// WithDecompressor returns a DialOption which sets a DecompressorGenerator for generating // WithDecompressor returns a DialOption which sets a DecompressorGenerator for generating
// message decompressor. // message decompressor.
func WithDecompressor(f DecompressorGenerator) DialOption { func WithDecompressor(dc Decompressor) DialOption {
return func(o *dialOptions) { return func(o *dialOptions) {
o.dg = f o.dc = dc
} }
} }

View File

@ -135,12 +135,6 @@ func (d *gzipDecompressor) Type() string {
return "gzip" return "gzip"
} }
// CompressorGenerator defines the function generating a Compressor.
type CompressorGenerator func() Compressor
// DecompressorGenerator defines the function generating a Decompressor.
type DecompressorGenerator func() Decompressor
// callInfo contains all related configuration and information about an RPC. // callInfo contains all related configuration and information about an RPC.
type callInfo struct { type callInfo struct {
failFast bool failFast bool
@ -290,15 +284,11 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, dc Decompressor) er
return nil return nil
} }
func recv(p *parser, c Codec, s *transport.Stream, dg DecompressorGenerator, m interface{}) error { func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{}) error {
pf, d, err := p.recvMsg() pf, d, err := p.recvMsg()
if err != nil { if err != nil {
return err return err
} }
var dc Decompressor
if pf == compressionMade && dg != nil {
dc = dg()
}
if err := checkRecvPayload(pf, s.RecvCompress(), dc); err != nil { if err := checkRecvPayload(pf, s.RecvCompress(), dc); err != nil {
return err return err
} }

View File

@ -93,8 +93,8 @@ type Server struct {
type options struct { type options struct {
creds credentials.Credentials creds credentials.Credentials
codec Codec codec Codec
cg CompressorGenerator cp Compressor
dg DecompressorGenerator dc Decompressor
maxConcurrentStreams uint32 maxConcurrentStreams uint32
} }
@ -108,15 +108,15 @@ func CustomCodec(codec Codec) ServerOption {
} }
} }
func CompressON(f CompressorGenerator) ServerOption { func RPCCompressor(cp Compressor) ServerOption {
return func(o *options) { return func(o *options) {
o.cg = f o.cp = cp
} }
} }
func DecompressON(f DecompressorGenerator) ServerOption { func RPCDecompressor(dc Decompressor) ServerOption {
return func(o *options) { return func(o *options) {
o.dg = f o.dc = dc
} }
} }
@ -362,11 +362,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
return err return err
} }
var dc Decompressor if err := checkRecvPayload(pf, stream.RecvCompress(), s.opts.dc); err != nil {
if pf == compressionMade && s.opts.dg != nil {
dc = s.opts.dg()
}
if err := checkRecvPayload(pf, stream.RecvCompress(), dc); err != nil {
switch err := err.(type) { switch err := err.(type) {
case transport.StreamError: case transport.StreamError:
if err := t.WriteStatus(stream, err.Code, err.Desc); err != nil { if err := t.WriteStatus(stream, err.Code, err.Desc); err != nil {
@ -385,7 +381,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
df := func(v interface{}) error { df := func(v interface{}) error {
if pf == compressionMade { if pf == compressionMade {
var err error var err error
req, err = dc.Do(bytes.NewReader(req)) req, err = s.opts.dc.Do(bytes.NewReader(req))
if err != nil { if err != nil {
if err := t.WriteStatus(stream, codes.Internal, err.Error()); err != nil { if err := t.WriteStatus(stream, codes.Internal, err.Error()); err != nil {
grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", err) grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", err)
@ -427,12 +423,10 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
Last: true, Last: true,
Delay: false, Delay: false,
} }
var cp Compressor if s.opts.cp != nil {
if s.opts.cg != nil { stream.SetSendCompress(s.opts.cp.Type())
cp = s.opts.cg()
stream.SetSendCompress(cp.Type())
} }
if err := s.sendResponse(t, stream, reply, cp, opts); err != nil { if err := s.sendResponse(t, stream, reply, s.opts.cp, opts); err != nil {
switch err := err.(type) { switch err := err.(type) {
case transport.ConnectionError: case transport.ConnectionError:
// Nothing to do here. // Nothing to do here.
@ -453,21 +447,19 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
} }
func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, sd *StreamDesc, trInfo *traceInfo) (err error) { func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, sd *StreamDesc, trInfo *traceInfo) (err error) {
var cp Compressor if s.opts.cp != nil {
if s.opts.cg != nil { stream.SetSendCompress(s.opts.cp.Type())
cp = s.opts.cg()
stream.SetSendCompress(cp.Type())
} }
ss := &serverStream{ ss := &serverStream{
t: t, t: t,
s: stream, s: stream,
p: &parser{s: stream}, p: &parser{s: stream},
codec: s.opts.codec, codec: s.opts.codec,
cp: cp, cp: s.opts.cp,
dg: s.opts.dg, dc: s.opts.dc,
trInfo: trInfo, trInfo: trInfo,
} }
if cp != nil { if ss.cp != nil {
ss.cbuf = new(bytes.Buffer) ss.cbuf = new(bytes.Buffer)
} }
if trInfo != nil { if trInfo != nil {

View File

@ -105,28 +105,24 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
if err != nil { if err != nil {
return nil, toRPCErr(err) return nil, toRPCErr(err)
} }
var cp Compressor
if cc.dopts.cg != nil {
cp = cc.dopts.cg()
}
// TODO(zhaoq): CallOption is omitted. Add support when it is needed. // TODO(zhaoq): CallOption is omitted. Add support when it is needed.
callHdr := &transport.CallHdr{ callHdr := &transport.CallHdr{
Host: cc.authority, Host: cc.authority,
Method: method, Method: method,
Flush: desc.ServerStreams&&desc.ClientStreams, Flush: desc.ServerStreams&&desc.ClientStreams,
} }
if cp != nil { if cc.dopts.cp != nil {
callHdr.SendCompress = cp.Type() callHdr.SendCompress = cc.dopts.cp.Type()
} }
cs := &clientStream{ cs := &clientStream{
desc: desc, desc: desc,
codec: cc.dopts.codec, codec: cc.dopts.codec,
cp: cp, cp: cc.dopts.cp,
dg: cc.dopts.dg, dc: cc.dopts.dc,
tracing: EnableTracing, tracing: EnableTracing,
} }
if cp != nil { if cc.dopts.cp != nil {
callHdr.SendCompress = cp.Type() callHdr.SendCompress = cc.dopts.cp.Type()
cs.cbuf = new(bytes.Buffer) cs.cbuf = new(bytes.Buffer)
} }
if cs.tracing { if cs.tracing {
@ -170,7 +166,7 @@ type clientStream struct {
codec Codec codec Codec
cp Compressor cp Compressor
cbuf *bytes.Buffer cbuf *bytes.Buffer
dg DecompressorGenerator dc Decompressor
tracing bool // set to EnableTracing when the clientStream is created. tracing bool // set to EnableTracing when the clientStream is created.
@ -229,7 +225,7 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {
} }
func (cs *clientStream) RecvMsg(m interface{}) (err error) { func (cs *clientStream) RecvMsg(m interface{}) (err error) {
err = recv(cs.p, cs.codec, cs.s, cs.dg, m) err = recv(cs.p, cs.codec, cs.s, cs.dc, m)
defer func() { defer func() {
// err != nil indicates the termination of the stream. // err != nil indicates the termination of the stream.
if err != nil { if err != nil {
@ -248,7 +244,7 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) {
return return
} }
// Special handling for client streaming rpc. // Special handling for client streaming rpc.
err = recv(cs.p, cs.codec, cs.s, cs.dg, m) err = recv(cs.p, cs.codec, cs.s, cs.dc, m)
cs.closeTransportStream(err) cs.closeTransportStream(err)
if err == nil { if err == nil {
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>")) return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
@ -334,7 +330,7 @@ type serverStream struct {
p *parser p *parser
codec Codec codec Codec
cp Compressor cp Compressor
dg DecompressorGenerator dc Decompressor
cbuf *bytes.Buffer cbuf *bytes.Buffer
statusCode codes.Code statusCode codes.Code
statusDesc string statusDesc string
@ -402,5 +398,5 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) {
ss.mu.Unlock() ss.mu.Unlock()
} }
}() }()
return recv(ss.p, ss.codec, ss.s, ss.dg, m) return recv(ss.p, ss.codec, ss.s, ss.dc, m)
} }

View File

@ -327,8 +327,8 @@ func listTestEnv() []env {
return []env{{"tcp", nil, ""}, {"tcp", nil, "tls"}, {"unix", unixDialer, ""}, {"unix", unixDialer, "tls"}} return []env{{"tcp", nil, ""}, {"tcp", nil, "tls"}, {"unix", unixDialer, ""}, {"unix", unixDialer, "tls"}}
} }
func serverSetUp(t *testing.T, servON bool, hs *health.HealthServer, maxStream uint32, cg grpc.CompressorGenerator, dg grpc.DecompressorGenerator, e env) (s *grpc.Server, addr string) { func serverSetUp(t *testing.T, servON bool, hs *health.HealthServer, maxStream uint32, cp grpc.Compressor, dc grpc.Decompressor, e env) (s *grpc.Server, addr string) {
sopts := []grpc.ServerOption{grpc.MaxConcurrentStreams(maxStream), grpc.CompressON(cg), grpc.DecompressON(dg)} sopts := []grpc.ServerOption{grpc.MaxConcurrentStreams(maxStream), grpc.RPCCompressor(cp), grpc.RPCDecompressor(dc)}
la := ":0" la := ":0"
switch e.network { switch e.network {
case "unix": case "unix":
@ -367,16 +367,16 @@ func serverSetUp(t *testing.T, servON bool, hs *health.HealthServer, maxStream u
return return
} }
func clientSetUp(t *testing.T, addr string, cg grpc.CompressorGenerator, dg grpc.DecompressorGenerator, ua string, e env) (cc *grpc.ClientConn) { func clientSetUp(t *testing.T, addr string, cp grpc.Compressor, dc grpc.Decompressor, ua string, e env) (cc *grpc.ClientConn) {
var derr error var derr error
if e.security == "tls" { if e.security == "tls" {
creds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "x.test.youtube.com") creds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "x.test.youtube.com")
if err != nil { if err != nil {
t.Fatalf("Failed to create credentials %v", err) t.Fatalf("Failed to create credentials %v", err)
} }
cc, derr = grpc.Dial(addr, grpc.WithTransportCredentials(creds), grpc.WithDialer(e.dialer), grpc.WithUserAgent(ua), grpc.WithCompressor(cg), grpc.WithDecompressor(dg)) cc, derr = grpc.Dial(addr, grpc.WithTransportCredentials(creds), grpc.WithDialer(e.dialer), grpc.WithUserAgent(ua), grpc.WithCompressor(cp), grpc.WithDecompressor(dc))
} else { } else {
cc, derr = grpc.Dial(addr, grpc.WithDialer(e.dialer), grpc.WithInsecure(), grpc.WithUserAgent(ua), grpc.WithCompressor(cg), grpc.WithDecompressor(dg)) cc, derr = grpc.Dial(addr, grpc.WithDialer(e.dialer), grpc.WithInsecure(), grpc.WithUserAgent(ua), grpc.WithCompressor(cp), grpc.WithDecompressor(dc))
} }
if derr != nil { if derr != nil {
t.Fatalf("Dial(%q) = %v", addr, derr) t.Fatalf("Dial(%q) = %v", addr, derr)
@ -1151,7 +1151,7 @@ func TestCompressServerHasNoSupport(t *testing.T) {
func testCompressServerHasNoSupport(t *testing.T, e env) { func testCompressServerHasNoSupport(t *testing.T, e env) {
s, addr := serverSetUp(t, true, nil, math.MaxUint32, nil, nil, e) s, addr := serverSetUp(t, true, nil, math.MaxUint32, nil, nil, e)
cc := clientSetUp(t, addr, grpc.NewGZIPCompressor, nil, "", e) cc := clientSetUp(t, addr, grpc.NewGZIPCompressor(), nil, "", e)
// Unary call // Unary call
tc := testpb.NewTestServiceClient(cc) tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc) defer tearDown(s, cc)
@ -1203,8 +1203,8 @@ func TestCompressOK(t *testing.T) {
} }
func testCompressOK(t *testing.T, e env) { func testCompressOK(t *testing.T, e env) {
s, addr := serverSetUp(t, true, nil, math.MaxUint32, grpc.NewGZIPCompressor, grpc.NewGZIPDecompressor, e) s, addr := serverSetUp(t, true, nil, math.MaxUint32, grpc.NewGZIPCompressor(), grpc.NewGZIPDecompressor(), e)
cc := clientSetUp(t, addr, grpc.NewGZIPCompressor, grpc.NewGZIPDecompressor, "", e) cc := clientSetUp(t, addr, grpc.NewGZIPCompressor(), grpc.NewGZIPDecompressor(), "", e)
// Unary call // Unary call
tc := testpb.NewTestServiceClient(cc) tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc) defer tearDown(s, cc)