diff --git a/picker_test.go b/picker_test.go index efe0c23f..b2030023 100644 --- a/picker_test.go +++ b/picker_test.go @@ -36,6 +36,7 @@ package grpc import ( "fmt" "math" + "sync" "testing" "time" @@ -45,11 +46,12 @@ import ( type testWatcher struct { // the channel to receives name resolution updates - update chan *naming.Update + update chan *naming.Update // the side channel to get to know how many updates in a batch - side chan int + side chan int // the channel to notifiy update injector that the update reading is done readDone chan int + wg *sync.WaitGroup } func (w *testWatcher) Next() (updates []*naming.Update, err error) { @@ -81,6 +83,7 @@ func (w *testWatcher) inject(updates []*naming.Update) { type testNameResolver struct { w *testWatcher addr string + wg *sync.WaitGroup } func (r *testNameResolver) Resolve(target string) (naming.Watcher, error) { @@ -88,6 +91,7 @@ func (r *testNameResolver) Resolve(target string) (naming.Watcher, error) { update: make(chan *naming.Update, 1), side: make(chan int, 1), readDone: make(chan int), + wg: r.wg, } r.w.side <- 1 r.w.update <- &naming.Update{ @@ -96,11 +100,14 @@ func (r *testNameResolver) Resolve(target string) (naming.Watcher, error) { } go func() { <-r.w.readDone + if r.w.wg != nil { + r.w.wg.Done() + } }() return r.w, nil } -func startServers(t *testing.T, numServers, port int, maxStreams uint32) ([]*server, *testNameResolver) { +func startServers(t *testing.T, numServers, port int, maxStreams uint32, wg *sync.WaitGroup) ([]*server, *testNameResolver) { var servers []*server for i := 0; i < numServers; i++ { s := &server{readyChan: make(chan bool)} @@ -110,12 +117,15 @@ func startServers(t *testing.T, numServers, port int, maxStreams uint32) ([]*ser } // Point to server1 addr := "127.0.0.1:" + servers[0].port - return servers, &testNameResolver{addr: addr} + return servers, &testNameResolver{ + addr: addr, + wg: wg, + } } func TestNameDiscovery(t *testing.T) { // Start 3 servers on 3 ports. - servers, r := startServers(t, 3, 0, math.MaxUint32) + servers, r := startServers(t, 3, 0, math.MaxUint32, nil) cc, err := Dial("foo.bar.com", WithPicker(NewUnicastNamingPicker(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) if err != nil { t.Fatalf("Failed to create ClientConn: %v", err) @@ -156,7 +166,9 @@ func TestNameDiscovery(t *testing.T) { } func TestEmptyAddrs(t *testing.T) { - servers, r := startServers(t, 1, 0, math.MaxUint32) + var wg sync.WaitGroup + servers, r := startServers(t, 1, 0, math.MaxUint32, &wg) + wg.Add(1) cc, err := Dial("foo.bar.com", WithPicker(NewUnicastNamingPicker(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) if err != nil { t.Fatalf("Failed to create ClientConn: %v", err) @@ -172,6 +184,8 @@ func TestEmptyAddrs(t *testing.T) { Op: naming.Delete, Addr: "127.0.0.1:" + servers[0].port, }) + // Wait until the first reading is done. + wg.Wait() r.w.inject(updates) ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond) if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc); err == nil {