diff --git a/call.go b/call.go index 67604f77..ac2d99d6 100644 --- a/call.go +++ b/call.go @@ -132,20 +132,14 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli Last: true, Delay: false, } - var ( - lastErr error // record the error that happened - put func() - ) + var put func() for { var ( err error t transport.ClientTransport stream *transport.Stream ) - // TODO(zhaoq): Need a formal spec of retry strategy for non-failfast rpcs. - if lastErr != nil && c.failFast { - return toRPCErr(lastErr) - } + // TODO(zhaoq): Need a formal spec of retry strategy for non-failFast rpcs. callHdr := &transport.CallHdr{ Host: cc.authority, Method: method, @@ -155,11 +149,19 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli } t, put, err = cc.getTransport(ctx) if err != nil { - if lastErr != nil { - // This was a retry; return the error from the last attempt. - return toRPCErr(lastErr) + if err == ErrClientConnClosing { + return toRPCErr(err) } - return toRPCErr(err) + if _, ok := err.(transport.StreamError); ok { + return toRPCErr(err) + } + if _, ok := err.(transport.ConnectionError); ok { + if c.failFast { + return toRPCErr(err) + } + } + // All the remaining cases are treated as retryable. + continue } if c.traceInfo.tr != nil { c.traceInfo.tr.LazyLog(&payload{sent: true, msg: args}, true) @@ -168,28 +170,31 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli if err != nil { put() if _, ok := err.(transport.ConnectionError); ok { - lastErr = err + if c.failFast { + return toRPCErr(err) + } continue } - if lastErr != nil { - return toRPCErr(lastErr) - } return toRPCErr(err) } // Receive the response - lastErr = recvResponse(cc.dopts, t, &c, stream, reply) - if _, ok := lastErr.(transport.ConnectionError); ok { + err = recvResponse(cc.dopts, t, &c, stream, reply) + if err != nil { put() - continue + if _, ok := err.(transport.ConnectionError); ok { + if c.failFast { + return toRPCErr(err) + } + continue + } + t.CloseStream(stream, err) + return toRPCErr(err) } if c.traceInfo.tr != nil { c.traceInfo.tr.LazyLog(&payload{sent: false, msg: reply}, true) } - t.CloseStream(stream, lastErr) + t.CloseStream(stream, nil) put() - if lastErr != nil { - return toRPCErr(lastErr) - } return Errorf(stream.StatusCode(), "%s", stream.StatusDesc()) } } diff --git a/clientconn.go b/clientconn.go index dc3fc689..6e6cb38b 100644 --- a/clientconn.go +++ b/clientconn.go @@ -61,14 +61,15 @@ var ( ErrCredentialsMisuse = errors.New("grpc: the credentials require transport level security (use grpc.WithTransportAuthenticator() to set)") // ErrClientConnClosing indicates that the operation is illegal because // the ClientConn is closing. - ErrClientConnClosing = errors.New("grpc: the client connection is closing") + ErrClientConnClosing = Errorf(codes.FailedPrecondition, "grpc: the client connection is closing") // ErrClientConnTimeout indicates that the connection could not be // established or re-established within the specified timeout. ErrClientConnTimeout = errors.New("grpc: timed out trying to connect") - // ErrNetworkIP indicates that the connection is down due to some network I/O error. - ErrNetworkIO = errors.New("grpc: failed with network I/O error") - // ErrConnDrain indicates that the connection starts to be drained and does not accept any new RPCs. - ErrConnDrain = errors.New("grpc: the connection is drained") + + // errNetworkIP indicates that the connection is down due to some network I/O error. + errNetworkIO = errors.New("grpc: failed with network I/O error") + // errConnDrain indicates that the connection starts to be drained and does not accept any new RPCs. + errConnDrain = errors.New("grpc: the connection is drained") errConnClosing = errors.New("grpc: the addrConn is closing") // minimum time to give a connection to complete minConnectTimeout = 20 * time.Second @@ -301,7 +302,7 @@ func (s ConnectivityState) String() string { } } -// ClientConn represents a client connection to an RPC service. +// ClientConn represents a client connection to an RPC server. type ClientConn struct { target string watcher naming.Watcher @@ -348,7 +349,7 @@ func (cc *ClientConn) watchAddrUpdates() error { continue } cc.mu.RUnlock() - ac.tearDown(ErrConnDrain) + ac.tearDown(errConnDrain) default: grpclog.Println("Unknown update.Op ", update.Op) } @@ -528,15 +529,6 @@ func (ac *addrConn) waitForStateChange(ctx context.Context, sourceState Connecti } func (ac *addrConn) resetTransport(closeTransport bool) error { - /* - ac.cc.mu.Lock() - if ac.cc.conns == nil { - ac.cc.mu.Unlock() - return ErrClientConnClosing - } - ac.cc.conns[ac.addr] = ac - ac.cc.mu.Unlock() - */ var retries int start := time.Now() for { @@ -548,7 +540,7 @@ func (ac *addrConn) resetTransport(closeTransport bool) error { return errConnClosing } if ac.down != nil { - ac.down(ErrNetworkIO) + ac.down(errNetworkIO) ac.down = nil } ac.state = Connecting @@ -732,7 +724,7 @@ func (ac *addrConn) tearDown(err error) { ac.ready = nil } if ac.transport != nil { - if err == ErrConnDrain { + if err == errConnDrain { ac.transport.GracefulClose() } else { ac.transport.Close() diff --git a/test/end2end_test.go b/test/end2end_test.go index 78e9eac3..70f85325 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -159,7 +159,6 @@ func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (* return nil, fmt.Errorf("Unknown server name %q", serverName) } } - // Simulate some service delay. time.Sleep(time.Second) @@ -1020,14 +1019,13 @@ func testRPCTimeout(t *testing.T, e env) { } for i := -1; i <= 10; i++ { ctx, _ := context.WithTimeout(context.Background(), time.Duration(i)*time.Millisecond) - reply, err := tc.UnaryCall(ctx, req) - if grpc.Code(err) != codes.DeadlineExceeded { - t.Fatalf(`TestService/UnaryCallv(_, _) = %v, %v; want , error code: %d`, reply, err, codes.DeadlineExceeded) + if _, err := tc.UnaryCall(ctx, req); grpc.Code(err) != codes.DeadlineExceeded { + t.Fatalf("TestService/UnaryCallv(_, _) = _, %v; want , error code: %d", err, codes.DeadlineExceeded) } } } -func TestCancelX(t *testing.T) { +func TestCancel(t *testing.T) { defer leakCheck(t)() for _, e := range listTestEnv() { testCancel(t, e) @@ -1058,9 +1056,8 @@ func testCancel(t *testing.T, e env) { } ctx, cancel := context.WithCancel(context.Background()) time.AfterFunc(1*time.Millisecond, cancel) - reply, err := tc.UnaryCall(ctx, req) - if grpc.Code(err) != codes.Canceled { - t.Fatalf(`TestService/UnaryCall(_, _) = %v, %v; want , error code: %d`, reply, err, codes.Canceled) + if r, err := tc.UnaryCall(ctx, req); grpc.Code(err) != codes.Canceled { + t.Fatalf("TestService/UnaryCall(_, _) = %v, %v; want _, error code: %d", r, err, codes.Canceled) } awaitNewConnLogOutput() }