diff --git a/clientconn.go b/clientconn.go index 2299ed99..8ddae3ad 100644 --- a/clientconn.go +++ b/clientconn.go @@ -536,10 +536,11 @@ func (cc *ClientConn) handleSubConnStateChange(sc balancer.SubConn, s connectivi // Caller needs to make sure len(addrs) > 0. func (cc *ClientConn) newAddrConn(addrs []resolver.Address) (*addrConn, error) { ac := &addrConn{ - cc: cc, - addrs: addrs, - dopts: cc.dopts, - czData: new(channelzData), + cc: cc, + addrs: addrs, + dopts: cc.dopts, + czData: new(channelzData), + resetBackoff: make(chan struct{}), } ac.ctx, ac.cancel = context.WithCancel(cc.ctx) // Track ac in cc. This needs to be done before any getTransport(...) is called. @@ -747,6 +748,24 @@ func (cc *ClientConn) resolveNow(o resolver.ResolveNowOption) { go r.resolveNow(o) } +// ResetConnectBackoff wakes up all subchannels in transient failure and causes +// them to attempt another connection immediately. It also resets the backoff +// times used for subsequent attempts regardless of the current state. +// +// In general, this function should not be used. Typical service or network +// outages result in a reasonable client reconnection strategy by default. +// However, if a previously unavailable network becomes available, this may be +// used to trigger an immediate reconnect. +// +// This API is EXPERIMENTAL. +func (cc *ClientConn) ResetConnectBackoff() { + cc.mu.Lock() + defer cc.mu.Unlock() + for ac := range cc.conns { + ac.resetConnectBackoff() + } +} + // Close tears down the ClientConn and all underlying connections. func (cc *ClientConn) Close() error { defer cc.cancel() @@ -815,6 +834,8 @@ type addrConn struct { // negotiations must complete. connectDeadline time.Time + resetBackoff chan struct{} + channelzID int64 // channelz unique identification number czData *channelzData } @@ -879,6 +900,7 @@ func (ac *addrConn) resetTransport() error { ac.dopts.copts.KeepaliveParams = ac.cc.mkp ac.cc.mu.RUnlock() var backoffDeadline, connectDeadline time.Time + var resetBackoff chan struct{} for connectRetryNum := 0; ; connectRetryNum++ { ac.mu.Lock() if ac.backoffDeadline.IsZero() { @@ -886,6 +908,7 @@ func (ac *addrConn) resetTransport() error { // or this is the first time this addrConn is trying to establish a // connection. backoffFor := ac.dopts.bs.Backoff(connectRetryNum) // time.Duration. + resetBackoff = ac.resetBackoff // This will be the duration that dial gets to finish. dialDuration := getMinConnectTimeout() if backoffFor > dialDuration { @@ -919,7 +942,7 @@ func (ac *addrConn) resetTransport() error { copy(addrsIter, ac.addrs) copts := ac.dopts.copts ac.mu.Unlock() - connected, err := ac.createTransport(connectRetryNum, ridx, backoffDeadline, connectDeadline, addrsIter, copts) + connected, err := ac.createTransport(connectRetryNum, ridx, backoffDeadline, connectDeadline, addrsIter, copts, resetBackoff) if err != nil { return err } @@ -931,7 +954,7 @@ func (ac *addrConn) resetTransport() error { // createTransport creates a connection to one of the backends in addrs. // It returns true if a connection was established. -func (ac *addrConn) createTransport(connectRetryNum, ridx int, backoffDeadline, connectDeadline time.Time, addrs []resolver.Address, copts transport.ConnectOptions) (bool, error) { +func (ac *addrConn) createTransport(connectRetryNum, ridx int, backoffDeadline, connectDeadline time.Time, addrs []resolver.Address, copts transport.ConnectOptions, resetBackoff chan struct{}) (bool, error) { for i := ridx; i < len(addrs); i++ { addr := addrs[i] target := transport.TargetInfo{ @@ -1031,6 +1054,8 @@ func (ac *addrConn) createTransport(connectRetryNum, ridx int, backoffDeadline, timer := time.NewTimer(backoffDeadline.Sub(time.Now())) select { case <-timer.C: + case <-resetBackoff: + timer.Stop() case <-ac.ctx.Done(): timer.Stop() return false, ac.ctx.Err() @@ -1038,6 +1063,14 @@ func (ac *addrConn) createTransport(connectRetryNum, ridx int, backoffDeadline, return false, nil } +func (ac *addrConn) resetConnectBackoff() { + ac.mu.Lock() + close(ac.resetBackoff) + ac.resetBackoff = make(chan struct{}) + ac.connectRetryNum = 0 + ac.mu.Unlock() +} + // Run in a goroutine to track the error in transport and create the // new transport if an error happens. It returns when the channel is closing. func (ac *addrConn) transportMonitor() { diff --git a/clientconn_test.go b/clientconn_test.go index b146a567..7fa43a56 100644 --- a/clientconn_test.go +++ b/clientconn_test.go @@ -19,6 +19,7 @@ package grpc import ( + "errors" "math" "net" "sync/atomic" @@ -722,3 +723,40 @@ func TestGetClientConnTarget(t *testing.T) { t.Fatalf("Target() = %s, want %s", cc.Target(), addr) } } + +type backoffForever struct{} + +func (b backoffForever) Backoff(int) time.Duration { return time.Duration(math.MaxInt64) } + +func TestResetConnectBackoff(t *testing.T) { + defer leakcheck.Check(t) + dials := make(chan struct{}) + dialer := func(string, time.Duration) (net.Conn, error) { + dials <- struct{}{} + return nil, errors.New("failed to fake dial") + } + cc, err := Dial("any", WithInsecure(), WithDialer(dialer), withBackoff(backoffForever{})) + if err != nil { + t.Fatalf("Dial() = _, %v; want _, nil", err) + } + defer cc.Close() + select { + case <-dials: + case <-time.NewTimer(10 * time.Second).C: + t.Fatal("Failed to call dial within 10s") + } + + select { + case <-dials: + t.Fatal("Dial called unexpectedly before resetting backoff") + case <-time.NewTimer(100 * time.Millisecond).C: + } + + cc.ResetConnectBackoff() + + select { + case <-dials: + case <-time.NewTimer(10 * time.Second).C: + t.Fatal("Failed to call dial within 10s after resetting backoff") + } +}