diff --git a/grpclb/grpclb.go b/grpclb/grpclb.go index ee379a0e..932fdf06 100644 --- a/grpclb/grpclb.go +++ b/grpclb/grpclb.go @@ -70,6 +70,7 @@ type addrInfo struct { type balancer struct { r naming.Resolver mu sync.Mutex + seq int // a sequence number to make sure addrCh does not get stale addresses. w naming.Watcher addrCh chan []grpc.Address rbs []remoteBalancerInfo @@ -84,12 +85,12 @@ func (b *balancer) watchAddrUpdates(w naming.Watcher, ch chan remoteBalancerInfo if err != nil { return err } - var bAddr remoteBalancerInfo b.mu.Lock() defer b.mu.Unlock() if b.done { return grpc.ErrClientConnClosing } + var bAddr remoteBalancerInfo if len(b.rbs) > 0 { bAddr = b.rbs[0] } @@ -102,7 +103,7 @@ func (b *balancer) watchAddrUpdates(w naming.Watcher, ch chan remoteBalancerInfo case naming.Add: var exist bool for _, v := range b.rbs { - // TODO: Is the same addr with different different server name a different balancer? + // TODO: Is the same addr with different server name a different balancer? if addr == v.addr { exist = true break @@ -139,7 +140,7 @@ func (b *balancer) watchAddrUpdates(w naming.Watcher, ch chan remoteBalancerInfo return nil } -func (b *balancer) processServerList(l *lbpb.ServerList) { +func (b *balancer) processServerList(l *lbpb.ServerList, seq int) { servers := l.GetServers() var ( sl []addrInfo @@ -159,7 +160,7 @@ func (b *balancer) processServerList(l *lbpb.ServerList) { } b.mu.Lock() defer b.mu.Unlock() - if b.done { + if b.done || seq < b.seq { return } if len(sl) > 0 { @@ -172,12 +173,6 @@ func (b *balancer) processServerList(l *lbpb.ServerList) { } func (b *balancer) callRemoteBalancer(lbc lbpb.LoadBalancerClient) (retry bool) { - b.mu.Lock() - if b.done { - b.mu.Unlock() - return - } - b.mu.Unlock() ctx, cancel := context.WithCancel(context.Background()) defer cancel() stream, err := lbc.BalanceLoad(ctx, grpc.FailFast(false)) @@ -185,6 +180,14 @@ func (b *balancer) callRemoteBalancer(lbc lbpb.LoadBalancerClient) (retry bool) grpclog.Printf("Failed to perform RPC to the remote balancer %v", err) return } + b.mu.Lock() + if b.done { + b.mu.Unlock() + return + } + b.seq++ + seq := b.seq + b.mu.Unlock() initReq := &lbpb.LoadBalanceRequest{ LoadBalanceRequestType: &lbpb.LoadBalanceRequest_InitialRequest{ InitialRequest: new(lbpb.InitialLoadBalanceRequest), @@ -217,7 +220,7 @@ func (b *balancer) callRemoteBalancer(lbc lbpb.LoadBalancerClient) (retry bool) break } if serverList := reply.GetServerList(); serverList != nil { - b.processServerList(serverList) + b.processServerList(serverList, seq) } } return true @@ -307,6 +310,9 @@ func (b *balancer) down(addr grpc.Address, err error) { func (b *balancer) Up(addr grpc.Address) func(error) { b.mu.Lock() defer b.mu.Unlock() + if b.done { + return nil + } var cnt int for _, a := range b.addrs { if a.addr == addr { diff --git a/grpclb/grpclb_test.go b/grpclb/grpclb_test.go index 8e142605..fabb9fec 100644 --- a/grpclb/grpclb_test.go +++ b/grpclb/grpclb_test.go @@ -1,6 +1,6 @@ /* * - * Copyright 2014, Google Inc. + * Copyright 2016, Google Inc. * All rights reserved. * * Redistribution and use in source and binary forms, with or without @@ -57,12 +57,15 @@ type testWatcher struct { } func (w *testWatcher) Next() (updates []*naming.Update, err error) { - n := <-w.side - if n == 0 { + n, ok := <-w.side + if !ok { return nil, fmt.Errorf("w.side is closed") } for i := 0; i < n; i++ { - u := <-w.update + u, ok := <-w.update + if !ok { + break + } if u != nil { updates = append(updates, u) } @@ -158,11 +161,11 @@ func stopBackends(servers []*grpc.Server) { } } -func TestGrpcLB(t *testing.T) { +func TestGRPCLB(t *testing.T) { // Start a backend. beLis, err := net.Listen("tcp", "localhost:0") if err != nil { - t.Fatalf("fadjf") + t.Fatalf("Failed to listen %v", err) } backends := startBackends(beLis) defer stopBackends(backends)