diff --git a/internal/balancer/gracefulswitch/gracefulswitch.go b/internal/balancer/gracefulswitch/gracefulswitch.go index 7ba8f4d1..08666f62 100644 --- a/internal/balancer/gracefulswitch/gracefulswitch.go +++ b/internal/balancer/gracefulswitch/gracefulswitch.go @@ -193,6 +193,8 @@ func (gsb *Balancer) ExitIdle() { ei.ExitIdle() return } + gsb.mu.Lock() + defer gsb.mu.Unlock() for sc := range balToUpdate.subconns { sc.Connect() } diff --git a/internal/balancer/gracefulswitch/gracefulswitch_test.go b/internal/balancer/gracefulswitch/gracefulswitch_test.go index 02018f06..265e1f78 100644 --- a/internal/balancer/gracefulswitch/gracefulswitch_test.go +++ b/internal/balancer/gracefulswitch/gracefulswitch_test.go @@ -826,6 +826,40 @@ func (s) TestInlineCallbackInBuild(t *testing.T) { } } +// TestExitIdle tests the ExitIdle operation on the Graceful Switch Balancer for +// both possible codepaths, one where the child implements ExitIdler interface +// and one where the child doesn't implement ExitIdler interface. +func (s) TestExitIdle(t *testing.T) { + _, gsb := setup(t) + // switch to a balancer that implements ExitIdle{} (will populate current). + gsb.SwitchTo(mockBalancerBuilder1{}) + currBal := gsb.balancerCurrent.Balancer.(*mockBalancer) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + // exitIdle on the Graceful Switch Balancer should get forwarded to the + // current child as it implements exitIdle. + gsb.ExitIdle() + if err := currBal.waitForExitIdle(ctx); err != nil { + t.Fatal(err) + } + + // switch to a balancer that doesn't implement ExitIdle{} (will populate + // pending). + gsb.SwitchTo(verifyBalancerBuilder{}) + // call exitIdle concurrently with newSubConn to make sure there is not a + // data race. + done := make(chan struct{}) + go func() { + gsb.ExitIdle() + close(done) + }() + pendBal := gsb.balancerPending.Balancer.(*verifyBalancer) + for i := 0; i < 10; i++ { + pendBal.newSubConn([]resolver.Address{}, balancer.NewSubConnOptions{}) + } + <-done +} + const balancerName1 = "mock_balancer_1" const balancerName2 = "mock_balancer_2" const verifyBalName = "verifyNoSubConnUpdateAfterCloseBalancer" @@ -839,6 +873,7 @@ func (mockBalancerBuilder1) Build(cc balancer.ClientConn, opts balancer.BuildOpt scStateCh: testutils.NewChannel(), resolverErrCh: testutils.NewChannel(), closeCh: testutils.NewChannel(), + exitIdleCh: testutils.NewChannel(), cc: cc, } } @@ -863,6 +898,8 @@ type mockBalancer struct { resolverErrCh *testutils.Channel // closeCh is a channel used to signal the closing of this balancer. closeCh *testutils.Channel + // exitIdleCh is a channel used to signal the receipt of an ExitIdle call. + exitIdleCh *testutils.Channel // Hold onto ClientConn wrapper to communicate with it cc balancer.ClientConn } @@ -890,6 +927,10 @@ func (mb1 *mockBalancer) Close() { mb1.closeCh.Send(struct{}{}) } +func (mb1 *mockBalancer) ExitIdle() { + mb1.exitIdleCh.Send(struct{}{}) +} + // waitForClientConnUpdate verifies if the mockBalancer receives the // provided ClientConnState within a reasonable amount of time. func (mb1 *mockBalancer) waitForClientConnUpdate(ctx context.Context, wantCCS balancer.ClientConnState) error { @@ -940,6 +981,15 @@ func (mb1 *mockBalancer) waitForClose(ctx context.Context) error { return nil } +// waitForExitIdle verifies that ExitIdle gets called on the mockBalancer before +// the context expires. +func (mb1 *mockBalancer) waitForExitIdle(ctx context.Context) error { + if _, err := mb1.exitIdleCh.Receive(ctx); err != nil { + return fmt.Errorf("error waiting for ExitIdle(): %v", err) + } + return nil +} + func (mb1 *mockBalancer) updateState(state balancer.State) { mb1.cc.UpdateState(state) }