diff --git a/clientconn.go b/clientconn.go index f21a5685..e1fd0898 100644 --- a/clientconn.go +++ b/clientconn.go @@ -312,8 +312,10 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) { func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *ClientConn, err error) { cc := &ClientConn{ target: target, + csMgr: &connectivityStateManager{}, conns: make(map[Address]*addrConn), } + cc.csEvltr = &connectivityStateEvaluator{csMgr: cc.csMgr} cc.ctx, cc.cancel = context.WithCancel(context.Background()) for _, opt := range opts { @@ -476,6 +478,97 @@ func (s ConnectivityState) String() string { } } +// connectivityStateEvaluator gets updated by addrConns when their +// states transition, based on which it evaluates the state of +// ClientConn. +// Note: This code will eventually sit in the balancer in the new design. +type connectivityStateEvaluator struct { + csMgr *connectivityStateManager + mu sync.Mutex + numReady uint64 // Number of addrConns in ready state. + numConnecting uint64 // Number of addrConns in connecting state. + numTransientFailure uint64 // Number of addrConns in transientFailure. +} + +// recordTransition records state change happening in every addrConn and based on +// that it evaluates what state the ClientConn is in. +// It can only transition between Ready, Connecting and TransientFailure. Other states, +// Idle and Shutdown are transitioned into by ClientConn; in the begining of the connection +// before any addrConn is created ClientConn is in idle state. In the end when ClientConn +// closes it is in Shutdown state. +// TODO Note that in later releases, a ClientConn with no activity will be put into an Idle state. +func (cse *connectivityStateEvaluator) recordTransition(oldState, newState ConnectivityState) { + cse.mu.Lock() + defer cse.mu.Unlock() + + // Update counters. + for idx, state := range []ConnectivityState{oldState, newState} { + updateVal := 2*uint64(idx) - 1 // -1 for oldState and +1 for new. + switch state { + case Ready: + cse.numReady += updateVal + case Connecting: + cse.numConnecting += updateVal + case TransientFailure: + cse.numTransientFailure += updateVal + } + } + + // Evaluate. + if cse.numReady > 0 { + cse.csMgr.updateState(Ready) + return + } + if cse.numConnecting > 0 { + cse.csMgr.updateState(Connecting) + return + } + cse.csMgr.updateState(TransientFailure) +} + +// connectivityStateManager keeps the ConnectivityState of ClientConn. +// This struct will eventually be exported so the balancers can access it. +type connectivityStateManager struct { + mu sync.Mutex + state ConnectivityState + notifyChan chan struct{} +} + +// updateState updates the ConnectivityState of ClientConn. +// If there's a change it notifies goroutines waiting on state change to +// happen. +func (csm *connectivityStateManager) updateState(state ConnectivityState) { + csm.mu.Lock() + defer csm.mu.Unlock() + if csm.state == Shutdown { + return + } + if csm.state == state { + return + } + csm.state = state + if csm.notifyChan != nil { + // There are other goroutines waiting on this channel. + close(csm.notifyChan) + csm.notifyChan = nil + } +} + +func (csm *connectivityStateManager) getState() ConnectivityState { + csm.mu.Lock() + defer csm.mu.Unlock() + return csm.state +} + +func (csm *connectivityStateManager) getNotifyChan() <-chan struct{} { + csm.mu.Lock() + defer csm.mu.Unlock() + if csm.notifyChan == nil { + csm.notifyChan = make(chan struct{}) + } + return csm.notifyChan +} + // ClientConn represents a client connection to an RPC server. type ClientConn struct { ctx context.Context @@ -484,6 +577,8 @@ type ClientConn struct { target string authority string dopts dialOptions + csMgr *connectivityStateManager + csEvltr *connectivityStateEvaluator // This will eventually be part of balancer. mu sync.RWMutex sc ServiceConfig @@ -492,6 +587,26 @@ type ClientConn struct { mkp keepalive.ClientParameters } +// WaitForStateChange waits until the ConnectivityState of ClientConn changes from sourceState or +// ctx expires. A true value is returned in former case and false in latter. +func (cc *ClientConn) WaitForStateChange(ctx context.Context, sourceState ConnectivityState) bool { + ch := cc.csMgr.getNotifyChan() + if cc.csMgr.getState() != sourceState { + return true + } + select { + case <-ctx.Done(): + return false + case <-ch: + return true + } +} + +// GetState returns the ConnectivityState of ClientConn. +func (cc *ClientConn) GetState() ConnectivityState { + return cc.csMgr.getState() +} + // lbWatcher watches the Notify channel of the balancer in cc and manages // connections accordingly. If doneChan is not nil, it is closed after the // first successfull connection is made. @@ -522,14 +637,18 @@ func (cc *ClientConn) lbWatcher(doneChan chan struct{}) { } cc.mu.Unlock() for _, a := range add { + var err error if doneChan != nil { - err := cc.resetAddrConn(a, true, nil) + err = cc.resetAddrConn(a, true, nil) if err == nil { close(doneChan) doneChan = nil } } else { - cc.resetAddrConn(a, false, nil) + err = cc.resetAddrConn(a, false, nil) + } + if err != nil { + grpclog.Warningf("Error creating connection to %v. Err: %v", a, err) } } for _, c := range del { @@ -570,7 +689,7 @@ func (cc *ClientConn) resetAddrConn(addr Address, block bool, tearDownErr error) dopts: cc.dopts, } ac.ctx, ac.cancel = context.WithCancel(cc.ctx) - ac.stateCV = sync.NewCond(&ac.mu) + ac.csEvltr = cc.csEvltr if EnableTracing { ac.events = trace.NewEventLog("grpc.ClientConn", ac.addr.Addr) } @@ -727,6 +846,7 @@ func (cc *ClientConn) Close() error { } conns := cc.conns cc.conns = nil + cc.csMgr.updateState(Shutdown) cc.mu.Unlock() if cc.dopts.balancer != nil { cc.dopts.balancer.Close() @@ -747,10 +867,11 @@ type addrConn struct { dopts dialOptions events trace.EventLog - mu sync.Mutex - state ConnectivityState - stateCV *sync.Cond - down func(error) // the handler called when a connection is down. + csEvltr *connectivityStateEvaluator + + mu sync.Mutex + state ConnectivityState + down func(error) // the handler called when a connection is down. // ready is closed and becomes nil when a new transport is up or failed // due to timeout. ready chan struct{} @@ -790,42 +911,6 @@ func (ac *addrConn) errorf(format string, a ...interface{}) { } } -// getState returns the connectivity state of the Conn -func (ac *addrConn) getState() ConnectivityState { - ac.mu.Lock() - defer ac.mu.Unlock() - return ac.state -} - -// waitForStateChange blocks until the state changes to something other than the sourceState. -func (ac *addrConn) waitForStateChange(ctx context.Context, sourceState ConnectivityState) (ConnectivityState, error) { - ac.mu.Lock() - defer ac.mu.Unlock() - if sourceState != ac.state { - return ac.state, nil - } - done := make(chan struct{}) - var err error - go func() { - select { - case <-ctx.Done(): - ac.mu.Lock() - err = ctx.Err() - ac.stateCV.Broadcast() - ac.mu.Unlock() - case <-done: - } - }() - defer close(done) - for sourceState == ac.state { - ac.stateCV.Wait() - if err != nil { - return ac.state, err - } - } - return ac.state, nil -} - // resetTransport recreates a transport to the address for ac. // For the old transport: // - if drain is true, it will be gracefully closed. @@ -841,8 +926,9 @@ func (ac *addrConn) resetTransport(drain bool) error { ac.down(downErrorf(false, true, "%v", errNetworkIO)) ac.down = nil } + oldState := ac.state ac.state = Connecting - ac.stateCV.Broadcast() + ac.csEvltr.recordTransition(oldState, ac.state) t := ac.transport ac.transport = nil ac.mu.Unlock() @@ -892,8 +978,9 @@ func (ac *addrConn) resetTransport(drain bool) error { return errConnClosing } ac.errorf("transient failure: %v", err) + oldState = ac.state ac.state = TransientFailure - ac.stateCV.Broadcast() + ac.csEvltr.recordTransition(oldState, ac.state) if ac.ready != nil { close(ac.ready) ac.ready = nil @@ -917,8 +1004,9 @@ func (ac *addrConn) resetTransport(drain bool) error { newTransport.Close() return errConnClosing } + oldState = ac.state ac.state = Ready - ac.stateCV.Broadcast() + ac.csEvltr.recordTransition(oldState, ac.state) ac.transport = newTransport if ac.ready != nil { close(ac.ready) @@ -993,8 +1081,9 @@ func (ac *addrConn) transportMonitor() { ac.mu.Unlock() return } + oldState := ac.state ac.state = TransientFailure - ac.stateCV.Broadcast() + ac.csEvltr.recordTransition(oldState, ac.state) ac.mu.Unlock() if err := ac.resetTransport(false); err != nil { grpclog.Infof("get error from resetTransport %v, transportMonitor returning", err) @@ -1076,9 +1165,10 @@ func (ac *addrConn) tearDown(err error) { if ac.state == Shutdown { return } + oldState := ac.state ac.state = Shutdown ac.tearDownErr = err - ac.stateCV.Broadcast() + ac.csEvltr.recordTransition(oldState, ac.state) if ac.events != nil { ac.events.Finish() ac.events = nil diff --git a/clientconn_test.go b/clientconn_test.go index f72e496d..e351c3ba 100644 --- a/clientconn_test.go +++ b/clientconn_test.go @@ -19,6 +19,7 @@ package grpc import ( + "math" "net" "testing" "time" @@ -27,10 +28,60 @@ import ( "google.golang.org/grpc/credentials" "google.golang.org/grpc/keepalive" + "google.golang.org/grpc/naming" ) const tlsDir = "testdata/" +func assertState(wantState ConnectivityState, cc *ClientConn) (ConnectivityState, bool) { + ctx, _ := context.WithTimeout(context.Background(), time.Second) + var state ConnectivityState + for state = cc.GetState(); state != wantState && cc.WaitForStateChange(ctx, state); state = cc.GetState() { + } + return state, state == wantState +} + +func TestConnectivityStates(t *testing.T) { + servers, resolver := startServers(t, 2, math.MaxUint32) + defer func() { + for i := 0; i < 2; i++ { + servers[i].stop() + } + }() + + cc, err := Dial("foo.bar.com", WithBalancer(RoundRobin(resolver)), WithInsecure()) + if err != nil { + t.Fatalf("Dial(\"foo.bar.com\", WithBalancer(_)) = _, %v, want _ ", err) + } + defer cc.Close() + wantState := Ready + if state, ok := assertState(wantState, cc); !ok { + t.Fatalf("asserState(%s) = %s, false, want %s, true", wantState, state, wantState) + } + // Send an update to delete the server connection (tearDown addrConn). + update := []*naming.Update{ + { + Op: naming.Delete, + Addr: "localhost:" + servers[0].port, + }, + } + resolver.w.inject(update) + wantState = TransientFailure + if state, ok := assertState(wantState, cc); !ok { + t.Fatalf("asserState(%s) = %s, false, want %s, true", wantState, state, wantState) + } + update[0] = &naming.Update{ + Op: naming.Add, + Addr: "localhost:" + servers[1].port, + } + resolver.w.inject(update) + wantState = Ready + if state, ok := assertState(wantState, cc); !ok { + t.Fatalf("asserState(%s) = %s, false, want %s, true", wantState, state, wantState) + } + +} + func TestDialTimeout(t *testing.T) { conn, err := Dial("Non-Existent.Server:80", WithTimeout(time.Millisecond), WithBlock(), WithInsecure()) if err == nil { diff --git a/test/end2end_test.go b/test/end2end_test.go index 2f45c9ca..f226bdcf 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -4778,6 +4778,37 @@ func testPerRPCCredentialsViaDialOptionsAndCallOptions(t *testing.T, e env) { } } +func TestWaitForReadyConnection(t *testing.T) { + defer leakCheck(t)() + for _, e := range listTestEnv() { + testWaitForReadyConnection(t, e) + } + +} + +func testWaitForReadyConnection(t *testing.T, e env) { + te := newTest(t, e) + te.userAgent = testAppUA + te.startServer(&testServer{security: e.security}) + defer te.tearDown() + + cc := te.clientConn() // Non-blocking dial. + tc := testpb.NewTestServiceClient(cc) + ctx, _ := context.WithTimeout(context.Background(), time.Second) + state := cc.GetState() + // Wait for connection to be Ready. + for ; state != grpc.Ready && cc.WaitForStateChange(ctx, state); state = cc.GetState() { + } + if state != grpc.Ready { + t.Fatalf("Want connection state to be Ready, got %v", state) + } + ctx, _ = context.WithTimeout(context.Background(), time.Second) + // Make a fail-fast RPC. + if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != nil { + t.Fatalf("TestService/EmptyCall(_,_) = _, %v, want _, nil", err) + } +} + type errCodec struct { noError bool }