Support compression

This commit is contained in:
iamqizhao
2016-01-22 18:21:41 -08:00
parent 5da22b92e9
commit da3bb0c9f7
12 changed files with 481 additions and 120 deletions

21
call.go
View File

@ -34,6 +34,7 @@
package grpc package grpc
import ( import (
"bytes"
"io" "io"
"time" "time"
@ -47,7 +48,7 @@ import (
// On error, it returns the error and indicates whether the call should be retried. // On error, it returns the error and indicates whether the call should be retried.
// //
// TODO(zhaoq): Check whether the received message sequence is valid. // TODO(zhaoq): Check whether the received message sequence is valid.
func recvResponse(codec Codec, t transport.ClientTransport, c *callInfo, stream *transport.Stream, reply interface{}) error { func recvResponse(dopts dialOptions, t transport.ClientTransport, c *callInfo, stream *transport.Stream, reply interface{}) error {
// Try to acquire header metadata from the server if there is any. // Try to acquire header metadata from the server if there is any.
var err error var err error
c.headerMD, err = stream.Header() c.headerMD, err = stream.Header()
@ -56,7 +57,7 @@ func recvResponse(codec Codec, t transport.ClientTransport, c *callInfo, stream
} }
p := &parser{s: stream} p := &parser{s: stream}
for { for {
if err = recv(p, codec, reply); err != nil { if err = recv(p, dopts.codec, stream, dopts.dg, reply); err != nil {
if err == io.EOF { if err == io.EOF {
break break
} }
@ -68,7 +69,7 @@ func recvResponse(codec Codec, t transport.ClientTransport, c *callInfo, stream
} }
// sendRequest writes out various information of an RPC such as Context and Message. // sendRequest writes out various information of an RPC such as Context and Message.
func sendRequest(ctx context.Context, codec Codec, callHdr *transport.CallHdr, t transport.ClientTransport, args interface{}, opts *transport.Options) (_ *transport.Stream, err error) { func sendRequest(ctx context.Context, codec Codec, compressor Compressor, callHdr *transport.CallHdr, t transport.ClientTransport, args interface{}, opts *transport.Options) (_ *transport.Stream, err error) {
stream, err := t.NewStream(ctx, callHdr) stream, err := t.NewStream(ctx, callHdr)
if err != nil { if err != nil {
return nil, err return nil, err
@ -80,8 +81,7 @@ func sendRequest(ctx context.Context, codec Codec, callHdr *transport.CallHdr, t
} }
} }
}() }()
// TODO(zhaoq): Support compression. outBuf, err := encode(codec, args, compressor, new(bytes.Buffer))
outBuf, err := encode(codec, args, compressionNone)
if err != nil { if err != nil {
return nil, transport.StreamErrorf(codes.Internal, "grpc: %v", err) return nil, transport.StreamErrorf(codes.Internal, "grpc: %v", err)
} }
@ -129,7 +129,11 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
} }
var ( var (
lastErr error // record the error that happened lastErr error // record the error that happened
cp Compressor
) )
if cc.dopts.cg != nil {
cp = cc.dopts.cg()
}
for { for {
var ( var (
err error err error
@ -144,6 +148,9 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
Host: cc.authority, Host: cc.authority,
Method: method, Method: method,
} }
if cp != nil {
callHdr.SendCompress = cp.Type()
}
t, err = cc.dopts.picker.Pick(ctx) t, err = cc.dopts.picker.Pick(ctx)
if err != nil { if err != nil {
if lastErr != nil { if lastErr != nil {
@ -155,7 +162,7 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
if c.traceInfo.tr != nil { if c.traceInfo.tr != nil {
c.traceInfo.tr.LazyLog(&payload{sent: true, msg: args}, true) c.traceInfo.tr.LazyLog(&payload{sent: true, msg: args}, true)
} }
stream, err = sendRequest(ctx, cc.dopts.codec, callHdr, t, args, topts) stream, err = sendRequest(ctx, cc.dopts.codec, cp, callHdr, t, args, topts)
if err != nil { if err != nil {
if _, ok := err.(transport.ConnectionError); ok { if _, ok := err.(transport.ConnectionError); ok {
lastErr = err lastErr = err
@ -167,7 +174,7 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
return toRPCErr(err) return toRPCErr(err)
} }
// Receive the response // Receive the response
lastErr = recvResponse(cc.dopts.codec, t, &c, stream, reply) lastErr = recvResponse(cc.dopts, t, &c, stream, reply)
if _, ok := lastErr.(transport.ConnectionError); ok { if _, ok := lastErr.(transport.ConnectionError); ok {
continue continue
} }

View File

@ -98,7 +98,7 @@ func (h *testStreamHandler) handleStream(t *testing.T, s *transport.Stream) {
} }
} }
// send a response back to end the stream. // send a response back to end the stream.
reply, err := encode(testCodec{}, &expectedResponse, compressionNone) reply, err := encode(testCodec{}, &expectedResponse, nil, nil)
if err != nil { if err != nil {
t.Fatalf("Failed to encode the response: %v", err) t.Fatalf("Failed to encode the response: %v", err)
} }

View File

@ -73,6 +73,8 @@ var (
// values passed to Dial. // values passed to Dial.
type dialOptions struct { type dialOptions struct {
codec Codec codec Codec
cg CompressorGenerator
dg DecompressorGenerator
picker Picker picker Picker
block bool block bool
insecure bool insecure bool
@ -89,6 +91,18 @@ func WithCodec(c Codec) DialOption {
} }
} }
func WithCompressor(f CompressorGenerator) DialOption {
return func(o *dialOptions) {
o.cg = f
}
}
func WithDecompressor(f DecompressorGenerator) DialOption {
return func(o *dialOptions) {
o.dg = f
}
}
// WithPicker returns a DialOption which sets a picker for connection selection. // WithPicker returns a DialOption which sets a picker for connection selection.
func WithPicker(p Picker) DialOption { func WithPicker(p Picker) DialOption {
return func(o *dialOptions) { return func(o *dialOptions) {

View File

@ -34,9 +34,12 @@
package grpc package grpc
import ( import (
"bytes"
"compress/gzip"
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"math" "math"
"math/rand" "math/rand"
"os" "os"
@ -75,6 +78,59 @@ func (protoCodec) String() string {
return "proto" return "proto"
} }
type Compressor interface {
Do(w io.Writer, p []byte) error
Type() string
}
func NewGZIPCompressor() Compressor {
return &gzipCompressor{}
}
type gzipCompressor struct {
}
func (c *gzipCompressor) Do(w io.Writer, p []byte) error {
z := gzip.NewWriter(w)
if _, err := z.Write(p); err != nil {
return err
}
return z.Close()
}
func (c *gzipCompressor) Type() string {
return "gzip"
}
type Decompressor interface {
Do(r io.Reader) ([]byte, error)
Type() string
}
type gzipDecompressor struct {
}
func NewGZIPDecompressor() Decompressor {
return &gzipDecompressor{}
}
func (d *gzipDecompressor) Do(r io.Reader) ([]byte, error) {
z, err := gzip.NewReader(r)
if err != nil {
return nil, err
}
defer z.Close()
return ioutil.ReadAll(z)
}
func (d *gzipDecompressor) Type() string {
return "gzip"
}
type CompressorGenerator func() Compressor
type DecompressorGenerator func() Decompressor
// callInfo contains all related configuration and information about an RPC. // callInfo contains all related configuration and information about an RPC.
type callInfo struct { type callInfo struct {
failFast bool failFast bool
@ -126,8 +182,7 @@ type payloadFormat uint8
const ( const (
compressionNone payloadFormat = iota // no compression compressionNone payloadFormat = iota // no compression
compressionFlate compressionMade
// More formats
) )
// parser reads complelete gRPC messages from the underlying reader. // parser reads complelete gRPC messages from the underlying reader.
@ -166,7 +221,7 @@ func (p *parser) recvMsg() (pf payloadFormat, msg []byte, err error) {
// encode serializes msg and prepends the message header. If msg is nil, it // encode serializes msg and prepends the message header. If msg is nil, it
// generates the message header of 0 message length. // generates the message header of 0 message length.
func encode(c Codec, msg interface{}, pf payloadFormat) ([]byte, error) { func encode(c Codec, msg interface{}, cp Compressor, cbuf *bytes.Buffer) ([]byte, error) {
var b []byte var b []byte
var length uint var length uint
if msg != nil { if msg != nil {
@ -176,6 +231,12 @@ func encode(c Codec, msg interface{}, pf payloadFormat) ([]byte, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
if cp != nil {
if err := cp.Do(cbuf, b); err != nil {
return nil, err
}
b = cbuf.Bytes()
}
length = uint(len(b)) length = uint(len(b))
} }
if length > math.MaxUint32 { if length > math.MaxUint32 {
@ -190,7 +251,11 @@ func encode(c Codec, msg interface{}, pf payloadFormat) ([]byte, error) {
var buf = make([]byte, payloadLen+sizeLen+len(b)) var buf = make([]byte, payloadLen+sizeLen+len(b))
// Write payload format // Write payload format
buf[0] = byte(pf) if cp == nil {
buf[0] = byte(compressionNone)
} else {
buf[0] = byte(compressionMade)
}
// Write length of b into buf // Write length of b into buf
binary.BigEndian.PutUint32(buf[1:], uint32(length)) binary.BigEndian.PutUint32(buf[1:], uint32(length))
// Copy encoded msg to buf // Copy encoded msg to buf
@ -199,22 +264,42 @@ func encode(c Codec, msg interface{}, pf payloadFormat) ([]byte, error) {
return buf, nil return buf, nil
} }
func recv(p *parser, c Codec, m interface{}) error { func checkRecvPayload(pf payloadFormat, recvCompress string, dc Decompressor) error {
switch pf {
case compressionNone:
case compressionMade:
if recvCompress == "" {
return transport.StreamErrorf(codes.InvalidArgument, "grpc: received unexpected payload format %d", pf)
}
if dc == nil || recvCompress != dc.Type() {
return transport.StreamErrorf(codes.InvalidArgument, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress)
}
default:
return transport.StreamErrorf(codes.InvalidArgument, "grpc: received unexpected payload format %d", pf)
}
return nil
}
func recv(p *parser, c Codec, s *transport.Stream, dg DecompressorGenerator, m interface{}) error {
pf, d, err := p.recvMsg() pf, d, err := p.recvMsg()
if err != nil { if err != nil {
return err return err
} }
switch pf { var dc Decompressor
case compressionNone: if pf == compressionMade && dg != nil {
if err := c.Unmarshal(d, m); err != nil { dc = dg()
if rErr, ok := err.(rpcError); ok { }
return rErr if err := checkRecvPayload(pf, s.RecvCompress(), dc); err != nil {
} else { return err
return Errorf(codes.Internal, "grpc: %v", err) }
} if pf == compressionMade {
d, err = dc.Do(bytes.NewReader(d))
if err != nil {
return transport.StreamErrorf(codes.Internal, "grpc: failed to decompress the received message %v", err)
} }
default: }
return Errorf(codes.Internal, "gprc: compression is not supported yet.") if err := c.Unmarshal(d, m); err != nil {
return transport.StreamErrorf(codes.Internal, "grpc: failed to unmarshal the received message %v", err)
} }
return nil return nil
} }

View File

@ -106,16 +106,40 @@ func TestEncode(t *testing.T) {
for _, test := range []struct { for _, test := range []struct {
// input // input
msg proto.Message msg proto.Message
pt payloadFormat cp Compressor
// outputs // outputs
b []byte b []byte
err error err error
}{ }{
{nil, compressionNone, []byte{0, 0, 0, 0, 0}, nil}, {nil, nil, []byte{0, 0, 0, 0, 0}, nil},
} { } {
b, err := encode(protoCodec{}, test.msg, test.pt) b, err := encode(protoCodec{}, test.msg, nil, nil)
if err != test.err || !bytes.Equal(b, test.b) { if err != test.err || !bytes.Equal(b, test.b) {
t.Fatalf("encode(_, _, %d) = %v, %v\nwant %v, %v", test.pt, b, err, test.b, test.err) t.Fatalf("encode(_, _, %v, _) = %v, %v\nwant %v, %v", test.cp, b, err, test.b, test.err)
}
}
}
func TestCompress(t *testing.T) {
for _, test := range []struct {
// input
data []byte
cp Compressor
dc Decompressor
// outputs
err error
}{
{make([]byte, 1024), &gzipCompressor{}, &gzipDecompressor{}, nil},
} {
b := new(bytes.Buffer)
if err := test.cp.Do(b, test.data); err != test.err {
t.Fatalf("Compressor.Do(_, %v) = %v, want %v", test.data, err, test.err)
}
if b.Len() >= len(test.data) {
t.Fatalf("The compressor fails to compress data.")
}
if p, err := test.dc.Do(b); err != nil || !bytes.Equal(test.data, p) {
t.Fatalf("Decompressor.Do(%v) = %v, %v, want %v, <nil>", b, p, err, test.data)
} }
} }
} }
@ -158,12 +182,12 @@ func TestContextErr(t *testing.T) {
// bytes. // bytes.
func bmEncode(b *testing.B, mSize int) { func bmEncode(b *testing.B, mSize int) {
msg := &perfpb.Buffer{Body: make([]byte, mSize)} msg := &perfpb.Buffer{Body: make([]byte, mSize)}
encoded, _ := encode(protoCodec{}, msg, compressionNone) encoded, _ := encode(protoCodec{}, msg, nil, nil)
encodedSz := int64(len(encoded)) encodedSz := int64(len(encoded))
b.ReportAllocs() b.ReportAllocs()
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
encode(protoCodec{}, msg, compressionNone) encode(protoCodec{}, msg, nil, nil)
} }
b.SetBytes(encodedSz) b.SetBytes(encodedSz)
} }

159
server.go
View File

@ -34,6 +34,7 @@
package grpc package grpc
import ( import (
"bytes"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@ -92,6 +93,8 @@ type Server struct {
type options struct { type options struct {
creds credentials.Credentials creds credentials.Credentials
codec Codec codec Codec
cg CompressorGenerator
dg DecompressorGenerator
maxConcurrentStreams uint32 maxConcurrentStreams uint32
} }
@ -105,6 +108,18 @@ func CustomCodec(codec Codec) ServerOption {
} }
} }
func CompressON(f CompressorGenerator) ServerOption {
return func(o *options) {
o.cg = f
}
}
func DecompressON(f DecompressorGenerator) ServerOption {
return func(o *options) {
o.dg = f
}
}
// MaxConcurrentStreams returns a ServerOption that will apply a limit on the number // MaxConcurrentStreams returns a ServerOption that will apply a limit on the number
// of concurrent streams to each ServerTransport. // of concurrent streams to each ServerTransport.
func MaxConcurrentStreams(n uint32) ServerOption { func MaxConcurrentStreams(n uint32) ServerOption {
@ -287,8 +302,8 @@ func (s *Server) Serve(lis net.Listener) error {
} }
} }
func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Stream, msg interface{}, pf payloadFormat, opts *transport.Options) error { func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Stream, msg interface{}, cp Compressor, opts *transport.Options) error {
p, err := encode(s.opts.codec, msg, pf) p, err := encode(s.opts.codec, msg, cp, new(bytes.Buffer))
if err != nil { if err != nil {
// This typically indicates a fatal issue (e.g., memory // This typically indicates a fatal issue (e.g., memory
// corruption or hardware faults) the application program // corruption or hardware faults) the application program
@ -327,82 +342,119 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
// Nothing to do here. // Nothing to do here.
case transport.StreamError: case transport.StreamError:
if err := t.WriteStatus(stream, err.Code, err.Desc); err != nil { if err := t.WriteStatus(stream, err.Code, err.Desc); err != nil {
grpclog.Printf("grpc: Server.processUnaryRPC failed to write status: %v", err) grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", err)
} }
default: default:
panic(fmt.Sprintf("grpc: Unexpected error (%T) from recvMsg: %v", err, err)) panic(fmt.Sprintf("grpc: Unexpected error (%T) from recvMsg: %v", err, err))
} }
return err return err
} }
switch pf {
case compressionNone: var dc Decompressor
statusCode := codes.OK if pf == compressionMade && s.opts.dg != nil {
statusDesc := "" dc = s.opts.dg()
df := func(v interface{}) error { }
if err := s.opts.codec.Unmarshal(req, v); err != nil { if err := checkRecvPayload(pf, stream.RecvCompress(), dc); err != nil {
return err switch err := err.(type) {
case transport.StreamError:
if err := t.WriteStatus(stream, err.Code, err.Desc); err != nil {
grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", err)
} }
if trInfo != nil { default:
trInfo.tr.LazyLog(&payload{sent: false, msg: v}, true) if err := t.WriteStatus(stream, codes.Internal, err.Error()); err != nil {
} grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", err)
return nil
}
reply, appErr := md.Handler(srv.server, stream.Context(), df)
if appErr != nil {
if err, ok := appErr.(rpcError); ok {
statusCode = err.code
statusDesc = err.desc
} else {
statusCode = convertCode(appErr)
statusDesc = appErr.Error()
}
if trInfo != nil && statusCode != codes.OK {
trInfo.tr.LazyLog(stringer(statusDesc), true)
trInfo.tr.SetError()
} }
if err := t.WriteStatus(stream, statusCode, statusDesc); err != nil { }
grpclog.Printf("grpc: Server.processUnaryRPC failed to write status: %v", err) return err
}
statusCode := codes.OK
statusDesc := ""
df := func(v interface{}) error {
if pf == compressionMade {
var err error
req, err = dc.Do(bytes.NewReader(req))
//req, err = ioutil.ReadAll(dc)
//defer dc.Close()
if err != nil {
if err := t.WriteStatus(stream, codes.Internal, err.Error()); err != nil {
grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", err)
}
return err return err
} }
return nil
} }
if trInfo != nil { if err := s.opts.codec.Unmarshal(req, v); err != nil {
trInfo.tr.LazyLog(stringer("OK"), false)
}
opts := &transport.Options{
Last: true,
Delay: false,
}
if err := s.sendResponse(t, stream, reply, compressionNone, opts); err != nil {
switch err := err.(type) {
case transport.ConnectionError:
// Nothing to do here.
case transport.StreamError:
statusCode = err.Code
statusDesc = err.Desc
default:
statusCode = codes.Unknown
statusDesc = err.Error()
}
return err return err
} }
if trInfo != nil { if trInfo != nil {
trInfo.tr.LazyLog(&payload{sent: true, msg: reply}, true) trInfo.tr.LazyLog(&payload{sent: false, msg: v}, true)
} }
return t.WriteStatus(stream, statusCode, statusDesc) return nil
default:
panic(fmt.Sprintf("payload format to be supported: %d", pf))
} }
reply, appErr := md.Handler(srv.server, stream.Context(), df)
if appErr != nil {
if err, ok := appErr.(rpcError); ok {
statusCode = err.code
statusDesc = err.desc
} else {
statusCode = convertCode(appErr)
statusDesc = appErr.Error()
}
if trInfo != nil && statusCode != codes.OK {
trInfo.tr.LazyLog(stringer(statusDesc), true)
trInfo.tr.SetError()
}
if err := t.WriteStatus(stream, statusCode, statusDesc); err != nil {
grpclog.Printf("grpc: Server.processUnaryRPC failed to write status: %v", err)
return err
}
return nil
}
if trInfo != nil {
trInfo.tr.LazyLog(stringer("OK"), false)
}
opts := &transport.Options{
Last: true,
Delay: false,
}
var cp Compressor
if s.opts.cg != nil {
cp = s.opts.cg()
stream.SetSendCompress(cp.Type())
}
if err := s.sendResponse(t, stream, reply, cp, opts); err != nil {
switch err := err.(type) {
case transport.ConnectionError:
// Nothing to do here.
case transport.StreamError:
statusCode = err.Code
statusDesc = err.Desc
default:
statusCode = codes.Unknown
statusDesc = err.Error()
}
return err
}
if trInfo != nil {
trInfo.tr.LazyLog(&payload{sent: true, msg: reply}, true)
}
return t.WriteStatus(stream, statusCode, statusDesc)
} }
} }
func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, sd *StreamDesc, trInfo *traceInfo) (err error) { func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, sd *StreamDesc, trInfo *traceInfo) (err error) {
var cp Compressor
if s.opts.cg != nil {
cp = s.opts.cg()
stream.SetSendCompress(cp.Type())
}
ss := &serverStream{ ss := &serverStream{
t: t, t: t,
s: stream, s: stream,
p: &parser{s: stream}, p: &parser{s: stream},
codec: s.opts.codec, codec: s.opts.codec,
cp: cp,
dg: s.opts.dg,
trInfo: trInfo, trInfo: trInfo,
} }
if trInfo != nil { if trInfo != nil {
@ -422,6 +474,9 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
if err, ok := appErr.(rpcError); ok { if err, ok := appErr.(rpcError); ok {
ss.statusCode = err.code ss.statusCode = err.code
ss.statusDesc = err.desc ss.statusDesc = err.desc
} else if err, ok := appErr.(transport.StreamError); ok {
ss.statusCode = err.Code
ss.statusDesc = err.Desc
} else { } else {
ss.statusCode = convertCode(appErr) ss.statusCode = convertCode(appErr)
ss.statusDesc = appErr.Error() ss.statusDesc = appErr.Error()

View File

@ -34,6 +34,7 @@
package grpc package grpc
import ( import (
"bytes"
"errors" "errors"
"io" "io"
"sync" "sync"
@ -104,14 +105,23 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
if err != nil { if err != nil {
return nil, toRPCErr(err) return nil, toRPCErr(err)
} }
var cp Compressor
if cc.dopts.cg != nil {
cp = cc.dopts.cg()
}
// TODO(zhaoq): CallOption is omitted. Add support when it is needed. // TODO(zhaoq): CallOption is omitted. Add support when it is needed.
callHdr := &transport.CallHdr{ callHdr := &transport.CallHdr{
Host: cc.authority, Host: cc.authority,
Method: method, Method: method,
} }
if cp != nil {
callHdr.SendCompress = cp.Type()
}
cs := &clientStream{ cs := &clientStream{
desc: desc, desc: desc,
codec: cc.dopts.codec, codec: cc.dopts.codec,
cp: cp,
dg: cc.dopts.dg,
tracing: EnableTracing, tracing: EnableTracing,
} }
if cs.tracing { if cs.tracing {
@ -153,6 +163,9 @@ type clientStream struct {
p *parser p *parser
desc *StreamDesc desc *StreamDesc
codec Codec codec Codec
cp Compressor
cbuf bytes.Buffer
dg DecompressorGenerator
tracing bool // set to EnableTracing when the clientStream is created. tracing bool // set to EnableTracing when the clientStream is created.
@ -198,7 +211,8 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {
} }
err = toRPCErr(err) err = toRPCErr(err)
}() }()
out, err := encode(cs.codec, m, compressionNone) out, err := encode(cs.codec, m, cs.cp, &cs.cbuf)
defer cs.cbuf.Reset()
if err != nil { if err != nil {
return transport.StreamErrorf(codes.Internal, "grpc: %v", err) return transport.StreamErrorf(codes.Internal, "grpc: %v", err)
} }
@ -206,7 +220,7 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {
} }
func (cs *clientStream) RecvMsg(m interface{}) (err error) { func (cs *clientStream) RecvMsg(m interface{}) (err error) {
err = recv(cs.p, cs.codec, m) err = recv(cs.p, cs.codec, cs.s, cs.dg, m)
defer func() { defer func() {
// err != nil indicates the termination of the stream. // err != nil indicates the termination of the stream.
if err != nil { if err != nil {
@ -225,7 +239,7 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) {
return return
} }
// Special handling for client streaming rpc. // Special handling for client streaming rpc.
err = recv(cs.p, cs.codec, m) err = recv(cs.p, cs.codec, cs.s, cs.dg, m)
cs.closeTransportStream(err) cs.closeTransportStream(err)
if err == nil { if err == nil {
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>")) return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
@ -310,6 +324,9 @@ type serverStream struct {
s *transport.Stream s *transport.Stream
p *parser p *parser
codec Codec codec Codec
cp Compressor
dg DecompressorGenerator
cbuf bytes.Buffer
statusCode codes.Code statusCode codes.Code
statusDesc string statusDesc string
trInfo *traceInfo trInfo *traceInfo
@ -348,7 +365,8 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) {
ss.mu.Unlock() ss.mu.Unlock()
} }
}() }()
out, err := encode(ss.codec, m, compressionNone) out, err := encode(ss.codec, m, ss.cp, &ss.cbuf)
defer ss.cbuf.Reset()
if err != nil { if err != nil {
err = transport.StreamErrorf(codes.Internal, "grpc: %v", err) err = transport.StreamErrorf(codes.Internal, "grpc: %v", err)
return err return err
@ -371,5 +389,5 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) {
ss.mu.Unlock() ss.mu.Unlock()
} }
}() }()
return recv(ss.p, ss.codec, m) return recv(ss.p, ss.codec, ss.s, ss.dg, m)
} }

View File

@ -143,7 +143,6 @@ func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &testpb.SimpleResponse{ return &testpb.SimpleResponse{
Payload: payload, Payload: payload,
}, nil }, nil
@ -328,8 +327,8 @@ func listTestEnv() []env {
return []env{{"tcp", nil, ""}, {"tcp", nil, "tls"}, {"unix", unixDialer, ""}, {"unix", unixDialer, "tls"}} return []env{{"tcp", nil, ""}, {"tcp", nil, "tls"}, {"unix", unixDialer, ""}, {"unix", unixDialer, "tls"}}
} }
func setUp(t *testing.T, hs *health.HealthServer, maxStream uint32, ua string, e env) (s *grpc.Server, cc *grpc.ClientConn) { func serverSetUp(t *testing.T, hs *health.HealthServer, maxStream uint32, cg grpc.CompressorGenerator, dg grpc.DecompressorGenerator, e env) (s *grpc.Server, addr string) {
sopts := []grpc.ServerOption{grpc.MaxConcurrentStreams(maxStream)} sopts := []grpc.ServerOption{grpc.MaxConcurrentStreams(maxStream), grpc.CompressON(cg), grpc.DecompressON(dg)}
la := ":0" la := ":0"
switch e.network { switch e.network {
case "unix": case "unix":
@ -353,7 +352,7 @@ func setUp(t *testing.T, hs *health.HealthServer, maxStream uint32, ua string, e
} }
testpb.RegisterTestServiceServer(s, &testServer{security: e.security}) testpb.RegisterTestServiceServer(s, &testServer{security: e.security})
go s.Serve(lis) go s.Serve(lis)
addr := la addr = la
switch e.network { switch e.network {
case "unix": case "unix":
default: default:
@ -363,17 +362,22 @@ func setUp(t *testing.T, hs *health.HealthServer, maxStream uint32, ua string, e
} }
addr = "localhost:" + port addr = "localhost:" + port
} }
return
}
func clientSetUp(t *testing.T, addr string, cg grpc.CompressorGenerator, dg grpc.DecompressorGenerator, ua string, e env) (cc *grpc.ClientConn) {
var derr error
if e.security == "tls" { if e.security == "tls" {
creds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "x.test.youtube.com") creds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "x.test.youtube.com")
if err != nil { if err != nil {
t.Fatalf("Failed to create credentials %v", err) t.Fatalf("Failed to create credentials %v", err)
} }
cc, err = grpc.Dial(addr, grpc.WithTransportCredentials(creds), grpc.WithDialer(e.dialer), grpc.WithUserAgent(ua)) cc, derr = grpc.Dial(addr, grpc.WithTransportCredentials(creds), grpc.WithDialer(e.dialer), grpc.WithUserAgent(ua), grpc.WithCompressor(cg), grpc.WithDecompressor(dg))
} else { } else {
cc, err = grpc.Dial(addr, grpc.WithDialer(e.dialer), grpc.WithInsecure(), grpc.WithUserAgent(ua)) cc, derr = grpc.Dial(addr, grpc.WithDialer(e.dialer), grpc.WithInsecure(), grpc.WithUserAgent(ua), grpc.WithCompressor(cg), grpc.WithDecompressor(dg))
} }
if err != nil { if derr != nil {
t.Fatalf("Dial(%q) = %v", addr, err) t.Fatalf("Dial(%q) = %v", addr, derr)
} }
return return
} }
@ -390,7 +394,8 @@ func TestTimeoutOnDeadServer(t *testing.T) {
} }
func testTimeoutOnDeadServer(t *testing.T, e env) { func testTimeoutOnDeadServer(t *testing.T, e env) {
s, cc := setUp(t, nil, math.MaxUint32, "", e) s, addr := serverSetUp(t, nil, math.MaxUint32, nil, nil, e)
cc := clientSetUp(t, addr, nil, nil, "", e)
tc := testpb.NewTestServiceClient(cc) tc := testpb.NewTestServiceClient(cc)
ctx, _ := context.WithTimeout(context.Background(), time.Second) ctx, _ := context.WithTimeout(context.Background(), time.Second)
if _, err := cc.WaitForStateChange(ctx, grpc.Idle); err != nil { if _, err := cc.WaitForStateChange(ctx, grpc.Idle); err != nil {
@ -443,7 +448,8 @@ func TestHealthCheckOnSuccess(t *testing.T) {
func testHealthCheckOnSuccess(t *testing.T, e env) { func testHealthCheckOnSuccess(t *testing.T, e env) {
hs := health.NewHealthServer() hs := health.NewHealthServer()
hs.SetServingStatus("grpc.health.v1alpha.Health", 1) hs.SetServingStatus("grpc.health.v1alpha.Health", 1)
s, cc := setUp(t, hs, math.MaxUint32, "", e) s, addr := serverSetUp(t, hs, math.MaxUint32, nil, nil, e)
cc := clientSetUp(t, addr, nil, nil, "", e)
defer tearDown(s, cc) defer tearDown(s, cc)
if _, err := healthCheck(1*time.Second, cc, "grpc.health.v1alpha.Health"); err != nil { if _, err := healthCheck(1*time.Second, cc, "grpc.health.v1alpha.Health"); err != nil {
t.Fatalf("Health/Check(_, _) = _, %v, want _, <nil>", err) t.Fatalf("Health/Check(_, _) = _, %v, want _, <nil>", err)
@ -459,7 +465,8 @@ func TestHealthCheckOnFailure(t *testing.T) {
func testHealthCheckOnFailure(t *testing.T, e env) { func testHealthCheckOnFailure(t *testing.T, e env) {
hs := health.NewHealthServer() hs := health.NewHealthServer()
hs.SetServingStatus("grpc.health.v1alpha.HealthCheck", 1) hs.SetServingStatus("grpc.health.v1alpha.HealthCheck", 1)
s, cc := setUp(t, hs, math.MaxUint32, "", e) s, addr := serverSetUp(t, hs, math.MaxUint32, nil, nil, e)
cc := clientSetUp(t, addr, nil, nil, "", e)
defer tearDown(s, cc) defer tearDown(s, cc)
if _, err := healthCheck(0*time.Second, cc, "grpc.health.v1alpha.Health"); err != grpc.Errorf(codes.DeadlineExceeded, "context deadline exceeded") { if _, err := healthCheck(0*time.Second, cc, "grpc.health.v1alpha.Health"); err != grpc.Errorf(codes.DeadlineExceeded, "context deadline exceeded") {
t.Fatalf("Health/Check(_, _) = _, %v, want _, error code %d", err, codes.DeadlineExceeded) t.Fatalf("Health/Check(_, _) = _, %v, want _, error code %d", err, codes.DeadlineExceeded)
@ -473,7 +480,8 @@ func TestHealthCheckOff(t *testing.T) {
} }
func testHealthCheckOff(t *testing.T, e env) { func testHealthCheckOff(t *testing.T, e env) {
s, cc := setUp(t, nil, math.MaxUint32, "", e) s, addr := serverSetUp(t, nil, math.MaxUint32, nil, nil, e)
cc := clientSetUp(t, addr, nil, nil, "", e)
defer tearDown(s, cc) defer tearDown(s, cc)
if _, err := healthCheck(1*time.Second, cc, ""); err != grpc.Errorf(codes.Unimplemented, "unknown service grpc.health.v1alpha.Health") { if _, err := healthCheck(1*time.Second, cc, ""); err != grpc.Errorf(codes.Unimplemented, "unknown service grpc.health.v1alpha.Health") {
t.Fatalf("Health/Check(_, _) = _, %v, want _, error code %d", err, codes.Unimplemented) t.Fatalf("Health/Check(_, _) = _, %v, want _, error code %d", err, codes.Unimplemented)
@ -488,7 +496,8 @@ func TestHealthCheckServingStatus(t *testing.T) {
func testHealthCheckServingStatus(t *testing.T, e env) { func testHealthCheckServingStatus(t *testing.T, e env) {
hs := health.NewHealthServer() hs := health.NewHealthServer()
s, cc := setUp(t, hs, math.MaxUint32, "", e) s, addr := serverSetUp(t, hs, math.MaxUint32, nil, nil, e)
cc := clientSetUp(t, addr, nil, nil, "", e)
defer tearDown(s, cc) defer tearDown(s, cc)
out, err := healthCheck(1*time.Second, cc, "") out, err := healthCheck(1*time.Second, cc, "")
if err != nil { if err != nil {
@ -526,7 +535,8 @@ func TestEmptyUnaryWithUserAgent(t *testing.T) {
} }
func testEmptyUnaryWithUserAgent(t *testing.T, e env) { func testEmptyUnaryWithUserAgent(t *testing.T, e env) {
s, cc := setUp(t, nil, math.MaxUint32, testAppUA, e) s, addr := serverSetUp(t, nil, math.MaxUint32, nil, nil, e)
cc := clientSetUp(t, addr, nil, nil, testAppUA, e)
// Wait until cc is connected. // Wait until cc is connected.
ctx, _ := context.WithTimeout(context.Background(), time.Second) ctx, _ := context.WithTimeout(context.Background(), time.Second)
if _, err := cc.WaitForStateChange(ctx, grpc.Idle); err != nil { if _, err := cc.WaitForStateChange(ctx, grpc.Idle); err != nil {
@ -569,7 +579,8 @@ func TestFailedEmptyUnary(t *testing.T) {
} }
func testFailedEmptyUnary(t *testing.T, e env) { func testFailedEmptyUnary(t *testing.T, e env) {
s, cc := setUp(t, nil, math.MaxUint32, "", e) s, addr := serverSetUp(t, nil, math.MaxUint32, nil, nil, e)
cc := clientSetUp(t, addr, nil, nil, "", e)
tc := testpb.NewTestServiceClient(cc) tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc) defer tearDown(s, cc)
ctx := metadata.NewContext(context.Background(), testMetadata) ctx := metadata.NewContext(context.Background(), testMetadata)
@ -585,7 +596,8 @@ func TestLargeUnary(t *testing.T) {
} }
func testLargeUnary(t *testing.T, e env) { func testLargeUnary(t *testing.T, e env) {
s, cc := setUp(t, nil, math.MaxUint32, "", e) s, addr := serverSetUp(t, nil, math.MaxUint32, nil, nil, e)
cc := clientSetUp(t, addr, nil, nil, "", e)
tc := testpb.NewTestServiceClient(cc) tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc) defer tearDown(s, cc)
argSize := 271828 argSize := 271828
@ -619,7 +631,8 @@ func TestMetadataUnaryRPC(t *testing.T) {
} }
func testMetadataUnaryRPC(t *testing.T, e env) { func testMetadataUnaryRPC(t *testing.T, e env) {
s, cc := setUp(t, nil, math.MaxUint32, "", e) s, addr := serverSetUp(t, nil, math.MaxUint32, nil, nil, e)
cc := clientSetUp(t, addr, nil, nil, "", e)
tc := testpb.NewTestServiceClient(cc) tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc) defer tearDown(s, cc)
argSize := 2718 argSize := 2718
@ -684,7 +697,8 @@ func TestRetry(t *testing.T) {
// TODO(zhaoq): Refactor to make this clearer and add more cases to test racy // TODO(zhaoq): Refactor to make this clearer and add more cases to test racy
// and error-prone paths. // and error-prone paths.
func testRetry(t *testing.T, e env) { func testRetry(t *testing.T, e env) {
s, cc := setUp(t, nil, math.MaxUint32, "", e) s, addr := serverSetUp(t, nil, math.MaxUint32, nil, nil, e)
cc := clientSetUp(t, addr, nil, nil, "", e)
tc := testpb.NewTestServiceClient(cc) tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc) defer tearDown(s, cc)
var wg sync.WaitGroup var wg sync.WaitGroup
@ -714,7 +728,8 @@ func TestRPCTimeout(t *testing.T) {
// TODO(zhaoq): Have a better test coverage of timeout and cancellation mechanism. // TODO(zhaoq): Have a better test coverage of timeout and cancellation mechanism.
func testRPCTimeout(t *testing.T, e env) { func testRPCTimeout(t *testing.T, e env) {
s, cc := setUp(t, nil, math.MaxUint32, "", e) s, addr := serverSetUp(t, nil, math.MaxUint32, nil, nil, e)
cc := clientSetUp(t, addr, nil, nil, "", e)
tc := testpb.NewTestServiceClient(cc) tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc) defer tearDown(s, cc)
argSize := 2718 argSize := 2718
@ -746,7 +761,8 @@ func TestCancel(t *testing.T) {
} }
func testCancel(t *testing.T, e env) { func testCancel(t *testing.T, e env) {
s, cc := setUp(t, nil, math.MaxUint32, "", e) s, addr := serverSetUp(t, nil, math.MaxUint32, nil, nil, e)
cc := clientSetUp(t, addr, nil, nil, "", e)
tc := testpb.NewTestServiceClient(cc) tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc) defer tearDown(s, cc)
argSize := 2718 argSize := 2718
@ -778,7 +794,8 @@ func TestCancelNoIO(t *testing.T) {
func testCancelNoIO(t *testing.T, e env) { func testCancelNoIO(t *testing.T, e env) {
// Only allows 1 live stream per server transport. // Only allows 1 live stream per server transport.
s, cc := setUp(t, nil, 1, "", e) s, addr := serverSetUp(t, nil, 1, nil, nil, e)
cc := clientSetUp(t, addr, nil, nil, "", e)
tc := testpb.NewTestServiceClient(cc) tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc) defer tearDown(s, cc)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
@ -829,7 +846,8 @@ func TestPingPong(t *testing.T) {
} }
func testPingPong(t *testing.T, e env) { func testPingPong(t *testing.T, e env) {
s, cc := setUp(t, nil, math.MaxUint32, "", e) s, addr := serverSetUp(t, nil, math.MaxUint32, nil, nil, e)
cc := clientSetUp(t, addr, nil, nil, "", e)
tc := testpb.NewTestServiceClient(cc) tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc) defer tearDown(s, cc)
stream, err := tc.FullDuplexCall(context.Background()) stream, err := tc.FullDuplexCall(context.Background())
@ -886,7 +904,8 @@ func TestMetadataStreamingRPC(t *testing.T) {
} }
func testMetadataStreamingRPC(t *testing.T, e env) { func testMetadataStreamingRPC(t *testing.T, e env) {
s, cc := setUp(t, nil, math.MaxUint32, "", e) s, addr := serverSetUp(t, nil, math.MaxUint32, nil, nil, e)
cc := clientSetUp(t, addr, nil, nil, "", e)
tc := testpb.NewTestServiceClient(cc) tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc) defer tearDown(s, cc)
ctx := metadata.NewContext(context.Background(), testMetadata) ctx := metadata.NewContext(context.Background(), testMetadata)
@ -952,7 +971,8 @@ func TestServerStreaming(t *testing.T) {
} }
func testServerStreaming(t *testing.T, e env) { func testServerStreaming(t *testing.T, e env) {
s, cc := setUp(t, nil, math.MaxUint32, "", e) s, addr := serverSetUp(t, nil, math.MaxUint32, nil, nil, e)
cc := clientSetUp(t, addr, nil, nil, "", e)
tc := testpb.NewTestServiceClient(cc) tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc) defer tearDown(s, cc)
respParam := make([]*testpb.ResponseParameters, len(respSizes)) respParam := make([]*testpb.ResponseParameters, len(respSizes))
@ -1004,7 +1024,8 @@ func TestFailedServerStreaming(t *testing.T) {
} }
func testFailedServerStreaming(t *testing.T, e env) { func testFailedServerStreaming(t *testing.T, e env) {
s, cc := setUp(t, nil, math.MaxUint32, "", e) s, addr := serverSetUp(t, nil, math.MaxUint32, nil, nil, e)
cc := clientSetUp(t, addr, nil, nil, "", e)
tc := testpb.NewTestServiceClient(cc) tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc) defer tearDown(s, cc)
respParam := make([]*testpb.ResponseParameters, len(respSizes)) respParam := make([]*testpb.ResponseParameters, len(respSizes))
@ -1034,7 +1055,8 @@ func TestClientStreaming(t *testing.T) {
} }
func testClientStreaming(t *testing.T, e env) { func testClientStreaming(t *testing.T, e env) {
s, cc := setUp(t, nil, math.MaxUint32, "", e) s, addr := serverSetUp(t, nil, math.MaxUint32, nil, nil, e)
cc := clientSetUp(t, addr, nil, nil, "", e)
tc := testpb.NewTestServiceClient(cc) tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc) defer tearDown(s, cc)
stream, err := tc.StreamingInputCall(context.Background()) stream, err := tc.StreamingInputCall(context.Background())
@ -1074,7 +1096,8 @@ func TestExceedMaxStreamsLimit(t *testing.T) {
func testExceedMaxStreamsLimit(t *testing.T, e env) { func testExceedMaxStreamsLimit(t *testing.T, e env) {
// Only allows 1 live stream per server transport. // Only allows 1 live stream per server transport.
s, cc := setUp(t, nil, 1, "", e) s, addr := serverSetUp(t, nil, 1, nil, nil, e)
cc := clientSetUp(t, addr, nil, nil, "", e)
tc := testpb.NewTestServiceClient(cc) tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc) defer tearDown(s, cc)
_, err := tc.StreamingInputCall(context.Background()) _, err := tc.StreamingInputCall(context.Background())
@ -1095,3 +1118,109 @@ func testExceedMaxStreamsLimit(t *testing.T, e env) {
t.Fatalf("%v.StreamingInputCall(_) = _, %v, want _, %d", tc, err, codes.DeadlineExceeded) t.Fatalf("%v.StreamingInputCall(_) = _, %v, want _, %d", tc, err, codes.DeadlineExceeded)
} }
} }
func TestCompressServerHasNoSupport(t *testing.T) {
for _, e := range listTestEnv() {
testCompressServerHasNoSupport(t, e)
}
}
func testCompressServerHasNoSupport(t *testing.T, e env) {
s, addr := serverSetUp(t, nil, math.MaxUint32, nil, nil, e)
cc := clientSetUp(t, addr, grpc.NewGZIPCompressor, nil, "", e)
// Unary call
tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc)
argSize := 271828
respSize := 314159
payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, int32(argSize))
if err != nil {
t.Fatal(err)
}
req := &testpb.SimpleRequest{
ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(),
ResponseSize: proto.Int32(int32(respSize)),
Payload: payload,
}
if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.InvalidArgument {
t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code %d", err, codes.InvalidArgument)
}
// Streaming RPC
stream, err := tc.FullDuplexCall(context.Background())
if err != nil {
t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
}
respParam := []*testpb.ResponseParameters{
{
Size: proto.Int32(31415),
},
}
payload, err = newPayload(testpb.PayloadType_COMPRESSABLE, int32(31415))
if err != nil {
t.Fatal(err)
}
sreq := &testpb.StreamingOutputCallRequest{
ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(),
ResponseParameters: respParam,
Payload: payload,
}
if err := stream.Send(sreq); err != nil {
t.Fatalf("%v.Send(%v) = %v, want <nil>", stream, sreq, err)
}
if _, err := stream.Recv(); err == nil || grpc.Code(err) != codes.InvalidArgument {
t.Fatalf("%v.Recv() = %v, want error code %d", stream, err, codes.InvalidArgument)
}
}
func TestCompressOK(t *testing.T) {
for _, e := range listTestEnv() {
testCompressOK(t, e)
}
}
func testCompressOK(t *testing.T, e env) {
s, addr := serverSetUp(t, nil, math.MaxUint32, grpc.NewGZIPCompressor, grpc.NewGZIPDecompressor, e)
cc := clientSetUp(t, addr, grpc.NewGZIPCompressor, grpc.NewGZIPDecompressor, "", e)
// Unary call
tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc)
argSize := 271828
respSize := 314159
payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, int32(argSize))
if err != nil {
t.Fatal(err)
}
req := &testpb.SimpleRequest{
ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(),
ResponseSize: proto.Int32(int32(respSize)),
Payload: payload,
}
if _, err := tc.UnaryCall(context.Background(), req); err != nil {
t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, <nil>", err)
}
// Streaming RPC
stream, err := tc.FullDuplexCall(context.Background())
if err != nil {
t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
}
respParam := []*testpb.ResponseParameters{
{
Size: proto.Int32(31415),
},
}
payload, err = newPayload(testpb.PayloadType_COMPRESSABLE, int32(31415))
if err != nil {
t.Fatal(err)
}
sreq := &testpb.StreamingOutputCallRequest{
ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(),
ResponseParameters: respParam,
Payload: payload,
}
if err := stream.Send(sreq); err != nil {
t.Fatalf("%v.Send(%v) = %v, want <nil>", stream, sreq, err)
}
if _, err := stream.Recv(); err != nil {
t.Fatalf("%v.Recv() = %v, want <nil>", stream, err)
}
}

View File

@ -208,12 +208,13 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream {
} }
// TODO(zhaoq): Handle uint32 overflow of Stream.id. // TODO(zhaoq): Handle uint32 overflow of Stream.id.
s := &Stream{ s := &Stream{
id: t.nextID, id: t.nextID,
method: callHdr.Method, method: callHdr.Method,
buf: newRecvBuffer(), sendCompress: callHdr.SendCompress,
fc: fc, buf: newRecvBuffer(),
sendQuotaPool: newQuotaPool(int(t.streamSendQuota)), fc: fc,
headerChan: make(chan struct{}), sendQuotaPool: newQuotaPool(int(t.streamSendQuota)),
headerChan: make(chan struct{}),
} }
t.nextID += 2 t.nextID += 2
s.windowHandler = func(n int) { s.windowHandler = func(n int) {
@ -322,6 +323,9 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
t.hEnc.WriteField(hpack.HeaderField{Name: "user-agent", Value: t.userAgent}) t.hEnc.WriteField(hpack.HeaderField{Name: "user-agent", Value: t.userAgent})
t.hEnc.WriteField(hpack.HeaderField{Name: "te", Value: "trailers"}) t.hEnc.WriteField(hpack.HeaderField{Name: "te", Value: "trailers"})
if callHdr.SendCompress != "" {
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-encoding", Value: callHdr.SendCompress})
}
if timeout > 0 { if timeout > 0 {
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-timeout", Value: timeoutEncode(timeout)}) t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-timeout", Value: timeoutEncode(timeout)})
} }
@ -694,8 +698,10 @@ func (t *http2Client) operateHeaders(hDec *hpackDecoder, s *Stream, frame header
if !endHeaders { if !endHeaders {
return s return s
} }
s.mu.Lock() s.mu.Lock()
if !endStream {
s.recvCompress = hDec.state.encoding
}
if !s.headerDone { if !s.headerDone {
if !endStream && len(hDec.state.mdata) > 0 { if !endStream && len(hDec.state.mdata) > 0 {
s.header = hDec.state.mdata s.header = hDec.state.mdata

View File

@ -164,6 +164,7 @@ func (t *http2Server) operateHeaders(hDec *hpackDecoder, s *Stream, frame header
if !endHeaders { if !endHeaders {
return s return s
} }
s.recvCompress = hDec.state.encoding
if hDec.state.timeoutSet { if hDec.state.timeoutSet {
s.ctx, s.cancel = context.WithTimeout(context.TODO(), hDec.state.timeout) s.ctx, s.cancel = context.WithTimeout(context.TODO(), hDec.state.timeout)
} else { } else {
@ -190,6 +191,7 @@ func (t *http2Server) operateHeaders(hDec *hpackDecoder, s *Stream, frame header
ctx: s.ctx, ctx: s.ctx,
recv: s.buf, recv: s.buf,
} }
s.recvCompress = hDec.state.encoding
s.method = hDec.state.method s.method = hDec.state.method
t.mu.Lock() t.mu.Lock()
if t.state != reachable { if t.state != reachable {
@ -446,6 +448,9 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error {
t.hBuf.Reset() t.hBuf.Reset()
t.hEnc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) t.hEnc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
t.hEnc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"}) t.hEnc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"})
if s.sendCompress != "" {
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-encoding", Value: s.sendCompress})
}
for k, v := range md { for k, v := range md {
for _, entry := range v { for _, entry := range v {
t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: entry}) t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: entry})
@ -520,6 +525,9 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error {
t.hBuf.Reset() t.hBuf.Reset()
t.hEnc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) t.hEnc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
t.hEnc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"}) t.hEnc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"})
if s.sendCompress != "" {
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-encoding", Value: s.sendCompress})
}
p := http2.HeadersFrameParam{ p := http2.HeadersFrameParam{
StreamID: s.id, StreamID: s.id,
BlockFragment: t.hBuf.Bytes(), BlockFragment: t.hBuf.Bytes(),

View File

@ -89,6 +89,7 @@ var (
// Records the states during HPACK decoding. Must be reset once the // Records the states during HPACK decoding. Must be reset once the
// decoding of the entire headers are finished. // decoding of the entire headers are finished.
type decodeState struct { type decodeState struct {
encoding string
// statusCode caches the stream status received from the trailer // statusCode caches the stream status received from the trailer
// the server sent. Client side only. // the server sent. Client side only.
statusCode codes.Code statusCode codes.Code
@ -145,6 +146,8 @@ func newHPACKDecoder() *hpackDecoder {
d.err = StreamErrorf(codes.FailedPrecondition, "transport: received the unexpected header") d.err = StreamErrorf(codes.FailedPrecondition, "transport: received the unexpected header")
return return
} }
case "grpc-encoding":
d.state.encoding = f.Value
case "grpc-status": case "grpc-status":
code, err := strconv.Atoi(f.Value) code, err := strconv.Atoi(f.Value)
if err != nil { if err != nil {

View File

@ -171,6 +171,8 @@ type Stream struct {
cancel context.CancelFunc cancel context.CancelFunc
// method records the associated RPC method of the stream. // method records the associated RPC method of the stream.
method string method string
recvCompress string
sendCompress string
buf *recvBuffer buf *recvBuffer
dec io.Reader dec io.Reader
fc *inFlow fc *inFlow
@ -201,6 +203,14 @@ type Stream struct {
statusDesc string statusDesc string
} }
func (s *Stream) RecvCompress() string {
return s.recvCompress
}
func (s *Stream) SetSendCompress(str string) {
s.sendCompress = str
}
// Header acquires the key-value pairs of header metadata once it // Header acquires the key-value pairs of header metadata once it
// is available. It blocks until i) the metadata is ready or ii) there is no // is available. It blocks until i) the metadata is ready or ii) there is no
// header metadata or iii) the stream is cancelled/expired. // header metadata or iii) the stream is cancelled/expired.
@ -350,6 +360,8 @@ type Options struct {
type CallHdr struct { type CallHdr struct {
Host string // peer host Host string // peer host
Method string // the operation to perform on the specified host Method string // the operation to perform on the specified host
RecvCompress string
SendCompress string
} }
// ClientTransport is the common interface for all gRPC client side transport // ClientTransport is the common interface for all gRPC client side transport