diff --git a/stats/stats.go b/stats/stats.go
index 4a5327a4..1317f139 100644
--- a/stats/stats.go
+++ b/stats/stats.go
@@ -42,6 +42,7 @@ import (
 	"time"
 
 	"golang.org/x/net/context"
+	"google.golang.org/grpc/grpclog"
 )
 
 // RPCStats contains stats information about RPCs.
@@ -194,19 +195,25 @@ func Handle(ctx context.Context, s RPCStats) {
 	handler(ctx, s)
 }
 
-// RegisterHandler registers the user handler function and starts the stats collection.
+// RegisterHandler registers the user handler function.
+// If another handler was registered before, this new handler will overwrite the old one.
 // This handler function will be called to process the stats.
 func RegisterHandler(f func(context.Context, RPCStats)) {
 	handler = f
-	start()
 }
 
-// start starts the stats collection.
-func start() {
+// Start starts the stats collection.
+// Stats will only be started if handler is not nil.
+func Start() {
+	if handler == nil {
+		grpclog.Println("handler is nil when starting stats. Stats is not started")
+		return
+	}
 	atomic.StoreInt32(on, 1)
 }
 
 // Stop stops the collection of any further stats.
+// Stop won't unregister handler.
 func Stop() {
 	atomic.StoreInt32(on, 0)
 }
diff --git a/stats/stats_test.go b/stats/stats_test.go
index ddf021e1..07a1e261 100644
--- a/stats/stats_test.go
+++ b/stats/stats_test.go
@@ -51,13 +51,21 @@ import (
 
 func TestStartStop(t *testing.T) {
 	stats.RegisterHandler(nil)
-	defer stats.Stop() // Stop stats in the case of the first Fatalf.
+	stats.Start()
+	if stats.On() != false {
+		t.Fatalf("stats.Start() with nil handler, stats.On() = true, want false")
+	}
+	stats.RegisterHandler(func(ctx context.Context, s stats.RPCStats) {})
+	if stats.On() != false {
+		t.Fatalf("after stats.RegisterHandler(), stats.On() = true, want false")
+	}
+	stats.Start()
 	if stats.On() != true {
-		t.Fatalf("after start.RegisterCallBack(_), stats.On() = false, want true")
+		t.Fatalf("after stats.Start(_), stats.On() = false, want true")
 	}
 	stats.Stop()
 	if stats.On() != false {
-		t.Fatalf("after start.Stop(), stats.On() = false, want true")
+		t.Fatalf("after stats.Stop(), stats.On() = true, want false")
 	}
 }
 
@@ -519,6 +527,7 @@ func TestServerStatsUnaryRPC(t *testing.T) {
 			got = append(got, &gotData{ctx, false, s})
 		}
 	})
+	stats.Start()
 
 	te := newTest(t, "")
 	te.startServer(&testServer{})
@@ -570,6 +579,7 @@ func TestServerStatsUnaryRPCError(t *testing.T) {
 			got = append(got, &gotData{ctx, false, s})
 		}
 	})
+	stats.Start()
 
 	te := newTest(t, "")
 	te.startServer(&testServer{})
@@ -622,6 +632,7 @@ func TestServerStatsStreamingRPC(t *testing.T) {
 			got = append(got, &gotData{ctx, false, s})
 		}
 	})
+	stats.Start()
 
 	te := newTest(t, "gzip")
 	te.startServer(&testServer{})
@@ -680,6 +691,7 @@ func TestServerStatsStreamingRPCError(t *testing.T) {
 			got = append(got, &gotData{ctx, false, s})
 		}
 	})
+	stats.Start()
 
 	te := newTest(t, "gzip")
 	te.startServer(&testServer{})
@@ -739,6 +751,7 @@ func TestClientStatsUnaryRPC(t *testing.T) {
 			got = append(got, &gotData{ctx, true, s})
 		}
 	})
+	stats.Start()
 
 	te := newTest(t, "")
 	te.startServer(&testServer{})
@@ -827,6 +840,7 @@ func TestClientStatsUnaryRPCError(t *testing.T) {
 			got = append(got, &gotData{ctx, true, s})
 		}
 	})
+	stats.Start()
 
 	te := newTest(t, "")
 	te.startServer(&testServer{})
@@ -879,6 +893,7 @@ func TestClientStatsStreamingRPC(t *testing.T) {
 			got = append(got, &gotData{ctx, true, s})
 		}
 	})
+	stats.Start()
 
 	te := newTest(t, "gzip")
 	te.startServer(&testServer{})
@@ -969,6 +984,7 @@ func TestClientStatsStreamingRPCError(t *testing.T) {
 			got = append(got, &gotData{ctx, true, s})
 		}
 	})
+	stats.Start()
 
 	te := newTest(t, "gzip")
 	te.startServer(&testServer{})