diff --git a/xds/internal/balancer/cdsbalancer/cdsbalancer_test.go b/xds/internal/balancer/cdsbalancer/cdsbalancer_test.go index 2d5a8222..e9f29090 100644 --- a/xds/internal/balancer/cdsbalancer/cdsbalancer_test.go +++ b/xds/internal/balancer/cdsbalancer/cdsbalancer_test.go @@ -33,6 +33,8 @@ import ( xdsinternal "google.golang.org/grpc/xds/internal" xdsbalancer "google.golang.org/grpc/xds/internal/balancer" xdsclient "google.golang.org/grpc/xds/internal/client" + "google.golang.org/grpc/xds/internal/testutils" + "google.golang.org/grpc/xds/internal/testutils/fakexds" ) const ( @@ -45,67 +47,6 @@ type testClientConn struct { balancer.ClientConn } -// testXDSClient is fake implementation of the xdsClientInterface. It contains -// a bunch of channels to signal different events to the test. -type testXDSClient struct { - // watchCb is the watch API callback registered by the cdsBalancer. Used to - // pass different CDS updates to the balancer, from the test. - watchCb func(xdsclient.CDSUpdate, error) - // clusterCh is a channel used to signal the cluster name for which the - // watch API call was invoked on this client. - clusterCh chan string - // cancelWatchCh is a channel used to signal the cancellation of the - // registered watch API. - cancelWatchCh chan struct{} -} - -func newTestXDSClient() *testXDSClient { - return &testXDSClient{ - clusterCh: make(chan string, 1), - cancelWatchCh: make(chan struct{}, 1), - } -} -func (tc *testXDSClient) WatchCluster(clusterName string, callback func(xdsclient.CDSUpdate, error)) func() { - tc.watchCb = callback - tc.clusterCh <- clusterName - return tc.cancelWatch -} - -func (tc *testXDSClient) Close() {} - -func (tc *testXDSClient) cancelWatch() { - tc.cancelWatchCh <- struct{}{} -} - -// waitForWatch verifies if the testXDSClient receives a CDS watch API with the -// provided clusterName within a reasonable amount of time. -func (tc *testXDSClient) waitForWatch(wantCluster string) error { - timer := time.NewTimer(defaultTestTimeout) - select { - case <-timer.C: - return errors.New("Timeout when expecting CDS watch call") - case gotCluster := <-tc.clusterCh: - timer.Stop() - if gotCluster != wantCluster { - return fmt.Errorf("WatchCluster called with clusterName: %s, want %s", gotCluster, wantCluster) - } - return nil - } -} - -// waitForCancelWatch verifies if the CDS watch API is cancelled within a -// reasonable amount of time. -func (tc *testXDSClient) waitForCancelWatch() error { - timer := time.NewTimer(defaultTestTimeout) - select { - case <-timer.C: - return errors.New("Timeout when expecting CDS watch call to be cancelled") - case <-tc.cancelWatchCh: - timer.Stop() - return nil - } -} - // cdsWatchInfo wraps the update and the error sent in a CDS watch callback. type cdsWatchInfo struct { update xdsclient.CDSUpdate @@ -114,8 +55,8 @@ type cdsWatchInfo struct { // invokeWatchCb invokes the CDS watch callback registered by the cdsBalancer // and waits for appropriate state to be pushed to the provided edsBalancer. -func (tc *testXDSClient) invokeWatchCb(cdsW cdsWatchInfo, wantCCS balancer.ClientConnState, edsB *testEDSBalancer) error { - tc.watchCb(cdsW.update, cdsW.err) +func invokeWatchCbAndWait(xdsC *fakexds.Client, cdsW cdsWatchInfo, wantCCS balancer.ClientConnState, edsB *testEDSBalancer) error { + xdsC.InvokeWatchClusterCallback(cdsW.update, cdsW.err) if cdsW.err != nil { return edsB.waitForResolverError(cdsW.err) } @@ -283,16 +224,20 @@ func setup() (*cdsBalancer, *testEDSBalancer, func()) { // setupWithWatch does everything that setup does, and also pushes a ClientConn // update to the cdsBalancer and waits for a CDS watch call to be registered. -func setupWithWatch(t *testing.T) (*testXDSClient, *cdsBalancer, *testEDSBalancer, func()) { +func setupWithWatch(t *testing.T) (*fakexds.Client, *cdsBalancer, *testEDSBalancer, func()) { t.Helper() - xdsC := newTestXDSClient() + xdsC := fakexds.NewClient() cdsB, edsB, cancel := setup() if err := cdsB.UpdateClientConnState(cdsCCS(clusterName, xdsC)); err != nil { t.Fatalf("cdsBalancer.UpdateClientConnState failed with error: %v", err) } - if err := xdsC.waitForWatch(clusterName); err != nil { - t.Fatal(err) + gotCluster, err := xdsC.WaitForWatchCluster() + if err != nil { + t.Fatalf("xdsClient.WatchCDS failed with error: %v", err) + } + if gotCluster != clusterName { + t.Fatalf("xdsClient.WatchCDS called for cluster: %v, want: %v", gotCluster, clusterName) } return xdsC, cdsB, edsB, cancel } @@ -301,7 +246,7 @@ func setupWithWatch(t *testing.T) (*testXDSClient, *cdsBalancer, *testEDSBalance // cdsBalancer with different inputs and verifies that the CDS watch API on the // provided xdsClient is invoked appropriately. func TestUpdateClientConnState(t *testing.T) { - xdsC := newTestXDSClient() + xdsC := fakexds.NewClient() tests := []struct { name string @@ -361,8 +306,12 @@ func TestUpdateClientConnState(t *testing.T) { // When we wanted an error and got it, we should return early. return } - if err := xdsC.waitForWatch(test.wantCluster); err != nil { - t.Fatal(err) + gotCluster, err := xdsC.WaitForWatchCluster() + if err != nil { + t.Fatalf("xdsClient.WatchCDS failed with error: %v", err) + } + if gotCluster != test.wantCluster { + t.Fatalf("xdsClient.WatchCDS called for cluster: %v, want: %v", gotCluster, test.wantCluster) } }) } @@ -375,7 +324,7 @@ func TestUpdateClientConnStateAfterClose(t *testing.T) { defer cancel() cdsB.Close() - if err := cdsB.UpdateClientConnState(cdsCCS(clusterName, newTestXDSClient())); err != errBalancerClosed { + if err := cdsB.UpdateClientConnState(cdsCCS(clusterName, fakexds.NewClient())); err != errBalancerClosed { t.Fatalf("UpdateClientConnState() after close returned %v, want %v", err, errBalancerClosed) } } @@ -393,8 +342,8 @@ func TestUpdateClientConnStateWithSameState(t *testing.T) { if err := cdsB.UpdateClientConnState(cdsCCS(clusterName, xdsC)); err != nil { t.Fatalf("cdsBalancer.UpdateClientConnState failed with error: %v", err) } - if err := xdsC.waitForWatch(clusterName); err == nil { - t.Fatal("Waiting for WatchCluster() should have timed out, but returned with nil error") + if _, err := xdsC.WaitForWatchCluster(); err != testutils.ErrRecvTimeout { + t.Fatalf("waiting for WatchCluster() should have timed out, but returned error: %v", err) } } @@ -432,7 +381,7 @@ func TestHandleClusterUpdate(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - if err := xdsC.invokeWatchCb(cdsWatchInfo{test.cdsUpdate, test.updateErr}, test.wantCCS, edsB); err != nil { + if err := invokeWatchCbAndWait(xdsC, cdsWatchInfo{test.cdsUpdate, test.updateErr}, test.wantCCS, edsB); err != nil { t.Fatal(err) } }) @@ -451,13 +400,13 @@ func TestResolverError(t *testing.T) { cdsUpdate := xdsclient.CDSUpdate{ServiceName: serviceName} wantCCS := edsCCS(serviceName, false, xdsC) - if err := xdsC.invokeWatchCb(cdsWatchInfo{cdsUpdate, nil}, wantCCS, edsB); err != nil { + if err := invokeWatchCbAndWait(xdsC, cdsWatchInfo{cdsUpdate, nil}, wantCCS, edsB); err != nil { t.Fatal(err) } rErr := errors.New("cdsBalancer resolver error") cdsB.ResolverError(rErr) - if err := xdsC.waitForCancelWatch(); err != nil { + if err := xdsC.WaitForCancelClusterWatch(); err != nil { t.Fatal(err) } if err := edsB.waitForResolverError(rErr); err != nil { @@ -476,7 +425,7 @@ func TestUpdateSubConnState(t *testing.T) { cdsUpdate := xdsclient.CDSUpdate{ServiceName: serviceName} wantCCS := edsCCS(serviceName, false, xdsC) - if err := xdsC.invokeWatchCb(cdsWatchInfo{cdsUpdate, nil}, wantCCS, edsB); err != nil { + if err := invokeWatchCbAndWait(xdsC, cdsWatchInfo{cdsUpdate, nil}, wantCCS, edsB); err != nil { t.Fatal(err) } @@ -496,12 +445,12 @@ func TestClose(t *testing.T) { cdsUpdate := xdsclient.CDSUpdate{ServiceName: serviceName} wantCCS := edsCCS(serviceName, false, xdsC) - if err := xdsC.invokeWatchCb(cdsWatchInfo{cdsUpdate, nil}, wantCCS, edsB); err != nil { + if err := invokeWatchCbAndWait(xdsC, cdsWatchInfo{cdsUpdate, nil}, wantCCS, edsB); err != nil { t.Fatal(err) } cdsB.Close() - if err := xdsC.waitForCancelWatch(); err != nil { + if err := xdsC.WaitForCancelClusterWatch(); err != nil { t.Fatal(err) } if err := edsB.waitForClose(); err != nil { diff --git a/xds/internal/testutils/fakexds/client.go b/xds/internal/testutils/fakexds/client.go index 8775dcdb..d9c66a57 100644 --- a/xds/internal/testutils/fakexds/client.go +++ b/xds/internal/testutils/fakexds/client.go @@ -33,14 +33,17 @@ import ( type Client struct { name string suWatchCh *testutils.Channel + cdsWatchCh *testutils.Channel edsWatchCh *testutils.Channel suCancelCh *testutils.Channel + cdsCancelCh *testutils.Channel edsCancelCh *testutils.Channel loadReportCh *testutils.Channel closeCh *testutils.Channel mu sync.Mutex serviceCb func(xdsclient.ServiceUpdate, error) + cdsCb func(xdsclient.CDSUpdate, error) edsCb func(*xdsclient.EDSUpdate, error) } @@ -60,6 +63,9 @@ func (xdsC *Client) WatchService(target string, callback func(xdsclient.ServiceU // within a reasonable timeout, and returns the serviceName being watched. func (xdsC *Client) WaitForWatchService() (string, error) { val, err := xdsC.suWatchCh.Receive() + if err != nil { + return "", err + } return val.(string), err } @@ -71,6 +77,43 @@ func (xdsC *Client) InvokeWatchServiceCallback(cluster string, err error) { xdsC.serviceCb(xdsclient.ServiceUpdate{Cluster: cluster}, err) } +// WatchCluster registers a CDS watch. +func (xdsC *Client) WatchCluster(clusterName string, callback func(xdsclient.CDSUpdate, error)) func() { + xdsC.mu.Lock() + defer xdsC.mu.Unlock() + + xdsC.cdsCb = callback + xdsC.cdsWatchCh.Send(clusterName) + return func() { + xdsC.cdsCancelCh.Send(nil) + } +} + +// WaitForWatchCluster waits for WatchCluster to be invoked on this client +// within a reasonable timeout, and returns the clusterName being watched. +func (xdsC *Client) WaitForWatchCluster() (string, error) { + val, err := xdsC.cdsWatchCh.Receive() + if err != nil { + return "", err + } + return val.(string), err +} + +// InvokeWatchClusterCallback invokes the registered cdsWatch callback. +func (xdsC *Client) InvokeWatchClusterCallback(update xdsclient.CDSUpdate, err error) { + xdsC.mu.Lock() + defer xdsC.mu.Unlock() + + xdsC.cdsCb(update, err) +} + +// WaitForCancelClusterWatch waits for a CDS watch to be cancelled within a +// reasonable timeout, and returns testutils.ErrRecvTimeout otherwise. +func (xdsC *Client) WaitForCancelClusterWatch() error { + _, err := xdsC.cdsCancelCh.Receive() + return err +} + // WatchEDS registers an EDS watch for provided clusterName. func (xdsC *Client) WatchEDS(clusterName string, callback func(*xdsclient.EDSUpdate, error)) (cancel func()) { xdsC.mu.Lock() @@ -87,6 +130,9 @@ func (xdsC *Client) WatchEDS(clusterName string, callback func(*xdsclient.EDSUpd // reasonable timeout, and returns the clusterName being watched. func (xdsC *Client) WaitForWatchEDS() (string, error) { val, err := xdsC.edsWatchCh.Receive() + if err != nil { + return "", err + } return val.(string), err } @@ -148,8 +194,10 @@ func NewClientWithName(name string) *Client { return &Client{ name: name, suWatchCh: testutils.NewChannel(), + cdsWatchCh: testutils.NewChannel(), edsWatchCh: testutils.NewChannel(), suCancelCh: testutils.NewChannel(), + cdsCancelCh: testutils.NewChannel(), edsCancelCh: testutils.NewChannel(), loadReportCh: testutils.NewChannel(), closeCh: testutils.NewChannel(),