diff --git a/test/end2end_test.go b/test/end2end_test.go index c76c58b1..e30830a6 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -450,6 +450,8 @@ func (te *test) startServer(ts testpb.TestServiceServer) { sopts = append(sopts, grpc.Creds(creds)) case "clientAlwaysFailCred": sopts = append(sopts, grpc.Creds(clientAlwaysFailCred{})) + case "clientTimeoutCreds": + sopts = append(sopts, grpc.Creds(&clientTimeoutCreds{})) } s := grpc.NewServer(sopts...) te.srv = s @@ -501,6 +503,8 @@ func (te *test) clientConn() *grpc.ClientConn { opts = append(opts, grpc.WithTransportCredentials(creds)) case "clientAlwaysFailCred": opts = append(opts, grpc.WithTransportCredentials(clientAlwaysFailCred{})) + case "clientTimeoutCreds": + opts = append(opts, grpc.WithTransportCredentials(&clientTimeoutCreds{})) default: opts = append(opts, grpc.WithInsecure()) } @@ -2321,7 +2325,7 @@ func TestFailFastRPCErrorOnBadCertificates(t *testing.T) { cc := te.clientConn() tc := testpb.NewTestServiceClient(cc) if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); !strings.Contains(err.Error(), clientAlwaysFailCredErrorMsg) { - te.t.Fatalf("Dial(%q) = %v, want err.Error() contains %q", te.srvAddr, err, clientAlwaysFailCredErrorMsg) + te.t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want err.Error() contains %q", err, clientAlwaysFailCredErrorMsg) } } @@ -2333,7 +2337,7 @@ func TestFailFastRPCWithNoBalancerErrorOnBadCertificates(t *testing.T) { cc := te.clientConn() tc := testpb.NewTestServiceClient(cc) if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); !strings.Contains(err.Error(), clientAlwaysFailCredErrorMsg) { - te.t.Fatalf("Dial(%q) = %v, want err.Error() contains %q", te.srvAddr, err, clientAlwaysFailCredErrorMsg) + te.t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want err.Error() contains %q", err, clientAlwaysFailCredErrorMsg) } } @@ -2345,7 +2349,39 @@ func TestNonFailFastRPCWithNoBalancerErrorOnBadCertificates(t *testing.T) { cc := te.clientConn() tc := testpb.NewTestServiceClient(cc) if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); !strings.Contains(err.Error(), clientAlwaysFailCredErrorMsg) { - te.t.Fatalf("Dial(%q) = %v, want err.Error() contains %q", te.srvAddr, err, clientAlwaysFailCredErrorMsg) + te.t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want err.Error() contains %q", err, clientAlwaysFailCredErrorMsg) + } +} + +type clientTimeoutCreds struct { + timeoutReturned bool +} + +func (c *clientTimeoutCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + if !c.timeoutReturned { + c.timeoutReturned = true + return nil, nil, context.DeadlineExceeded + } + return rawConn, nil, nil +} +func (c *clientTimeoutCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + return rawConn, nil, nil +} +func (c *clientTimeoutCreds) Info() credentials.ProtocolInfo { + return credentials.ProtocolInfo{} +} + +func TestNonFailFastRPCSucceedOnTimeoutCreds(t *testing.T) { + te := newTest(t, env{name: "timeout-cred", network: "tcp", security: "clientTimeoutCreds", balancer: false}) + te.userAgent = testAppUA + te.startServer(&testServer{security: "clientTimeoutCreds"}) + defer te.tearDown() + + cc := te.clientConn() + tc := testpb.NewTestServiceClient(cc) + // This unary call should succeed, because ClientHandshake will succeed for the second time. + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil { + te.t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want ", err) } }