clientconn: do not automatically reconnect addrConns; go idle instead (#4613)

This commit is contained in:
Doug Fawley
2021-08-10 13:22:34 -07:00
committed by GitHub
parent 01bababd83
commit 997ce619eb
9 changed files with 254 additions and 178 deletions

View File

@ -353,8 +353,9 @@ var ErrBadResolverState = errors.New("bad resolver state")
// //
// It's not thread safe. // It's not thread safe.
type ConnectivityStateEvaluator struct { type ConnectivityStateEvaluator struct {
numReady uint64 // Number of addrConns in ready state. numReady uint64 // Number of addrConns in ready state.
numConnecting uint64 // Number of addrConns in connecting state. numConnecting uint64 // Number of addrConns in connecting state.
numTransientFailure uint64 // Number of addrConns in transient failure state.
} }
// RecordTransition records state change happening in subConn and based on that // RecordTransition records state change happening in subConn and based on that
@ -362,9 +363,10 @@ type ConnectivityStateEvaluator struct {
// //
// - If at least one SubConn in Ready, the aggregated state is Ready; // - If at least one SubConn in Ready, the aggregated state is Ready;
// - Else if at least one SubConn in Connecting, the aggregated state is Connecting; // - Else if at least one SubConn in Connecting, the aggregated state is Connecting;
// - Else the aggregated state is TransientFailure. // - Else if at least one SubConn is TransientFailure, the aggregated state is Transient Failure;
// - Else the aggregated state is Idle
// //
// Idle and Shutdown are not considered. // Shutdown is not considered.
func (cse *ConnectivityStateEvaluator) RecordTransition(oldState, newState connectivity.State) connectivity.State { func (cse *ConnectivityStateEvaluator) RecordTransition(oldState, newState connectivity.State) connectivity.State {
// Update counters. // Update counters.
for idx, state := range []connectivity.State{oldState, newState} { for idx, state := range []connectivity.State{oldState, newState} {
@ -374,6 +376,8 @@ func (cse *ConnectivityStateEvaluator) RecordTransition(oldState, newState conne
cse.numReady += updateVal cse.numReady += updateVal
case connectivity.Connecting: case connectivity.Connecting:
cse.numConnecting += updateVal cse.numConnecting += updateVal
case connectivity.TransientFailure:
cse.numTransientFailure += updateVal
} }
} }
@ -384,5 +388,8 @@ func (cse *ConnectivityStateEvaluator) RecordTransition(oldState, newState conne
if cse.numConnecting > 0 { if cse.numConnecting > 0 {
return connectivity.Connecting return connectivity.Connecting
} }
return connectivity.TransientFailure if cse.numTransientFailure > 0 {
return connectivity.TransientFailure
}
return connectivity.Idle
} }

View File

@ -239,17 +239,17 @@ func (acbw *acBalancerWrapper) UpdateAddresses(addrs []resolver.Address) {
return return
} }
ac, err := cc.newAddrConn(addrs, opts) newAC, err := cc.newAddrConn(addrs, opts)
if err != nil { if err != nil {
channelz.Warningf(logger, acbw.ac.channelzID, "acBalancerWrapper: UpdateAddresses: failed to newAddrConn: %v", err) channelz.Warningf(logger, acbw.ac.channelzID, "acBalancerWrapper: UpdateAddresses: failed to newAddrConn: %v", err)
return return
} }
acbw.ac = ac acbw.ac = newAC
ac.mu.Lock() newAC.mu.Lock()
ac.acbw = acbw newAC.acbw = acbw
ac.mu.Unlock() newAC.mu.Unlock()
if acState != connectivity.Idle { if acState != connectivity.Idle {
ac.connect() go newAC.connect()
} }
} }
} }
@ -257,7 +257,7 @@ func (acbw *acBalancerWrapper) UpdateAddresses(addrs []resolver.Address) {
func (acbw *acBalancerWrapper) Connect() { func (acbw *acBalancerWrapper) Connect() {
acbw.mu.Lock() acbw.mu.Lock()
defer acbw.mu.Unlock() defer acbw.mu.Unlock()
acbw.ac.connect() go acbw.ac.connect()
} }
func (acbw *acBalancerWrapper) getAddrConn() *addrConn { func (acbw *acBalancerWrapper) getAddrConn() *addrConn {

View File

@ -322,6 +322,7 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
// A blocking dial blocks until the clientConn is ready. // A blocking dial blocks until the clientConn is ready.
if cc.dopts.block { if cc.dopts.block {
for { for {
cc.Connect()
s := cc.GetState() s := cc.GetState()
if s == connectivity.Ready { if s == connectivity.Ready {
break break
@ -539,12 +540,31 @@ func (cc *ClientConn) WaitForStateChange(ctx context.Context, sourceState connec
// //
// Experimental // Experimental
// //
// Notice: This API is EXPERIMENTAL and may be changed or removed in a // Notice: This API is EXPERIMENTAL and may be changed or removed in a later
// later release. // release.
func (cc *ClientConn) GetState() connectivity.State { func (cc *ClientConn) GetState() connectivity.State {
return cc.csMgr.getState() return cc.csMgr.getState()
} }
// Connect causes all subchannels in the ClientConn to attempt to connect if
// the channel is idle. Does not wait for the connection attempts to begin
// before returning.
//
// Experimental
//
// Notice: This API is EXPERIMENTAL and may be changed or removed in a later
// release.
func (cc *ClientConn) Connect() {
if cc.GetState() == connectivity.Idle {
cc.mu.Lock()
for ac := range cc.conns {
// TODO: should this be a signal to the LB policy instead?
go ac.connect()
}
cc.mu.Unlock()
}
}
func (cc *ClientConn) scWatcher() { func (cc *ClientConn) scWatcher() {
for { for {
select { select {
@ -845,8 +865,7 @@ func (ac *addrConn) connect() error {
ac.updateConnectivityState(connectivity.Connecting, nil) ac.updateConnectivityState(connectivity.Connecting, nil)
ac.mu.Unlock() ac.mu.Unlock()
// Start a goroutine connecting to the server asynchronously. ac.resetTransport()
go ac.resetTransport()
return nil return nil
} }
@ -883,6 +902,10 @@ func (ac *addrConn) tryUpdateAddrs(addrs []resolver.Address) bool {
// ac.state is Ready, try to find the connected address. // ac.state is Ready, try to find the connected address.
var curAddrFound bool var curAddrFound bool
for _, a := range addrs { for _, a := range addrs {
// a.ServerName takes precedent over ClientConn authority, if present.
if a.ServerName == "" {
a.ServerName = ac.cc.authority
}
if reflect.DeepEqual(ac.curAddr, a) { if reflect.DeepEqual(ac.curAddr, a) {
curAddrFound = true curAddrFound = true
break break
@ -1135,112 +1158,86 @@ func (ac *addrConn) adjustParams(r transport.GoAwayReason) {
} }
func (ac *addrConn) resetTransport() { func (ac *addrConn) resetTransport() {
for i := 0; ; i++ { ac.mu.Lock()
if i > 0 { if ac.state == connectivity.Shutdown {
ac.cc.resolveNow(resolver.ResolveNowOptions{})
}
ac.mu.Lock()
if ac.state == connectivity.Shutdown {
ac.mu.Unlock()
return
}
addrs := ac.addrs
backoffFor := ac.dopts.bs.Backoff(ac.backoffIdx)
// This will be the duration that dial gets to finish.
dialDuration := minConnectTimeout
if ac.dopts.minConnectTimeout != nil {
dialDuration = ac.dopts.minConnectTimeout()
}
if dialDuration < backoffFor {
// Give dial more time as we keep failing to connect.
dialDuration = backoffFor
}
// We can potentially spend all the time trying the first address, and
// if the server accepts the connection and then hangs, the following
// addresses will never be tried.
//
// The spec doesn't mention what should be done for multiple addresses.
// https://github.com/grpc/grpc/blob/master/doc/connection-backoff.md#proposed-backoff-algorithm
connectDeadline := time.Now().Add(dialDuration)
ac.updateConnectivityState(connectivity.Connecting, nil)
ac.transport = nil
ac.mu.Unlock() ac.mu.Unlock()
return
newTr, addr, reconnect, err := ac.tryAllAddrs(addrs, connectDeadline)
if err != nil {
// After exhausting all addresses, the addrConn enters
// TRANSIENT_FAILURE.
ac.mu.Lock()
if ac.state == connectivity.Shutdown {
ac.mu.Unlock()
return
}
ac.updateConnectivityState(connectivity.TransientFailure, err)
// Backoff.
b := ac.resetBackoff
ac.mu.Unlock()
timer := time.NewTimer(backoffFor)
select {
case <-timer.C:
ac.mu.Lock()
ac.backoffIdx++
ac.mu.Unlock()
case <-b:
timer.Stop()
case <-ac.ctx.Done():
timer.Stop()
return
}
continue
}
ac.mu.Lock()
if ac.state == connectivity.Shutdown {
ac.mu.Unlock()
newTr.Close(fmt.Errorf("reached connectivity state: SHUTDOWN"))
return
}
ac.curAddr = addr
ac.transport = newTr
ac.backoffIdx = 0
hctx, hcancel := context.WithCancel(ac.ctx)
ac.startHealthCheck(hctx)
ac.mu.Unlock()
// Block until the created transport is down. And when this happens,
// we restart from the top of the addr list.
<-reconnect.Done()
hcancel()
// restart connecting - the top of the loop will set state to
// CONNECTING. This is against the current connectivity semantics doc,
// however it allows for graceful behavior for RPCs not yet dispatched
// - unfortunate timing would otherwise lead to the RPC failing even
// though the TRANSIENT_FAILURE state (called for by the doc) would be
// instantaneous.
//
// Ideally we should transition to Idle here and block until there is
// RPC activity that leads to the balancer requesting a reconnect of
// the associated SubConn.
} }
addrs := ac.addrs
backoffFor := ac.dopts.bs.Backoff(ac.backoffIdx)
// This will be the duration that dial gets to finish.
dialDuration := minConnectTimeout
if ac.dopts.minConnectTimeout != nil {
dialDuration = ac.dopts.minConnectTimeout()
}
if dialDuration < backoffFor {
// Give dial more time as we keep failing to connect.
dialDuration = backoffFor
}
// We can potentially spend all the time trying the first address, and
// if the server accepts the connection and then hangs, the following
// addresses will never be tried.
//
// The spec doesn't mention what should be done for multiple addresses.
// https://github.com/grpc/grpc/blob/master/doc/connection-backoff.md#proposed-backoff-algorithm
connectDeadline := time.Now().Add(dialDuration)
ac.updateConnectivityState(connectivity.Connecting, nil)
ac.mu.Unlock()
if err := ac.tryAllAddrs(addrs, connectDeadline); err != nil {
ac.cc.resolveNow(resolver.ResolveNowOptions{})
// After exhausting all addresses, the addrConn enters
// TRANSIENT_FAILURE.
ac.mu.Lock()
if ac.state == connectivity.Shutdown {
ac.mu.Unlock()
return
}
ac.updateConnectivityState(connectivity.TransientFailure, err)
// Backoff.
b := ac.resetBackoff
ac.mu.Unlock()
timer := time.NewTimer(backoffFor)
select {
case <-timer.C:
ac.mu.Lock()
ac.backoffIdx++
ac.mu.Unlock()
case <-b:
timer.Stop()
case <-ac.ctx.Done():
timer.Stop()
return
}
ac.mu.Lock()
if ac.state != connectivity.Shutdown {
ac.updateConnectivityState(connectivity.Idle, err)
}
ac.mu.Unlock()
return
}
// Success; reset backoff.
ac.mu.Lock()
ac.backoffIdx = 0
ac.mu.Unlock()
} }
// tryAllAddrs tries to creates a connection to the addresses, and stop when at the // tryAllAddrs tries to creates a connection to the addresses, and stop when at
// first successful one. It returns the transport, the address and a Event in // the first successful one. It returns an error if no address was successfully
// the successful case. The Event fires when the returned transport disconnects. // connected, or updates ac appropriately with the new transport.
func (ac *addrConn) tryAllAddrs(addrs []resolver.Address, connectDeadline time.Time) (transport.ClientTransport, resolver.Address, *grpcsync.Event, error) { func (ac *addrConn) tryAllAddrs(addrs []resolver.Address, connectDeadline time.Time) error {
var firstConnErr error var firstConnErr error
for _, addr := range addrs { for _, addr := range addrs {
ac.mu.Lock() ac.mu.Lock()
if ac.state == connectivity.Shutdown { if ac.state == connectivity.Shutdown {
ac.mu.Unlock() ac.mu.Unlock()
return nil, resolver.Address{}, nil, errConnClosing return errConnClosing
} }
ac.cc.mu.RLock() ac.cc.mu.RLock()
@ -1255,9 +1252,9 @@ func (ac *addrConn) tryAllAddrs(addrs []resolver.Address, connectDeadline time.T
channelz.Infof(logger, ac.channelzID, "Subchannel picks a new address %q to connect", addr.Addr) channelz.Infof(logger, ac.channelzID, "Subchannel picks a new address %q to connect", addr.Addr)
newTr, reconnect, err := ac.createTransport(addr, copts, connectDeadline) err := ac.createTransport(addr, copts, connectDeadline)
if err == nil { if err == nil {
return newTr, addr, reconnect, nil return nil
} }
if firstConnErr == nil { if firstConnErr == nil {
firstConnErr = err firstConnErr = err
@ -1266,57 +1263,49 @@ func (ac *addrConn) tryAllAddrs(addrs []resolver.Address, connectDeadline time.T
} }
// Couldn't connect to any address. // Couldn't connect to any address.
return nil, resolver.Address{}, nil, firstConnErr return firstConnErr
} }
// createTransport creates a connection to addr. It returns the transport and a // createTransport creates a connection to addr. It returns an error if the
// Event in the successful case. The Event fires when the returned transport // address was not successfully connected, or updates ac appropriately with the
// disconnects. // new transport.
func (ac *addrConn) createTransport(addr resolver.Address, copts transport.ConnectOptions, connectDeadline time.Time) (transport.ClientTransport, *grpcsync.Event, error) { func (ac *addrConn) createTransport(addr resolver.Address, copts transport.ConnectOptions, connectDeadline time.Time) error {
prefaceReceived := make(chan struct{}) // TODO: Delete prefaceReceived and move the logic to wait for it into the
onCloseCalled := make(chan struct{}) // transport.
reconnect := grpcsync.NewEvent() prefaceReceived := grpcsync.NewEvent()
connClosed := grpcsync.NewEvent()
// addr.ServerName takes precedent over ClientConn authority, if present. // addr.ServerName takes precedent over ClientConn authority, if present.
if addr.ServerName == "" { if addr.ServerName == "" {
addr.ServerName = ac.cc.authority addr.ServerName = ac.cc.authority
} }
once := sync.Once{} hctx, hcancel := context.WithCancel(ac.ctx)
onGoAway := func(r transport.GoAwayReason) { hcStarted := false // protected by ac.mu
ac.mu.Lock()
ac.adjustParams(r)
once.Do(func() {
if ac.state == connectivity.Ready {
// Prevent this SubConn from being used for new RPCs by setting its
// state to Connecting.
//
// TODO: this should be Idle when grpc-go properly supports it.
ac.updateConnectivityState(connectivity.Connecting, nil)
}
})
ac.mu.Unlock()
reconnect.Fire()
}
onClose := func() { onClose := func() {
ac.mu.Lock() ac.mu.Lock()
once.Do(func() { defer ac.mu.Unlock()
if ac.state == connectivity.Ready { defer connClosed.Fire()
// Prevent this SubConn from being used for new RPCs by setting its if !hcStarted {
// state to Connecting. // We didn't start the health check or set the state to READY, so
// // no need to do anything else here.
// TODO: this should be Idle when grpc-go properly supports it. return
ac.updateConnectivityState(connectivity.Connecting, nil) }
} hcancel()
}) ac.transport = nil
ac.mu.Unlock() // Refresh the name resolver
close(onCloseCalled) ac.cc.resolveNow(resolver.ResolveNowOptions{})
reconnect.Fire() if ac.state != connectivity.Shutdown {
ac.updateConnectivityState(connectivity.Idle, nil)
}
} }
onPrefaceReceipt := func() { onGoAway := func(r transport.GoAwayReason) {
close(prefaceReceived) ac.mu.Lock()
ac.adjustParams(r)
ac.mu.Unlock()
onClose()
} }
connectCtx, cancel := context.WithDeadline(ac.ctx, connectDeadline) connectCtx, cancel := context.WithDeadline(ac.ctx, connectDeadline)
@ -1325,27 +1314,47 @@ func (ac *addrConn) createTransport(addr resolver.Address, copts transport.Conne
copts.ChannelzParentID = ac.channelzID copts.ChannelzParentID = ac.channelzID
} }
newTr, err := transport.NewClientTransport(connectCtx, ac.cc.ctx, addr, copts, onPrefaceReceipt, onGoAway, onClose) newTr, err := transport.NewClientTransport(connectCtx, ac.cc.ctx, addr, copts, func() { prefaceReceived.Fire() }, onGoAway, onClose)
if err != nil { if err != nil {
// newTr is either nil, or closed. // newTr is either nil, or closed.
channelz.Warningf(logger, ac.channelzID, "grpc: addrConn.createTransport failed to connect to %v. Err: %v. Reconnecting...", addr, err) channelz.Warningf(logger, ac.channelzID, "grpc: addrConn.createTransport failed to connect to %v. Err: %v", addr, err)
return nil, nil, err return err
} }
select { select {
case <-time.After(time.Until(connectDeadline)): case <-time.After(time.Until(connectDeadline)):
// We didn't get the preface in time. // We didn't get the preface in time.
newTr.Close(fmt.Errorf("failed to receive server preface within timeout")) err := fmt.Errorf("failed to receive server preface within timeout")
channelz.Warningf(logger, ac.channelzID, "grpc: addrConn.createTransport failed to connect to %v: didn't receive server preface in time. Reconnecting...", addr) newTr.Close(err)
return nil, nil, errors.New("timed out waiting for server handshake") channelz.Warningf(logger, ac.channelzID, "grpc: addrConn.createTransport failed to connect to %v: %v", addr, err)
case <-prefaceReceived: return err
case <-prefaceReceived.Done():
// We got the preface - huzzah! things are good. // We got the preface - huzzah! things are good.
case <-onCloseCalled: ac.mu.Lock()
// The transport has already closed - noop. defer ac.mu.Unlock()
return nil, nil, errors.New("connection closed") defer prefaceReceived.Fire()
// TODO(deklerk) this should bail on ac.ctx.Done(). Add a test and fix. if connClosed.HasFired() {
// onClose called first; go idle but do nothing else.
if ac.state != connectivity.Shutdown {
ac.updateConnectivityState(connectivity.Idle, nil)
}
return nil
}
ac.curAddr = addr
ac.transport = newTr
hcStarted = true
ac.startHealthCheck(hctx) // Will set state to READY if appropriate.
return nil
case <-connClosed.Done():
// The transport has already closed. If we received the preface, too,
// this is not an error.
select {
case <-prefaceReceived.Done():
return nil
default:
return errors.New("connection closed before server preface received")
}
} }
return newTr, reconnect, nil
} }
// startHealthCheck starts the health checking stream (RPC) to watch the health // startHealthCheck starts the health checking stream (RPC) to watch the health

View File

@ -75,7 +75,7 @@ func (s) TestStateTransitions_SingleAddress(t *testing.T) {
}, },
}, },
{ {
desc: "When the connection is closed, the client enters TRANSIENT FAILURE.", desc: "When the connection is closed before the preface is sent, the client enters TRANSIENT FAILURE.",
want: []connectivity.State{ want: []connectivity.State{
connectivity.Connecting, connectivity.Connecting,
connectivity.TransientFailure, connectivity.TransientFailure,
@ -167,6 +167,7 @@ func testStateTransitionSingleAddress(t *testing.T, want []connectivity.State, s
t.Fatal(err) t.Fatal(err)
} }
defer client.Close() defer client.Close()
go stayConnected(client)
stateNotifications := testBalancerBuilder.nextStateNotifier() stateNotifications := testBalancerBuilder.nextStateNotifier()
@ -193,11 +194,12 @@ func testStateTransitionSingleAddress(t *testing.T, want []connectivity.State, s
} }
} }
// When a READY connection is closed, the client enters CONNECTING. // When a READY connection is closed, the client enters IDLE then CONNECTING.
func (s) TestStateTransitions_ReadyToConnecting(t *testing.T) { func (s) TestStateTransitions_ReadyToConnecting(t *testing.T) {
want := []connectivity.State{ want := []connectivity.State{
connectivity.Connecting, connectivity.Connecting,
connectivity.Ready, connectivity.Ready,
connectivity.Idle,
connectivity.Connecting, connectivity.Connecting,
} }
@ -240,6 +242,7 @@ func (s) TestStateTransitions_ReadyToConnecting(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
defer client.Close() defer client.Close()
go stayConnected(client)
stateNotifications := testBalancerBuilder.nextStateNotifier() stateNotifications := testBalancerBuilder.nextStateNotifier()
@ -359,6 +362,7 @@ func (s) TestStateTransitions_MultipleAddrsEntersReady(t *testing.T) {
want := []connectivity.State{ want := []connectivity.State{
connectivity.Connecting, connectivity.Connecting,
connectivity.Ready, connectivity.Ready,
connectivity.Idle,
connectivity.Connecting, connectivity.Connecting,
} }
@ -415,6 +419,7 @@ func (s) TestStateTransitions_MultipleAddrsEntersReady(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
defer client.Close() defer client.Close()
go stayConnected(client)
stateNotifications := testBalancerBuilder.nextStateNotifier() stateNotifications := testBalancerBuilder.nextStateNotifier()

View File

@ -217,7 +217,7 @@ func (s) TestDialWaitsForServerSettingsAndFails(t *testing.T) {
client.Close() client.Close()
t.Fatalf("Unexpected success (err=nil) while dialing") t.Fatalf("Unexpected success (err=nil) while dialing")
} }
expectedMsg := "server handshake" expectedMsg := "server preface"
if !strings.Contains(err.Error(), context.DeadlineExceeded.Error()) || !strings.Contains(err.Error(), expectedMsg) { if !strings.Contains(err.Error(), context.DeadlineExceeded.Error()) || !strings.Contains(err.Error(), expectedMsg) {
t.Fatalf("DialContext(_) = %v; want a message that includes both %q and %q", err, context.DeadlineExceeded.Error(), expectedMsg) t.Fatalf("DialContext(_) = %v; want a message that includes both %q and %q", err, context.DeadlineExceeded.Error(), expectedMsg)
} }
@ -289,6 +289,9 @@ func (s) TestCloseConnectionWhenServerPrefaceNotReceived(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Error while dialing. Err: %v", err) t.Fatalf("Error while dialing. Err: %v", err)
} }
go stayConnected(client)
// wait for connection to be accepted on the server. // wait for connection to be accepted on the server.
timer := time.NewTimer(time.Second * 10) timer := time.NewTimer(time.Second * 10)
select { select {
@ -311,9 +314,7 @@ func (s) TestBackoffWhenNoServerPrefaceReceived(t *testing.T) {
defer lis.Close() defer lis.Close()
done := make(chan struct{}) done := make(chan struct{})
go func() { // Launch the server. go func() { // Launch the server.
defer func() { defer close(done)
close(done)
}()
conn, err := lis.Accept() // Accept the connection only to close it immediately. conn, err := lis.Accept() // Accept the connection only to close it immediately.
if err != nil { if err != nil {
t.Errorf("Error while accepting. Err: %v", err) t.Errorf("Error while accepting. Err: %v", err)
@ -340,13 +341,13 @@ func (s) TestBackoffWhenNoServerPrefaceReceived(t *testing.T) {
prevAt = meow prevAt = meow
} }
}() }()
client, err := Dial(lis.Addr().String(), WithInsecure()) cc, err := Dial(lis.Addr().String(), WithInsecure())
if err != nil { if err != nil {
t.Fatalf("Error while dialing. Err: %v", err) t.Fatalf("Error while dialing. Err: %v", err)
} }
defer client.Close() defer cc.Close()
go stayConnected(cc)
<-done <-done
} }
func (s) TestWithTimeout(t *testing.T) { func (s) TestWithTimeout(t *testing.T) {
@ -831,6 +832,7 @@ func (s) TestResetConnectBackoff(t *testing.T) {
t.Fatalf("Dial() = _, %v; want _, nil", err) t.Fatalf("Dial() = _, %v; want _, nil", err)
} }
defer cc.Close() defer cc.Close()
go stayConnected(cc)
select { select {
case <-dials: case <-dials:
case <-time.NewTimer(10 * time.Second).C: case <-time.NewTimer(10 * time.Second).C:
@ -985,6 +987,7 @@ func (s) TestUpdateAddresses_RetryFromFirstAddr(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
defer client.Close() defer client.Close()
go stayConnected(client)
timeout := time.After(5 * time.Second) timeout := time.After(5 * time.Second)
@ -1112,3 +1115,14 @@ func testDefaultServiceConfigWhenResolverReturnInvalidServiceConfig(t *testing.T
t.Fatal("default service config failed to be applied after 1s") t.Fatal("default service config failed to be applied after 1s")
} }
} }
// stayConnected makes cc stay connected by repeatedly calling cc.Connect()
// until the state becomes Shutdown or until 10 seconds elapses.
func stayConnected(cc *ClientConn) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
for state := cc.GetState(); state != connectivity.Shutdown && cc.WaitForStateChange(ctx, state); state = cc.GetState() {
cc.Connect()
}
}

View File

@ -107,10 +107,12 @@ func (b *pickfirstBalancer) UpdateSubConnState(sc balancer.SubConn, s balancer.S
} }
switch s.ConnectivityState { switch s.ConnectivityState {
case connectivity.Ready, connectivity.Idle: case connectivity.Ready:
b.cc.UpdateState(balancer.State{ConnectivityState: s.ConnectivityState, Picker: &picker{result: balancer.PickResult{SubConn: sc}}}) b.cc.UpdateState(balancer.State{ConnectivityState: s.ConnectivityState, Picker: &picker{result: balancer.PickResult{SubConn: sc}}})
case connectivity.Connecting: case connectivity.Connecting:
b.cc.UpdateState(balancer.State{ConnectivityState: s.ConnectivityState, Picker: &picker{err: balancer.ErrNoSubConnAvailable}}) b.cc.UpdateState(balancer.State{ConnectivityState: s.ConnectivityState, Picker: &picker{err: balancer.ErrNoSubConnAvailable}})
case connectivity.Idle:
b.cc.UpdateState(balancer.State{ConnectivityState: s.ConnectivityState, Picker: &idlePicker{sc: sc}})
case connectivity.TransientFailure: case connectivity.TransientFailure:
b.cc.UpdateState(balancer.State{ b.cc.UpdateState(balancer.State{
ConnectivityState: s.ConnectivityState, ConnectivityState: s.ConnectivityState,
@ -131,6 +133,17 @@ func (p *picker) Pick(info balancer.PickInfo) (balancer.PickResult, error) {
return p.result, p.err return p.result, p.err
} }
// idlePicker is used when the SubConn is IDLE and kicks the SubConn into
// CONNECTING when Pick is called.
type idlePicker struct {
sc balancer.SubConn
}
func (i *idlePicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) {
i.sc.Connect()
return balancer.PickResult{}, balancer.ErrNoSubConnAvailable
}
func init() { func init() {
balancer.Register(newPickfirstBuilder()) balancer.Register(newPickfirstBuilder())
} }

View File

@ -1689,8 +1689,22 @@ func (s) TestCZSubChannelPickedNewAddress(t *testing.T) {
} }
te.srvs[0].Stop() te.srvs[0].Stop()
te.srvs[1].Stop() te.srvs[1].Stop()
// Here, we just wait for all sockets to be up. In the future, if we implement // Here, we just wait for all sockets to be up. Make several rpc calls to
// IDLE, we may need to make several rpc calls to create the sockets. // create the sockets since we do not automatically reconnect.
done := make(chan struct{})
defer close(done)
go func() {
for {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
tc.EmptyCall(ctx, &testpb.Empty{})
cancel()
select {
case <-time.After(10 * time.Millisecond):
case <-done:
return
}
}
}()
if err := verifyResultWithDelay(func() (bool, error) { if err := verifyResultWithDelay(func() (bool, error) {
tcs, _ := channelz.GetTopChannels(0, 0) tcs, _ := channelz.GetTopChannels(0, 0)
if len(tcs) != 1 { if len(tcs) != 1 {

View File

@ -165,6 +165,7 @@ func (c *clientTimeoutCreds) Clone() credentials.TransportCredentials {
func (s) TestNonFailFastRPCSucceedOnTimeoutCreds(t *testing.T) { func (s) TestNonFailFastRPCSucceedOnTimeoutCreds(t *testing.T) {
te := newTest(t, env{name: "timeout-cred", network: "tcp", security: "empty"}) te := newTest(t, env{name: "timeout-cred", network: "tcp", security: "empty"})
te.userAgent = testAppUA te.userAgent = testAppUA
te.nonBlockingDial = true
te.startServer(&testServer{security: te.e.security}) te.startServer(&testServer{security: te.e.security})
defer te.tearDown() defer te.tearDown()

View File

@ -7123,7 +7123,20 @@ func (s) TestGoAwayThenClose(t *testing.T) {
// Send GO_AWAY to connection 1. // Send GO_AWAY to connection 1.
go s1.GracefulStop() go s1.GracefulStop()
// Wait for connection 2 to be established. // Wait for the ClientConn to enter IDLE state.
state := cc.GetState()
for ; state != connectivity.Idle && cc.WaitForStateChange(ctx, state); state = cc.GetState() {
}
if state != connectivity.Idle {
t.Fatalf("timed out waiting for IDLE channel state; last state = %v", state)
}
// Initiate another RPC to create another connection.
if _, err := client.UnaryCall(ctx, &testpb.SimpleRequest{}); err != nil {
t.Fatalf("UnaryCall(_) = _, %v; want _, nil", err)
}
// Assert that connection 2 has been established.
<-conn2Established.Done() <-conn2Established.Done()
// Close connection 1. // Close connection 1.