diff --git a/stats/stats.go b/stats/stats.go index 4a5327a4..1317f139 100644 --- a/stats/stats.go +++ b/stats/stats.go @@ -42,6 +42,7 @@ import ( "time" "golang.org/x/net/context" + "google.golang.org/grpc/grpclog" ) // RPCStats contains stats information about RPCs. @@ -194,19 +195,25 @@ func Handle(ctx context.Context, s RPCStats) { handler(ctx, s) } -// RegisterHandler registers the user handler function and starts the stats collection. +// RegisterHandler registers the user handler function. +// If another handler was registered before, this new handler will overwrite the old one. // This handler function will be called to process the stats. func RegisterHandler(f func(context.Context, RPCStats)) { handler = f - start() } -// start starts the stats collection. -func start() { +// Start starts the stats collection. +// Stats will only be started if handler is not nil. +func Start() { + if handler == nil { + grpclog.Println("handler is nil when starting stats. Stats is not started") + return + } atomic.StoreInt32(on, 1) } // Stop stops the collection of any further stats. +// Stop won't unregister handler. func Stop() { atomic.StoreInt32(on, 0) } diff --git a/stats/stats_test.go b/stats/stats_test.go index ddf021e1..07a1e261 100644 --- a/stats/stats_test.go +++ b/stats/stats_test.go @@ -51,13 +51,21 @@ import ( func TestStartStop(t *testing.T) { stats.RegisterHandler(nil) - defer stats.Stop() // Stop stats in the case of the first Fatalf. + stats.Start() + if stats.On() != false { + t.Fatalf("stats.Start() with nil handler, stats.On() = true, want false") + } + stats.RegisterHandler(func(ctx context.Context, s stats.RPCStats) {}) + if stats.On() != false { + t.Fatalf("after stats.RegisterHandler(), stats.On() = true, want false") + } + stats.Start() if stats.On() != true { - t.Fatalf("after start.RegisterCallBack(_), stats.On() = false, want true") + t.Fatalf("after stats.Start(_), stats.On() = false, want true") } stats.Stop() if stats.On() != false { - t.Fatalf("after start.Stop(), stats.On() = false, want true") + t.Fatalf("after stats.Stop(), stats.On() = true, want false") } } @@ -519,6 +527,7 @@ func TestServerStatsUnaryRPC(t *testing.T) { got = append(got, &gotData{ctx, false, s}) } }) + stats.Start() te := newTest(t, "") te.startServer(&testServer{}) @@ -570,6 +579,7 @@ func TestServerStatsUnaryRPCError(t *testing.T) { got = append(got, &gotData{ctx, false, s}) } }) + stats.Start() te := newTest(t, "") te.startServer(&testServer{}) @@ -622,6 +632,7 @@ func TestServerStatsStreamingRPC(t *testing.T) { got = append(got, &gotData{ctx, false, s}) } }) + stats.Start() te := newTest(t, "gzip") te.startServer(&testServer{}) @@ -680,6 +691,7 @@ func TestServerStatsStreamingRPCError(t *testing.T) { got = append(got, &gotData{ctx, false, s}) } }) + stats.Start() te := newTest(t, "gzip") te.startServer(&testServer{}) @@ -739,6 +751,7 @@ func TestClientStatsUnaryRPC(t *testing.T) { got = append(got, &gotData{ctx, true, s}) } }) + stats.Start() te := newTest(t, "") te.startServer(&testServer{}) @@ -827,6 +840,7 @@ func TestClientStatsUnaryRPCError(t *testing.T) { got = append(got, &gotData{ctx, true, s}) } }) + stats.Start() te := newTest(t, "") te.startServer(&testServer{}) @@ -879,6 +893,7 @@ func TestClientStatsStreamingRPC(t *testing.T) { got = append(got, &gotData{ctx, true, s}) } }) + stats.Start() te := newTest(t, "gzip") te.startServer(&testServer{}) @@ -969,6 +984,7 @@ func TestClientStatsStreamingRPCError(t *testing.T) { got = append(got, &gotData{ctx, true, s}) } }) + stats.Start() te := newTest(t, "gzip") te.startServer(&testServer{})