diff --git a/balancer/grpclb/grpclb_picker.go b/balancer/grpclb/grpclb_picker.go
index efef9baa..d77af401 100644
--- a/balancer/grpclb/grpclb_picker.go
+++ b/balancer/grpclb/grpclb_picker.go
@@ -29,7 +29,10 @@ import (
 	"google.golang.org/grpc/status"
 )
 
+// rpcStats is same as lbmpb.ClientStats, except that numCallsDropped is a map
+// instead of a slice.
 type rpcStats struct {
+	// Only access the following fields atomically.
 	numCallsStarted                        int64
 	numCallsFinished                       int64
 	numCallsFinishedWithClientFailedToSend int64
diff --git a/balancer/grpclb/grpclb_test.go b/balancer/grpclb/grpclb_test.go
index ee78c8ac..7e6ca76e 100644
--- a/balancer/grpclb/grpclb_test.go
+++ b/balancer/grpclb/grpclb_test.go
@@ -26,6 +26,7 @@ import (
 	"strconv"
 	"strings"
 	"sync"
+	"sync/atomic"
 	"testing"
 	"time"
 
@@ -121,35 +122,19 @@ func fakeNameDialer(addr string, timeout time.Duration) (net.Conn, error) {
 	return net.DialTimeout("tcp", addr, timeout)
 }
 
-// rpcStatsForTest is same as lbpb.ClientStats, except that numCallsDropped is a map
-// instead of a slice of pointers.
+// merge merges the new client stats into current stats.
 //
-// TODO: this struct was already defined in grpclb_picker.go. Try to merge these
-// two after moving grpclb to its own package (this package).
-type rpcStatsForTest struct {
-	numCallsStarted                        int64
-	numCallsFinished                       int64
-	numCallsFinishedWithClientFailedToSend int64
-	numCallsFinishedKnownReceived          int64
-
-	// map load_balance_token -> num_calls_dropped
-	numCallsDropped map[string]int64
-}
-
-func newRPCStatsForTest() *rpcStatsForTest {
-	return &rpcStatsForTest{
-		numCallsDropped: make(map[string]int64),
-	}
-}
-
-func (stats *rpcStatsForTest) merge(new *lbpb.ClientStats) {
-	stats.numCallsStarted += new.NumCallsStarted
-	stats.numCallsFinished += new.NumCallsFinished
-	stats.numCallsFinishedWithClientFailedToSend += new.NumCallsFinishedWithClientFailedToSend
-	stats.numCallsFinishedKnownReceived += new.NumCallsFinishedKnownReceived
+// It's a test-only method. rpcStats is defined in grpclb_picker.
+func (s *rpcStats) merge(new *lbpb.ClientStats) {
+	atomic.AddInt64(&s.numCallsStarted, new.NumCallsStarted)
+	atomic.AddInt64(&s.numCallsFinished, new.NumCallsFinished)
+	atomic.AddInt64(&s.numCallsFinishedWithClientFailedToSend, new.NumCallsFinishedWithClientFailedToSend)
+	atomic.AddInt64(&s.numCallsFinishedKnownReceived, new.NumCallsFinishedKnownReceived)
+	s.mu.Lock()
 	for _, perToken := range new.CallsFinishedWithDrop {
-		stats.numCallsDropped[perToken.LoadBalanceToken] += perToken.NumCalls
+		s.numCallsDropped[perToken.LoadBalanceToken] += perToken.NumCalls
 	}
+	s.mu.Unlock()
 }
 
 func mapsEqual(a, b map[string]int64) bool {
@@ -164,20 +149,31 @@ func mapsEqual(a, b map[string]int64) bool {
 	return true
 }
 
-func (stats *rpcStatsForTest) equal(new *rpcStatsForTest) bool {
-	if stats.numCallsStarted != new.numCallsStarted {
+func atomicEqual(a, b *int64) bool {
+	return atomic.LoadInt64(a) == atomic.LoadInt64(b)
+}
+
+// equal compares two rpcStats.
+//
+// It's a test-only method. rpcStats is defined in grpclb_picker.
+func (s *rpcStats) equal(new *rpcStats) bool {
+	if !atomicEqual(&s.numCallsStarted, &new.numCallsStarted) {
 		return false
 	}
-	if stats.numCallsFinished != new.numCallsFinished {
+	if !atomicEqual(&s.numCallsFinished, &new.numCallsFinished) {
 		return false
 	}
-	if stats.numCallsFinishedWithClientFailedToSend != new.numCallsFinishedWithClientFailedToSend {
+	if !atomicEqual(&s.numCallsFinishedWithClientFailedToSend, &new.numCallsFinishedWithClientFailedToSend) {
 		return false
 	}
-	if stats.numCallsFinishedKnownReceived != new.numCallsFinishedKnownReceived {
+	if !atomicEqual(&s.numCallsFinishedKnownReceived, &new.numCallsFinishedKnownReceived) {
 		return false
 	}
-	if !mapsEqual(stats.numCallsDropped, new.numCallsDropped) {
+	s.mu.Lock()
+	defer s.mu.Unlock()
+	new.mu.Lock()
+	defer new.mu.Unlock()
+	if !mapsEqual(s.numCallsDropped, new.numCallsDropped) {
 		return false
 	}
 	return true
@@ -187,15 +183,14 @@ type remoteBalancer struct {
 	sls       chan *lbpb.ServerList
 	statsDura time.Duration
 	done      chan struct{}
-	mu        sync.Mutex
-	stats     *rpcStatsForTest
+	stats     *rpcStats
 }
 
 func newRemoteBalancer(intervals []time.Duration) *remoteBalancer {
 	return &remoteBalancer{
 		sls:   make(chan *lbpb.ServerList, 1),
 		done:  make(chan struct{}),
-		stats: newRPCStatsForTest(),
+		stats: newRPCStats(),
 	}
 }
 
@@ -235,9 +230,7 @@ func (b *remoteBalancer) BalanceLoad(stream lbgrpc.LoadBalancer_BalanceLoadServe
 			if req, err = stream.Recv(); err != nil {
 				return
 			}
-			b.mu.Lock()
 			b.stats.merge(req.GetClientStats())
-			b.mu.Unlock()
 		}
 	}()
 	for v := range b.sls {
@@ -752,14 +745,14 @@ func (failPreRPCCred) RequireTransportSecurity() bool {
 	return false
 }
 
-func checkStats(stats, expected *rpcStatsForTest) error {
+func checkStats(stats, expected *rpcStats) error {
 	if !stats.equal(expected) {
 		return fmt.Errorf("stats not equal: got %+v, want %+v", stats, expected)
 	}
 	return nil
 }
 
-func runAndGetStats(t *testing.T, drop bool, runRPCs func(*grpc.ClientConn)) *rpcStatsForTest {
+func runAndGetStats(t *testing.T, drop bool, runRPCs func(*grpc.ClientConn)) *rpcStats {
 	defer leakcheck.Check(t)
 
 	r, cleanup := manual.GenerateAndRegisterManualResolver()
@@ -800,9 +793,7 @@ func runAndGetStats(t *testing.T, drop bool, runRPCs func(*grpc.ClientConn)) *rp
 
 	runRPCs(cc)
 	time.Sleep(1 * time.Second)
-	tss.ls.mu.Lock()
 	stats := tss.ls.stats
-	tss.ls.mu.Unlock()
 	return stats
 }
 
@@ -825,7 +816,7 @@ func TestGRPCLBStatsUnarySuccess(t *testing.T) {
 		}
 	})
 
-	if err := checkStats(stats, &rpcStatsForTest{
+	if err := checkStats(stats, &rpcStats{
 		numCallsStarted:               int64(countRPC),
 		numCallsFinished:              int64(countRPC),
 		numCallsFinishedKnownReceived: int64(countRPC),
@@ -852,7 +843,7 @@ func TestGRPCLBStatsUnaryDrop(t *testing.T) {
 		}
 	})
 
-	if err := checkStats(stats, &rpcStatsForTest{
+	if err := checkStats(stats, &rpcStats{
 		numCallsStarted:                        int64(countRPC + c),
 		numCallsFinished:                       int64(countRPC + c),
 		numCallsFinishedWithClientFailedToSend: int64(c - 1),
@@ -875,7 +866,7 @@ func TestGRPCLBStatsUnaryFailedToSend(t *testing.T) {
 		}
 	})
 
-	if err := checkStats(stats, &rpcStatsForTest{
+	if err := checkStats(stats, &rpcStats{
 		numCallsStarted:                        int64(countRPC),
 		numCallsFinished:                       int64(countRPC),
 		numCallsFinishedWithClientFailedToSend: int64(countRPC - 1),
@@ -912,7 +903,7 @@ func TestGRPCLBStatsStreamingSuccess(t *testing.T) {
 		}
 	})
 
-	if err := checkStats(stats, &rpcStatsForTest{
+	if err := checkStats(stats, &rpcStats{
 		numCallsStarted:               int64(countRPC),
 		numCallsFinished:              int64(countRPC),
 		numCallsFinishedKnownReceived: int64(countRPC),
@@ -939,7 +930,7 @@ func TestGRPCLBStatsStreamingDrop(t *testing.T) {
 		}
 	})
 
-	if err := checkStats(stats, &rpcStatsForTest{
+	if err := checkStats(stats, &rpcStats{
 		numCallsStarted:                        int64(countRPC + c),
 		numCallsFinished:                       int64(countRPC + c),
 		numCallsFinishedWithClientFailedToSend: int64(c - 1),
@@ -968,7 +959,7 @@ func TestGRPCLBStatsStreamingFailedToSend(t *testing.T) {
 		}
 	})
 
-	if err := checkStats(stats, &rpcStatsForTest{
+	if err := checkStats(stats, &rpcStats{
 		numCallsStarted:                        int64(countRPC),
 		numCallsFinished:                       int64(countRPC),
 		numCallsFinishedWithClientFailedToSend: int64(countRPC - 1),