From 4013f8d559720d4d0087cca0a3f18e0fdf25093d Mon Sep 17 00:00:00 2001 From: apolcyn Date: Wed, 10 May 2017 10:55:38 -0700 Subject: [PATCH 01/11] tentative fix to a flow control over-give-back bug (#1170) --- transport/http2_client.go | 8 +++----- transport/http2_server.go | 8 +++----- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/transport/http2_client.go b/transport/http2_client.go index bc202df2..736a4b35 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -802,11 +802,6 @@ func (t *http2Client) handleData(f *http2.DataFrame) { return } if size > 0 { - if f.Header().Flags.Has(http2.FlagDataPadded) { - if w := t.fc.onRead(uint32(size) - uint32(len(f.Data()))); w > 0 { - t.controlBuf.put(&windowUpdate{0, w}) - } - } s.mu.Lock() if s.state == streamDone { s.mu.Unlock() @@ -825,6 +820,9 @@ func (t *http2Client) handleData(f *http2.DataFrame) { return } if f.Header().Flags.Has(http2.FlagDataPadded) { + if w := t.fc.onRead(uint32(size) - uint32(len(f.Data()))); w > 0 { + t.controlBuf.put(&windowUpdate{0, w}) + } if w := s.fc.onRead(uint32(size) - uint32(len(f.Data()))); w > 0 { t.controlBuf.put(&windowUpdate{s.id, w}) } diff --git a/transport/http2_server.go b/transport/http2_server.go index 4e950363..4f719451 100644 --- a/transport/http2_server.go +++ b/transport/http2_server.go @@ -465,11 +465,6 @@ func (t *http2Server) handleData(f *http2.DataFrame) { return } if size > 0 { - if f.Header().Flags.Has(http2.FlagDataPadded) { - if w := t.fc.onRead(uint32(size) - uint32(len(f.Data()))); w > 0 { - t.controlBuf.put(&windowUpdate{0, w}) - } - } s.mu.Lock() if s.state == streamDone { s.mu.Unlock() @@ -486,6 +481,9 @@ func (t *http2Server) handleData(f *http2.DataFrame) { return } if f.Header().Flags.Has(http2.FlagDataPadded) { + if w := t.fc.onRead(uint32(size) - uint32(len(f.Data()))); w > 0 { + t.controlBuf.put(&windowUpdate{0, w}) + } if w := s.fc.onRead(uint32(size) - uint32(len(f.Data()))); w > 0 { t.controlBuf.put(&windowUpdate{s.id, w}) } From 600406e696bb27288ec8a4b501184f948f8bb222 Mon Sep 17 00:00:00 2001 From: Steeve Morin Date: Thu, 11 May 2017 01:40:25 +0200 Subject: [PATCH 02/11] Use pooled gzip.{Writer,Reader} in gzip{Compressor,Decompressor} (#1217) This change saves a lot of memory by reusing the underlying gzip.{Writer,Reader}, which allocates up to 1.4mb at every instanciation according to [1]. This was fixed by adding a Reset method by to the object at [2]. The amount of memory (and GC time) saved is pretty high, as reported by pprof: flat flat% sum% cum cum% 28.33GB 85.70% 85.70% 32.74GB 99.05% compress/flate.NewWriter flat flat% sum% cum cum% 19.39MB 16.74% 16.74% 22.07MB 19.05% compress/flate.NewWriter And the benchmarks: benchmark old ns/op new ns/op delta BenchmarkGZIPCompressor1B-4 215170 22291 -89.64% BenchmarkGZIPCompressor1KiB-4 225971 27213 -87.96% BenchmarkGZIPCompressor8KiB-4 246696 54785 -77.79% BenchmarkGZIPCompressor64KiB-4 444851 286924 -35.50% BenchmarkGZIPCompressor512KiB-4 2279043 2115863 -7.16% BenchmarkGZIPCompressor1MiB-4 4412989 4258635 -3.50% benchmark old allocs new allocs delta BenchmarkGZIPCompressor1B-4 17 0 -100.00% BenchmarkGZIPCompressor1KiB-4 17 0 -100.00% BenchmarkGZIPCompressor8KiB-4 17 0 -100.00% BenchmarkGZIPCompressor64KiB-4 17 0 -100.00% BenchmarkGZIPCompressor512KiB-4 17 0 -100.00% BenchmarkGZIPCompressor1MiB-4 17 0 -100.00% benchmark old bytes new bytes delta BenchmarkGZIPCompressor1B-4 813872 8 -100.00% BenchmarkGZIPCompressor1KiB-4 813872 16 -100.00% BenchmarkGZIPCompressor8KiB-4 813875 27 -100.00% BenchmarkGZIPCompressor64KiB-4 813918 190 -99.98% BenchmarkGZIPCompressor512KiB-4 814928 1871 -99.77% BenchmarkGZIPCompressor1MiB-4 820889 9735 -98.81% [1] https://github.com/golang/go/issues/6138 [2] https://github.com/golang/go/commit/db12f9d4e406dcab81b476e955c8e119112522fa Signed-off-by: Steeve Morin --- rpc_util.go | 43 ++++++++++++++++++++++++++++++++++--------- rpc_util_test.go | 39 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 72 insertions(+), 10 deletions(-) diff --git a/rpc_util.go b/rpc_util.go index bd8379c2..31a87325 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -41,6 +41,7 @@ import ( "io/ioutil" "math" "os" + "sync" "time" "golang.org/x/net/context" @@ -60,16 +61,24 @@ type Compressor interface { Type() string } -// NewGZIPCompressor creates a Compressor based on GZIP. -func NewGZIPCompressor() Compressor { - return &gzipCompressor{} +type gzipCompressor struct { + pool sync.Pool } -type gzipCompressor struct { +// NewGZIPCompressor creates a Compressor based on GZIP. +func NewGZIPCompressor() Compressor { + return &gzipCompressor{ + pool: sync.Pool{ + New: func() interface{} { + return gzip.NewWriter(ioutil.Discard) + }, + }, + } } func (c *gzipCompressor) Do(w io.Writer, p []byte) error { - z := gzip.NewWriter(w) + z := c.pool.Get().(*gzip.Writer) + z.Reset(w) if _, err := z.Write(p); err != nil { return err } @@ -89,6 +98,7 @@ type Decompressor interface { } type gzipDecompressor struct { + pool sync.Pool } // NewGZIPDecompressor creates a Decompressor based on GZIP. @@ -97,11 +107,26 @@ func NewGZIPDecompressor() Decompressor { } func (d *gzipDecompressor) Do(r io.Reader) ([]byte, error) { - z, err := gzip.NewReader(r) - if err != nil { - return nil, err + var z *gzip.Reader + switch maybeZ := d.pool.Get().(type) { + case nil: + newZ, err := gzip.NewReader(r) + if err != nil { + return nil, err + } + z = newZ + case *gzip.Reader: + z = maybeZ + if err := z.Reset(r); err != nil { + d.pool.Put(z) + return nil, err + } } - defer z.Close() + + defer func() { + z.Close() + d.pool.Put(z) + }() return ioutil.ReadAll(z) } diff --git a/rpc_util_test.go b/rpc_util_test.go index 8c92b963..b2b85c73 100644 --- a/rpc_util_test.go +++ b/rpc_util_test.go @@ -130,7 +130,7 @@ func TestCompress(t *testing.T) { // outputs err error }{ - {make([]byte, 1024), &gzipCompressor{}, &gzipDecompressor{}, nil}, + {make([]byte, 1024), NewGZIPCompressor(), NewGZIPDecompressor(), nil}, } { b := new(bytes.Buffer) if err := test.cp.Do(b, test.data); err != test.err { @@ -202,3 +202,40 @@ func BenchmarkEncode512KiB(b *testing.B) { func BenchmarkEncode1MiB(b *testing.B) { bmEncode(b, 1024*1024) } + +// bmCompressor benchmarks a compressor of a Protocol Buffer message containing +// mSize bytes. +func bmCompressor(b *testing.B, mSize int, cp Compressor) { + payload := make([]byte, mSize) + cBuf := bytes.NewBuffer(make([]byte, mSize)) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + cp.Do(cBuf, payload) + cBuf.Reset() + } +} + +func BenchmarkGZIPCompressor1B(b *testing.B) { + bmCompressor(b, 1, NewGZIPCompressor()) +} + +func BenchmarkGZIPCompressor1KiB(b *testing.B) { + bmCompressor(b, 1024, NewGZIPCompressor()) +} + +func BenchmarkGZIPCompressor8KiB(b *testing.B) { + bmCompressor(b, 8*1024, NewGZIPCompressor()) +} + +func BenchmarkGZIPCompressor64KiB(b *testing.B) { + bmCompressor(b, 64*1024, NewGZIPCompressor()) +} + +func BenchmarkGZIPCompressor512KiB(b *testing.B) { + bmCompressor(b, 512*1024, NewGZIPCompressor()) +} + +func BenchmarkGZIPCompressor1MiB(b *testing.B) { + bmCompressor(b, 1024*1024, NewGZIPCompressor()) +} From 17760cfd5b60decdc7ca12e5ae74fcdffe9ef0d2 Mon Sep 17 00:00:00 2001 From: Menghan Li Date: Wed, 10 May 2017 17:54:49 -0700 Subject: [PATCH 03/11] Calling handleRPC with context derived from the original (#1227) * Calling handleRPC with different context derived from the original context * change comment for tagRPC and stats fields --- call.go | 2 +- stats/handlers.go | 7 +++++-- stats/stats.go | 4 ++-- stats/stats_test.go | 9 +++++---- stream.go | 2 +- transport/http2_client.go | 8 +++----- transport/transport.go | 5 ----- 7 files changed, 17 insertions(+), 20 deletions(-) diff --git a/call.go b/call.go index b3937312..0eb5f5cf 100644 --- a/call.go +++ b/call.go @@ -182,7 +182,7 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli ctx = newContextWithRPCInfo(ctx) sh := cc.dopts.copts.StatsHandler if sh != nil { - ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method}) + ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method, FailFast: c.failFast}) begin := &stats.Begin{ Client: true, BeginTime: time.Now(), diff --git a/stats/handlers.go b/stats/handlers.go index 26e1a8e2..5fdce2f5 100644 --- a/stats/handlers.go +++ b/stats/handlers.go @@ -45,19 +45,22 @@ type ConnTagInfo struct { RemoteAddr net.Addr // LocalAddr is the local address of the corresponding connection. LocalAddr net.Addr - // TODO add QOS related fields. } // RPCTagInfo defines the relevant information needed by RPC context tagger. type RPCTagInfo struct { // FullMethodName is the RPC method in the format of /package.service/method. FullMethodName string + // FailFast indicates if this RPC is failfast. + // This field is only valid on client side, it's always false on server side. + FailFast bool } // Handler defines the interface for the related stats handling (e.g., RPCs, connections). type Handler interface { // TagRPC can attach some information to the given context. - // The returned context is used in the rest lifetime of the RPC. + // The context used for the rest lifetime of the RPC will be derived from + // the returned context. TagRPC(context.Context, *RPCTagInfo) context.Context // HandleRPC processes the RPC stats. HandleRPC(context.Context, RPCStats) diff --git a/stats/stats.go b/stats/stats.go index 75bdd816..6c406c7b 100644 --- a/stats/stats.go +++ b/stats/stats.go @@ -86,13 +86,13 @@ func (s *InPayload) IsClient() bool { return s.Client } func (s *InPayload) isRPCStats() {} // InHeader contains stats when a header is received. -// FullMethod, addresses and Compression are only valid if Client is false. type InHeader struct { // Client is true if this InHeader is from client side. Client bool // WireLength is the wire length of header. WireLength int + // The following fields are valid only if Client is false. // FullMethod is the full RPC method string, i.e., /package.service/method. FullMethod string // RemoteAddr is the remote address of the corresponding connection. @@ -143,13 +143,13 @@ func (s *OutPayload) IsClient() bool { return s.Client } func (s *OutPayload) isRPCStats() {} // OutHeader contains stats when a header is sent. -// FullMethod, addresses and Compression are only valid if Client is true. type OutHeader struct { // Client is true if this OutHeader is from client side. Client bool // WireLength is the wire length of header. WireLength int + // The following fields are valid only if Client is true. // FullMethod is the full RPC method string, i.e., /package.service/method. FullMethod string // RemoteAddr is the remote address of the corresponding connection. diff --git a/stats/stats_test.go b/stats/stats_test.go index c770c151..35d60a45 100644 --- a/stats/stats_test.go +++ b/stats/stats_test.go @@ -800,13 +800,14 @@ func checkClientStats(t *testing.T, got []*gotData, expect *expectedData, checkF t.Fatalf("got %v stats, want %v stats", len(got), expectLen) } - var rpcctx context.Context + var tagInfoInCtx *stats.RPCTagInfo for i := 0; i < len(got); i++ { if _, ok := got[i].s.(stats.RPCStats); ok { - if rpcctx != nil && got[i].ctx != rpcctx { - t.Fatalf("got different contexts with stats %T", got[i].s) + tagInfoInCtxNew, _ := got[i].ctx.Value(rpcCtxKey{}).(*stats.RPCTagInfo) + if tagInfoInCtx != nil && tagInfoInCtx != tagInfoInCtxNew { + t.Fatalf("got context containing different tagInfo with stats %T", got[i].s) } - rpcctx = got[i].ctx + tagInfoInCtx = tagInfoInCtxNew } } diff --git a/stream.go b/stream.go index 42049679..25764330 100644 --- a/stream.go +++ b/stream.go @@ -154,7 +154,7 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth ctx = newContextWithRPCInfo(ctx) sh := cc.dopts.copts.StatsHandler if sh != nil { - ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method}) + ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method, FailFast: c.failFast}) begin := &stats.Begin{ Client: true, BeginTime: time.Now(), diff --git a/transport/http2_client.go b/transport/http2_client.go index 736a4b35..80583ab7 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -334,7 +334,6 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea if t.authInfo != nil { pr.AuthInfo = t.authInfo } - userCtx := ctx ctx = peer.NewContext(ctx, pr) authData := make(map[string]string) for _, c := range t.creds { @@ -401,7 +400,6 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea return nil, ErrConnClosing } s := t.newStream(ctx, callHdr) - s.clientStatsCtx = userCtx t.activeStreams[s.id] = s // If the number of active streams change from 0 to 1, then check if keepalive // has gone dormant. If so, wake it up. @@ -514,7 +512,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea LocalAddr: t.localAddr, Compression: callHdr.SendCompress, } - t.statsHandler.HandleRPC(s.clientStatsCtx, outHeader) + t.statsHandler.HandleRPC(s.ctx, outHeader) } t.writableChan <- 0 return s, nil @@ -993,13 +991,13 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { Client: true, WireLength: int(frame.Header().Length), } - t.statsHandler.HandleRPC(s.clientStatsCtx, inHeader) + t.statsHandler.HandleRPC(s.ctx, inHeader) } else { inTrailer := &stats.InTrailer{ Client: true, WireLength: int(frame.Header().Length), } - t.statsHandler.HandleRPC(s.clientStatsCtx, inTrailer) + t.statsHandler.HandleRPC(s.ctx, inTrailer) } } }() diff --git a/transport/transport.go b/transport/transport.go index c22333cf..4bd4dc44 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -171,11 +171,6 @@ type Stream struct { id uint32 // nil for client side Stream. st ServerTransport - // clientStatsCtx keeps the user context for stats handling. - // It's only valid on client side. Server side stats context is same as s.ctx. - // All client side stats collection should use the clientStatsCtx (instead of the stream context) - // so that all the generated stats for a particular RPC can be associated in the processing phase. - clientStatsCtx context.Context // ctx is the associated context of the stream. ctx context.Context // cancel is always nil for client side Stream. From 3dd14ccc71c1ebd2df4fa4588da252fdbe7bf9bd Mon Sep 17 00:00:00 2001 From: MakMukhi Date: Thu, 11 May 2017 09:40:46 -0700 Subject: [PATCH 04/11] Http status to grpc status conversion (#1195) --- transport/http2_client.go | 20 ++-- transport/http_util.go | 70 ++++++++++++- transport/transport_test.go | 191 ++++++++++++++++++++++++++++++++++++ 3 files changed, 267 insertions(+), 14 deletions(-) diff --git a/transport/http2_client.go b/transport/http2_client.go index 80583ab7..3e5ff731 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -968,18 +968,16 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { } s.bytesReceived = true var state decodeState - for _, hf := range frame.Fields { - if err := state.processHeaderField(hf); err != nil { - s.mu.Lock() - if !s.headerDone { - close(s.headerChan) - s.headerDone = true - } - s.mu.Unlock() - s.write(recvMsg{err: err}) - // Something wrong. Stops reading even when there is remaining. - return + if err := state.decodeResponseHeader(frame); err != nil { + s.mu.Lock() + if !s.headerDone { + close(s.headerChan) + s.headerDone = true } + s.mu.Unlock() + s.write(recvMsg{err: err}) + // Something wrong. Stops reading even when there is remaining. + return } endStream := frame.StreamEnded() diff --git a/transport/http_util.go b/transport/http_util.go index 795d5d18..9b31717c 100644 --- a/transport/http_util.go +++ b/transport/http_util.go @@ -40,6 +40,7 @@ import ( "fmt" "io" "net" + "net/http" "strconv" "strings" "sync/atomic" @@ -88,6 +89,24 @@ var ( codes.ResourceExhausted: http2.ErrCodeEnhanceYourCalm, codes.PermissionDenied: http2.ErrCodeInadequateSecurity, } + httpStatusConvTab = map[int]codes.Code{ + // 400 Bad Request - INTERNAL. + http.StatusBadRequest: codes.Internal, + // 401 Unauthorized - UNAUTHENTICATED. + http.StatusUnauthorized: codes.Unauthenticated, + // 403 Forbidden - PERMISSION_DENIED. + http.StatusForbidden: codes.PermissionDenied, + // 404 Not Found - UNIMPLEMENTED. + http.StatusNotFound: codes.Unimplemented, + // 429 Too Many Requests - UNAVAILABLE. + http.StatusTooManyRequests: codes.Unavailable, + // 502 Bad Gateway - UNAVAILABLE. + http.StatusBadGateway: codes.Unavailable, + // 503 Service Unavailable - UNAVAILABLE. + http.StatusServiceUnavailable: codes.Unavailable, + // 504 Gateway timeout - UNAVAILABLE. + http.StatusGatewayTimeout: codes.Unavailable, + } ) // Records the states during HPACK decoding. Must be reset once the @@ -100,8 +119,9 @@ type decodeState struct { statusGen *status.Status // rawStatusCode and rawStatusMsg are set from the raw trailer fields and are not // intended for direct access outside of parsing. - rawStatusCode int32 + rawStatusCode *int rawStatusMsg string + httpStatus *int // Server side only fields. timeoutSet bool timeout time.Duration @@ -159,7 +179,7 @@ func validContentType(t string) bool { func (d *decodeState) status() *status.Status { if d.statusGen == nil { // No status-details were provided; generate status using code/msg. - d.statusGen = status.New(codes.Code(d.rawStatusCode), d.rawStatusMsg) + d.statusGen = status.New(codes.Code(int32(*(d.rawStatusCode))), d.rawStatusMsg) } return d.statusGen } @@ -193,6 +213,44 @@ func decodeMetadataHeader(k, v string) (string, error) { return v, nil } +func (d *decodeState) decodeResponseHeader(frame *http2.MetaHeadersFrame) error { + for _, hf := range frame.Fields { + if err := d.processHeaderField(hf); err != nil { + return err + } + } + + // If grpc status exists, no need to check further. + if d.rawStatusCode != nil || d.statusGen != nil { + return nil + } + + // If grpc status doesn't exist and http status doesn't exist, + // then it's a malformed header. + if d.httpStatus == nil { + return streamErrorf(codes.Internal, "malformed header: doesn't contain status(gRPC or HTTP)") + } + + if *(d.httpStatus) != http.StatusOK { + code, ok := httpStatusConvTab[*(d.httpStatus)] + if !ok { + code = codes.Unknown + } + return streamErrorf(code, http.StatusText(*(d.httpStatus))) + } + + // gRPC status doesn't exist and http status is OK. + // Set rawStatusCode to be unknown and return nil error. + // So that, if the stream has ended this Unknown status + // will be propogated to the user. + // Otherwise, it will be ignored. In which case, status from + // a later trailer, that has StreamEnded flag set, is propogated. + code := int(codes.Unknown) + d.rawStatusCode = &code + return nil + +} + func (d *decodeState) processHeaderField(f hpack.HeaderField) error { switch f.Name { case "content-type": @@ -206,7 +264,7 @@ func (d *decodeState) processHeaderField(f hpack.HeaderField) error { if err != nil { return streamErrorf(codes.Internal, "transport: malformed grpc-status: %v", err) } - d.rawStatusCode = int32(code) + d.rawStatusCode = &code case "grpc-message": d.rawStatusMsg = decodeGrpcMessage(f.Value) case "grpc-status-details-bin": @@ -227,6 +285,12 @@ func (d *decodeState) processHeaderField(f hpack.HeaderField) error { } case ":path": d.method = f.Value + case ":status": + code, err := strconv.Atoi(f.Value) + if err != nil { + return streamErrorf(codes.Internal, "transport: malformed http-status: %v", err) + } + d.httpStatus = &code default: if !isReservedHeader(f.Name) || isWhitelistedPseudoHeader(f.Name) { if d.mdata == nil { diff --git a/transport/transport_test.go b/transport/transport_test.go index 7429f2e2..0b534d2e 100644 --- a/transport/transport_test.go +++ b/transport/transport_test.go @@ -34,11 +34,13 @@ package transport import ( + "bufio" "bytes" "fmt" "io" "math" "net" + "net/http" "reflect" "strconv" "strings" @@ -1416,3 +1418,192 @@ func waitWhileTrue(t *testing.T, condition func() (bool, error)) { break } } + +// A function of type writeHeaders writes out +// http status with the given stream ID using the given framer. +type writeHeaders func(*http2.Framer, uint32, int) error + +func writeOneHeader(framer *http2.Framer, sid uint32, httpStatus int) error { + var buf bytes.Buffer + henc := hpack.NewEncoder(&buf) + henc.WriteField(hpack.HeaderField{Name: ":status", Value: fmt.Sprint(httpStatus)}) + if err := framer.WriteHeaders(http2.HeadersFrameParam{ + StreamID: sid, + BlockFragment: buf.Bytes(), + EndStream: true, + EndHeaders: true, + }); err != nil { + return err + } + return nil +} + +func writeTwoHeaders(framer *http2.Framer, sid uint32, httpStatus int) error { + var buf bytes.Buffer + henc := hpack.NewEncoder(&buf) + henc.WriteField(hpack.HeaderField{ + Name: ":status", + Value: fmt.Sprint(http.StatusOK), + }) + if err := framer.WriteHeaders(http2.HeadersFrameParam{ + StreamID: sid, + BlockFragment: buf.Bytes(), + EndHeaders: true, + }); err != nil { + return err + } + buf.Reset() + henc.WriteField(hpack.HeaderField{ + Name: ":status", + Value: fmt.Sprint(httpStatus), + }) + if err := framer.WriteHeaders(http2.HeadersFrameParam{ + StreamID: sid, + BlockFragment: buf.Bytes(), + EndStream: true, + EndHeaders: true, + }); err != nil { + return err + } + return nil +} + +type httpServer struct { + conn net.Conn + httpStatus int + wh writeHeaders +} + +func (s *httpServer) start(t *testing.T, lis net.Listener) { + // Launch an HTTP server to send back header with httpStatus. + go func() { + var err error + s.conn, err = lis.Accept() + if err != nil { + t.Errorf("Error accepting connection: %v", err) + return + } + defer s.conn.Close() + // Read preface sent by client. + if _, err = io.ReadFull(s.conn, make([]byte, len(http2.ClientPreface))); err != nil { + t.Errorf("Error at server-side while reading preface from cleint. Err: %v", err) + return + } + reader := bufio.NewReaderSize(s.conn, http2IOBufSize) + writer := bufio.NewWriterSize(s.conn, http2IOBufSize) + framer := http2.NewFramer(writer, reader) + if err = framer.WriteSettingsAck(); err != nil { + t.Errorf("Error at server-side while sending Settings ack. Err: %v", err) + return + } + var sid uint32 + // Read frames until a header is received. + for { + frame, err := framer.ReadFrame() + if err != nil { + t.Errorf("Error at server-side while reading frame. Err: %v", err) + return + } + if hframe, ok := frame.(*http2.HeadersFrame); ok { + sid = hframe.Header().StreamID + break + } + } + if err = s.wh(framer, sid, s.httpStatus); err != nil { + t.Errorf("Error at server-side while writing headers. Err: %v", err) + return + } + writer.Flush() + }() +} + +func (s *httpServer) cleanUp() { + if s.conn != nil { + s.conn.Close() + } +} + +func setUpHTTPStatusTest(t *testing.T, httpStatus int, wh writeHeaders) (stream *Stream, cleanUp func()) { + var ( + err error + lis net.Listener + server *httpServer + client ClientTransport + ) + cleanUp = func() { + if lis != nil { + lis.Close() + } + if server != nil { + server.cleanUp() + } + if client != nil { + client.Close() + } + } + defer func() { + if err != nil { + cleanUp() + } + }() + lis, err = net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("Failed to listen. Err: %v", err) + } + server = &httpServer{ + httpStatus: httpStatus, + wh: wh, + } + server.start(t, lis) + client, err = newHTTP2Client(context.Background(), TargetInfo{Addr: lis.Addr().String()}, ConnectOptions{}) + if err != nil { + t.Fatalf("Error creating client. Err: %v", err) + } + stream, err = client.NewStream(context.Background(), &CallHdr{Method: "bogus/method", Flush: true}) + if err != nil { + t.Fatalf("Error creating stream at client-side. Err: %v", err) + } + return +} + +func TestHTTPToGRPCStatusMapping(t *testing.T) { + for k := range httpStatusConvTab { + testHTTPToGRPCStatusMapping(t, k, writeOneHeader) + } +} + +func testHTTPToGRPCStatusMapping(t *testing.T, httpStatus int, wh writeHeaders) { + stream, cleanUp := setUpHTTPStatusTest(t, httpStatus, wh) + defer cleanUp() + want := httpStatusConvTab[httpStatus] + _, err := stream.Read([]byte{}) + if err == nil { + t.Fatalf("Stream.Read(_) unexpectedly returned no error. Expected stream error with code %v", want) + } + serr, ok := err.(StreamError) + if !ok { + t.Fatalf("err.(Type) = %T, want StreamError", err) + } + if want != serr.Code { + t.Fatalf("Want error code: %v, got: %v", want, serr.Code) + } +} + +func TestHTTPStatusOKAndMissingGRPCStatus(t *testing.T) { + stream, cleanUp := setUpHTTPStatusTest(t, http.StatusOK, writeOneHeader) + defer cleanUp() + _, err := stream.Read([]byte{}) + if err != io.EOF { + t.Fatalf("stream.Read(_) = _, %v, want _, io.EOF", err) + } + want := codes.Unknown + stream.mu.Lock() + defer stream.mu.Unlock() + if stream.status.Code() != want { + t.Fatalf("Status code of stream: %v, want: %v", stream.status.Code(), want) + } +} + +func TestHTTPStatusNottOKAndMissingGRPCStatusInSecondHeader(t *testing.T) { + testHTTPToGRPCStatusMapping(t, http.StatusUnauthorized, writeTwoHeaders) +} From 07bd9434faf350ba0fbf5cd774a354ae7501cac8 Mon Sep 17 00:00:00 2001 From: Menghan Li Date: Thu, 11 May 2017 10:10:19 -0700 Subject: [PATCH 05/11] Pass custom dialer to balancer (#1205) * Pass custom dialer to balancer * add test for passing custom dialer * add grpclb package comment --- balancer.go | 5 +++++ clientconn.go | 1 + grpclb.go | 21 ++++++++++++------ grpclb/grpclb_test.go | 51 +++++++++++++++++++++++++++++-------------- 4 files changed, 55 insertions(+), 23 deletions(-) diff --git a/balancer.go b/balancer.go index 44db9317..9af4ee70 100644 --- a/balancer.go +++ b/balancer.go @@ -35,6 +35,7 @@ package grpc import ( "fmt" + "net" "sync" "golang.org/x/net/context" @@ -60,6 +61,10 @@ type BalancerConfig struct { // use to dial to a remote load balancer server. The Balancer implementations // can ignore this if it does not need to talk to another party securely. DialCreds credentials.TransportCredentials + // Dialer is the custom dialer the Balancer implementation can use to dial + // to a remote load balancer server. The Balancer implementations + // can ignore this if it doesn't need to talk to remote balancer. + Dialer func(context.Context, string) (net.Conn, error) } // BalancerGetOptions configures a Get call. diff --git a/clientconn.go b/clientconn.go index 41a52937..be511f93 100644 --- a/clientconn.go +++ b/clientconn.go @@ -398,6 +398,7 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * } config := BalancerConfig{ DialCreds: credsClone, + Dialer: cc.dopts.copts.Dialer, } if err := cc.dopts.balancer.Start(target, config); err != nil { waitC <- err diff --git a/grpclb.go b/grpclb.go index 896e26bf..b9d70738 100644 --- a/grpclb.go +++ b/grpclb.go @@ -490,20 +490,27 @@ func (b *balancer) Start(target string, config BalancerConfig) error { cc.Close() } // Talk to the remote load balancer to get the server list. - var err error - creds := config.DialCreds - ccError = make(chan struct{}) - if creds == nil { - cc, err = Dial(rb.addr, WithInsecure()) - } else { + var ( + err error + dopts []DialOption + ) + if creds := config.DialCreds; creds != nil { if rb.name != "" { if err := creds.OverrideServerName(rb.name); err != nil { grpclog.Printf("Failed to override the server name in the credentials: %v", err) continue } } - cc, err = Dial(rb.addr, WithTransportCredentials(creds)) + dopts = append(dopts, WithTransportCredentials(creds)) + } else { + dopts = append(dopts, WithInsecure()) } + if dialer := config.Dialer; dialer != nil { + // WithDialer takes a different type of function, so we instead use a special DialOption here. + dopts = append(dopts, func(o *dialOptions) { o.copts.Dialer = dialer }) + } + ccError = make(chan struct{}) + cc, err = Dial(rb.addr, dopts...) if err != nil { grpclog.Printf("Failed to setup a connection to the remote balancer %v: %v", rb.addr, err) close(ccError) diff --git a/grpclb/grpclb_test.go b/grpclb/grpclb_test.go index 29c89092..bc37785e 100644 --- a/grpclb/grpclb_test.go +++ b/grpclb/grpclb_test.go @@ -31,6 +31,7 @@ * */ +// Package grpclb is currently used only for grpclb testing. package grpclb import ( @@ -59,6 +60,11 @@ var ( lbsn = "bar.com" besn = "foo.com" lbToken = "iamatoken" + + // Resolver replaces 127.0.0.1 with fakeName in Next(). + // Dialer replaces fakeName with 127.0.0.1 when dialing. + // This will test that custom dialer is passed from Dial to grpclb. + fakeName = "fake.Name" ) type testWatcher struct { @@ -81,6 +87,9 @@ func (w *testWatcher) Next() (updates []*naming.Update, err error) { break } if u != nil { + // Resolver replaces 127.0.0.1 with fakeName in Next(). + // Custom dialer will replace fakeName with 127.0.0.1 when dialing. + u.Addr = strings.Replace(u.Addr, "127.0.0.1", fakeName, 1) updates = append(updates, u) } } @@ -171,6 +180,13 @@ func (c *serverNameCheckCreds) OverrideServerName(s string) error { return nil } +// fakeNameDialer replaces fakeName with 127.0.0.1 when dialing. +// This will test that custom dialer is passed from Dial to grpclb. +func fakeNameDialer(addr string, timeout time.Duration) (net.Conn, error) { + addr = strings.Replace(addr, fakeName, "127.0.0.1", 1) + return net.DialTimeout("tcp", addr, timeout) +} + type remoteBalancer struct { sls []*lbpb.ServerList intervals []time.Duration @@ -387,9 +403,9 @@ func TestGRPCLB(t *testing.T) { } ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - cc, err := grpc.DialContext(ctx, besn, grpc.WithBalancer(grpc.NewGRPCLBBalancer(&testNameResolver{ - addrs: []string{tss.lbAddr}, - })), grpc.WithBlock(), grpc.WithTransportCredentials(&creds)) + cc, err := grpc.DialContext(ctx, besn, + grpc.WithBalancer(grpc.NewGRPCLBBalancer(&testNameResolver{addrs: []string{tss.lbAddr}})), + grpc.WithBlock(), grpc.WithTransportCredentials(&creds), grpc.WithDialer(fakeNameDialer)) if err != nil { t.Fatalf("Failed to dial to the backend %v", err) } @@ -425,9 +441,9 @@ func TestDropRequest(t *testing.T) { } ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - cc, err := grpc.DialContext(ctx, besn, grpc.WithBalancer(grpc.NewGRPCLBBalancer(&testNameResolver{ - addrs: []string{tss.lbAddr}, - })), grpc.WithBlock(), grpc.WithTransportCredentials(&creds)) + cc, err := grpc.DialContext(ctx, besn, + grpc.WithBalancer(grpc.NewGRPCLBBalancer(&testNameResolver{addrs: []string{tss.lbAddr}})), + grpc.WithBlock(), grpc.WithTransportCredentials(&creds), grpc.WithDialer(fakeNameDialer)) if err != nil { t.Fatalf("Failed to dial to the backend %v", err) } @@ -476,9 +492,9 @@ func TestDropRequestFailedNonFailFast(t *testing.T) { } ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - cc, err := grpc.DialContext(ctx, besn, grpc.WithBalancer(grpc.NewGRPCLBBalancer(&testNameResolver{ - addrs: []string{tss.lbAddr}, - })), grpc.WithBlock(), grpc.WithTransportCredentials(&creds)) + cc, err := grpc.DialContext(ctx, besn, + grpc.WithBalancer(grpc.NewGRPCLBBalancer(&testNameResolver{addrs: []string{tss.lbAddr}})), + grpc.WithBlock(), grpc.WithTransportCredentials(&creds), grpc.WithDialer(fakeNameDialer)) if err != nil { t.Fatalf("Failed to dial to the backend %v", err) } @@ -528,9 +544,9 @@ func TestServerExpiration(t *testing.T) { } ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - cc, err := grpc.DialContext(ctx, besn, grpc.WithBalancer(grpc.NewGRPCLBBalancer(&testNameResolver{ - addrs: []string{tss.lbAddr}, - })), grpc.WithBlock(), grpc.WithTransportCredentials(&creds)) + cc, err := grpc.DialContext(ctx, besn, + grpc.WithBalancer(grpc.NewGRPCLBBalancer(&testNameResolver{addrs: []string{tss.lbAddr}})), + grpc.WithBlock(), grpc.WithTransportCredentials(&creds), grpc.WithDialer(fakeNameDialer)) if err != nil { t.Fatalf("Failed to dial to the backend %v", err) } @@ -589,7 +605,9 @@ func TestBalancerDisconnects(t *testing.T) { resolver := &testNameResolver{ addrs: lbAddrs[:2], } - cc, err := grpc.DialContext(ctx, besn, grpc.WithBalancer(grpc.NewGRPCLBBalancer(resolver)), grpc.WithBlock(), grpc.WithTransportCredentials(&creds)) + cc, err := grpc.DialContext(ctx, besn, + grpc.WithBalancer(grpc.NewGRPCLBBalancer(resolver)), + grpc.WithBlock(), grpc.WithTransportCredentials(&creds), grpc.WithDialer(fakeNameDialer)) if err != nil { t.Fatalf("Failed to dial to the backend %v", err) } @@ -681,9 +699,10 @@ func runAndGetStats(t *testing.T, dropForLoadBalancing, dropForRateLimiting bool ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - cc, err := grpc.DialContext(ctx, besn, grpc.WithBalancer(grpc.NewGRPCLBBalancer(&testNameResolver{ - addrs: []string{tss.lbAddr}, - })), grpc.WithBlock(), grpc.WithTransportCredentials(&creds), grpc.WithPerRPCCredentials(failPreRPCCred{})) + cc, err := grpc.DialContext(ctx, besn, + grpc.WithBalancer(grpc.NewGRPCLBBalancer(&testNameResolver{addrs: []string{tss.lbAddr}})), + grpc.WithTransportCredentials(&creds), grpc.WithPerRPCCredentials(failPreRPCCred{}), + grpc.WithBlock(), grpc.WithDialer(fakeNameDialer)) if err != nil { t.Fatalf("Failed to dial to the backend %v", err) } From 88a73d35c975fb7bf79071dbcdb84472e7e6669c Mon Sep 17 00:00:00 2001 From: MakMukhi Date: Thu, 11 May 2017 11:07:38 -0700 Subject: [PATCH 06/11] Adding dial options for PerRPCCredentials (#1225) * Adding dial options for PerRPCCredentials * Added tests for PerRPCCredentials * Post-review updates * post-review updates --- call.go | 3 + rpc_util.go | 11 ++++ stream.go | 3 + test/end2end_test.go | 121 ++++++++++++++++++++++++++++++++++++++ transport/http2_client.go | 52 +++++++++++++--- transport/transport.go | 3 + 6 files changed, 184 insertions(+), 9 deletions(-) diff --git a/call.go b/call.go index 0eb5f5cf..08438652 100644 --- a/call.go +++ b/call.go @@ -219,6 +219,9 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli if cc.dopts.cp != nil { callHdr.SendCompress = cc.dopts.cp.Type() } + if c.creds != nil { + callHdr.Creds = c.creds + } gopts := BalancerGetOptions{ BlockingWait: !c.failFast, diff --git a/rpc_util.go b/rpc_util.go index 31a87325..6a32afdf 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -46,6 +46,7 @@ import ( "golang.org/x/net/context" "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" "google.golang.org/grpc/stats" @@ -141,6 +142,7 @@ type callInfo struct { trailerMD metadata.MD peer *peer.Peer traceInfo traceInfo // in trace.go + creds credentials.PerRPCCredentials } var defaultCallInfo = callInfo{failFast: true} @@ -207,6 +209,15 @@ func FailFast(failFast bool) CallOption { }) } +// PerRPCCredentials returns a CallOption that sets credentials.PerRPCCredentials +// for a call. +func PerRPCCredentials(creds credentials.PerRPCCredentials) CallOption { + return beforeCall(func(c *callInfo) error { + c.creds = creds + return nil + }) +} + // The format of the payload: compressed or not? type payloadFormat uint8 diff --git a/stream.go b/stream.go index 25764330..ec534a01 100644 --- a/stream.go +++ b/stream.go @@ -132,6 +132,9 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth if cc.dopts.cp != nil { callHdr.SendCompress = cc.dopts.cp.Type() } + if c.creds != nil { + callHdr.Creds = c.creds + } var trInfo traceInfo if EnableTracing { trInfo.tr = trace.New("grpc.Sent."+methodFamily(method), method) diff --git a/test/end2end_test.go b/test/end2end_test.go index 01b3e4f7..ced25096 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -449,6 +449,7 @@ type test struct { serverInitialConnWindowSize int32 clientInitialWindowSize int32 clientInitialConnWindowSize int32 + perRPCCreds credentials.PerRPCCredentials // srv and srvAddr are set once startServer is called. srv *grpc.Server @@ -621,6 +622,9 @@ func (te *test) clientConn() *grpc.ClientConn { if te.clientInitialConnWindowSize > 0 { opts = append(opts, grpc.WithInitialConnWindowSize(te.clientInitialConnWindowSize)) } + if te.perRPCCreds != nil { + opts = append(opts, grpc.WithPerRPCCredentials(te.perRPCCreds)) + } var err error te.cc, err = grpc.Dial(te.srvAddr, opts...) if err != nil { @@ -3984,3 +3988,120 @@ func testConfigurableWindowSize(t *testing.T, e env, wc windowSizeConfig) { t.Fatalf("%v.CloseSend() = %v, want ", stream, err) } } + +var ( + // test authdata + authdata = map[string]string{ + "test-key": "test-value", + "test-key2-bin": string([]byte{1, 2, 3}), + } +) + +type testPerRPCCredentials struct{} + +func (cr testPerRPCCredentials) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { + return authdata, nil +} + +func (cr testPerRPCCredentials) RequireTransportSecurity() bool { + return false +} + +func authHandle(ctx context.Context, info *tap.Info) (context.Context, error) { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return ctx, fmt.Errorf("didn't find metadata in context") + } + for k, vwant := range authdata { + vgot, ok := md[k] + if !ok { + return ctx, fmt.Errorf("didn't find authdata key %v in context", k) + } + if vgot[0] != vwant { + return ctx, fmt.Errorf("for key %v, got value %v, want %v", k, vgot, vwant) + } + } + return ctx, nil +} + +func TestPerRPCCredentialsViaDialOptions(t *testing.T) { + defer leakCheck(t)() + for _, e := range listTestEnv() { + testPerRPCCredentialsViaDialOptions(t, e) + } +} + +func testPerRPCCredentialsViaDialOptions(t *testing.T, e env) { + te := newTest(t, e) + te.tapHandle = authHandle + te.perRPCCreds = testPerRPCCredentials{} + te.startServer(&testServer{security: e.security}) + defer te.tearDown() + + cc := te.clientConn() + tc := testpb.NewTestServiceClient(cc) + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); err != nil { + t.Fatalf("Test failed. Reason: %v", err) + } +} + +func TestPerRPCCredentialsViaCallOptions(t *testing.T) { + defer leakCheck(t)() + for _, e := range listTestEnv() { + testPerRPCCredentialsViaCallOptions(t, e) + } +} + +func testPerRPCCredentialsViaCallOptions(t *testing.T, e env) { + te := newTest(t, e) + te.tapHandle = authHandle + te.startServer(&testServer{security: e.security}) + defer te.tearDown() + + cc := te.clientConn() + tc := testpb.NewTestServiceClient(cc) + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.PerRPCCredentials(testPerRPCCredentials{})); err != nil { + t.Fatalf("Test failed. Reason: %v", err) + } +} + +func TestPerRPCCredentialsViaDialOptionsAndCallOptions(t *testing.T) { + defer leakCheck(t)() + for _, e := range listTestEnv() { + testPerRPCCredentialsViaDialOptionsAndCallOptions(t, e) + } +} + +func testPerRPCCredentialsViaDialOptionsAndCallOptions(t *testing.T, e env) { + te := newTest(t, e) + te.perRPCCreds = testPerRPCCredentials{} + // When credentials are provided via both dial options and call options, + // we apply both sets. + te.tapHandle = func(ctx context.Context, _ *tap.Info) (context.Context, error) { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return ctx, fmt.Errorf("couldn't find metadata in context") + } + for k, vwant := range authdata { + vgot, ok := md[k] + if !ok { + return ctx, fmt.Errorf("couldn't find metadata for key %v", k) + } + if len(vgot) != 2 { + return ctx, fmt.Errorf("len of value for key %v was %v, want 2", k, len(vgot)) + } + if vgot[0] != vwant || vgot[1] != vwant { + return ctx, fmt.Errorf("value for %v was %v, want [%v, %v]", k, vgot, vwant, vwant) + } + } + return ctx, nil + } + te.startServer(&testServer{security: e.security}) + defer te.tearDown() + + cc := te.clientConn() + tc := testpb.NewTestServiceClient(cc) + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.PerRPCCredentials(testPerRPCCredentials{})); err != nil { + t.Fatalf("Test failed. Reason: %v", err) + } +} diff --git a/transport/http2_client.go b/transport/http2_client.go index 3e5ff731..7db73d38 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -101,6 +101,8 @@ type http2Client struct { // The scheme used: https if TLS is on, http otherwise. scheme string + isSecure bool + creds []credentials.PerRPCCredentials // Boolean to keep track of reading activity on transport. @@ -181,7 +183,10 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) ( conn.Close() } }(conn) - var authInfo credentials.AuthInfo + var ( + isSecure bool + authInfo credentials.AuthInfo + ) if creds := opts.TransportCredentials; creds != nil { scheme = "https" conn, authInfo, err = creds.ClientHandshake(ctx, addr.Addr, conn) @@ -191,6 +196,7 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) ( temp := isTemporary(err) return nil, connectionErrorf(temp, err, "transport: %v", err) } + isSecure = true } kp := opts.KeepaliveParams // Validate keepalive parameters. @@ -230,6 +236,7 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) ( scheme: scheme, state: reachable, activeStreams: make(map[uint32]*Stream), + isSecure: isSecure, creds: opts.PerRPCCredentials, maxStreams: defaultMaxStreamsClient, streamsQuota: newQuotaPool(defaultMaxStreamsClient), @@ -335,8 +342,12 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea pr.AuthInfo = t.authInfo } ctx = peer.NewContext(ctx, pr) - authData := make(map[string]string) - for _, c := range t.creds { + var ( + authData = make(map[string]string) + audience string + ) + // Create an audience string only if needed. + if len(t.creds) > 0 || callHdr.Creds != nil { // Construct URI required to get auth request metadata. var port string if pos := strings.LastIndex(t.target, ":"); pos != -1 { @@ -347,17 +358,39 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea } pos := strings.LastIndex(callHdr.Method, "/") if pos == -1 { - return nil, streamErrorf(codes.InvalidArgument, "transport: malformed method name: %q", callHdr.Method) + pos = len(callHdr.Method) } - audience := "https://" + callHdr.Host + port + callHdr.Method[:pos] + audience = "https://" + callHdr.Host + port + callHdr.Method[:pos] + } + for _, c := range t.creds { data, err := c.GetRequestMetadata(ctx, audience) if err != nil { - return nil, streamErrorf(codes.InvalidArgument, "transport: %v", err) + return nil, streamErrorf(codes.Internal, "transport: %v", err) } for k, v := range data { + // Capital header names are illegal in HTTP/2. + k = strings.ToLower(k) authData[k] = v } } + callAuthData := make(map[string]string) + // Check if credentials.PerRPCCredentials were provided via call options. + // Note: if these credentials are provided both via dial options and call + // options, then both sets of credentials will be applied. + if callCreds := callHdr.Creds; callCreds != nil { + if !t.isSecure && callCreds.RequireTransportSecurity() { + return nil, streamErrorf(codes.Unauthenticated, "transport: cannot send secure credentials on an insecure conneciton") + } + data, err := callCreds.GetRequestMetadata(ctx, audience) + if err != nil { + return nil, streamErrorf(codes.Internal, "transport: %v", err) + } + for k, v := range data { + // Capital header names are illegal in HTTP/2 + k = strings.ToLower(k) + callAuthData[k] = v + } + } t.mu.Lock() if t.activeStreams == nil { t.mu.Unlock() @@ -435,9 +468,10 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea } for k, v := range authData { - // Capital header names are illegal in HTTP/2. - k = strings.ToLower(k) - t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: v}) + t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)}) + } + for k, v := range callAuthData { + t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)}) } var ( hasMD bool diff --git a/transport/transport.go b/transport/transport.go index 4bd4dc44..2dff7c8b 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -469,6 +469,9 @@ type CallHdr struct { // outbound message. SendCompress string + // Creds specifies credentials.PerRPCCredentials for a call. + Creds credentials.PerRPCCredentials + // Flush indicates whether a new stream command should be sent // to the peer without waiting for the first data. This is // only a hint. The transport may modify the flush decision From a0c3e72252b6fbf4826bb143e450eb05588a9d6d Mon Sep 17 00:00:00 2001 From: Charith Ellawala Date: Thu, 11 May 2017 19:58:13 +0100 Subject: [PATCH 07/11] Add stats test for client streaming and server streaming RPCs (#1140) --- stats/grpc_testing/test.pb.go | 157 +++++++++++++++- stats/grpc_testing/test.proto | 6 + stats/stats_test.go | 339 ++++++++++++++++++++++++++++++---- 3 files changed, 455 insertions(+), 47 deletions(-) diff --git a/stats/grpc_testing/test.pb.go b/stats/grpc_testing/test.pb.go index b24dcd8d..5730004a 100644 --- a/stats/grpc_testing/test.pb.go +++ b/stats/grpc_testing/test.pb.go @@ -34,7 +34,6 @@ var _ = math.Inf // proto package needs to be updated. const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package -// Unary request. type SimpleRequest struct { Id int32 `protobuf:"varint,2,opt,name=id" json:"id,omitempty"` } @@ -44,7 +43,13 @@ func (m *SimpleRequest) String() string { return proto.CompactTextStr func (*SimpleRequest) ProtoMessage() {} func (*SimpleRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} } -// Unary response, as configured by the request. +func (m *SimpleRequest) GetId() int32 { + if m != nil { + return m.Id + } + return 0 +} + type SimpleResponse struct { Id int32 `protobuf:"varint,3,opt,name=id" json:"id,omitempty"` } @@ -54,6 +59,13 @@ func (m *SimpleResponse) String() string { return proto.CompactTextSt func (*SimpleResponse) ProtoMessage() {} func (*SimpleResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} } +func (m *SimpleResponse) GetId() int32 { + if m != nil { + return m.Id + } + return 0 +} + func init() { proto.RegisterType((*SimpleRequest)(nil), "grpc.testing.SimpleRequest") proto.RegisterType((*SimpleResponse)(nil), "grpc.testing.SimpleResponse") @@ -77,6 +89,10 @@ type TestServiceClient interface { // As one request could lead to multiple responses, this interface // demonstrates the idea of full duplexing. FullDuplexCall(ctx context.Context, opts ...grpc.CallOption) (TestService_FullDuplexCallClient, error) + // Client stream + ClientStreamCall(ctx context.Context, opts ...grpc.CallOption) (TestService_ClientStreamCallClient, error) + // Server stream + ServerStreamCall(ctx context.Context, in *SimpleRequest, opts ...grpc.CallOption) (TestService_ServerStreamCallClient, error) } type testServiceClient struct { @@ -127,6 +143,72 @@ func (x *testServiceFullDuplexCallClient) Recv() (*SimpleResponse, error) { return m, nil } +func (c *testServiceClient) ClientStreamCall(ctx context.Context, opts ...grpc.CallOption) (TestService_ClientStreamCallClient, error) { + stream, err := grpc.NewClientStream(ctx, &_TestService_serviceDesc.Streams[1], c.cc, "/grpc.testing.TestService/ClientStreamCall", opts...) + if err != nil { + return nil, err + } + x := &testServiceClientStreamCallClient{stream} + return x, nil +} + +type TestService_ClientStreamCallClient interface { + Send(*SimpleRequest) error + CloseAndRecv() (*SimpleResponse, error) + grpc.ClientStream +} + +type testServiceClientStreamCallClient struct { + grpc.ClientStream +} + +func (x *testServiceClientStreamCallClient) Send(m *SimpleRequest) error { + return x.ClientStream.SendMsg(m) +} + +func (x *testServiceClientStreamCallClient) CloseAndRecv() (*SimpleResponse, error) { + if err := x.ClientStream.CloseSend(); err != nil { + return nil, err + } + m := new(SimpleResponse) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +func (c *testServiceClient) ServerStreamCall(ctx context.Context, in *SimpleRequest, opts ...grpc.CallOption) (TestService_ServerStreamCallClient, error) { + stream, err := grpc.NewClientStream(ctx, &_TestService_serviceDesc.Streams[2], c.cc, "/grpc.testing.TestService/ServerStreamCall", opts...) + if err != nil { + return nil, err + } + x := &testServiceServerStreamCallClient{stream} + if err := x.ClientStream.SendMsg(in); err != nil { + return nil, err + } + if err := x.ClientStream.CloseSend(); err != nil { + return nil, err + } + return x, nil +} + +type TestService_ServerStreamCallClient interface { + Recv() (*SimpleResponse, error) + grpc.ClientStream +} + +type testServiceServerStreamCallClient struct { + grpc.ClientStream +} + +func (x *testServiceServerStreamCallClient) Recv() (*SimpleResponse, error) { + m := new(SimpleResponse) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + // Server API for TestService service type TestServiceServer interface { @@ -137,6 +219,10 @@ type TestServiceServer interface { // As one request could lead to multiple responses, this interface // demonstrates the idea of full duplexing. FullDuplexCall(TestService_FullDuplexCallServer) error + // Client stream + ClientStreamCall(TestService_ClientStreamCallServer) error + // Server stream + ServerStreamCall(*SimpleRequest, TestService_ServerStreamCallServer) error } func RegisterTestServiceServer(s *grpc.Server, srv TestServiceServer) { @@ -187,6 +273,53 @@ func (x *testServiceFullDuplexCallServer) Recv() (*SimpleRequest, error) { return m, nil } +func _TestService_ClientStreamCall_Handler(srv interface{}, stream grpc.ServerStream) error { + return srv.(TestServiceServer).ClientStreamCall(&testServiceClientStreamCallServer{stream}) +} + +type TestService_ClientStreamCallServer interface { + SendAndClose(*SimpleResponse) error + Recv() (*SimpleRequest, error) + grpc.ServerStream +} + +type testServiceClientStreamCallServer struct { + grpc.ServerStream +} + +func (x *testServiceClientStreamCallServer) SendAndClose(m *SimpleResponse) error { + return x.ServerStream.SendMsg(m) +} + +func (x *testServiceClientStreamCallServer) Recv() (*SimpleRequest, error) { + m := new(SimpleRequest) + if err := x.ServerStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +func _TestService_ServerStreamCall_Handler(srv interface{}, stream grpc.ServerStream) error { + m := new(SimpleRequest) + if err := stream.RecvMsg(m); err != nil { + return err + } + return srv.(TestServiceServer).ServerStreamCall(m, &testServiceServerStreamCallServer{stream}) +} + +type TestService_ServerStreamCallServer interface { + Send(*SimpleResponse) error + grpc.ServerStream +} + +type testServiceServerStreamCallServer struct { + grpc.ServerStream +} + +func (x *testServiceServerStreamCallServer) Send(m *SimpleResponse) error { + return x.ServerStream.SendMsg(m) +} + var _TestService_serviceDesc = grpc.ServiceDesc{ ServiceName: "grpc.testing.TestService", HandlerType: (*TestServiceServer)(nil), @@ -203,6 +336,16 @@ var _TestService_serviceDesc = grpc.ServiceDesc{ ServerStreams: true, ClientStreams: true, }, + { + StreamName: "ClientStreamCall", + Handler: _TestService_ClientStreamCall_Handler, + ClientStreams: true, + }, + { + StreamName: "ServerStreamCall", + Handler: _TestService_ServerStreamCall_Handler, + ServerStreams: true, + }, }, Metadata: "test.proto", } @@ -210,16 +353,18 @@ var _TestService_serviceDesc = grpc.ServiceDesc{ func init() { proto.RegisterFile("test.proto", fileDescriptor0) } var fileDescriptor0 = []byte{ - // 167 bytes of a gzipped FileDescriptorProto + // 196 bytes of a gzipped FileDescriptorProto 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xe2, 0x2a, 0x49, 0x2d, 0x2e, 0xd1, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0xe2, 0x49, 0x2f, 0x2a, 0x48, 0xd6, 0x03, 0x09, 0x64, 0xe6, 0xa5, 0x2b, 0xc9, 0x73, 0xf1, 0x06, 0x67, 0xe6, 0x16, 0xe4, 0xa4, 0x06, 0xa5, 0x16, 0x96, 0xa6, 0x16, 0x97, 0x08, 0xf1, 0x71, 0x31, 0x65, 0xa6, 0x48, 0x30, 0x29, 0x30, 0x6a, 0xb0, 0x06, 0x31, 0x65, 0xa6, 0x28, 0x29, 0x70, 0xf1, 0xc1, 0x14, 0x14, 0x17, 0xe4, 0xe7, 0x15, 0xa7, 0x42, - 0x55, 0x30, 0xc3, 0x54, 0x18, 0x2d, 0x63, 0xe4, 0xe2, 0x0e, 0x49, 0x2d, 0x2e, 0x09, 0x4e, 0x2d, + 0x55, 0x30, 0xc3, 0x54, 0x18, 0x9d, 0x60, 0xe2, 0xe2, 0x0e, 0x49, 0x2d, 0x2e, 0x09, 0x4e, 0x2d, 0x2a, 0xcb, 0x4c, 0x4e, 0x15, 0x72, 0xe3, 0xe2, 0x0c, 0xcd, 0x4b, 0x2c, 0xaa, 0x74, 0x4e, 0xcc, 0xc9, 0x11, 0x92, 0xd6, 0x43, 0xb6, 0x4e, 0x0f, 0xc5, 0x2e, 0x29, 0x19, 0xec, 0x92, 0x50, 0x7b, 0xfc, 0xb9, 0xf8, 0xdc, 0x4a, 0x73, 0x72, 0x5c, 0x4a, 0x0b, 0x72, 0x52, 0x2b, 0x28, 0x34, 0x4c, - 0x83, 0xd1, 0x80, 0x31, 0x89, 0x0d, 0x1c, 0x00, 0xc6, 0x80, 0x00, 0x00, 0x00, 0xff, 0xff, 0x8d, - 0x82, 0x5b, 0xdd, 0x0e, 0x01, 0x00, 0x00, + 0x83, 0xd1, 0x80, 0x51, 0xc8, 0x9f, 0x4b, 0xc0, 0x39, 0x27, 0x33, 0x35, 0xaf, 0x24, 0xb8, 0xa4, + 0x28, 0x35, 0x31, 0x97, 0x62, 0x23, 0x41, 0x06, 0x82, 0x3c, 0x9d, 0x5a, 0x44, 0x15, 0x03, 0x0d, + 0x18, 0x93, 0xd8, 0xc0, 0x51, 0x64, 0x0c, 0x08, 0x00, 0x00, 0xff, 0xff, 0x61, 0x49, 0x59, 0xe6, + 0xb0, 0x01, 0x00, 0x00, } diff --git a/stats/grpc_testing/test.proto b/stats/grpc_testing/test.proto index 54e6f744..bea8c4c7 100644 --- a/stats/grpc_testing/test.proto +++ b/stats/grpc_testing/test.proto @@ -20,4 +20,10 @@ service TestService { // As one request could lead to multiple responses, this interface // demonstrates the idea of full duplexing. rpc FullDuplexCall(stream SimpleRequest) returns (stream SimpleResponse); + + // Client stream + rpc ClientStreamCall(stream SimpleRequest) returns (SimpleResponse); + + // Server stream + rpc ServerStreamCall(SimpleRequest) returns (stream SimpleResponse); } diff --git a/stats/stats_test.go b/stats/stats_test.go index 35d60a45..467d6a58 100644 --- a/stats/stats_test.go +++ b/stats/stats_test.go @@ -120,6 +120,51 @@ func (s *testServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServ } } +func (s *testServer) ClientStreamCall(stream testpb.TestService_ClientStreamCallServer) error { + md, ok := metadata.FromContext(stream.Context()) + if ok { + if err := stream.SendHeader(md); err != nil { + return grpc.Errorf(grpc.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, md, err, nil) + } + stream.SetTrailer(testTrailerMetadata) + } + for { + in, err := stream.Recv() + if err == io.EOF { + // read done. + return stream.SendAndClose(&testpb.SimpleResponse{Id: int32(0)}) + } + if err != nil { + return err + } + + if in.Id == errorID { + return fmt.Errorf("got error id: %v", in.Id) + } + } +} + +func (s *testServer) ServerStreamCall(in *testpb.SimpleRequest, stream testpb.TestService_ServerStreamCallServer) error { + md, ok := metadata.FromContext(stream.Context()) + if ok { + if err := stream.SendHeader(md); err != nil { + return grpc.Errorf(grpc.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, md, err, nil) + } + stream.SetTrailer(testTrailerMetadata) + } + + if in.Id == errorID { + return fmt.Errorf("got error id: %v", in.Id) + } + + for i := 0; i < 5; i++ { + if err := stream.Send(&testpb.SimpleResponse{Id: in.Id}); err != nil { + return err + } + } + return nil +} + // test is an end-to-end test. It should be created with the newTest // func, modified as needed, and then started with its startServer method. // It should be cleaned up with the tearDown method. @@ -218,12 +263,21 @@ func (te *test) clientConn() *grpc.ClientConn { return te.cc } +type rpcType int + +const ( + unaryRPC rpcType = iota + clientStreamRPC + serverStreamRPC + fullDuplexStreamRPC +) + type rpcConfig struct { count int // Number of requests and responses for streaming RPCs. success bool // Whether the RPC should succeed or return error. failfast bool - streaming bool // Whether the rpc should be a streaming RPC. - noLastRecv bool // Whether to call recv for io.EOF. When true, last recv won't be called. Only valid for streaming RPCs. + callType rpcType // Type of RPC. + noLastRecv bool // Whether to call recv for io.EOF. When true, last recv won't be called. Only valid for streaming RPCs. } func (te *test) doUnaryCall(c *rpcConfig) (*testpb.SimpleRequest, *testpb.SimpleResponse, error) { @@ -289,6 +343,64 @@ func (te *test) doFullDuplexCallRoundtrip(c *rpcConfig) ([]*testpb.SimpleRequest return reqs, resps, nil } +func (te *test) doClientStreamCall(c *rpcConfig) ([]*testpb.SimpleRequest, *testpb.SimpleResponse, error) { + var ( + reqs []*testpb.SimpleRequest + resp *testpb.SimpleResponse + err error + ) + tc := testpb.NewTestServiceClient(te.clientConn()) + stream, err := tc.ClientStreamCall(metadata.NewContext(context.Background(), testMetadata), grpc.FailFast(c.failfast)) + if err != nil { + return reqs, resp, err + } + var startID int32 + if !c.success { + startID = errorID + } + for i := 0; i < c.count; i++ { + req := &testpb.SimpleRequest{ + Id: int32(i) + startID, + } + reqs = append(reqs, req) + if err = stream.Send(req); err != nil { + return reqs, resp, err + } + } + resp, err = stream.CloseAndRecv() + return reqs, resp, err +} + +func (te *test) doServerStreamCall(c *rpcConfig) (*testpb.SimpleRequest, []*testpb.SimpleResponse, error) { + var ( + req *testpb.SimpleRequest + resps []*testpb.SimpleResponse + err error + ) + + tc := testpb.NewTestServiceClient(te.clientConn()) + + var startID int32 + if !c.success { + startID = errorID + } + req = &testpb.SimpleRequest{Id: startID} + stream, err := tc.ServerStreamCall(metadata.NewContext(context.Background(), testMetadata), req, grpc.FailFast(c.failfast)) + if err != nil { + return req, resps, err + } + for { + var resp *testpb.SimpleResponse + resp, err := stream.Recv() + if err == io.EOF { + return req, resps, nil + } else if err != nil { + return req, resps, err + } + resps = append(resps, resp) + } +} + type expectedData struct { method string serverAddr string @@ -672,16 +784,35 @@ func testServerStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs []f defer te.tearDown() var ( - reqs []*testpb.SimpleRequest - resps []*testpb.SimpleResponse - err error + reqs []*testpb.SimpleRequest + resps []*testpb.SimpleResponse + err error + method string + + req *testpb.SimpleRequest + resp *testpb.SimpleResponse + e error ) - if !cc.streaming { - req, resp, e := te.doUnaryCall(cc) + + switch cc.callType { + case unaryRPC: + method = "/grpc.testing.TestService/UnaryCall" + req, resp, e = te.doUnaryCall(cc) reqs = []*testpb.SimpleRequest{req} resps = []*testpb.SimpleResponse{resp} err = e - } else { + case clientStreamRPC: + method = "/grpc.testing.TestService/ClientStreamCall" + reqs, resp, e = te.doClientStreamCall(cc) + resps = []*testpb.SimpleResponse{resp} + err = e + case serverStreamRPC: + method = "/grpc.testing.TestService/ServerStreamCall" + req, resps, e = te.doServerStreamCall(cc) + reqs = []*testpb.SimpleRequest{req} + err = e + case fullDuplexStreamRPC: + method = "/grpc.testing.TestService/FullDuplexCall" reqs, resps, err = te.doFullDuplexCallRoundtrip(cc) } if cc.success != (err == nil) { @@ -713,22 +844,18 @@ func testServerStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs []f expect := &expectedData{ serverAddr: te.srvAddr, compression: tc.compress, + method: method, requests: reqs, responses: resps, err: err, } - if !cc.streaming { - expect.method = "/grpc.testing.TestService/UnaryCall" - } else { - expect.method = "/grpc.testing.TestService/FullDuplexCall" - } checkConnStats(t, h.gotConn) checkServerStats(t, h.gotRPC, expect, checkFuncs) } func TestServerStatsUnaryRPC(t *testing.T) { - testServerStats(t, &testConfig{compress: ""}, &rpcConfig{success: true}, []func(t *testing.T, d *gotData, e *expectedData){ + testServerStats(t, &testConfig{compress: ""}, &rpcConfig{success: true, callType: unaryRPC}, []func(t *testing.T, d *gotData, e *expectedData){ checkInHeader, checkBegin, checkInPayload, @@ -740,7 +867,7 @@ func TestServerStatsUnaryRPC(t *testing.T) { } func TestServerStatsUnaryRPCError(t *testing.T) { - testServerStats(t, &testConfig{compress: ""}, &rpcConfig{success: false}, []func(t *testing.T, d *gotData, e *expectedData){ + testServerStats(t, &testConfig{compress: ""}, &rpcConfig{success: false, callType: unaryRPC}, []func(t *testing.T, d *gotData, e *expectedData){ checkInHeader, checkBegin, checkInPayload, @@ -750,7 +877,73 @@ func TestServerStatsUnaryRPCError(t *testing.T) { }) } -func TestServerStatsStreamingRPC(t *testing.T) { +func TestServerStatsClientStreamRPC(t *testing.T) { + count := 5 + checkFuncs := []func(t *testing.T, d *gotData, e *expectedData){ + checkInHeader, + checkBegin, + checkOutHeader, + } + ioPayFuncs := []func(t *testing.T, d *gotData, e *expectedData){ + checkInPayload, + } + for i := 0; i < count; i++ { + checkFuncs = append(checkFuncs, ioPayFuncs...) + } + checkFuncs = append(checkFuncs, + checkOutPayload, + checkOutTrailer, + checkEnd, + ) + testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, callType: clientStreamRPC}, checkFuncs) +} + +func TestServerStatsClientStreamRPCError(t *testing.T) { + count := 1 + testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, callType: clientStreamRPC}, []func(t *testing.T, d *gotData, e *expectedData){ + checkInHeader, + checkBegin, + checkOutHeader, + checkInPayload, + checkOutTrailer, + checkEnd, + }) +} + +func TestServerStatsServerStreamRPC(t *testing.T) { + count := 5 + checkFuncs := []func(t *testing.T, d *gotData, e *expectedData){ + checkInHeader, + checkBegin, + checkInPayload, + checkOutHeader, + } + ioPayFuncs := []func(t *testing.T, d *gotData, e *expectedData){ + checkOutPayload, + } + for i := 0; i < count; i++ { + checkFuncs = append(checkFuncs, ioPayFuncs...) + } + checkFuncs = append(checkFuncs, + checkOutTrailer, + checkEnd, + ) + testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, callType: serverStreamRPC}, checkFuncs) +} + +func TestServerStatsServerStreamRPCError(t *testing.T) { + count := 5 + testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, callType: serverStreamRPC}, []func(t *testing.T, d *gotData, e *expectedData){ + checkInHeader, + checkBegin, + checkInPayload, + checkOutHeader, + checkOutTrailer, + checkEnd, + }) +} + +func TestServerStatsFullDuplexRPC(t *testing.T) { count := 5 checkFuncs := []func(t *testing.T, d *gotData, e *expectedData){ checkInHeader, @@ -768,12 +961,12 @@ func TestServerStatsStreamingRPC(t *testing.T) { checkOutTrailer, checkEnd, ) - testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, streaming: true}, checkFuncs) + testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, callType: fullDuplexStreamRPC}, checkFuncs) } -func TestServerStatsStreamingRPCError(t *testing.T) { +func TestServerStatsFullDuplexRPCError(t *testing.T) { count := 5 - testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, streaming: true}, []func(t *testing.T, d *gotData, e *expectedData){ + testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, callType: fullDuplexStreamRPC}, []func(t *testing.T, d *gotData, e *expectedData){ checkInHeader, checkBegin, checkOutHeader, @@ -880,16 +1073,34 @@ func testClientStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs map defer te.tearDown() var ( - reqs []*testpb.SimpleRequest - resps []*testpb.SimpleResponse - err error + reqs []*testpb.SimpleRequest + resps []*testpb.SimpleResponse + method string + err error + + req *testpb.SimpleRequest + resp *testpb.SimpleResponse + e error ) - if !cc.streaming { - req, resp, e := te.doUnaryCall(cc) + switch cc.callType { + case unaryRPC: + method = "/grpc.testing.TestService/UnaryCall" + req, resp, e = te.doUnaryCall(cc) reqs = []*testpb.SimpleRequest{req} resps = []*testpb.SimpleResponse{resp} err = e - } else { + case clientStreamRPC: + method = "/grpc.testing.TestService/ClientStreamCall" + reqs, resp, e = te.doClientStreamCall(cc) + resps = []*testpb.SimpleResponse{resp} + err = e + case serverStreamRPC: + method = "/grpc.testing.TestService/ServerStreamCall" + req, resps, e = te.doServerStreamCall(cc) + reqs = []*testpb.SimpleRequest{req} + err = e + case fullDuplexStreamRPC: + method = "/grpc.testing.TestService/FullDuplexCall" reqs, resps, err = te.doFullDuplexCallRoundtrip(cc) } if cc.success != (err == nil) { @@ -925,23 +1136,19 @@ func testClientStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs map expect := &expectedData{ serverAddr: te.srvAddr, compression: tc.compress, + method: method, requests: reqs, responses: resps, failfast: cc.failfast, err: err, } - if !cc.streaming { - expect.method = "/grpc.testing.TestService/UnaryCall" - } else { - expect.method = "/grpc.testing.TestService/FullDuplexCall" - } checkConnStats(t, h.gotConn) checkClientStats(t, h.gotRPC, expect, checkFuncs) } func TestClientStatsUnaryRPC(t *testing.T) { - testClientStats(t, &testConfig{compress: ""}, &rpcConfig{success: true, failfast: false}, map[int]*checkFuncWithCount{ + testClientStats(t, &testConfig{compress: ""}, &rpcConfig{success: true, failfast: false, callType: unaryRPC}, map[int]*checkFuncWithCount{ begin: {checkBegin, 1}, outHeader: {checkOutHeader, 1}, outPayload: {checkOutPayload, 1}, @@ -953,7 +1160,7 @@ func TestClientStatsUnaryRPC(t *testing.T) { } func TestClientStatsUnaryRPCError(t *testing.T) { - testClientStats(t, &testConfig{compress: ""}, &rpcConfig{success: false, failfast: false}, map[int]*checkFuncWithCount{ + testClientStats(t, &testConfig{compress: ""}, &rpcConfig{success: false, failfast: false, callType: unaryRPC}, map[int]*checkFuncWithCount{ begin: {checkBegin, 1}, outHeader: {checkOutHeader, 1}, outPayload: {checkOutPayload, 1}, @@ -963,23 +1170,59 @@ func TestClientStatsUnaryRPCError(t *testing.T) { }) } -func TestClientStatsStreamingRPC(t *testing.T) { +func TestClientStatsClientStreamRPC(t *testing.T) { count := 5 - testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, failfast: false, streaming: true}, map[int]*checkFuncWithCount{ + testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, failfast: false, callType: clientStreamRPC}, map[int]*checkFuncWithCount{ begin: {checkBegin, 1}, outHeader: {checkOutHeader, 1}, - outPayload: {checkOutPayload, count}, inHeader: {checkInHeader, 1}, - inPayload: {checkInPayload, count}, + outPayload: {checkOutPayload, count}, inTrailer: {checkInTrailer, 1}, + inPayload: {checkInPayload, 1}, end: {checkEnd, 1}, }) } -// If the user doesn't call the last recv() on clientSteam. -func TestClientStatsStreamingRPCNotCallingLastRecv(t *testing.T) { +func TestClientStatsClientStreamRPCError(t *testing.T) { count := 1 - testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, failfast: false, streaming: true, noLastRecv: true}, map[int]*checkFuncWithCount{ + testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, failfast: false, callType: clientStreamRPC}, map[int]*checkFuncWithCount{ + begin: {checkBegin, 1}, + outHeader: {checkOutHeader, 1}, + inHeader: {checkInHeader, 1}, + outPayload: {checkOutPayload, 1}, + inTrailer: {checkInTrailer, 1}, + end: {checkEnd, 1}, + }) +} + +func TestClientStatsServerStreamRPC(t *testing.T) { + count := 5 + testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, failfast: false, callType: serverStreamRPC}, map[int]*checkFuncWithCount{ + begin: {checkBegin, 1}, + outHeader: {checkOutHeader, 1}, + outPayload: {checkOutPayload, 1}, + inHeader: {checkInHeader, 1}, + inPayload: {checkInPayload, count}, + inTrailer: {checkInTrailer, 1}, + end: {checkEnd, 1}, + }) +} + +func TestClientStatsServerStreamRPCError(t *testing.T) { + count := 5 + testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, failfast: false, callType: serverStreamRPC}, map[int]*checkFuncWithCount{ + begin: {checkBegin, 1}, + outHeader: {checkOutHeader, 1}, + outPayload: {checkOutPayload, 1}, + inHeader: {checkInHeader, 1}, + inTrailer: {checkInTrailer, 1}, + end: {checkEnd, 1}, + }) +} + +func TestClientStatsFullDuplexRPC(t *testing.T) { + count := 5 + testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, failfast: false, callType: fullDuplexStreamRPC}, map[int]*checkFuncWithCount{ begin: {checkBegin, 1}, outHeader: {checkOutHeader, 1}, outPayload: {checkOutPayload, count}, @@ -990,9 +1233,9 @@ func TestClientStatsStreamingRPCNotCallingLastRecv(t *testing.T) { }) } -func TestClientStatsStreamingRPCError(t *testing.T) { +func TestClientStatsFullDuplexRPCError(t *testing.T) { count := 5 - testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, failfast: false, streaming: true}, map[int]*checkFuncWithCount{ + testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, failfast: false, callType: fullDuplexStreamRPC}, map[int]*checkFuncWithCount{ begin: {checkBegin, 1}, outHeader: {checkOutHeader, 1}, outPayload: {checkOutPayload, 1}, @@ -1001,3 +1244,17 @@ func TestClientStatsStreamingRPCError(t *testing.T) { end: {checkEnd, 1}, }) } + +// If the user doesn't call the last recv() on clientStream. +func TestClientStatsFullDuplexRPCNotCallingLastRecv(t *testing.T) { + count := 1 + testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, failfast: false, callType: fullDuplexStreamRPC, noLastRecv: true}, map[int]*checkFuncWithCount{ + begin: {checkBegin, 1}, + outHeader: {checkOutHeader, 1}, + outPayload: {checkOutPayload, count}, + inHeader: {checkInHeader, 1}, + inPayload: {checkInPayload, count}, + inTrailer: {checkInTrailer, 1}, + end: {checkEnd, 1}, + }) +} From 780308da60e78498d55642114b807c33550af7c1 Mon Sep 17 00:00:00 2001 From: Menghan Li Date: Fri, 12 May 2017 14:05:32 -0700 Subject: [PATCH 08/11] add logs to grpclb on send and recv (#1235) --- grpclb.go | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/grpclb.go b/grpclb.go index b9d70738..0e4d269b 100644 --- a/grpclb.go +++ b/grpclb.go @@ -152,6 +152,7 @@ type balancer struct { func (b *balancer) watchAddrUpdates(w naming.Watcher, ch chan []remoteBalancerInfo) error { updates, err := w.Next() if err != nil { + grpclog.Printf("grpclb: failed to get next addr update from watcher: %v", err) return err } b.mu.Lock() @@ -306,6 +307,7 @@ func (b *balancer) sendLoadReport(s *balanceLoadClientStream, interval time.Dura ClientStats: &stats, }, }); err != nil { + grpclog.Printf("grpclb: failed to send load report: %v", err) return } } @@ -316,7 +318,7 @@ func (b *balancer) callRemoteBalancer(lbc *loadBalancerClient, seq int) (retry b defer cancel() stream, err := lbc.BalanceLoad(ctx) if err != nil { - grpclog.Printf("Failed to perform RPC to the remote balancer %v", err) + grpclog.Printf("grpclb: failed to perform RPC to the remote balancer %v", err) return } b.mu.Lock() @@ -333,17 +335,19 @@ func (b *balancer) callRemoteBalancer(lbc *loadBalancerClient, seq int) (retry b }, } if err := stream.Send(initReq); err != nil { + grpclog.Printf("grpclb: failed to send init request: %v", err) // TODO: backoff on retry? return true } reply, err := stream.Recv() if err != nil { + grpclog.Printf("grpclb: failed to recv init response: %v", err) // TODO: backoff on retry? return true } initResp := reply.GetInitialResponse() if initResp == nil { - grpclog.Println("Failed to receive the initial response from the remote balancer.") + grpclog.Println("grpclb: reply from remote balancer did not include initial response.") return } // TODO: Support delegation. @@ -364,6 +368,7 @@ func (b *balancer) callRemoteBalancer(lbc *loadBalancerClient, seq int) (retry b for { reply, err := stream.Recv() if err != nil { + grpclog.Printf("grpclb: failed to recv server list: %v", err) break } b.mu.Lock() @@ -397,6 +402,7 @@ func (b *balancer) Start(target string, config BalancerConfig) error { w, err := b.r.Resolve(target) if err != nil { b.mu.Unlock() + grpclog.Printf("grpclb: failed to resolve address: %v, err: %v", target, err) return err } b.w = w @@ -406,7 +412,7 @@ func (b *balancer) Start(target string, config BalancerConfig) error { go func() { for { if err := b.watchAddrUpdates(w, balancerAddrsCh); err != nil { - grpclog.Printf("grpc: the naming watcher stops working due to %v.\n", err) + grpclog.Printf("grpclb: the naming watcher stops working due to %v.\n", err) close(balancerAddrsCh) return } @@ -497,7 +503,7 @@ func (b *balancer) Start(target string, config BalancerConfig) error { if creds := config.DialCreds; creds != nil { if rb.name != "" { if err := creds.OverrideServerName(rb.name); err != nil { - grpclog.Printf("Failed to override the server name in the credentials: %v", err) + grpclog.Printf("grpclb: failed to override the server name in the credentials: %v", err) continue } } @@ -512,7 +518,7 @@ func (b *balancer) Start(target string, config BalancerConfig) error { ccError = make(chan struct{}) cc, err = Dial(rb.addr, dopts...) if err != nil { - grpclog.Printf("Failed to setup a connection to the remote balancer %v: %v", rb.addr, err) + grpclog.Printf("grpclb: failed to setup a connection to the remote balancer %v: %v", rb.addr, err) close(ccError) continue } From 1c69e4cae0f5180ce7d8b472bf0a55d2654fe31b Mon Sep 17 00:00:00 2001 From: Mehrdad Afshari Date: Fri, 12 May 2017 14:06:11 -0700 Subject: [PATCH 09/11] Eagerly set a pointer to nil to help GC (#1232) --- transport/transport.go | 1 + 1 file changed, 1 insertion(+) diff --git a/transport/transport.go b/transport/transport.go index 2dff7c8b..cccbaf5e 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -105,6 +105,7 @@ func (b *recvBuffer) load() { if len(b.backlog) > 0 { select { case b.c <- b.backlog[0]: + b.backlog[0] = nil b.backlog = b.backlog[1:] default: } From 135247d85c4519e31f5e1b0459658147ab1a6781 Mon Sep 17 00:00:00 2001 From: Eric Drechsel Date: Mon, 15 May 2017 12:41:55 -0700 Subject: [PATCH 10/11] fix server panic trying to send on stream as client disconnects #1111 (#1115) --- transport/handler_server.go | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/transport/handler_server.go b/transport/handler_server.go index 24f306ba..31b0570e 100644 --- a/transport/handler_server.go +++ b/transport/handler_server.go @@ -179,11 +179,18 @@ func (a strAddr) String() string { return string(a) } // do runs fn in the ServeHTTP goroutine. func (ht *serverHandlerTransport) do(fn func()) error { + // Avoid a panic writing to closed channel. Imperfect but maybe good enough. select { - case ht.writes <- fn: - return nil case <-ht.closedCh: return ErrConnClosing + default: + select { + case ht.writes <- fn: + return nil + case <-ht.closedCh: + return ErrConnClosing + } + } } From aacd01c2197fc8234db4c9b169114e7f1d739914 Mon Sep 17 00:00:00 2001 From: Menghan Li Date: Mon, 15 May 2017 12:43:49 -0700 Subject: [PATCH 11/11] call listen with "localhost:port" instead of ":port" in tests (#1237) --- proxy_test.go | 4 ++-- test/end2end_test.go | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/proxy_test.go b/proxy_test.go index 846b396b..bd007066 100644 --- a/proxy_test.go +++ b/proxy_test.go @@ -133,7 +133,7 @@ func (p *proxyServer) stop() { } func TestHTTPConnect(t *testing.T) { - plis, err := net.Listen("tcp", ":0") + plis, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatalf("failed to listen: %v", err) } @@ -141,7 +141,7 @@ func TestHTTPConnect(t *testing.T) { go p.run() defer p.stop() - blis, err := net.Listen("tcp", ":0") + blis, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatalf("failed to listen: %v", err) } diff --git a/test/end2end_test.go b/test/end2end_test.go index ced25096..6bc6661a 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -3412,7 +3412,7 @@ func (c *serverDispatchCred) getRawConn() net.Conn { } func TestServerCredsDispatch(t *testing.T) { - lis, err := net.Listen("tcp", ":0") + lis, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatalf("Failed to listen: %v", err) } @@ -3453,7 +3453,7 @@ func TestFlowControlLogicalRace(t *testing.T) { requestCount = 1000 } - lis, err := net.Listen("tcp", ":0") + lis, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatalf("Failed to listen: %v", err) } @@ -3763,9 +3763,9 @@ func (ss *stubServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallSer // Start starts the server and creates a client connected to it. func (ss *stubServer) Start() error { - lis, err := net.Listen("tcp", ":0") + lis, err := net.Listen("tcp", "localhost:0") if err != nil { - return fmt.Errorf(`net.Listen("tcp", ":0") = %v`, err) + return fmt.Errorf(`net.Listen("tcp", "localhost:0") = %v`, err) } ss.cleanups = append(ss.cleanups, func() { lis.Close() })