diff --git a/clientconn.go b/clientconn.go index 42221882..040206a1 100644 --- a/clientconn.go +++ b/clientconn.go @@ -35,6 +35,7 @@ package grpc import ( "errors" + "fmt" "net" "strings" "sync" @@ -182,6 +183,23 @@ const ( Shutdown ) +func (s ConnectivityState) String() string { + switch s { + case Idle: + return "IDLE" + case Connecting: + return "CONNECTING" + case Ready: + return "READY" + case TransientFailure: + return "TRANSIENT_FAILURE" + case Shutdown: + return "SHUTDOWN" + default: + panic(fmt.Sprintf("unknown connectivity state: %d", s)) + } +} + // ClientConn represents a client connection to an RPC service. type ClientConn struct { target string @@ -212,27 +230,31 @@ func (cc *ClientConn) State() ConnectivityState { // WaitForStateChange returns true when the state changes to something other than the // sourceState and false if timeout fires. func (cc *ClientConn) WaitForStateChange(timeout time.Duration, sourceState ConnectivityState) bool { + start := time.Now() cc.mu.Lock() defer cc.mu.Unlock() if sourceState != cc.state { return true } - // Shutdown state is a sink -- once it is entered, no furhter state change could happen. - if sourceState == Shutdown { - return false - } done := make(chan struct{}) + expired := timeout <= time.Since(start) go func() { select { - case <-time.After(timeout): + case <-time.After(timeout-time.Since(start)): + cc.mu.Lock() + expired = true cc.stateCV.Broadcast() + cc.mu.Unlock() case <-done: } }() + defer close(done) for sourceState == cc.state { cc.stateCV.Wait() + if expired { + return false + } } - close(done) return true } @@ -242,6 +264,7 @@ func (cc *ClientConn) resetTransport(closeTransport bool) error { for { cc.mu.Lock() cc.state = Connecting + cc.stateCV.Broadcast() t := cc.transport ts := cc.transportSeq // Avoid wait() picking up a dying transport unnecessarily. @@ -280,6 +303,7 @@ func (cc *ClientConn) resetTransport(closeTransport bool) error { if err != nil { cc.mu.Lock() cc.state = TransientFailure + cc.stateCV.Broadcast() cc.mu.Unlock() sleepTime -= time.Since(connectTime) if sleepTime < 0 { @@ -304,6 +328,7 @@ func (cc *ClientConn) resetTransport(closeTransport bool) error { return ErrClientConnClosing } cc.state = Ready + cc.stateCV.Broadcast() cc.transport = newTransport cc.transportSeq = ts + 1 if cc.ready != nil { @@ -327,6 +352,7 @@ func (cc *ClientConn) transportMonitor() { case <-cc.transport.Error(): cc.mu.Lock() cc.state = TransientFailure + cc.stateCV.Broadcast() cc.mu.Unlock() if err := cc.resetTransport(true); err != nil { // The ClientConn is closing. @@ -381,6 +407,7 @@ func (cc *ClientConn) Close() error { return ErrClientConnClosing } cc.state = Shutdown + cc.stateCV.Broadcast() if cc.ready != nil { close(cc.ready) cc.ready = nil diff --git a/credentials/credentials.go b/credentials/credentials.go index c1a331e8..0c2b24c0 100644 --- a/credentials/credentials.go +++ b/credentials/credentials.go @@ -51,7 +51,7 @@ import ( var ( // alpnProtoStr are the specified application level protocols for gRPC. - alpnProtoStr = []string{"h2"} + alpnProtoStr = []string{"h2", "h2-14", "h2-15", "h2-16"} ) // Credentials defines the common interface all supported credentials must diff --git a/test/end2end_test.go b/test/end2end_test.go index 9dc667c0..6462c080 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -354,6 +354,15 @@ func TestTimeoutOnDeadServer(t *testing.T) { func testTimeoutOnDeadServer(t *testing.T, e env) { s, cc := setUp(nil, math.MaxUint32, "", e) tc := testpb.NewTestServiceClient(cc) + if ok := cc.WaitForStateChange(time.Second, grpc.Idle); !ok { + t.Fatalf("cc.WaitForStateChange(_, %s) = %t, want true", grpc.Idle, ok) + } + if ok := cc.WaitForStateChange(time.Second, grpc.Connecting); !ok { + t.Fatalf("cc.WaitForStateChange(_, %s) = %t, want true", grpc.Connecting, ok) + } + if cc.State() != grpc.Ready { + t.Fatalf("cc.State() = %s, want %s", cc.State(), grpc.Ready) + } s.Stop() // Set -1 as the timeout to make sure if transportMonitor gets error // notification in time the failure path of the 1st invoke of @@ -362,6 +371,13 @@ func testTimeoutOnDeadServer(t *testing.T, e env) { if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); grpc.Code(err) != codes.DeadlineExceeded { t.Fatalf("TestService/EmptyCall(%v, _) = _, error %v, want _, error code: %d", ctx, err, codes.DeadlineExceeded) } + if ok := cc.WaitForStateChange(time.Second, grpc.Ready); !ok { + t.Fatalf("cc.WaitForStateChange(_, %s) = %t, want true", grpc.Connecting, ok) + } + state := cc.State() + if state != grpc.Connecting && state != grpc.TransientFailure { + t.Fatalf("cc.State() = %s, want %s or %s", state, grpc.Connecting, grpc.TransientFailure) + } cc.Close() } @@ -461,8 +477,20 @@ func TestEmptyUnaryWithUserAgent(t *testing.T) { func testEmptyUnaryWithUserAgent(t *testing.T, e env) { s, cc := setUp(nil, math.MaxUint32, testAppUA, e) + // Wait until cc is connected. + if ok := cc.WaitForStateChange(time.Second, grpc.Idle); !ok { + t.Fatalf("cc.WaitForStateChange(_, %s) = %t, want true", grpc.Idle, ok) + } + if ok := cc.WaitForStateChange(10 * time.Second, grpc.Connecting); !ok { + t.Fatalf("cc.WaitForStateChange(_, %s) = %t, want true", grpc.Connecting, ok) + } + if cc.State() != grpc.Ready { + t.Fatalf("cc.State() = %s, want %s", cc.State(), grpc.Ready) + } + if ok := cc.WaitForStateChange(time.Second, grpc.Ready); ok { + t.Fatalf("cc.WaitForStateChange(_, %s) = %t, want false", grpc.Ready, ok) + } tc := testpb.NewTestServiceClient(cc) - defer tearDown(s, cc) var header metadata.MD reply, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Header(&header)) if err != nil || !proto.Equal(&testpb.Empty{}, reply) { @@ -471,6 +499,13 @@ func testEmptyUnaryWithUserAgent(t *testing.T, e env) { if v, ok := header["ua"]; !ok || v != testAppUA { t.Fatalf("header[\"ua\"] = %q, %t, want %q, true", v, ok, testAppUA) } + tearDown(s, cc) + if ok := cc.WaitForStateChange(5 * time.Second, grpc.Ready); !ok { + t.Fatalf("cc.WaitForStateChange(_, %s) = %t, want true", grpc.Ready, ok) + } + if cc.State() != grpc.Shutdown { + t.Fatalf("cc.State() = %s, want %s", cc.State(), grpc.Shutdown) + } } func TestFailedEmptyUnary(t *testing.T) {