diff --git a/clientconn.go b/clientconn.go index c374a595..e4e72368 100644 --- a/clientconn.go +++ b/clientconn.go @@ -966,6 +966,12 @@ func (ac *addrConn) resetTransport() error { newTransport, err := transport.NewClientTransport(ac.cc.ctx, sinfo, copts, timeout) if err != nil { if e, ok := err.(transport.ConnectionError); ok && !e.Temporary() { + ac.mu.Lock() + if ac.state != connectivity.Shutdown { + ac.state = connectivity.TransientFailure + ac.cc.handleSubConnStateChange(ac.acbw, ac.state) + } + ac.mu.Unlock() return err } grpclog.Warningf("grpc: addrConn.resetTransport failed to create client transport: %v; Reconnecting to %v", err, addr) diff --git a/test/end2end_test.go b/test/end2end_test.go index e2d1d2d6..e7a1f1bd 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -4340,6 +4340,50 @@ func TestCredsHandshakeAuthority(t *testing.T) { } } +type clientFailCreds struct { + got string +} + +func (c *clientFailCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + return rawConn, nil, nil +} +func (c *clientFailCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + return nil, nil, fmt.Errorf("client handshake fails with fatal error") +} +func (c *clientFailCreds) Info() credentials.ProtocolInfo { + return credentials.ProtocolInfo{} +} +func (c *clientFailCreds) Clone() credentials.TransportCredentials { + return c +} +func (c *clientFailCreds) OverrideServerName(s string) error { + return nil +} + +// This test makes sure that failfast RPCs fail if client handshake fails with +// fatal errors. +func TestFailfastRPCFailOnFatalHandshakeError(t *testing.T) { + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("Failed to listen: %v", err) + } + defer lis.Close() + + cc, err := grpc.Dial("passthrough:///"+lis.Addr().String(), grpc.WithTransportCredentials(&clientFailCreds{})) + if err != nil { + t.Fatalf("grpc.Dial(_) = %v", err) + } + defer cc.Close() + + tc := testpb.NewTestServiceClient(cc) + // This unary call should fail, but not timeout. + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(true)); grpc.Code(err) != codes.Unavailable { + t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want ", err) + } +} + func TestFlowControlLogicalRace(t *testing.T) { // Test for a regression of https://github.com/grpc/grpc-go/issues/632, // and other flow control bugs.