From 1ca9df53a7c8333301d197cacb9ecfff4efbcc9d Mon Sep 17 00:00:00 2001
From: Jean de Klerk <deklerk@google.com>
Date: Fri, 12 Oct 2018 15:22:27 -0700
Subject: [PATCH] internal: clean up and unflake state transitions test (#2366)

internal: clean up and unflake state transitions test

Switches state transitions test to using a notification from a custom load
balancer, instead of relying on waiting for laggy balancer state updates.

Also generally adds more coverage around state transitions and a framework
for easily adding more of these kinds of tests.

Fixes #2348
---
 clientconn_state_transition_test.go | 428 ++++++++++++++++++++++++++++
 clientconn_test.go                  | 147 ----------
 2 files changed, 428 insertions(+), 147 deletions(-)
 create mode 100644 clientconn_state_transition_test.go

diff --git a/clientconn_state_transition_test.go b/clientconn_state_transition_test.go
new file mode 100644
index 00000000..e16d606f
--- /dev/null
+++ b/clientconn_state_transition_test.go
@@ -0,0 +1,428 @@
+/*
+ *
+ * Copyright 2018 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+package grpc
+
+import (
+	"net"
+	"sync"
+	"testing"
+	"time"
+
+	"golang.org/x/net/context"
+	"golang.org/x/net/http2"
+	"google.golang.org/grpc/balancer"
+	"google.golang.org/grpc/connectivity"
+	"google.golang.org/grpc/internal/leakcheck"
+	"google.golang.org/grpc/resolver"
+	"google.golang.org/grpc/resolver/manual"
+)
+
+const stateRecordingBalancerName = "state_recoding_balancer"
+
+var testBalancer = &stateRecordingBalancer{}
+
+func init() {
+	balancer.Register(testBalancer)
+}
+
+func TestStateTransitions_SingleAddress(t *testing.T) {
+	for _, test := range []struct {
+		name   string
+		want   []connectivity.State
+		server func(net.Listener)
+	}{
+		// When the server returns server preface, the client enters READY.
+		{
+			name: "ServerEntersReadyOnPrefaceReceipt",
+			want: []connectivity.State{
+				connectivity.Connecting,
+				connectivity.Ready,
+			},
+			server: func(lis net.Listener) {
+				conn, err := lis.Accept()
+				if err != nil {
+					t.Error(err)
+					return
+				}
+
+				framer := http2.NewFramer(conn, conn)
+				if err := framer.WriteSettings(http2.Setting{}); err != nil {
+					t.Errorf("Error while writing settings frame. %v", err)
+					return
+				}
+			},
+		},
+		// When the connection is closed, the client enters TRANSIENT FAILURE.
+		{
+			name: "ServerEntersTransientFailureOnClose",
+			want: []connectivity.State{
+				connectivity.Connecting,
+				connectivity.TransientFailure,
+			},
+			server: func(lis net.Listener) {
+				conn, err := lis.Accept()
+				if err != nil {
+					t.Error(err)
+					return
+				}
+
+				conn.Close()
+			},
+		},
+	} {
+		t.Logf("Test %s", test.name)
+		testStateTransitionSingleAddress(t, test.want, test.server)
+	}
+}
+
+func testStateTransitionSingleAddress(t *testing.T, want []connectivity.State, server func(net.Listener)) {
+	defer leakcheck.Check(t)
+
+	stateNotifications := make(chan connectivity.State, len(want))
+	testBalancer.ResetNotifier(stateNotifications)
+	defer close(stateNotifications)
+
+	ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+	defer cancel()
+
+	lis, err := net.Listen("tcp", "localhost:0")
+	if err != nil {
+		t.Fatalf("Error while listening. Err: %v", err)
+	}
+	defer lis.Close()
+
+	// Launch the server.
+	go server(lis)
+
+	client, err := DialContext(ctx, lis.Addr().String(), WithWaitForHandshake(), WithInsecure(), WithBalancerName(stateRecordingBalancerName))
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer client.Close()
+
+	timeout := time.After(5 * time.Second)
+
+	for i := 0; i < len(want); i++ {
+		select {
+		case <-timeout:
+			t.Fatalf("timed out waiting for state %d (%v) in flow %v", i, want[i], want)
+		case seen := <-stateNotifications:
+			if seen != want[i] {
+				t.Fatalf("expected to see %v at position %d in flow %v, got %v", want[i], i, want, seen)
+			}
+		}
+	}
+}
+
+// When a READY connection is closed, the client enters TRANSIENT FAILURE before CONNECTING.
+func TestStateTransition_ReadyToTransientFailure(t *testing.T) {
+	defer leakcheck.Check(t)
+
+	want := []connectivity.State{
+		connectivity.Connecting,
+		connectivity.Ready,
+		connectivity.TransientFailure,
+		connectivity.Connecting,
+	}
+
+	stateNotifications := make(chan connectivity.State, len(want))
+	testBalancer.ResetNotifier(stateNotifications)
+	defer close(stateNotifications)
+
+	ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+	defer cancel()
+
+	lis, err := net.Listen("tcp", "localhost:0")
+	if err != nil {
+		t.Fatalf("Error while listening. Err: %v", err)
+	}
+	defer lis.Close()
+
+	sawReady := make(chan struct{})
+
+	// Launch the server.
+	go func() {
+		conn, err := lis.Accept()
+		if err != nil {
+			t.Error(err)
+			return
+		}
+
+		framer := http2.NewFramer(conn, conn)
+		if err := framer.WriteSettings(http2.Setting{}); err != nil {
+			t.Errorf("Error while writing settings frame. %v", err)
+			return
+		}
+
+		// Prevents race between onPrefaceReceipt and onClose.
+		<-sawReady
+
+		conn.Close()
+	}()
+
+	client, err := DialContext(ctx, lis.Addr().String(), WithWaitForHandshake(), WithInsecure(), WithBalancerName(stateRecordingBalancerName))
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer client.Close()
+
+	timeout := time.After(5 * time.Second)
+
+	for i := 0; i < len(want); i++ {
+		select {
+		case <-timeout:
+			t.Fatalf("timed out waiting for state %d (%v) in flow %v", i, want[i], want)
+		case seen := <-stateNotifications:
+			if seen == connectivity.Ready {
+				close(sawReady)
+			}
+			if seen != want[i] {
+				t.Fatalf("expected to see %v at position %d in flow %v, got %v", want[i], i, want, seen)
+			}
+		}
+	}
+}
+
+// When the first connection is closed, the client enters stays in CONNECTING until it tries the second
+// address (which succeeds, and then it enters READY).
+func TestStateTransitions_TriesAllAddrsBeforeTransientFailure(t *testing.T) {
+	defer leakcheck.Check(t)
+
+	want := []connectivity.State{
+		connectivity.Connecting,
+		connectivity.Ready,
+	}
+
+	stateNotifications := make(chan connectivity.State, len(want))
+	testBalancer.ResetNotifier(stateNotifications)
+	defer close(stateNotifications)
+
+	ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+	defer cancel()
+
+	lis1, err := net.Listen("tcp", "localhost:0")
+	if err != nil {
+		t.Fatalf("Error while listening. Err: %v", err)
+	}
+	defer lis1.Close()
+
+	lis2, err := net.Listen("tcp", "localhost:0")
+	if err != nil {
+		t.Fatalf("Error while listening. Err: %v", err)
+	}
+	defer lis2.Close()
+
+	server1Done := make(chan struct{})
+	server2Done := make(chan struct{})
+
+	// Launch server 1.
+	go func() {
+		conn, err := lis1.Accept()
+		if err != nil {
+			t.Error(err)
+			return
+		}
+
+		conn.Close()
+		close(server1Done)
+	}()
+	// Launch server 2.
+	go func() {
+		conn, err := lis2.Accept()
+		if err != nil {
+			t.Error(err)
+			return
+		}
+
+		framer := http2.NewFramer(conn, conn)
+		if err := framer.WriteSettings(http2.Setting{}); err != nil {
+			t.Errorf("Error while writing settings frame. %v", err)
+			return
+		}
+		close(server2Done)
+	}()
+
+	rb := manual.NewBuilderWithScheme("whatever")
+	rb.InitialAddrs([]resolver.Address{
+		{Addr: lis1.Addr().String()},
+		{Addr: lis2.Addr().String()},
+	})
+	client, err := DialContext(ctx, "this-gets-overwritten", WithInsecure(), WithWaitForHandshake(), WithBalancerName(stateRecordingBalancerName), withResolverBuilder(rb))
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer client.Close()
+
+	timeout := time.After(5 * time.Second)
+
+	for i := 0; i < len(want); i++ {
+		select {
+		case <-timeout:
+			t.Fatalf("timed out waiting for state %d (%v) in flow %v", i, want[i], want)
+		case seen := <-stateNotifications:
+			if seen != want[i] {
+				t.Fatalf("expected to see %v at position %d in flow %v, got %v", want[i], i, want, seen)
+			}
+		}
+	}
+	select {
+	case <-timeout:
+		t.Fatal("saw the correct state transitions, but timed out waiting for client to finish interactions with server 1")
+	case <-server1Done:
+	}
+	select {
+	case <-timeout:
+		t.Fatal("saw the correct state transitions, but timed out waiting for client to finish interactions with server 2")
+	case <-server2Done:
+	}
+}
+
+// When there are multiple addresses, and we enter READY on one of them, a later closure should cause
+// the client to enter TRANSIENT FAILURE before it re-enters CONNECTING.
+func TestStateTransitions_MultipleAddrsEntersReady(t *testing.T) {
+	defer leakcheck.Check(t)
+
+	want := []connectivity.State{
+		connectivity.Connecting,
+		connectivity.Ready,
+		connectivity.TransientFailure,
+		connectivity.Connecting,
+	}
+
+	stateNotifications := make(chan connectivity.State, len(want))
+	testBalancer.ResetNotifier(stateNotifications)
+	defer close(stateNotifications)
+
+	ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+	defer cancel()
+
+	lis1, err := net.Listen("tcp", "localhost:0")
+	if err != nil {
+		t.Fatalf("Error while listening. Err: %v", err)
+	}
+	defer lis1.Close()
+
+	// Never actually gets used; we just want it to be alive so that the resolver has two addresses to target.
+	lis2, err := net.Listen("tcp", "localhost:0")
+	if err != nil {
+		t.Fatalf("Error while listening. Err: %v", err)
+	}
+	defer lis2.Close()
+
+	server1Done := make(chan struct{})
+	sawReady := make(chan struct{})
+
+	// Launch server 1.
+	go func() {
+		conn, err := lis1.Accept()
+		if err != nil {
+			t.Error(err)
+			return
+		}
+
+		framer := http2.NewFramer(conn, conn)
+		if err := framer.WriteSettings(http2.Setting{}); err != nil {
+			t.Errorf("Error while writing settings frame. %v", err)
+			return
+		}
+
+		<-sawReady
+
+		conn.Close()
+
+		_, err = lis1.Accept()
+		if err != nil {
+			t.Error(err)
+			return
+		}
+
+		close(server1Done)
+	}()
+
+	rb := manual.NewBuilderWithScheme("whatever")
+	rb.InitialAddrs([]resolver.Address{
+		{Addr: lis1.Addr().String()},
+		{Addr: lis2.Addr().String()},
+	})
+	client, err := DialContext(ctx, "this-gets-overwritten", WithInsecure(), WithWaitForHandshake(), WithBalancerName(stateRecordingBalancerName), withResolverBuilder(rb))
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer client.Close()
+
+	timeout := time.After(2 * time.Second)
+
+	for i := 0; i < len(want); i++ {
+		select {
+		case <-timeout:
+			t.Fatalf("timed out waiting for state %d (%v) in flow %v", i, want[i], want)
+		case seen := <-stateNotifications:
+			if seen == connectivity.Ready {
+				close(sawReady)
+			}
+			if seen != want[i] {
+				t.Fatalf("expected to see %v at position %d in flow %v, got %v", want[i], i, want, seen)
+			}
+		}
+	}
+	select {
+	case <-timeout:
+		t.Fatal("saw the correct state transitions, but timed out waiting for client to finish interactions with server 1")
+	case <-server1Done:
+	}
+}
+
+type stateRecordingBalancer struct {
+	mu       sync.Mutex
+	notifier chan<- connectivity.State
+
+	balancer.Balancer
+}
+
+func (b *stateRecordingBalancer) HandleSubConnStateChange(sc balancer.SubConn, s connectivity.State) {
+	b.mu.Lock()
+	b.notifier <- s
+	b.mu.Unlock()
+
+	b.Balancer.HandleSubConnStateChange(sc, s)
+}
+
+func (b *stateRecordingBalancer) ResetNotifier(r chan<- connectivity.State) {
+	b.mu.Lock()
+	defer b.mu.Unlock()
+	b.notifier = r
+}
+
+func (b *stateRecordingBalancer) Close() {
+	b.mu.Lock()
+	u := b.Balancer
+	b.mu.Unlock()
+	u.Close()
+}
+
+func (b *stateRecordingBalancer) Name() string {
+	return stateRecordingBalancerName
+}
+
+func (b *stateRecordingBalancer) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer {
+	b.mu.Lock()
+	b.Balancer = balancer.Get(PickFirstBalancerName).Build(cc, opts)
+	b.mu.Unlock()
+	return b
+}
diff --git a/clientconn_test.go b/clientconn_test.go
index 8ee4595b..c856897f 100644
--- a/clientconn_test.go
+++ b/clientconn_test.go
@@ -836,150 +836,3 @@ func TestBackoffCancel(t *testing.T) {
 	cc.Close()
 	// Should not leak. May need -count 5000 to exercise.
 }
-
-// CONNECTING -> READY -> TRANSIENT FAILURE -> CONNECTING -> TRANSIENT FAILURE -> CONNECTING -> TRANSIENT FAILURE
-//
-// Note: csmgr (which drives GetState and WaitForStateChange) lags behind reality a bit because state updates go
-// through the balancer. So, we are somewhat overaggressive in using WaitForStateChange in this test in order to force
-// it to keep up.
-//
-// TODO(deklerk) Rewrite this test with a custom balancer that records state transitions. This will mean we can get
-// rid of all this synchronization and realtime state-checking, and instead just check the state transitions at
-// once after all the activity has happened.
-func TestDialCloseStateTransition(t *testing.T) {
-	defer leakcheck.Check(t)
-
-	lis, err := net.Listen("tcp", "localhost:0")
-	if err != nil {
-		t.Fatalf("Error while listening. Err: %v", err)
-	}
-	defer lis.Close()
-	testFinished := make(chan struct{})
-	backoffCaseReady := make(chan struct{})
-	killFirstConnection := make(chan struct{})
-	killSecondConnection := make(chan struct{})
-
-	// Launch the server.
-	go func() {
-		// Establish a successful connection so that we enter READY. We need to get
-		// to READY so that we can get a client back for us to introspect later (as
-		// opposed to just CONNECTING).
-		conn, err := lis.Accept()
-		if err != nil {
-			t.Error(err)
-			return
-		}
-		defer conn.Close()
-
-		framer := http2.NewFramer(conn, conn)
-		if err := framer.WriteSettings(http2.Setting{}); err != nil {
-			t.Errorf("Error while writing settings frame. %v", err)
-			return
-		}
-
-		select {
-		case <-testFinished:
-			return
-		case <-killFirstConnection:
-		}
-
-		// Close the conn to cause onShutdown, causing us to enter TRANSIENT FAILURE. Note that we are not in
-		// WaitForHandshake at this point because the preface was sent successfully.
-		conn.Close()
-
-		// We have to re-accept and re-close the connection because the first re-connect after a successful handshake
-		// has no backoff. So, we need to get to the second re-connect after the successful handshake for our infinite
-		// backoff to happen.
-		conn, err = lis.Accept()
-		if err != nil {
-			t.Error(err)
-			return
-		}
-		err = conn.Close()
-		if err != nil {
-			t.Error(err)
-		}
-
-		// The client should now be headed towards backoff.
-		close(backoffCaseReady)
-
-		// Re-connect (without server preface).
-		conn, err = lis.Accept()
-		if err != nil {
-			t.Error(err)
-			return
-		}
-
-		// Close the conn to cause onShutdown, causing us to enter TRANSIENT FAILURE. Note that we are in
-		// WaitForHandshake at this point because the preface has not been sent yet.
-		select {
-		case <-testFinished:
-			return
-		case <-killSecondConnection:
-		}
-		conn.Close()
-	}()
-
-	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
-	defer cancel()
-
-	client, err := DialContext(ctx, lis.Addr().String(), WithInsecure(), WithWaitForHandshake(), WithBlock(), withBackoff(backoffForever{}))
-	if err != nil {
-		t.Fatal(err)
-	}
-	defer client.Close()
-
-	// It should start in READY because the server sends the server preface.
-	if got, want := client.GetState(), connectivity.Ready; got != want {
-		t.Fatalf("expected addrconn state to be %v, was %v", want, got)
-	}
-
-	// Once the connection is killed, it should go:
-	// READY -> TRANSIENT FAILURE (no backoff) -> CONNECTING -> TRANSIENT FAILURE (infinite backoff)
-	// The first TRANSIENT FAILURE is triggered by closing a channel. Then, we wait for the server to let us know
-	// when the client has progressed past the first failure (which does not get backoff, because handshake was
-	// successful).
-	close(killFirstConnection)
-	if !client.WaitForStateChange(ctx, connectivity.Ready) {
-		t.Fatal("expected WaitForStateChange to change state, but it timed out")
-	}
-	if !client.WaitForStateChange(ctx, connectivity.TransientFailure) {
-		t.Fatal("expected WaitForStateChange to change state, but it timed out")
-	}
-	<-backoffCaseReady
-	if !client.WaitForStateChange(ctx, connectivity.Connecting) {
-		t.Fatal("expected WaitForStateChange to change state, but it timed out")
-	}
-	if got, want := client.GetState(), connectivity.TransientFailure; got != want {
-		t.Fatalf("expected addrconn state to be %v, was %v", want, got)
-	}
-
-	// Stop backing off, allowing a re-connect. Note: this races with the client actually getting to the backoff,
-	// so continually reset backoff until we notice the state change.
-	for i := 0; i < 100; i++ {
-		client.ResetConnectBackoff()
-		cctx, cancel := context.WithTimeout(ctx, 10*time.Millisecond)
-		defer cancel()
-		if client.WaitForStateChange(cctx, connectivity.TransientFailure) {
-			break
-		}
-	}
-	if got, want := client.GetState(), connectivity.Connecting; got != want {
-		t.Fatalf("expected addrconn state to be %v, was %v", want, got)
-	}
-
-	select {
-	case <-testFinished:
-	case killSecondConnection <- struct{}{}:
-	}
-
-	// The connection should be killed shortly by the above goroutine, and here we watch for the first new connectivity
-	// state and make sure it's TRANSIENT FAILURE. This is racy, but fairly accurate - expect it to catch failures
-	// 90% of the time or so.
-	if !client.WaitForStateChange(ctx, connectivity.Connecting) {
-		t.Fatal("expected WaitForStateChange to change state, but it timed out")
-	}
-	if got, want := client.GetState(), connectivity.TransientFailure; got != want {
-		t.Fatalf("expected addrconn state to be %v, was %v", want, got)
-	}
-}