From c72b08a7743fa6a04afd04d5892c527cf8059a2a Mon Sep 17 00:00:00 2001
From: Menghan Li <menghanl@google.com>
Date: Tue, 16 Aug 2016 14:16:42 -0700
Subject: [PATCH] Change errors returned by ac.wait()

---
 call.go       |  4 ++--
 clientconn.go | 54 +++++++++++++++++++++++++++------------------------
 stream.go     |  4 ++--
 3 files changed, 33 insertions(+), 29 deletions(-)

diff --git a/call.go b/call.go
index 5fba11eb..fea07998 100644
--- a/call.go
+++ b/call.go
@@ -170,9 +170,9 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
 			if _, ok := err.(*rpcError); ok {
 				return err
 			}
-			if err == errConnClosing {
+			if err == errConnClosing || err == errConnUnavailable {
 				if c.failFast {
-					return Errorf(codes.Unavailable, "%v", errConnClosing)
+					return Errorf(codes.Unavailable, "%v", err)
 				}
 				continue
 			}
diff --git a/clientconn.go b/clientconn.go
index d16ea201..9d2324d9 100644
--- a/clientconn.go
+++ b/clientconn.go
@@ -73,7 +73,9 @@ var (
 	errConnDrain = errors.New("grpc: the connection is drained")
 	// errConnClosing indicates that the connection is closing.
 	errConnClosing = errors.New("grpc: the connection is closing")
-	errNoAddr      = errors.New("grpc: there is no address available to dial")
+	// errConnUnavailable indicates that the connection is unavailable.
+	errConnUnavailable = errors.New("grpc: the connection is unavailable")
+	errNoAddr          = errors.New("grpc: there is no address available to dial")
 	// minimum time to give a connection to complete
 	minConnectTimeout = 20 * time.Second
 )
@@ -501,11 +503,7 @@ func (cc *ClientConn) getTransport(ctx context.Context, opts BalancerGetOptions)
 		}
 		return nil, nil, errConnClosing
 	}
-	// ac.wait should block on transient failure only if balancer is nil and RPC is non-failfast.
-	//  - If RPC is failfast, ac.wait should not block.
-	//  - If balancer is not nil, ac.wait should return errConnClosing on transient failure
-	//    so that non-failfast RPCs will try to get a new transport instead of waiting on ac.
-	t, err := ac.wait(ctx, cc.dopts.balancer == nil && opts.BlockingWait)
+	t, err := ac.wait(ctx, cc.dopts.balancer != nil, !opts.BlockingWait)
 	if err != nil {
 		if put != nil {
 			put()
@@ -757,36 +755,42 @@ func (ac *addrConn) transportMonitor() {
 }
 
 // wait blocks until i) the new transport is up or ii) ctx is done or iii) ac is closed or
-// iv) transport is in TransientFailure and blocking is false.
-func (ac *addrConn) wait(ctx context.Context, blocking bool) (transport.ClientTransport, error) {
+// iv) transport is in TransientFailure and there's no balancer/failfast is true.
+func (ac *addrConn) wait(ctx context.Context, hasBalancer, failfast bool) (transport.ClientTransport, error) {
 	for {
 		ac.mu.Lock()
 		switch {
 		case ac.state == Shutdown:
-			err := ac.tearDownErr
+			if failfast || !hasBalancer {
+				// RPC is failfast or balancer is nil. This RPC should fail with ac.tearDownErr.
+				err := ac.tearDownErr
+				ac.mu.Unlock()
+				return nil, err
+			}
 			ac.mu.Unlock()
-			return nil, err
+			return nil, errConnClosing
 		case ac.state == Ready:
 			ct := ac.transport
 			ac.mu.Unlock()
 			return ct, nil
-		case ac.state == TransientFailure && !blocking:
-			ac.mu.Unlock()
-			return nil, errConnClosing
-		default:
-			ready := ac.ready
-			if ready == nil {
-				ready = make(chan struct{})
-				ac.ready = ready
-			}
-			ac.mu.Unlock()
-			select {
-			case <-ctx.Done():
-				return nil, toRPCErr(ctx.Err())
-			// Wait until the new transport is ready or failed.
-			case <-ready:
+		case ac.state == TransientFailure:
+			if failfast || hasBalancer {
+				ac.mu.Unlock()
+				return nil, errConnUnavailable
 			}
 		}
+		ready := ac.ready
+		if ready == nil {
+			ready = make(chan struct{})
+			ac.ready = ready
+		}
+		ac.mu.Unlock()
+		select {
+		case <-ctx.Done():
+			return nil, toRPCErr(ctx.Err())
+		// Wait until the new transport is ready or failed.
+		case <-ready:
+		}
 	}
 }
 
diff --git a/stream.go b/stream.go
index c1b07e89..51df3f01 100644
--- a/stream.go
+++ b/stream.go
@@ -146,9 +146,9 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
 			if _, ok := err.(*rpcError); ok {
 				return nil, err
 			}
-			if err == errConnClosing {
+			if err == errConnClosing || err == errConnUnavailable {
 				if c.failFast {
-					return nil, Errorf(codes.Unavailable, "%v", errConnClosing)
+					return nil, Errorf(codes.Unavailable, "%v", err)
 				}
 				continue
 			}