From e42a66c81bb5b8c28696b43bfcf789fc74ceba25 Mon Sep 17 00:00:00 2001 From: Menghan Li Date: Tue, 1 Nov 2016 16:57:54 -0700 Subject: [PATCH] add client stats --- call.go | 64 ++++- stats/stats_test.go | 545 ++++++++++++++++++++++++++++++++------ stream.go | 67 ++++- transport/http2_client.go | 40 +++ transport/transport.go | 5 + 5 files changed, 620 insertions(+), 101 deletions(-) diff --git a/call.go b/call.go index f7fbc7e7..8aabbec6 100644 --- a/call.go +++ b/call.go @@ -42,6 +42,7 @@ import ( "golang.org/x/net/context" "golang.org/x/net/trace" "google.golang.org/grpc/codes" + "google.golang.org/grpc/stats" "google.golang.org/grpc/transport" ) @@ -63,14 +64,24 @@ func recvResponse(dopts dialOptions, t transport.ClientTransport, c *callInfo, s return } p := &parser{r: stream} + var incomingPayloadStats *stats.IncomingPayloadStats + if stats.On() { + incomingPayloadStats = &stats.IncomingPayloadStats{ + IsClient: true, + } + } for { - if err = recv(p, dopts.codec, stream, dopts.dc, reply, math.MaxInt32, nil); err != nil { + if err = recv(p, dopts.codec, stream, dopts.dc, reply, math.MaxInt32, incomingPayloadStats); err != nil { if err == io.EOF { break } return } } + if err == io.EOF && stream.StatusCode() == codes.OK && incomingPayloadStats != nil { + // TODO in the current implementation, incomingTrailerStats is handled before incomingPayloadStats. Fix the order if necessary. + stats.Handle(stream.Context(), incomingPayloadStats) + } c.trailerMD = stream.Trailer() return nil } @@ -89,15 +100,27 @@ func sendRequest(ctx context.Context, codec Codec, compressor Compressor, callHd } } }() - var cbuf *bytes.Buffer + var ( + cbuf *bytes.Buffer + outgoingPayloadStats *stats.OutgoingPayloadStats + ) if compressor != nil { cbuf = new(bytes.Buffer) } - outBuf, err := encode(codec, args, compressor, cbuf, nil) + if stats.On() { + outgoingPayloadStats = &stats.OutgoingPayloadStats{ + IsClient: true, + } + } + outBuf, err := encode(codec, args, compressor, cbuf, outgoingPayloadStats) if err != nil { return nil, Errorf(codes.Internal, "grpc: %v", err) } err = t.Write(stream, outBuf, opts) + if outgoingPayloadStats != nil { + outgoingPayloadStats.SentTime = time.Now() + stats.Handle(stream.Context(), outgoingPayloadStats) + } // t.NewStream(...) could lead to an early rejection of the RPC (e.g., the service/method // does not exist.) so that t.Write could get io.EOF from wait(...). Leave the following // recvResponse to get the final status. @@ -118,7 +141,7 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli return invoke(ctx, method, args, reply, cc, opts...) } -func invoke(ctx context.Context, method string, args, reply interface{}, cc *ClientConn, opts ...CallOption) (err error) { +func invoke(ctx context.Context, method string, args, reply interface{}, cc *ClientConn, opts ...CallOption) (e error) { c := defaultCallInfo for _, o := range opts { if err := o.before(&c); err != nil { @@ -140,25 +163,38 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli c.traceInfo.tr.LazyLog(&c.traceInfo.firstLine, false) // TODO(dsymonds): Arrange for c.traceInfo.firstLine.remoteAddr to be set. defer func() { - if err != nil { - c.traceInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true) + if e != nil { + c.traceInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{e}}, true) c.traceInfo.tr.SetError() } }() } + var ( + err error + t transport.ClientTransport + stream *transport.Stream + // Record the put handler from Balancer.Get(...). It is called once the + // RPC has completed or failed. + put func() + ) + defer func() { + if e != nil && stats.On() { + errorStats := &stats.ErrorStats{ + IsClient: true, + Error: e, + } + if stream != nil { + stats.Handle(stream.Context(), errorStats) + } else { + stats.Handle(ctx, errorStats) + } + } + }() topts := &transport.Options{ Last: true, Delay: false, } for { - var ( - err error - t transport.ClientTransport - stream *transport.Stream - // Record the put handler from Balancer.Get(...). It is called once the - // RPC has completed or failed. - put func() - ) // TODO(zhaoq): Need a formal spec of fail-fast. callHdr := &transport.CallHdr{ Host: cc.authority, diff --git a/stats/stats_test.go b/stats/stats_test.go index 3ab4e426..db07ed6e 100644 --- a/stats/stats_test.go +++ b/stats/stats_test.go @@ -275,21 +275,33 @@ func (te *test) doFullDuplexCallRoundtrip(count int, success bool) ([]*testpb.Si } type expectedData struct { - method string - localAddr string - encryption string - expectedInIdx int - incoming []*testpb.SimpleRequest - expectedOutIdx int - outgoing []*testpb.SimpleResponse - err error + method string + serverAddr string + encryption string + reqIdx int + requests []*testpb.SimpleRequest + respIdx int + responses []*testpb.SimpleResponse + err error } type gotData struct { - ctx context.Context - s stats.Stats + ctx context.Context + client bool + s stats.Stats } +const ( + inits int = iota + inpay + inheader + intrailer + outpay + outheader + outtrailer + errors +) + func checkIncomingHeaderStats(t *testing.T, d *gotData, e *expectedData) { var ( ok bool @@ -301,22 +313,21 @@ func checkIncomingHeaderStats(t *testing.T, d *gotData, e *expectedData) { if d.ctx == nil { t.Fatalf("d.ctx = nil, want ") } - if st.IsClient { - t.Fatalf("st.IsClient = true, want false") - } - if st.Method != e.method { - t.Fatalf("st.Method = %s, want %v", st.Method, e.method) - } - if st.LocalAddr.String() != e.localAddr { - t.Fatalf("st.LocalAddr = %v, want %v", st.LocalAddr, e.localAddr) - } - if st.Encryption != e.encryption { - t.Fatalf("st.Encryption = %v, want %v", st.Encryption, e.encryption) - } // TODO check real length, not just > 0. if st.WireLength <= 0 { t.Fatalf("st.Lenght = 0, want > 0") } + if !d.client { + if st.Method != e.method { + t.Fatalf("st.Method = %s, want %v", st.Method, e.method) + } + if st.LocalAddr.String() != e.serverAddr { + t.Fatalf("st.LocalAddr = %v, want %v", st.LocalAddr, e.serverAddr) + } + if st.Encryption != e.encryption { + t.Fatalf("st.Encryption = %v, want %v", st.Encryption, e.encryption) + } + } } func checkIncomingPayloadStats(t *testing.T, d *gotData, e *expectedData) { @@ -330,22 +341,36 @@ func checkIncomingPayloadStats(t *testing.T, d *gotData, e *expectedData) { if d.ctx == nil { t.Fatalf("d.ctx = nil, want ") } - if st.IsClient { - t.Fatalf("st IsClient = true, want false") - } - b, err := proto.Marshal(e.incoming[e.expectedInIdx]) - if err != nil { - t.Fatalf("failed to marshal message: %v", err) - } - if reflect.TypeOf(st.Payload) != reflect.TypeOf(e.incoming[e.expectedInIdx]) { - t.Fatalf("st.Payload = %T, want %T", st.Payload, e.incoming[e.expectedInIdx]) - } - e.expectedInIdx++ - if string(st.Data) != string(b) { - t.Fatalf("st.Data = %v, want %v", st.Data, b) - } - if st.Length != len(b) { - t.Fatalf("st.Lenght = %v, want %v", st.Length, len(b)) + if d.client { + b, err := proto.Marshal(e.responses[e.respIdx]) + if err != nil { + t.Fatalf("failed to marshal message: %v", err) + } + if reflect.TypeOf(st.Payload) != reflect.TypeOf(e.responses[e.respIdx]) { + t.Fatalf("st.Payload = %T, want %T", st.Payload, e.responses[e.respIdx]) + } + e.respIdx++ + if string(st.Data) != string(b) { + t.Fatalf("st.Data = %v, want %v", st.Data, b) + } + if st.Length != len(b) { + t.Fatalf("st.Lenght = %v, want %v", st.Length, len(b)) + } + } else { + b, err := proto.Marshal(e.requests[e.reqIdx]) + if err != nil { + t.Fatalf("failed to marshal message: %v", err) + } + if reflect.TypeOf(st.Payload) != reflect.TypeOf(e.requests[e.reqIdx]) { + t.Fatalf("st.Payload = %T, want %T", st.Payload, e.requests[e.reqIdx]) + } + e.reqIdx++ + if string(st.Data) != string(b) { + t.Fatalf("st.Data = %v, want %v", st.Data, b) + } + if st.Length != len(b) { + t.Fatalf("st.Lenght = %v, want %v", st.Length, len(b)) + } } // TODO check WireLength and ReceivedTime. if st.ReceivedTime.IsZero() { @@ -364,9 +389,6 @@ func checkIncomingTrailerStats(t *testing.T, d *gotData, e *expectedData) { if d.ctx == nil { t.Fatalf("d.ctx = nil, want ") } - if st.IsClient { - t.Fatalf("st.IsClient = true, want false") - } // TODO check real length, not just > 0. if st.WireLength <= 0 { t.Fatalf("st.Lenght = 0, want > 0") @@ -384,13 +406,21 @@ func checkOutgoingHeaderStats(t *testing.T, d *gotData, e *expectedData) { if d.ctx == nil { t.Fatalf("d.ctx = nil, want ") } - if st.IsClient { - t.Fatalf("st IsClient = true, want false") - } // TODO check real length, not just > 0. if st.WireLength <= 0 { t.Fatalf("st.Lenght = 0, want > 0") } + if d.client { + if st.Method != e.method { + t.Fatalf("st.Method = %s, want %v", st.Method, e.method) + } + if st.RemoteAddr.String() != e.serverAddr { + t.Fatalf("st.LocalAddr = %v, want %v", st.LocalAddr, e.serverAddr) + } + if st.Encryption != e.encryption { + t.Fatalf("st.Encryption = %v, want %v", st.Encryption, e.encryption) + } + } } func checkOutgoingPayloadStats(t *testing.T, d *gotData, e *expectedData) { @@ -404,22 +434,36 @@ func checkOutgoingPayloadStats(t *testing.T, d *gotData, e *expectedData) { if d.ctx == nil { t.Fatalf("d.ctx = nil, want ") } - if st.IsClient { - t.Fatalf("st IsClient = true, want false") - } - b, err := proto.Marshal(e.outgoing[e.expectedOutIdx]) - if err != nil { - t.Fatalf("failed to marshal message: %v", err) - } - if reflect.TypeOf(st.Payload) != reflect.TypeOf(e.outgoing[e.expectedOutIdx]) { - t.Fatalf("st.Payload = %T, want %T", st.Payload, e.outgoing[e.expectedOutIdx]) - } - e.expectedOutIdx++ - if string(st.Data) != string(b) { - t.Fatalf("st.Data = %v, want %v", st.Data, b) - } - if st.Length != len(b) { - t.Fatalf("st.Lenght = %v, want %v", st.Length, len(b)) + if d.client { + b, err := proto.Marshal(e.requests[e.reqIdx]) + if err != nil { + t.Fatalf("failed to marshal message: %v", err) + } + if reflect.TypeOf(st.Payload) != reflect.TypeOf(e.requests[e.reqIdx]) { + t.Fatalf("st.Payload = %T, want %T", st.Payload, e.requests[e.reqIdx]) + } + e.reqIdx++ + if string(st.Data) != string(b) { + t.Fatalf("st.Data = %v, want %v", st.Data, b) + } + if st.Length != len(b) { + t.Fatalf("st.Lenght = %v, want %v", st.Length, len(b)) + } + } else { + b, err := proto.Marshal(e.responses[e.respIdx]) + if err != nil { + t.Fatalf("failed to marshal message: %v", err) + } + if reflect.TypeOf(st.Payload) != reflect.TypeOf(e.responses[e.respIdx]) { + t.Fatalf("st.Payload = %T, want %T", st.Payload, e.responses[e.respIdx]) + } + e.respIdx++ + if string(st.Data) != string(b) { + t.Fatalf("st.Data = %v, want %v", st.Data, b) + } + if st.Length != len(b) { + t.Fatalf("st.Lenght = %v, want %v", st.Length, len(b)) + } } // TODO check WireLength and ReceivedTime. if st.SentTime.IsZero() { @@ -458,9 +502,6 @@ func checkErrorStats(t *testing.T, d *gotData, e *expectedData) { if d.ctx == nil { t.Fatalf("d.ctx = nil, want ") } - if st.IsClient { - t.Fatalf("st IsClient = true, want false") - } if grpc.Code(st.Error) != grpc.Code(e.err) || grpc.ErrorDesc(st.Error) != grpc.ErrorDesc(e.err) { t.Fatalf("st.Error = %v, want %v", st.Error, e.err) } @@ -474,7 +515,9 @@ func TestServerStatsUnaryRPC(t *testing.T) { stats.RegisterHandler(func(ctx context.Context, s stats.Stats) { mu.Lock() defer mu.Unlock() - got = append(got, &gotData{ctx, s}) + if !s.ClientStats() { + got = append(got, &gotData{ctx, false, s}) + } }) te := newTest(t, "") @@ -488,10 +531,10 @@ func TestServerStatsUnaryRPC(t *testing.T) { te.srv.GracefulStop() // Wait for the server to stop. expect := &expectedData{ - method: "/grpc.testing.TestService/UnaryCall", - localAddr: te.srvAddr, - incoming: []*testpb.SimpleRequest{req}, - outgoing: []*testpb.SimpleResponse{resp}, + method: "/grpc.testing.TestService/UnaryCall", + serverAddr: te.srvAddr, + requests: []*testpb.SimpleRequest{req}, + responses: []*testpb.SimpleResponse{resp}, } checkFuncs := []func(t *testing.T, d *gotData, e *expectedData){ @@ -523,7 +566,9 @@ func TestServerStatsUnaryRPCError(t *testing.T) { stats.RegisterHandler(func(ctx context.Context, s stats.Stats) { mu.Lock() defer mu.Unlock() - got = append(got, &gotData{ctx, s}) + if !s.ClientStats() { + got = append(got, &gotData{ctx, false, s}) + } }) te := newTest(t, "") @@ -537,11 +582,11 @@ func TestServerStatsUnaryRPCError(t *testing.T) { te.srv.GracefulStop() // Wait for the server to stop. expect := &expectedData{ - method: "/grpc.testing.TestService/UnaryCall", - localAddr: te.srvAddr, - incoming: []*testpb.SimpleRequest{req}, - outgoing: []*testpb.SimpleResponse{resp}, - err: err, + method: "/grpc.testing.TestService/UnaryCall", + serverAddr: te.srvAddr, + requests: []*testpb.SimpleRequest{req}, + responses: []*testpb.SimpleResponse{resp}, + err: err, } checkFuncs := []func(t *testing.T, d *gotData, e *expectedData){ @@ -573,7 +618,9 @@ func TestServerStatsStreamingRPC(t *testing.T) { stats.RegisterHandler(func(ctx context.Context, s stats.Stats) { mu.Lock() defer mu.Unlock() - got = append(got, &gotData{ctx, s}) + if !s.ClientStats() { + got = append(got, &gotData{ctx, false, s}) + } }) te := newTest(t, "gzip") @@ -589,10 +636,10 @@ func TestServerStatsStreamingRPC(t *testing.T) { expect := &expectedData{ method: "/grpc.testing.TestService/FullDuplexCall", - localAddr: te.srvAddr, + serverAddr: te.srvAddr, encryption: "gzip", - incoming: reqs, - outgoing: resps, + requests: reqs, + responses: resps, } checkFuncs := []func(t *testing.T, d *gotData, e *expectedData){ @@ -629,7 +676,9 @@ func TestServerStatsStreamingRPCError(t *testing.T) { stats.RegisterHandler(func(ctx context.Context, s stats.Stats) { mu.Lock() defer mu.Unlock() - got = append(got, &gotData{ctx, s}) + if !s.ClientStats() { + got = append(got, &gotData{ctx, false, s}) + } }) te := newTest(t, "gzip") @@ -645,10 +694,10 @@ func TestServerStatsStreamingRPCError(t *testing.T) { expect := &expectedData{ method: "/grpc.testing.TestService/FullDuplexCall", - localAddr: te.srvAddr, + serverAddr: te.srvAddr, encryption: "gzip", - incoming: reqs, - outgoing: resps, + requests: reqs, + responses: resps, err: err, } @@ -672,3 +721,335 @@ func TestServerStatsStreamingRPCError(t *testing.T) { stats.Stop() } + +type checkFuncWithCount struct { + f func(t *testing.T, d *gotData, e *expectedData) + c int // expected count +} + +func TestClientStatsUnaryRPC(t *testing.T) { + var ( + mu sync.Mutex + got []*gotData + ) + stats.RegisterHandler(func(ctx context.Context, s stats.Stats) { + mu.Lock() + defer mu.Unlock() + if s.ClientStats() { + got = append(got, &gotData{ctx, true, s}) + } + }) + + te := newTest(t, "") + te.startServer(&testServer{}) + defer te.tearDown() + + req, resp, err := te.doUnaryCall(true) + if err != nil { + t.Fatalf(err.Error()) + } + te.srv.GracefulStop() // Wait for the server to stop. + + expect := &expectedData{ + method: "/grpc.testing.TestService/UnaryCall", + serverAddr: te.srvAddr, + requests: []*testpb.SimpleRequest{req}, + responses: []*testpb.SimpleResponse{resp}, + } + + checkFuncs := map[int]*checkFuncWithCount{ + outheader: &checkFuncWithCount{checkOutgoingHeaderStats, 1}, + outpay: &checkFuncWithCount{checkOutgoingPayloadStats, 1}, + inheader: &checkFuncWithCount{checkIncomingHeaderStats, 1}, + inpay: &checkFuncWithCount{checkIncomingPayloadStats, 1}, + intrailer: &checkFuncWithCount{checkIncomingTrailerStats, 1}, + } + + var expectLen int + for _, v := range checkFuncs { + expectLen += v.c + } + if len(got) != expectLen { + t.Fatalf("got %v stats, want %v stats", len(got), expectLen) + } + + for _, s := range got { + mu.Lock() + switch s.s.(type) { + case *stats.OutgoingHeaderStats: + if checkFuncs[outheader].c <= 0 { + t.Fatalf("unexpected stats: %T", s) + } + checkFuncs[outheader].f(t, s, expect) + checkFuncs[outheader].c-- + case *stats.OutgoingPayloadStats: + if checkFuncs[outpay].c <= 0 { + t.Fatalf("unexpected stats: %T", s) + } + checkFuncs[outpay].f(t, s, expect) + checkFuncs[outpay].c-- + case *stats.IncomingHeaderStats: + if checkFuncs[inheader].c <= 0 { + t.Fatalf("unexpected stats: %T", s) + } + checkFuncs[inheader].f(t, s, expect) + checkFuncs[inheader].c-- + case *stats.IncomingPayloadStats: + if checkFuncs[inpay].c <= 0 { + t.Fatalf("unexpected stats: %T", s) + } + checkFuncs[inpay].f(t, s, expect) + checkFuncs[inpay].c-- + case *stats.IncomingTrailerStats: + if checkFuncs[intrailer].c <= 0 { + t.Fatalf("unexpected stats: %T", s) + } + checkFuncs[intrailer].f(t, s, expect) + checkFuncs[intrailer].c-- + default: + t.Fatalf("unexpected stats: %T", s) + } + mu.Unlock() + } + + stats.Stop() +} + +func TestClientStatsUnaryRPCError(t *testing.T) { + var ( + mu sync.Mutex + got []*gotData + ) + stats.RegisterHandler(func(ctx context.Context, s stats.Stats) { + mu.Lock() + defer mu.Unlock() + if s.ClientStats() { + got = append(got, &gotData{ctx, true, s}) + } + }) + + te := newTest(t, "") + te.startServer(&testServer{}) + defer te.tearDown() + + req, resp, err := te.doUnaryCall(false) + if err == nil { + t.Fatalf("got error ; want ") + } + te.srv.GracefulStop() // Wait for the server to stop. + + expect := &expectedData{ + method: "/grpc.testing.TestService/UnaryCall", + serverAddr: te.srvAddr, + requests: []*testpb.SimpleRequest{req}, + responses: []*testpb.SimpleResponse{resp}, + err: err, + } + + checkFuncs := []func(t *testing.T, d *gotData, e *expectedData){ + checkOutgoingHeaderStats, + checkOutgoingPayloadStats, + checkIncomingHeaderStats, + checkIncomingTrailerStats, + checkErrorStats, + } + + if len(got) != len(checkFuncs) { + t.Fatalf("got %v stats, want %v stats", len(got), len(checkFuncs)) + } + + for i, f := range checkFuncs { + mu.Lock() + f(t, got[i], expect) + mu.Unlock() + } + + stats.Stop() +} + +func TestClientStatsStreamingRPC(t *testing.T) { + var ( + mu sync.Mutex + got []*gotData + ) + stats.RegisterHandler(func(ctx context.Context, s stats.Stats) { + mu.Lock() + defer mu.Unlock() + if s.ClientStats() { + got = append(got, &gotData{ctx, true, s}) + } + }) + + te := newTest(t, "gzip") + te.startServer(&testServer{}) + defer te.tearDown() + + count := 5 + reqs, resps, err := te.doFullDuplexCallRoundtrip(count, true) + if err == nil { + t.Fatalf(err.Error()) + } + te.srv.GracefulStop() // Wait for the server to stop. + + expect := &expectedData{ + method: "/grpc.testing.TestService/FullDuplexCall", + serverAddr: te.srvAddr, + encryption: "gzip", + requests: reqs, + responses: resps, + } + + checkFuncs := map[int]*checkFuncWithCount{ + outheader: &checkFuncWithCount{checkOutgoingHeaderStats, 1}, + outpay: &checkFuncWithCount{checkOutgoingPayloadStats, count}, + inheader: &checkFuncWithCount{checkIncomingHeaderStats, 1}, + inpay: &checkFuncWithCount{checkIncomingPayloadStats, count}, + intrailer: &checkFuncWithCount{checkIncomingTrailerStats, 1}, + } + + var expectLen int + for _, v := range checkFuncs { + expectLen += v.c + } + if len(got) != expectLen { + t.Fatalf("got %v stats, want %v stats", len(got), expectLen) + } + + for _, s := range got { + mu.Lock() + switch s.s.(type) { + case *stats.OutgoingHeaderStats: + if checkFuncs[outheader].c <= 0 { + t.Fatalf("unexpected stats: %T", s) + } + checkFuncs[outheader].f(t, s, expect) + checkFuncs[outheader].c-- + case *stats.OutgoingPayloadStats: + if checkFuncs[outpay].c <= 0 { + t.Fatalf("unexpected stats: %T", s) + } + checkFuncs[outpay].f(t, s, expect) + checkFuncs[outpay].c-- + case *stats.IncomingHeaderStats: + if checkFuncs[inheader].c <= 0 { + t.Fatalf("unexpected stats: %T", s) + } + checkFuncs[inheader].f(t, s, expect) + checkFuncs[inheader].c-- + case *stats.IncomingPayloadStats: + if checkFuncs[inpay].c <= 0 { + t.Fatalf("unexpected stats: %T", s) + } + checkFuncs[inpay].f(t, s, expect) + checkFuncs[inpay].c-- + case *stats.IncomingTrailerStats: + if checkFuncs[intrailer].c <= 0 { + t.Fatalf("unexpected stats: %T", s) + } + checkFuncs[intrailer].f(t, s, expect) + checkFuncs[intrailer].c-- + default: + t.Fatalf("unexpected stats: %T", s) + } + mu.Unlock() + } + + stats.Stop() +} + +func TestClientStatsStreamingRPCError(t *testing.T) { + var ( + mu sync.Mutex + got []*gotData + ) + stats.RegisterHandler(func(ctx context.Context, s stats.Stats) { + mu.Lock() + defer mu.Unlock() + if s.ClientStats() { + got = append(got, &gotData{ctx, true, s}) + } + }) + + te := newTest(t, "gzip") + te.startServer(&testServer{}) + defer te.tearDown() + + count := 5 + reqs, resps, err := te.doFullDuplexCallRoundtrip(count, false) + if err == nil { + t.Fatalf("got error ; want ") + } + te.srv.GracefulStop() // Wait for the server to stop. + + expect := &expectedData{ + method: "/grpc.testing.TestService/FullDuplexCall", + serverAddr: te.srvAddr, + encryption: "gzip", + requests: reqs, + responses: resps, + err: err, + } + + checkFuncs := map[int]*checkFuncWithCount{ + outheader: &checkFuncWithCount{checkOutgoingHeaderStats, 1}, + outpay: &checkFuncWithCount{checkOutgoingPayloadStats, 1}, + inheader: &checkFuncWithCount{checkIncomingHeaderStats, 1}, + intrailer: &checkFuncWithCount{checkIncomingTrailerStats, 1}, + errors: &checkFuncWithCount{checkErrorStats, 1}, + } + + var expectLen int + for _, v := range checkFuncs { + expectLen += v.c + } + if len(got) != expectLen { + t.Fatalf("got %v stats, want %v stats", len(got), expectLen) + } + + for _, s := range got { + mu.Lock() + switch s.s.(type) { + case *stats.OutgoingHeaderStats: + if checkFuncs[outheader].c <= 0 { + t.Fatalf("unexpected stats: %T", s) + } + checkFuncs[outheader].f(t, s, expect) + checkFuncs[outheader].c-- + case *stats.OutgoingPayloadStats: + if checkFuncs[outpay].c <= 0 { + t.Fatalf("unexpected stats: %T", s) + } + checkFuncs[outpay].f(t, s, expect) + checkFuncs[outpay].c-- + case *stats.IncomingHeaderStats: + if checkFuncs[inheader].c <= 0 { + t.Fatalf("unexpected stats: %T", s) + } + checkFuncs[inheader].f(t, s, expect) + checkFuncs[inheader].c-- + case *stats.IncomingPayloadStats: + if checkFuncs[inpay].c <= 0 { + t.Fatalf("unexpected stats: %T", s) + } + checkFuncs[inpay].f(t, s, expect) + checkFuncs[inpay].c-- + case *stats.IncomingTrailerStats: + if checkFuncs[intrailer].c <= 0 { + t.Fatalf("unexpected stats: %T", s) + } + checkFuncs[intrailer].f(t, s, expect) + checkFuncs[intrailer].c-- + case *stats.ErrorStats: + if checkFuncs[errors].c <= 0 { + t.Fatalf("unexpected stats: %T", s) + } + checkFuncs[errors].f(t, s, expect) + checkFuncs[errors].c-- + default: + t.Fatalf("unexpected stats: %T", s) + } + mu.Unlock() + } + + stats.Stop() +} diff --git a/stream.go b/stream.go index ffcc1ca5..b1fc9634 100644 --- a/stream.go +++ b/stream.go @@ -98,7 +98,16 @@ type ClientStream interface { // NewClientStream creates a new Stream for the client side. This is called // by generated code. -func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (ClientStream, error) { +func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (_ ClientStream, err error) { + defer func() { + if err != nil && stats.On() { + errorStats := &stats.ErrorStats{ + IsClient: true, + Error: err, + } + stats.Handle(ctx, errorStats) + } + }() if cc.dopts.streamInt != nil { return cc.dopts.streamInt(ctx, desc, cc, method, newClientStream, opts...) } @@ -253,7 +262,16 @@ func (cs *clientStream) Context() context.Context { return cs.s.Context() } -func (cs *clientStream) Header() (metadata.MD, error) { +func (cs *clientStream) Header() (_ metadata.MD, err error) { + defer func() { + if err != nil && stats.On() { + errorStats := &stats.ErrorStats{ + IsClient: true, + Error: err, + } + stats.Handle(cs.s.Context(), errorStats) + } + }() m, err := cs.s.Header() if err != nil { if _, ok := err.(transport.ConnectionError); !ok { @@ -275,6 +293,15 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) { } cs.mu.Unlock() } + defer func() { + if err != nil && stats.On() { + errorStats := &stats.ErrorStats{ + IsClient: true, + Error: err, + } + stats.Handle(cs.s.Context(), errorStats) + } + }() defer func() { if err != nil { cs.finish(err) @@ -297,7 +324,13 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) { } err = toRPCErr(err) }() - out, err := encode(cs.codec, m, cs.cp, cs.cbuf, nil) + var outgoingPayloadStats *stats.OutgoingPayloadStats + if stats.On() { + outgoingPayloadStats = &stats.OutgoingPayloadStats{ + IsClient: true, + } + } + out, err := encode(cs.codec, m, cs.cp, cs.cbuf, outgoingPayloadStats) defer func() { if cs.cbuf != nil { cs.cbuf.Reset() @@ -306,11 +339,31 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) { if err != nil { return Errorf(codes.Internal, "grpc: %v", err) } - return cs.t.Write(cs.s, out, &transport.Options{Last: false}) + err = cs.t.Write(cs.s, out, &transport.Options{Last: false}) + if outgoingPayloadStats != nil { + outgoingPayloadStats.SentTime = time.Now() + stats.Handle(cs.s.Context(), outgoingPayloadStats) + } + return err } func (cs *clientStream) RecvMsg(m interface{}) (err error) { - err = recv(cs.p, cs.codec, cs.s, cs.dc, m, math.MaxInt32, nil) + defer func() { + if err != nil && err != io.EOF && stats.On() { + errorStats := &stats.ErrorStats{ + IsClient: true, + Error: err, + } + stats.Handle(cs.s.Context(), errorStats) + } + }() + var incomingPayloadStats *stats.IncomingPayloadStats + if stats.On() { + incomingPayloadStats = &stats.IncomingPayloadStats{ + IsClient: true, + } + } + err = recv(cs.p, cs.codec, cs.s, cs.dc, m, math.MaxInt32, incomingPayloadStats) defer func() { // err != nil indicates the termination of the stream. if err != nil { @@ -325,10 +378,14 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) { } cs.mu.Unlock() } + if incomingPayloadStats != nil { + stats.Handle(cs.s.Context(), incomingPayloadStats) + } if !cs.desc.ClientStreams || cs.desc.ServerStreams { return } // Special handling for client streaming rpc. + // This recv expects EOF or errors, so we don't collect incomingPayloadStats. err = recv(cs.p, cs.codec, cs.s, cs.dc, m, math.MaxInt32, nil) cs.closeTransportStream(err) if err == nil { diff --git a/transport/http2_client.go b/transport/http2_client.go index 2b0f6801..c663bdbe 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -51,6 +51,7 @@ import ( "google.golang.org/grpc/grpclog" "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" + "google.golang.org/grpc/stats" ) // http2Client implements the ClientTransport interface with HTTP2. @@ -413,6 +414,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea } } first := true + bufLen := t.hBuf.Len() // Sends the headers in a single batch even when they span multiple frames. for !endHeaders { size := t.hBuf.Len() @@ -447,6 +449,17 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea return nil, connectionErrorf(true, err, "transport: %v", err) } } + if stats.On() { + outgoingHeaderStats := &stats.OutgoingHeaderStats{ + IsClient: true, + WireLength: bufLen, + Method: callHdr.Method, + RemoteAddr: t.RemoteAddr(), + LocalAddr: t.LocalAddr(), + Encryption: callHdr.SendCompress, + } + stats.Handle(s.Context(), outgoingHeaderStats) + } t.writableChan <- 0 return s, nil } @@ -874,6 +887,24 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { } endStream := frame.StreamEnded() + var isHeader bool + defer func() { + if stats.On() { + if isHeader { + incomingHeaderStats := &stats.IncomingHeaderStats{ + IsClient: true, + WireLength: int(frame.Header().Length), + } + stats.Handle(s.ctx, incomingHeaderStats) + } else { + incomingTrailerStats := &stats.IncomingTrailerStats{ + IsClient: true, + WireLength: int(frame.Header().Length), + } + stats.Handle(s.ctx, incomingTrailerStats) + } + } + }() s.mu.Lock() if !endStream { @@ -885,6 +916,7 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { } close(s.headerChan) s.headerDone = true + isHeader = true } if !endStream || s.state == streamDone { s.mu.Unlock() @@ -1070,3 +1102,11 @@ func (t *http2Client) notifyError(err error) { } t.mu.Unlock() } + +func (t *http2Client) LocalAddr() net.Addr { + return t.conn.LocalAddr() +} + +func (t *http2Client) RemoteAddr() net.Addr { + return t.conn.RemoteAddr() +} diff --git a/transport/transport.go b/transport/transport.go index 7dd02a02..a7824965 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -465,6 +465,11 @@ type ClientTransport interface { // receives the draining signal from the server (e.g., GOAWAY frame in // HTTP/2). GoAway() <-chan struct{} + + // LocalAddr returns the local network address. + LocalAddr() net.Addr + // RemoteAddr returns the remote network address. + RemoteAddr() net.Addr } // ServerTransport is the common interface for all gRPC server-side transport