From d15f1a4aa159e36c58e43e21e3feea418ccbc468 Mon Sep 17 00:00:00 2001 From: Easwar Swaminathan Date: Thu, 16 Apr 2020 16:30:23 -0700 Subject: [PATCH] test: Move creds related to tests to creds_test.go (#3542) --- test/creds_test.go | 410 ++++++++++++++++++++++++++++++++++++- test/end2end_test.go | 466 +------------------------------------------ 2 files changed, 408 insertions(+), 468 deletions(-) diff --git a/test/creds_test.go b/test/creds_test.go index 275165c8..65d99c75 100644 --- a/test/creds_test.go +++ b/test/creds_test.go @@ -18,15 +18,24 @@ package test -// TODO(https://github.com/grpc/grpc-go/issues/2330): move all creds related -// tests to this file. - import ( "context" + "errors" + "fmt" + "net" + "strings" "testing" + "time" "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/resolver/manual" + "google.golang.org/grpc/status" + "google.golang.org/grpc/tap" testpb "google.golang.org/grpc/test/grpc_testing" "google.golang.org/grpc/testdata" ) @@ -125,3 +134,398 @@ func (s) TestCredsBundlePerRPCCredentials(t *testing.T) { t.Fatalf("Test failed. Reason: %v", err) } } + +type clientTimeoutCreds struct { + credentials.TransportCredentials + timeoutReturned bool +} + +func (c *clientTimeoutCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + if !c.timeoutReturned { + c.timeoutReturned = true + return nil, nil, context.DeadlineExceeded + } + return rawConn, nil, nil +} + +func (c *clientTimeoutCreds) Info() credentials.ProtocolInfo { + return credentials.ProtocolInfo{} +} + +func (c *clientTimeoutCreds) Clone() credentials.TransportCredentials { + return nil +} + +func (s) TestNonFailFastRPCSucceedOnTimeoutCreds(t *testing.T) { + te := newTest(t, env{name: "timeout-cred", network: "tcp", security: "empty", balancer: "v1"}) + te.userAgent = testAppUA + te.startServer(&testServer{security: te.e.security}) + defer te.tearDown() + + cc := te.clientConn(grpc.WithTransportCredentials(&clientTimeoutCreds{})) + tc := testpb.NewTestServiceClient(cc) + // This unary call should succeed, because ClientHandshake will succeed for the second time. + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.WaitForReady(true)); err != nil { + te.t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want ", err) + } +} + +type methodTestCreds struct{} + +func (m *methodTestCreds) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { + ri, _ := credentials.RequestInfoFromContext(ctx) + return nil, status.Errorf(codes.Unknown, ri.Method) +} + +func (m *methodTestCreds) RequireTransportSecurity() bool { return false } + +func (s) TestGRPCMethodAccessibleToCredsViaContextRequestInfo(t *testing.T) { + const wantMethod = "/grpc.testing.TestService/EmptyCall" + te := newTest(t, env{name: "context-request-info", network: "tcp", balancer: "v1"}) + te.userAgent = testAppUA + te.startServer(&testServer{security: te.e.security}) + defer te.tearDown() + + cc := te.clientConn(grpc.WithPerRPCCredentials(&methodTestCreds{})) + tc := testpb.NewTestServiceClient(cc) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); status.Convert(err).Message() != wantMethod { + t.Fatalf("ss.client.EmptyCall(_, _) = _, %v; want _, _.Message()=%q", err, wantMethod) + } +} + +const clientAlwaysFailCredErrorMsg = "clientAlwaysFailCred always fails" + +type clientAlwaysFailCred struct { + credentials.TransportCredentials +} + +func (c clientAlwaysFailCred) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + return nil, nil, errors.New(clientAlwaysFailCredErrorMsg) +} +func (c clientAlwaysFailCred) Info() credentials.ProtocolInfo { + return credentials.ProtocolInfo{} +} +func (c clientAlwaysFailCred) Clone() credentials.TransportCredentials { + return nil +} + +func (s) TestFailFastRPCErrorOnBadCertificates(t *testing.T) { + te := newTest(t, env{name: "bad-cred", network: "tcp", security: "empty", balancer: "round_robin"}) + te.startServer(&testServer{security: te.e.security}) + defer te.tearDown() + + opts := []grpc.DialOption{grpc.WithTransportCredentials(clientAlwaysFailCred{})} + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + cc, err := grpc.DialContext(ctx, te.srvAddr, opts...) + if err != nil { + t.Fatalf("Dial(_) = %v, want %v", err, nil) + } + defer cc.Close() + + tc := testpb.NewTestServiceClient(cc) + for i := 0; i < 1000; i++ { + // This loop runs for at most 1 second. The first several RPCs will fail + // with Unavailable because the connection hasn't started. When the + // first connection failed with creds error, the next RPC should also + // fail with the expected error. + if _, err = tc.EmptyCall(context.Background(), &testpb.Empty{}); strings.Contains(err.Error(), clientAlwaysFailCredErrorMsg) { + return + } + time.Sleep(time.Millisecond) + } + te.t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want err.Error() contains %q", err, clientAlwaysFailCredErrorMsg) +} + +func (s) TestWaitForReadyRPCErrorOnBadCertificates(t *testing.T) { + te := newTest(t, env{name: "bad-cred", network: "tcp", security: "empty", balancer: "round_robin"}) + te.startServer(&testServer{security: te.e.security}) + defer te.tearDown() + + opts := []grpc.DialOption{grpc.WithTransportCredentials(clientAlwaysFailCred{})} + dctx, dcancel := context.WithTimeout(context.Background(), 10*time.Second) + defer dcancel() + cc, err := grpc.DialContext(dctx, te.srvAddr, opts...) + if err != nil { + t.Fatalf("Dial(_) = %v, want %v", err, nil) + } + defer cc.Close() + + tc := testpb.NewTestServiceClient(cc) + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + if _, err = tc.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); strings.Contains(err.Error(), clientAlwaysFailCredErrorMsg) { + return + } + te.t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want err.Error() contains %q", err, clientAlwaysFailCredErrorMsg) +} + +var ( + // test authdata + authdata = map[string]string{ + "test-key": "test-value", + "test-key2-bin": string([]byte{1, 2, 3}), + } +) + +type testPerRPCCredentials struct{} + +func (cr testPerRPCCredentials) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { + return authdata, nil +} + +func (cr testPerRPCCredentials) RequireTransportSecurity() bool { + return false +} + +func authHandle(ctx context.Context, info *tap.Info) (context.Context, error) { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return ctx, fmt.Errorf("didn't find metadata in context") + } + for k, vwant := range authdata { + vgot, ok := md[k] + if !ok { + return ctx, fmt.Errorf("didn't find authdata key %v in context", k) + } + if vgot[0] != vwant { + return ctx, fmt.Errorf("for key %v, got value %v, want %v", k, vgot, vwant) + } + } + return ctx, nil +} + +func (s) TestPerRPCCredentialsViaDialOptions(t *testing.T) { + for _, e := range listTestEnv() { + testPerRPCCredentialsViaDialOptions(t, e) + } +} + +func testPerRPCCredentialsViaDialOptions(t *testing.T, e env) { + te := newTest(t, e) + te.tapHandle = authHandle + te.perRPCCreds = testPerRPCCredentials{} + te.startServer(&testServer{security: e.security}) + defer te.tearDown() + + cc := te.clientConn() + tc := testpb.NewTestServiceClient(cc) + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); err != nil { + t.Fatalf("Test failed. Reason: %v", err) + } +} + +func (s) TestPerRPCCredentialsViaCallOptions(t *testing.T) { + for _, e := range listTestEnv() { + testPerRPCCredentialsViaCallOptions(t, e) + } +} + +func testPerRPCCredentialsViaCallOptions(t *testing.T, e env) { + te := newTest(t, e) + te.tapHandle = authHandle + te.startServer(&testServer{security: e.security}) + defer te.tearDown() + + cc := te.clientConn() + tc := testpb.NewTestServiceClient(cc) + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.PerRPCCredentials(testPerRPCCredentials{})); err != nil { + t.Fatalf("Test failed. Reason: %v", err) + } +} + +func (s) TestPerRPCCredentialsViaDialOptionsAndCallOptions(t *testing.T) { + for _, e := range listTestEnv() { + testPerRPCCredentialsViaDialOptionsAndCallOptions(t, e) + } +} + +func testPerRPCCredentialsViaDialOptionsAndCallOptions(t *testing.T, e env) { + te := newTest(t, e) + te.perRPCCreds = testPerRPCCredentials{} + // When credentials are provided via both dial options and call options, + // we apply both sets. + te.tapHandle = func(ctx context.Context, _ *tap.Info) (context.Context, error) { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return ctx, fmt.Errorf("couldn't find metadata in context") + } + for k, vwant := range authdata { + vgot, ok := md[k] + if !ok { + return ctx, fmt.Errorf("couldn't find metadata for key %v", k) + } + if len(vgot) != 2 { + return ctx, fmt.Errorf("len of value for key %v was %v, want 2", k, len(vgot)) + } + if vgot[0] != vwant || vgot[1] != vwant { + return ctx, fmt.Errorf("value for %v was %v, want [%v, %v]", k, vgot, vwant, vwant) + } + } + return ctx, nil + } + te.startServer(&testServer{security: e.security}) + defer te.tearDown() + + cc := te.clientConn() + tc := testpb.NewTestServiceClient(cc) + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.PerRPCCredentials(testPerRPCCredentials{})); err != nil { + t.Fatalf("Test failed. Reason: %v", err) + } +} + +const testAuthority = "test.auth.ori.ty" + +type authorityCheckCreds struct { + credentials.TransportCredentials + got string +} + +func (c *authorityCheckCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + c.got = authority + return rawConn, nil, nil +} +func (c *authorityCheckCreds) Info() credentials.ProtocolInfo { + return credentials.ProtocolInfo{} +} +func (c *authorityCheckCreds) Clone() credentials.TransportCredentials { + return c +} + +// This test makes sure that the authority client handshake gets is the endpoint +// in dial target, not the resolved ip address. +func (s) TestCredsHandshakeAuthority(t *testing.T) { + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatal(err) + } + cred := &authorityCheckCreds{} + s := grpc.NewServer() + go s.Serve(lis) + defer s.Stop() + + r, rcleanup := manual.GenerateAndRegisterManualResolver() + defer rcleanup() + + cc, err := grpc.Dial(r.Scheme()+":///"+testAuthority, grpc.WithTransportCredentials(cred)) + if err != nil { + t.Fatalf("grpc.Dial(%q) = %v", lis.Addr().String(), err) + } + defer cc.Close() + r.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: lis.Addr().String()}}}) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + for { + s := cc.GetState() + if s == connectivity.Ready { + break + } + if !cc.WaitForStateChange(ctx, s) { + t.Fatalf("ClientConn is not ready after 100 ms") + } + } + + if cred.got != testAuthority { + t.Fatalf("client creds got authority: %q, want: %q", cred.got, testAuthority) + } +} + +// This test makes sure that the authority client handshake gets is the endpoint +// of the ServerName of the address when it is set. +func (s) TestCredsHandshakeServerNameAuthority(t *testing.T) { + const testServerName = "test.server.name" + + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatal(err) + } + cred := &authorityCheckCreds{} + s := grpc.NewServer() + go s.Serve(lis) + defer s.Stop() + + r, rcleanup := manual.GenerateAndRegisterManualResolver() + defer rcleanup() + + cc, err := grpc.Dial(r.Scheme()+":///"+testAuthority, grpc.WithTransportCredentials(cred)) + if err != nil { + t.Fatalf("grpc.Dial(%q) = %v", lis.Addr().String(), err) + } + defer cc.Close() + r.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: lis.Addr().String(), ServerName: testServerName}}}) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + for { + s := cc.GetState() + if s == connectivity.Ready { + break + } + if !cc.WaitForStateChange(ctx, s) { + t.Fatalf("ClientConn is not ready after 100 ms") + } + } + + if cred.got != testServerName { + t.Fatalf("client creds got authority: %q, want: %q", cred.got, testAuthority) + } +} + +type serverDispatchCred struct { + rawConnCh chan net.Conn +} + +func (c *serverDispatchCred) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + return rawConn, nil, nil +} +func (c *serverDispatchCred) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + select { + case c.rawConnCh <- rawConn: + default: + } + return nil, nil, credentials.ErrConnDispatched +} +func (c *serverDispatchCred) Info() credentials.ProtocolInfo { + return credentials.ProtocolInfo{} +} +func (c *serverDispatchCred) Clone() credentials.TransportCredentials { + return nil +} +func (c *serverDispatchCred) OverrideServerName(s string) error { + return nil +} +func (c *serverDispatchCred) getRawConn() net.Conn { + return <-c.rawConnCh +} + +func (s) TestServerCredsDispatch(t *testing.T) { + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatal(err) + } + cred := &serverDispatchCred{ + rawConnCh: make(chan net.Conn, 1), + } + s := grpc.NewServer(grpc.Creds(cred)) + go s.Serve(lis) + defer s.Stop() + + cc, err := grpc.Dial(lis.Addr().String(), grpc.WithTransportCredentials(cred)) + if err != nil { + t.Fatalf("grpc.Dial(%q) = %v", lis.Addr().String(), err) + } + defer cc.Close() + + rawConn := cred.getRawConn() + // Give grpc a chance to see the error and potentially close the connection. + // And check that connection is not closed after that. + time.Sleep(100 * time.Millisecond) + // Check rawConn is not closed. + if n, err := rawConn.Write([]byte{0}); n <= 0 || err != nil { + t.Errorf("Read() = %v, %v; want n>0, ", n, err) + } +} diff --git a/test/end2end_test.go b/test/end2end_test.go index 19ba39c9..a32e9341 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -619,15 +619,12 @@ func (te *test) listenAndServe(ts testpb.TestServiceServer, listen func(network, if err != nil { te.t.Fatalf("Failed to listen: %v", err) } - switch te.e.security { - case "tls": + if te.e.security == "tls" { creds, err := credentials.NewServerTLSFromFile(testdata.Path("server1.pem"), testdata.Path("server1.key")) if err != nil { te.t.Fatalf("Failed to generate credentials %v", err) } sopts = append(sopts, grpc.Creds(creds)) - case "clientTimeoutCreds": - sopts = append(sopts, grpc.Creds(&clientTimeoutCreds{})) } sopts = append(sopts, te.customServerOptions...) s := grpc.NewServer(sopts...) @@ -803,8 +800,6 @@ func (te *test) configDial(opts ...grpc.DialOption) ([]grpc.DialOption, string) te.t.Fatalf("Failed to load credentials: %v", err) } opts = append(opts, grpc.WithTransportCredentials(creds)) - case "clientTimeoutCreds": - opts = append(opts, grpc.WithTransportCredentials(&clientTimeoutCreds{})) case "empty": // Don't add any transport creds option. default: @@ -4720,250 +4715,6 @@ func testClientResourceExhaustedCancelFullDuplex(t *testing.T, e env) { } } -type clientTimeoutCreds struct { - timeoutReturned bool -} - -func (c *clientTimeoutCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { - if !c.timeoutReturned { - c.timeoutReturned = true - return nil, nil, context.DeadlineExceeded - } - return rawConn, nil, nil -} -func (c *clientTimeoutCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { - return rawConn, nil, nil -} -func (c *clientTimeoutCreds) Info() credentials.ProtocolInfo { - return credentials.ProtocolInfo{} -} -func (c *clientTimeoutCreds) Clone() credentials.TransportCredentials { - return nil -} -func (c *clientTimeoutCreds) OverrideServerName(s string) error { - return nil -} - -func (s) TestNonFailFastRPCSucceedOnTimeoutCreds(t *testing.T) { - te := newTest(t, env{name: "timeout-cred", network: "tcp", security: "clientTimeoutCreds", balancer: "v1"}) - te.userAgent = testAppUA - te.startServer(&testServer{security: te.e.security}) - defer te.tearDown() - - cc := te.clientConn() - tc := testpb.NewTestServiceClient(cc) - // This unary call should succeed, because ClientHandshake will succeed for the second time. - if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.WaitForReady(true)); err != nil { - te.t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want ", err) - } -} - -type serverDispatchCred struct { - rawConnCh chan net.Conn -} - -func newServerDispatchCred() *serverDispatchCred { - return &serverDispatchCred{ - rawConnCh: make(chan net.Conn, 1), - } -} -func (c *serverDispatchCred) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { - return rawConn, nil, nil -} -func (c *serverDispatchCred) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { - select { - case c.rawConnCh <- rawConn: - default: - } - return nil, nil, credentials.ErrConnDispatched -} -func (c *serverDispatchCred) Info() credentials.ProtocolInfo { - return credentials.ProtocolInfo{} -} -func (c *serverDispatchCred) Clone() credentials.TransportCredentials { - return nil -} -func (c *serverDispatchCred) OverrideServerName(s string) error { - return nil -} -func (c *serverDispatchCred) getRawConn() net.Conn { - return <-c.rawConnCh -} - -func (s) TestServerCredsDispatch(t *testing.T) { - lis, err := net.Listen("tcp", "localhost:0") - if err != nil { - t.Fatalf("Failed to listen: %v", err) - } - cred := newServerDispatchCred() - s := grpc.NewServer(grpc.Creds(cred)) - go s.Serve(lis) - defer s.Stop() - - cc, err := grpc.Dial(lis.Addr().String(), grpc.WithTransportCredentials(cred)) - if err != nil { - t.Fatalf("grpc.Dial(%q) = %v", lis.Addr().String(), err) - } - defer cc.Close() - - rawConn := cred.getRawConn() - // Give grpc a chance to see the error and potentially close the connection. - // And check that connection is not closed after that. - time.Sleep(100 * time.Millisecond) - // Check rawConn is not closed. - if n, err := rawConn.Write([]byte{0}); n <= 0 || err != nil { - t.Errorf("Read() = %v, %v; want n>0, ", n, err) - } -} - -type authorityCheckCreds struct { - got string -} - -func (c *authorityCheckCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { - return rawConn, nil, nil -} -func (c *authorityCheckCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { - c.got = authority - return rawConn, nil, nil -} -func (c *authorityCheckCreds) Info() credentials.ProtocolInfo { - return credentials.ProtocolInfo{} -} -func (c *authorityCheckCreds) Clone() credentials.TransportCredentials { - return c -} -func (c *authorityCheckCreds) OverrideServerName(s string) error { - return nil -} - -// This test makes sure that the authority client handshake gets is the endpoint -// in dial target, not the resolved ip address. -func (s) TestCredsHandshakeAuthority(t *testing.T) { - const testAuthority = "test.auth.ori.ty" - - lis, err := net.Listen("tcp", "localhost:0") - if err != nil { - t.Fatalf("Failed to listen: %v", err) - } - cred := &authorityCheckCreds{} - s := grpc.NewServer() - go s.Serve(lis) - defer s.Stop() - - r, rcleanup := manual.GenerateAndRegisterManualResolver() - defer rcleanup() - - cc, err := grpc.Dial(r.Scheme()+":///"+testAuthority, grpc.WithTransportCredentials(cred)) - if err != nil { - t.Fatalf("grpc.Dial(%q) = %v", lis.Addr().String(), err) - } - defer cc.Close() - r.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: lis.Addr().String()}}}) - - ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) - defer cancel() - for { - s := cc.GetState() - if s == connectivity.Ready { - break - } - if !cc.WaitForStateChange(ctx, s) { - // ctx got timeout or canceled. - t.Fatalf("ClientConn is not ready after 100 ms") - } - } - - if cred.got != testAuthority { - t.Fatalf("client creds got authority: %q, want: %q", cred.got, testAuthority) - } -} - -// This test makes sure that the authority client handshake gets is the endpoint -// of the ServerName of the address when it is set. -func (s) TestCredsHandshakeServerNameAuthority(t *testing.T) { - const testAuthority = "test.auth.ori.ty" - const testServerName = "test.server.name" - - lis, err := net.Listen("tcp", "localhost:0") - if err != nil { - t.Fatalf("Failed to listen: %v", err) - } - cred := &authorityCheckCreds{} - s := grpc.NewServer() - go s.Serve(lis) - defer s.Stop() - - r, rcleanup := manual.GenerateAndRegisterManualResolver() - defer rcleanup() - - cc, err := grpc.Dial(r.Scheme()+":///"+testAuthority, grpc.WithTransportCredentials(cred)) - if err != nil { - t.Fatalf("grpc.Dial(%q) = %v", lis.Addr().String(), err) - } - defer cc.Close() - r.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: lis.Addr().String(), ServerName: testServerName}}}) - - ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) - defer cancel() - for { - s := cc.GetState() - if s == connectivity.Ready { - break - } - if !cc.WaitForStateChange(ctx, s) { - // ctx got timeout or canceled. - t.Fatalf("ClientConn is not ready after 100 ms") - } - } - - if cred.got != testServerName { - t.Fatalf("client creds got authority: %q, want: %q", cred.got, testAuthority) - } -} - -type clientFailCreds struct{} - -func (c *clientFailCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { - return rawConn, nil, nil -} -func (c *clientFailCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { - return nil, nil, fmt.Errorf("client handshake fails with fatal error") -} -func (c *clientFailCreds) Info() credentials.ProtocolInfo { - return credentials.ProtocolInfo{} -} -func (c *clientFailCreds) Clone() credentials.TransportCredentials { - return c -} -func (c *clientFailCreds) OverrideServerName(s string) error { - return nil -} - -// This test makes sure that failfast RPCs fail if client handshake fails with -// fatal errors. -func (s) TestFailfastRPCFailOnFatalHandshakeError(t *testing.T) { - lis, err := net.Listen("tcp", "localhost:0") - if err != nil { - t.Fatalf("Failed to listen: %v", err) - } - defer lis.Close() - - cc, err := grpc.Dial("passthrough:///"+lis.Addr().String(), grpc.WithTransportCredentials(&clientFailCreds{})) - if err != nil { - t.Fatalf("grpc.Dial(_) = %v", err) - } - defer cc.Close() - - tc := testpb.NewTestServiceClient(cc) - // This unary call should fail, but not timeout. - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(false)); status.Code(err) != codes.Unavailable { - t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want ", err) - } -} - func (s) TestFlowControlLogicalRace(t *testing.T) { // Test for a regression of https://github.com/grpc/grpc-go/issues/632, // and other flow control bugs. @@ -5633,120 +5384,6 @@ func testConfigurableWindowSize(t *testing.T, e env, wc windowSizeConfig) { } } -var ( - // test authdata - authdata = map[string]string{ - "test-key": "test-value", - "test-key2-bin": string([]byte{1, 2, 3}), - } -) - -type testPerRPCCredentials struct{} - -func (cr testPerRPCCredentials) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { - return authdata, nil -} - -func (cr testPerRPCCredentials) RequireTransportSecurity() bool { - return false -} - -func authHandle(ctx context.Context, info *tap.Info) (context.Context, error) { - md, ok := metadata.FromIncomingContext(ctx) - if !ok { - return ctx, fmt.Errorf("didn't find metadata in context") - } - for k, vwant := range authdata { - vgot, ok := md[k] - if !ok { - return ctx, fmt.Errorf("didn't find authdata key %v in context", k) - } - if vgot[0] != vwant { - return ctx, fmt.Errorf("for key %v, got value %v, want %v", k, vgot, vwant) - } - } - return ctx, nil -} - -func (s) TestPerRPCCredentialsViaDialOptions(t *testing.T) { - for _, e := range listTestEnv() { - testPerRPCCredentialsViaDialOptions(t, e) - } -} - -func testPerRPCCredentialsViaDialOptions(t *testing.T, e env) { - te := newTest(t, e) - te.tapHandle = authHandle - te.perRPCCreds = testPerRPCCredentials{} - te.startServer(&testServer{security: e.security}) - defer te.tearDown() - - cc := te.clientConn() - tc := testpb.NewTestServiceClient(cc) - if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); err != nil { - t.Fatalf("Test failed. Reason: %v", err) - } -} - -func (s) TestPerRPCCredentialsViaCallOptions(t *testing.T) { - for _, e := range listTestEnv() { - testPerRPCCredentialsViaCallOptions(t, e) - } -} - -func testPerRPCCredentialsViaCallOptions(t *testing.T, e env) { - te := newTest(t, e) - te.tapHandle = authHandle - te.startServer(&testServer{security: e.security}) - defer te.tearDown() - - cc := te.clientConn() - tc := testpb.NewTestServiceClient(cc) - if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.PerRPCCredentials(testPerRPCCredentials{})); err != nil { - t.Fatalf("Test failed. Reason: %v", err) - } -} - -func (s) TestPerRPCCredentialsViaDialOptionsAndCallOptions(t *testing.T) { - for _, e := range listTestEnv() { - testPerRPCCredentialsViaDialOptionsAndCallOptions(t, e) - } -} - -func testPerRPCCredentialsViaDialOptionsAndCallOptions(t *testing.T, e env) { - te := newTest(t, e) - te.perRPCCreds = testPerRPCCredentials{} - // When credentials are provided via both dial options and call options, - // we apply both sets. - te.tapHandle = func(ctx context.Context, _ *tap.Info) (context.Context, error) { - md, ok := metadata.FromIncomingContext(ctx) - if !ok { - return ctx, fmt.Errorf("couldn't find metadata in context") - } - for k, vwant := range authdata { - vgot, ok := md[k] - if !ok { - return ctx, fmt.Errorf("couldn't find metadata for key %v", k) - } - if len(vgot) != 2 { - return ctx, fmt.Errorf("len of value for key %v was %v, want 2", k, len(vgot)) - } - if vgot[0] != vwant || vgot[1] != vwant { - return ctx, fmt.Errorf("value for %v was %v, want [%v, %v]", k, vgot, vwant, vwant) - } - } - return ctx, nil - } - te.startServer(&testServer{security: e.security}) - defer te.tearDown() - - cc := te.clientConn() - tc := testpb.NewTestServiceClient(cc) - if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.PerRPCCredentials(testPerRPCCredentials{})); err != nil { - t.Fatalf("Test failed. Reason: %v", err) - } -} - func (s) TestWaitForReadyConnection(t *testing.T) { for _, e := range listTestEnv() { testWaitForReadyConnection(t, e) @@ -6637,79 +6274,6 @@ func testClientDoesntDeadlockWhileWritingErrornousLargeMessages(t *testing.T, e wg.Wait() } -const clientAlwaysFailCredErrorMsg = "clientAlwaysFailCred always fails" - -var errClientAlwaysFailCred = errors.New(clientAlwaysFailCredErrorMsg) - -type clientAlwaysFailCred struct{} - -func (c clientAlwaysFailCred) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { - return nil, nil, errClientAlwaysFailCred -} -func (c clientAlwaysFailCred) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { - return rawConn, nil, nil -} -func (c clientAlwaysFailCred) Info() credentials.ProtocolInfo { - return credentials.ProtocolInfo{} -} -func (c clientAlwaysFailCred) Clone() credentials.TransportCredentials { - return nil -} -func (c clientAlwaysFailCred) OverrideServerName(s string) error { - return nil -} - -func (s) TestFailFastRPCErrorOnBadCertificates(t *testing.T) { - te := newTest(t, env{name: "bad-cred", network: "tcp", security: "clientAlwaysFailCred", balancer: "round_robin"}) - te.startServer(&testServer{security: te.e.security}) - defer te.tearDown() - - opts := []grpc.DialOption{grpc.WithTransportCredentials(clientAlwaysFailCred{})} - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - cc, err := grpc.DialContext(ctx, te.srvAddr, opts...) - if err != nil { - t.Fatalf("Dial(_) = %v, want %v", err, nil) - } - defer cc.Close() - - tc := testpb.NewTestServiceClient(cc) - for i := 0; i < 1000; i++ { - // This loop runs for at most 1 second. The first several RPCs will fail - // with Unavailable because the connection hasn't started. When the - // first connection failed with creds error, the next RPC should also - // fail with the expected error. - if _, err = tc.EmptyCall(context.Background(), &testpb.Empty{}); strings.Contains(err.Error(), clientAlwaysFailCredErrorMsg) { - return - } - time.Sleep(time.Millisecond) - } - te.t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want err.Error() contains %q", err, clientAlwaysFailCredErrorMsg) -} - -func (s) TestWaitForReadyRPCErrorOnBadCertificates(t *testing.T) { - te := newTest(t, env{name: "bad-cred", network: "tcp", security: "clientAlwaysFailCred", balancer: "round_robin"}) - te.startServer(&testServer{security: te.e.security}) - defer te.tearDown() - - opts := []grpc.DialOption{grpc.WithTransportCredentials(clientAlwaysFailCred{})} - dctx, dcancel := context.WithTimeout(context.Background(), 10*time.Second) - defer dcancel() - cc, err := grpc.DialContext(dctx, te.srvAddr, opts...) - if err != nil { - t.Fatalf("Dial(_) = %v, want %v", err, nil) - } - defer cc.Close() - - tc := testpb.NewTestServiceClient(cc) - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) - defer cancel() - if _, err = tc.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); strings.Contains(err.Error(), clientAlwaysFailCredErrorMsg) { - return - } - te.t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want err.Error() contains %q", err, clientAlwaysFailCredErrorMsg) -} - func (s) TestRPCTimeout(t *testing.T) { for _, e := range listTestEnv() { testRPCTimeout(t, e) @@ -6747,7 +6311,6 @@ func testRPCTimeout(t *testing.T, e env) { } func (s) TestDisabledIOBuffers(t *testing.T) { - payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, int32(60000)) if err != nil { t.Fatalf("Failed to create payload: %v", err) @@ -7439,33 +7002,6 @@ func parseCfg(r *manual.Resolver, s string) *serviceconfig.ParseResult { return g } -type methodTestCreds struct{} - -func (m methodTestCreds) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { - ri, _ := credentials.RequestInfoFromContext(ctx) - return nil, status.Errorf(codes.Unknown, ri.Method) -} - -func (m methodTestCreds) RequireTransportSecurity() bool { - return false -} - -func (s) TestGRPCMethodAccessibleToCredsViaContextRequestInfo(t *testing.T) { - const wantMethod = "/grpc.testing.TestService/EmptyCall" - ss := &stubServer{} - if err := ss.Start(nil, grpc.WithPerRPCCredentials(methodTestCreds{})); err != nil { - t.Fatalf("Error starting endpoint server: %v", err) - } - defer ss.Stop() - - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - - if _, err := ss.client.EmptyCall(ctx, &testpb.Empty{}); status.Convert(err).Message() != wantMethod { - t.Fatalf("ss.client.EmptyCall(_, _) = _, %v; want _, _.Message()=%q", err, wantMethod) - } -} - func (s) TestClientCancellationPropagatesUnary(t *testing.T) { wg := &sync.WaitGroup{} called, done := make(chan struct{}), make(chan struct{})