From 3c400e7fcc8719ba3d133563f9c9d126f03a80c0 Mon Sep 17 00:00:00 2001 From: ZhenLian Date: Sat, 10 Oct 2020 13:47:49 -0700 Subject: [PATCH] advancedtls: clean up test files and shared code (#3897) * advancedtls: clean up test files and shared code --- security/advancedtls/advancedtls.go | 73 +----- .../advancedtls_integration_test.go | 174 +++++--------- security/advancedtls/advancedtls_test.go | 219 +++++------------- .../internal/testutils/testutils.go | 100 ++++++++ security/advancedtls/pemfile_provider_test.go | 37 +-- 5 files changed, 234 insertions(+), 369 deletions(-) create mode 100644 security/advancedtls/internal/testutils/testutils.go diff --git a/security/advancedtls/advancedtls.go b/security/advancedtls/advancedtls.go index 74564632..ea93d640 100644 --- a/security/advancedtls/advancedtls.go +++ b/security/advancedtls/advancedtls.go @@ -28,7 +28,6 @@ import ( "fmt" "net" "reflect" - "syscall" "time" "google.golang.org/grpc/credentials" @@ -374,7 +373,7 @@ func (c advancedTLSCreds) Info() credentials.ProtocolInfo { func (c *advancedTLSCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { // Use local cfg to avoid clobbering ServerName if using multiple endpoints. - cfg := cloneTLSConfig(c.config) + cfg := credinternal.CloneTLSConfig(c.config) // We return the full authority name to users if ServerName is empty without // stripping the trailing port. if cfg.ServerName == "" { @@ -404,11 +403,11 @@ func (c *advancedTLSCreds) ClientHandshake(ctx context.Context, authority string }, } info.SPIFFEID = credinternal.SPIFFEIDFromState(conn.ConnectionState()) - return WrapSyscallConn(rawConn, conn), info, nil + return credinternal.WrapSyscallConn(rawConn, conn), info, nil } func (c *advancedTLSCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { - cfg := cloneTLSConfig(c.config) + cfg := credinternal.CloneTLSConfig(c.config) cfg.VerifyPeerCertificate = buildVerifyFunc(c, "", rawConn) conn := tls.Server(rawConn, cfg) if err := conn.Handshake(); err != nil { @@ -422,12 +421,12 @@ func (c *advancedTLSCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credenti }, } info.SPIFFEID = credinternal.SPIFFEIDFromState(conn.ConnectionState()) - return WrapSyscallConn(rawConn, conn), info, nil + return credinternal.WrapSyscallConn(rawConn, conn), info, nil } func (c *advancedTLSCreds) Clone() credentials.TransportCredentials { return &advancedTLSCreds{ - config: cloneTLSConfig(c.config), + config: credinternal.CloneTLSConfig(c.config), verifyFunc: c.verifyFunc, getRootCAs: c.getRootCAs, isClient: c.isClient, @@ -530,7 +529,7 @@ func NewClientCreds(o *ClientOptions) (credentials.TransportCredentials, error) verifyFunc: o.VerifyPeer, vType: o.VType, } - tc.config.NextProtos = appendH2ToNextProtos(tc.config.NextProtos) + tc.config.NextProtos = credinternal.AppendH2ToNextProtos(tc.config.NextProtos) return tc, nil } @@ -548,64 +547,6 @@ func NewServerCreds(o *ServerOptions) (credentials.TransportCredentials, error) verifyFunc: o.VerifyPeer, vType: o.VType, } - tc.config.NextProtos = appendH2ToNextProtos(tc.config.NextProtos) + tc.config.NextProtos = credinternal.AppendH2ToNextProtos(tc.config.NextProtos) return tc, nil } - -// TODO(ZhenLian): The code below are duplicates with gRPC-Go under -// credentials/internal. Consider refactoring in the future. -const alpnProtoStrH2 = "h2" - -func appendH2ToNextProtos(ps []string) []string { - for _, p := range ps { - if p == alpnProtoStrH2 { - return ps - } - } - ret := make([]string, 0, len(ps)+1) - ret = append(ret, ps...) - return append(ret, alpnProtoStrH2) -} - -// We give syscall.Conn a new name here since syscall.Conn and net.Conn used -// below have the same names. -type sysConn = syscall.Conn - -// syscallConn keeps reference of rawConn to support syscall.Conn for channelz. -// SyscallConn() (the method in interface syscall.Conn) is explicitly -// implemented on this type, -// -// Interface syscall.Conn is implemented by most net.Conn implementations (e.g. -// TCPConn, UnixConn), but is not part of net.Conn interface. So wrapper conns -// that embed net.Conn don't implement syscall.Conn. (Side note: tls.Conn -// doesn't embed net.Conn, so even if syscall.Conn is part of net.Conn, it won't -// help here). -type syscallConn struct { - net.Conn - // sysConn is a type alias of syscall.Conn. It's necessary because the name - // `Conn` collides with `net.Conn`. - sysConn -} - -// WrapSyscallConn tries to wrap rawConn and newConn into a net.Conn that -// implements syscall.Conn. rawConn will be used to support syscall, and newConn -// will be used for read/write. -// -// This function returns newConn if rawConn doesn't implement syscall.Conn. -func WrapSyscallConn(rawConn, newConn net.Conn) net.Conn { - sysConn, ok := rawConn.(syscall.Conn) - if !ok { - return newConn - } - return &syscallConn{ - Conn: newConn, - sysConn: sysConn, - } -} - -func cloneTLSConfig(cfg *tls.Config) *tls.Config { - if cfg == nil { - return &tls.Config{} - } - return cfg.Clone() -} diff --git a/security/advancedtls/advancedtls_integration_test.go b/security/advancedtls/advancedtls_integration_test.go index a95fa56a..3f4e7059 100644 --- a/security/advancedtls/advancedtls_integration_test.go +++ b/security/advancedtls/advancedtls_integration_test.go @@ -31,7 +31,7 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/credentials" pb "google.golang.org/grpc/examples/helloworld/helloworld" - "google.golang.org/grpc/security/advancedtls/testdata" + "google.golang.org/grpc/security/advancedtls/internal/testutils" ) var ( @@ -67,69 +67,6 @@ func (s *stageInfo) reset() { s.stage = 0 } -// certStore contains all the certificates used in the integration tests. -type certStore struct { - // clientPeer1 is the certificate sent by client to prove its identity. - // It is trusted by serverTrust1. - clientPeer1 tls.Certificate - // clientPeer2 is the certificate sent by client to prove its identity. - // It is trusted by serverTrust2. - clientPeer2 tls.Certificate - // serverPeer1 is the certificate sent by server to prove its identity. - // It is trusted by clientTrust1. - serverPeer1 tls.Certificate - // serverPeer2 is the certificate sent by server to prove its identity. - // It is trusted by clientTrust2. - serverPeer2 tls.Certificate - clientTrust1 *x509.CertPool - clientTrust2 *x509.CertPool - serverTrust1 *x509.CertPool - serverTrust2 *x509.CertPool -} - -// loadCerts function is used to load test certificates at the beginning of -// each integration test. -func (cs *certStore) loadCerts() error { - var err error - cs.clientPeer1, err = tls.LoadX509KeyPair(testdata.Path("client_cert_1.pem"), - testdata.Path("client_key_1.pem")) - if err != nil { - return err - } - cs.clientPeer2, err = tls.LoadX509KeyPair(testdata.Path("client_cert_2.pem"), - testdata.Path("client_key_2.pem")) - if err != nil { - return err - } - cs.serverPeer1, err = tls.LoadX509KeyPair(testdata.Path("server_cert_1.pem"), - testdata.Path("server_key_1.pem")) - if err != nil { - return err - } - cs.serverPeer2, err = tls.LoadX509KeyPair(testdata.Path("server_cert_2.pem"), - testdata.Path("server_key_2.pem")) - if err != nil { - return err - } - cs.clientTrust1, err = readTrustCert(testdata.Path("client_trust_cert_1.pem")) - if err != nil { - return err - } - cs.clientTrust2, err = readTrustCert(testdata.Path("client_trust_cert_2.pem")) - if err != nil { - return err - } - cs.serverTrust1, err = readTrustCert(testdata.Path("server_trust_cert_1.pem")) - if err != nil { - return err - } - cs.serverTrust2, err = readTrustCert(testdata.Path("server_trust_cert_2.pem")) - if err != nil { - return err - } - return nil -} - type greeterServer struct { pb.UnimplementedGreeterServer } @@ -183,10 +120,9 @@ func callAndVerifyWithClientConn(connCtx context.Context, msg string, creds cred // (could be change the client's trust certificate, or change custom // verification function, etc) func (s) TestEnd2End(t *testing.T) { - cs := &certStore{} - err := cs.loadCerts() - if err != nil { - t.Fatalf("failed to load certs: %v", err) + cs := &testutils.CertStore{} + if err := cs.LoadCerts(); err != nil { + t.Fatalf("cs.LoadCerts() failed, err: %v", err) } stage := &stageInfo{} for _, test := range []struct { @@ -206,38 +142,38 @@ func (s) TestEnd2End(t *testing.T) { }{ // Test Scenarios: // At initialization(stage = 0), client will be initialized with cert - // clientPeer1 and clientTrust1, server with serverPeer1 and serverTrust1. - // The mutual authentication works at the beginning, since clientPeer1 is - // trusted by serverTrust1, and serverPeer1 by clientTrust1. - // At stage 1, client changes clientPeer1 to clientPeer2. Since clientPeer2 - // is not trusted by serverTrust1, following rpc calls are expected to + // ClientCert1 and ClientTrust1, server with ServerCert1 and ServerTrust1. + // The mutual authentication works at the beginning, since ClientCert1 is + // trusted by ServerTrust1, and ServerCert1 by ClientTrust1. + // At stage 1, client changes ClientCert1 to ClientCert2. Since ClientCert2 + // is not trusted by ServerTrust1, following rpc calls are expected to // fail, while the previous rpc calls are still good because those are // already authenticated. - // At stage 2, the server changes serverTrust1 to serverTrust2, and we - // should see it again accepts the connection, since clientPeer2 is trusted - // by serverTrust2. + // At stage 2, the server changes ServerTrust1 to ServerTrust2, and we + // should see it again accepts the connection, since ClientCert2 is trusted + // by ServerTrust2. { desc: "TestClientPeerCertReloadServerTrustCertReload", clientGetCert: func(*tls.CertificateRequestInfo) (*tls.Certificate, error) { switch stage.read() { case 0: - return &cs.clientPeer1, nil + return &cs.ClientCert1, nil default: - return &cs.clientPeer2, nil + return &cs.ClientCert2, nil } }, - clientRoot: cs.clientTrust1, + clientRoot: cs.ClientTrust1, clientVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) { return &VerificationResults{}, nil }, clientVType: CertVerification, - serverCert: []tls.Certificate{cs.serverPeer1}, + serverCert: []tls.Certificate{cs.ServerCert1}, serverGetRoot: func(params *GetRootCAsParams) (*GetRootCAsResults, error) { switch stage.read() { case 0, 1: - return &GetRootCAsResults{TrustCerts: cs.serverTrust1}, nil + return &GetRootCAsResults{TrustCerts: cs.ServerTrust1}, nil default: - return &GetRootCAsResults{TrustCerts: cs.serverTrust2}, nil + return &GetRootCAsResults{TrustCerts: cs.ServerTrust2}, nil } }, serverVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) { @@ -247,25 +183,25 @@ func (s) TestEnd2End(t *testing.T) { }, // Test Scenarios: // At initialization(stage = 0), client will be initialized with cert - // clientPeer1 and clientTrust1, server with serverPeer1 and serverTrust1. - // The mutual authentication works at the beginning, since clientPeer1 is - // trusted by serverTrust1, and serverPeer1 by clientTrust1. - // At stage 1, server changes serverPeer1 to serverPeer2. Since serverPeer2 - // is not trusted by clientTrust1, following rpc calls are expected to + // ClientCert1 and ClientTrust1, server with ServerCert1 and ServerTrust1. + // The mutual authentication works at the beginning, since ClientCert1 is + // trusted by ServerTrust1, and ServerCert1 by ClientTrust1. + // At stage 1, server changes ServerCert1 to ServerCert2. Since ServerCert2 + // is not trusted by ClientTrust1, following rpc calls are expected to // fail, while the previous rpc calls are still good because those are // already authenticated. - // At stage 2, the client changes clientTrust1 to clientTrust2, and we - // should see it again accepts the connection, since serverPeer2 is trusted - // by clientTrust2. + // At stage 2, the client changes ClientTrust1 to ClientTrust2, and we + // should see it again accepts the connection, since ServerCert2 is trusted + // by ClientTrust2. { desc: "TestServerPeerCertReloadClientTrustCertReload", - clientCert: []tls.Certificate{cs.clientPeer1}, + clientCert: []tls.Certificate{cs.ClientCert1}, clientGetRoot: func(params *GetRootCAsParams) (*GetRootCAsResults, error) { switch stage.read() { case 0, 1: - return &GetRootCAsResults{TrustCerts: cs.clientTrust1}, nil + return &GetRootCAsResults{TrustCerts: cs.ClientTrust1}, nil default: - return &GetRootCAsResults{TrustCerts: cs.clientTrust2}, nil + return &GetRootCAsResults{TrustCerts: cs.ClientTrust2}, nil } }, clientVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) { @@ -275,12 +211,12 @@ func (s) TestEnd2End(t *testing.T) { serverGetCert: func(*tls.ClientHelloInfo) ([]*tls.Certificate, error) { switch stage.read() { case 0: - return []*tls.Certificate{&cs.serverPeer1}, nil + return []*tls.Certificate{&cs.ServerCert1}, nil default: - return []*tls.Certificate{&cs.serverPeer2}, nil + return []*tls.Certificate{&cs.ServerCert2}, nil } }, - serverRoot: cs.serverTrust1, + serverRoot: cs.ServerTrust1, serverVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) { return &VerificationResults{}, nil }, @@ -288,26 +224,26 @@ func (s) TestEnd2End(t *testing.T) { }, // Test Scenarios: // At initialization(stage = 0), client will be initialized with cert - // clientPeer1 and clientTrust1, server with serverPeer1 and serverTrust1. - // The mutual authentication works at the beginning, since clientPeer1 - // trusted by serverTrust1, serverPeer1 by clientTrust1, and also the - // custom verification check allows the CommonName on serverPeer1. - // At stage 1, server changes serverPeer1 to serverPeer2, and client - // changes clientTrust1 to clientTrust2. Although serverPeer2 is trusted by - // clientTrust2, our authorization check only accepts serverPeer1, and + // ClientCert1 and ClientTrust1, server with ServerCert1 and ServerTrust1. + // The mutual authentication works at the beginning, since ClientCert1 + // trusted by ServerTrust1, ServerCert1 by ClientTrust1, and also the + // custom verification check allows the CommonName on ServerCert1. + // At stage 1, server changes ServerCert1 to ServerCert2, and client + // changes ClientTrust1 to ClientTrust2. Although ServerCert2 is trusted by + // ClientTrust2, our authorization check only accepts ServerCert1, and // hence the following calls should fail. Previous connections should // not be affected. // At stage 2, the client changes authorization check to only accept - // serverPeer2. Now we should see the connection becomes normal again. + // ServerCert2. Now we should see the connection becomes normal again. { desc: "TestClientCustomVerification", - clientCert: []tls.Certificate{cs.clientPeer1}, + clientCert: []tls.Certificate{cs.ClientCert1}, clientGetRoot: func(params *GetRootCAsParams) (*GetRootCAsResults, error) { switch stage.read() { case 0: - return &GetRootCAsResults{TrustCerts: cs.clientTrust1}, nil + return &GetRootCAsResults{TrustCerts: cs.ClientTrust1}, nil default: - return &GetRootCAsResults{TrustCerts: cs.clientTrust2}, nil + return &GetRootCAsResults{TrustCerts: cs.ClientTrust2}, nil } }, clientVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) { @@ -321,12 +257,12 @@ func (s) TestEnd2End(t *testing.T) { authzCheck := false switch stage.read() { case 0, 1: - // foo.bar.com is the common name on serverPeer1 + // foo.bar.com is the common name on ServerCert1 if cert.Subject.CommonName == "foo.bar.com" { authzCheck = true } default: - // foo.bar.server2.com is the common name on serverPeer2 + // foo.bar.server2.com is the common name on ServerCert2 if cert.Subject.CommonName == "foo.bar.server2.com" { authzCheck = true } @@ -340,12 +276,12 @@ func (s) TestEnd2End(t *testing.T) { serverGetCert: func(*tls.ClientHelloInfo) ([]*tls.Certificate, error) { switch stage.read() { case 0: - return []*tls.Certificate{&cs.serverPeer1}, nil + return []*tls.Certificate{&cs.ServerCert1}, nil default: - return []*tls.Certificate{&cs.serverPeer2}, nil + return []*tls.Certificate{&cs.ServerCert2}, nil } }, - serverRoot: cs.serverTrust1, + serverRoot: cs.ServerTrust1, serverVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) { return &VerificationResults{}, nil }, @@ -353,9 +289,9 @@ func (s) TestEnd2End(t *testing.T) { }, // Test Scenarios: // At initialization(stage = 0), client will be initialized with cert - // clientPeer1 and clientTrust1, server with serverPeer1 and serverTrust1. - // The mutual authentication works at the beginning, since clientPeer1 - // trusted by serverTrust1, serverPeer1 by clientTrust1, and also the + // ClientCert1 and ClientTrust1, server with ServerCert1 and ServerTrust1. + // The mutual authentication works at the beginning, since ClientCert1 + // trusted by ServerTrust1, ServerCert1 by ClientTrust1, and also the // custom verification check on server side allows all connections. // At stage 1, server disallows the the connections by setting custom // verification check. The following calls should fail. Previous @@ -364,14 +300,14 @@ func (s) TestEnd2End(t *testing.T) { // authentications should go back to normal. { desc: "TestServerCustomVerification", - clientCert: []tls.Certificate{cs.clientPeer1}, - clientRoot: cs.clientTrust1, + clientCert: []tls.Certificate{cs.ClientCert1}, + clientRoot: cs.ClientTrust1, clientVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) { return &VerificationResults{}, nil }, clientVType: CertVerification, - serverCert: []tls.Certificate{cs.serverPeer1}, - serverRoot: cs.serverTrust1, + serverCert: []tls.Certificate{cs.ServerCert1}, + serverRoot: cs.ServerTrust1, serverVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) { switch stage.read() { case 0, 2: diff --git a/security/advancedtls/advancedtls_test.go b/security/advancedtls/advancedtls_test.go index a631ee46..a7ecf276 100644 --- a/security/advancedtls/advancedtls_test.go +++ b/security/advancedtls/advancedtls_test.go @@ -22,21 +22,17 @@ import ( "context" "crypto/tls" "crypto/x509" - "encoding/pem" "errors" "fmt" - "io/ioutil" "math/big" "net" - "reflect" - "syscall" "testing" "github.com/google/go-cmp/cmp" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/tls/certprovider" "google.golang.org/grpc/internal/grpctest" - "google.golang.org/grpc/security/advancedtls/testdata" + "google.golang.org/grpc/security/advancedtls/internal/testutils" ) type s struct { @@ -65,27 +61,26 @@ func (f fakeProvider) KeyMaterial(ctx context.Context) (*certprovider.KeyMateria if f.wantError { return nil, fmt.Errorf("bad fakeProvider") } - cs := &certStore{} - err := cs.loadCerts() - if err != nil { - return nil, fmt.Errorf("failed to load certs: %v", err) + cs := &testutils.CertStore{} + if err := cs.LoadCerts(); err != nil { + return nil, fmt.Errorf("cs.LoadCerts() failed, err: %v", err) } if f.pt == provTypeRoot && f.isClient { - return &certprovider.KeyMaterial{Roots: cs.clientTrust1}, nil + return &certprovider.KeyMaterial{Roots: cs.ClientTrust1}, nil } if f.pt == provTypeRoot && !f.isClient { - return &certprovider.KeyMaterial{Roots: cs.serverTrust1}, nil + return &certprovider.KeyMaterial{Roots: cs.ServerTrust1}, nil } if f.pt == provTypeIdentity && f.isClient { if f.wantMultiCert { - return &certprovider.KeyMaterial{Certs: []tls.Certificate{cs.clientPeer1, cs.clientPeer2}}, nil + return &certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ClientCert1, cs.ClientCert2}}, nil } - return &certprovider.KeyMaterial{Certs: []tls.Certificate{cs.clientPeer1}}, nil + return &certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ClientCert1}}, nil } if f.wantMultiCert { - return &certprovider.KeyMaterial{Certs: []tls.Certificate{cs.serverPeer1, cs.serverPeer2}}, nil + return &certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ServerCert1, cs.ServerCert2}}, nil } - return &certprovider.KeyMaterial{Certs: []tls.Certificate{cs.serverPeer1}}, nil + return &certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ServerCert1}}, nil } func (f fakeProvider) Close() {} @@ -308,13 +303,12 @@ func (s) TestServerOptionsConfigSuccessCases(t *testing.T) { } func (s) TestClientServerHandshake(t *testing.T) { - cs := &certStore{} - err := cs.loadCerts() - if err != nil { - t.Fatalf("Failed to load certs: %v", err) + cs := &testutils.CertStore{} + if err := cs.LoadCerts(); err != nil { + t.Fatalf("cs.LoadCerts() failed, err: %v", err) } getRootCAsForClient := func(params *GetRootCAsParams) (*GetRootCAsResults, error) { - return &GetRootCAsResults{TrustCerts: cs.clientTrust1}, nil + return &GetRootCAsResults{TrustCerts: cs.ClientTrust1}, nil } clientVerifyFuncGood := func(params *VerificationFuncParams) (*VerificationResults, error) { if params.ServerName == "" { @@ -331,7 +325,7 @@ func (s) TestClientServerHandshake(t *testing.T) { return nil, fmt.Errorf("custom verification function failed") } getRootCAsForServer := func(params *GetRootCAsParams) (*GetRootCAsResults, error) { - return &GetRootCAsResults{TrustCerts: cs.serverTrust1}, nil + return &GetRootCAsResults{TrustCerts: cs.ServerTrust1}, nil } serverVerifyFunc := func(params *VerificationFuncParams) (*VerificationResults, error) { if params.ServerName != "" { @@ -378,7 +372,7 @@ func (s) TestClientServerHandshake(t *testing.T) { desc: "Client has no trust cert with verifyFuncGood; server sends peer cert", clientVerifyFunc: clientVerifyFuncGood, clientVType: SkipVerification, - serverCert: []tls.Certificate{cs.serverPeer1}, + serverCert: []tls.Certificate{cs.ServerCert1}, serverVType: CertAndHostVerification, }, // Client: only set clientRoot @@ -389,10 +383,10 @@ func (s) TestClientServerHandshake(t *testing.T) { // this test suites. { desc: "Client has root cert; server sends peer cert", - clientRoot: cs.clientTrust1, + clientRoot: cs.ClientTrust1, clientVType: CertAndHostVerification, clientExpectHandshakeError: true, - serverCert: []tls.Certificate{cs.serverPeer1}, + serverCert: []tls.Certificate{cs.ServerCert1}, serverVType: CertAndHostVerification, serverExpectError: true, }, @@ -407,7 +401,7 @@ func (s) TestClientServerHandshake(t *testing.T) { clientGetRoot: getRootCAsForClient, clientVType: CertAndHostVerification, clientExpectHandshakeError: true, - serverCert: []tls.Certificate{cs.serverPeer1}, + serverCert: []tls.Certificate{cs.ServerCert1}, serverVType: CertAndHostVerification, serverExpectError: true, }, @@ -419,7 +413,7 @@ func (s) TestClientServerHandshake(t *testing.T) { clientGetRoot: getRootCAsForClient, clientVerifyFunc: clientVerifyFuncGood, clientVType: CertVerification, - serverCert: []tls.Certificate{cs.serverPeer1}, + serverCert: []tls.Certificate{cs.ServerCert1}, serverVType: CertAndHostVerification, }, // Client: set clientGetRoot and bad clientVerifyFunc function @@ -432,7 +426,7 @@ func (s) TestClientServerHandshake(t *testing.T) { clientVerifyFunc: verifyFuncBad, clientVType: CertVerification, clientExpectHandshakeError: true, - serverCert: []tls.Certificate{cs.serverPeer1}, + serverCert: []tls.Certificate{cs.ServerCert1}, serverVType: CertVerification, serverExpectError: true, }, @@ -441,13 +435,13 @@ func (s) TestClientServerHandshake(t *testing.T) { // Expected Behavior: success { desc: "Client sets peer cert, reload root function with verifyFuncGood; server sets peer cert and root cert; mutualTLS", - clientCert: []tls.Certificate{cs.clientPeer1}, + clientCert: []tls.Certificate{cs.ClientCert1}, clientGetRoot: getRootCAsForClient, clientVerifyFunc: clientVerifyFuncGood, clientVType: CertVerification, serverMutualTLS: true, - serverCert: []tls.Certificate{cs.serverPeer1}, - serverRoot: cs.serverTrust1, + serverCert: []tls.Certificate{cs.ServerCert1}, + serverRoot: cs.ServerTrust1, serverVType: CertVerification, }, // Client: set clientGetRoot, clientVerifyFunc and clientCert @@ -455,12 +449,12 @@ func (s) TestClientServerHandshake(t *testing.T) { // Expected Behavior: success { desc: "Client sets peer cert, reload root function with verifyFuncGood; Server sets peer cert, reload root function; mutualTLS", - clientCert: []tls.Certificate{cs.clientPeer1}, + clientCert: []tls.Certificate{cs.ClientCert1}, clientGetRoot: getRootCAsForClient, clientVerifyFunc: clientVerifyFuncGood, clientVType: CertVerification, serverMutualTLS: true, - serverCert: []tls.Certificate{cs.serverPeer1}, + serverCert: []tls.Certificate{cs.ServerCert1}, serverGetRoot: getRootCAsForServer, serverVType: CertVerification, }, @@ -471,12 +465,12 @@ func (s) TestClientServerHandshake(t *testing.T) { // Reason: server side reloading returns failure { desc: "Client sets peer cert, reload root function with verifyFuncGood; Server sets peer cert, bad reload root function; mutualTLS", - clientCert: []tls.Certificate{cs.clientPeer1}, + clientCert: []tls.Certificate{cs.ClientCert1}, clientGetRoot: getRootCAsForClient, clientVerifyFunc: clientVerifyFuncGood, clientVType: CertVerification, serverMutualTLS: true, - serverCert: []tls.Certificate{cs.serverPeer1}, + serverCert: []tls.Certificate{cs.ServerCert1}, serverGetRoot: getRootCAsForServerBad, serverVType: CertVerification, serverExpectError: true, @@ -487,14 +481,14 @@ func (s) TestClientServerHandshake(t *testing.T) { { desc: "Client sets reload peer/root function with verifyFuncGood; Server sets reload peer/root function with verifyFuncGood; mutualTLS", clientGetCert: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) { - return &cs.clientPeer1, nil + return &cs.ClientCert1, nil }, clientGetRoot: getRootCAsForClient, clientVerifyFunc: clientVerifyFuncGood, clientVType: CertVerification, serverMutualTLS: true, serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) { - return []*tls.Certificate{&cs.serverPeer1}, nil + return []*tls.Certificate{&cs.ServerCert1}, nil }, serverGetRoot: getRootCAsForServer, serverVerifyFunc: serverVerifyFunc, @@ -508,14 +502,14 @@ func (s) TestClientServerHandshake(t *testing.T) { { desc: "Client sends wrong peer cert; Server sets reload peer/root function with verifyFuncGood; mutualTLS", clientGetCert: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) { - return &cs.serverPeer1, nil + return &cs.ServerCert1, nil }, clientGetRoot: getRootCAsForClient, clientVerifyFunc: clientVerifyFuncGood, clientVType: CertVerification, serverMutualTLS: true, serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) { - return []*tls.Certificate{&cs.serverPeer1}, nil + return []*tls.Certificate{&cs.ServerCert1}, nil }, serverGetRoot: getRootCAsForServer, serverVerifyFunc: serverVerifyFunc, @@ -529,7 +523,7 @@ func (s) TestClientServerHandshake(t *testing.T) { { desc: "Client has wrong trust cert; Server sets reload peer/root function with verifyFuncGood; mutualTLS", clientGetCert: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) { - return &cs.clientPeer1, nil + return &cs.ClientCert1, nil }, clientGetRoot: getRootCAsForServer, clientVerifyFunc: clientVerifyFuncGood, @@ -537,7 +531,7 @@ func (s) TestClientServerHandshake(t *testing.T) { clientExpectHandshakeError: true, serverMutualTLS: true, serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) { - return []*tls.Certificate{&cs.serverPeer1}, nil + return []*tls.Certificate{&cs.ServerCert1}, nil }, serverGetRoot: getRootCAsForServer, serverVerifyFunc: serverVerifyFunc, @@ -552,14 +546,14 @@ func (s) TestClientServerHandshake(t *testing.T) { { desc: "Client sets reload peer/root function with verifyFuncGood; Server sends wrong peer cert; mutualTLS", clientGetCert: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) { - return &cs.clientPeer1, nil + return &cs.ClientCert1, nil }, clientGetRoot: getRootCAsForClient, clientVerifyFunc: clientVerifyFuncGood, clientVType: CertVerification, serverMutualTLS: true, serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) { - return []*tls.Certificate{&cs.clientPeer1}, nil + return []*tls.Certificate{&cs.ClientCert1}, nil }, serverGetRoot: getRootCAsForServer, serverVerifyFunc: serverVerifyFunc, @@ -573,7 +567,7 @@ func (s) TestClientServerHandshake(t *testing.T) { { desc: "Client sets reload peer/root function with verifyFuncGood; Server has wrong trust cert; mutualTLS", clientGetCert: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) { - return &cs.clientPeer1, nil + return &cs.ClientCert1, nil }, clientGetRoot: getRootCAsForClient, clientVerifyFunc: clientVerifyFuncGood, @@ -581,7 +575,7 @@ func (s) TestClientServerHandshake(t *testing.T) { clientExpectHandshakeError: true, serverMutualTLS: true, serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) { - return []*tls.Certificate{&cs.serverPeer1}, nil + return []*tls.Certificate{&cs.ServerCert1}, nil }, serverGetRoot: getRootCAsForClient, serverVerifyFunc: serverVerifyFunc, @@ -594,13 +588,13 @@ func (s) TestClientServerHandshake(t *testing.T) { // server custom check fails { desc: "Client sets peer cert, reload root function with verifyFuncGood; Server sets bad custom check; mutualTLS", - clientCert: []tls.Certificate{cs.clientPeer1}, + clientCert: []tls.Certificate{cs.ClientCert1}, clientGetRoot: getRootCAsForClient, clientVerifyFunc: clientVerifyFuncGood, clientVType: CertVerification, clientExpectHandshakeError: true, serverMutualTLS: true, - serverCert: []tls.Certificate{cs.serverPeer1}, + serverCert: []tls.Certificate{cs.ServerCert1}, serverGetRoot: getRootCAsForServer, serverVerifyFunc: verifyFuncBad, serverVType: CertVerification, @@ -776,24 +770,6 @@ func (s) TestClientServerHandshake(t *testing.T) { } } -func readTrustCert(fileName string) (*x509.CertPool, error) { - trustData, err := ioutil.ReadFile(fileName) - if err != nil { - return nil, err - } - trustBlock, _ := pem.Decode(trustData) - if trustBlock == nil { - return nil, err - } - trustCert, err := x509.ParseCertificate(trustBlock.Bytes) - if err != nil { - return nil, err - } - trustPool := x509.NewCertPool() - trustPool.AddCert(trustCert) - return trustPool, nil -} - func compare(a1, a2 credentials.AuthInfo) bool { if a1.AuthType() != a2.AuthType() { return false @@ -816,13 +792,13 @@ func compare(a1, a2 credentials.AuthInfo) bool { func (s) TestAdvancedTLSOverrideServerName(t *testing.T) { expectedServerName := "server.name" - clientTrustPool, err := readTrustCert(testdata.Path("client_trust_cert_1.pem")) - if err != nil { - t.Fatalf("Client is unable to load trust certs. Error: %v", err) + cs := &testutils.CertStore{} + if err := cs.LoadCerts(); err != nil { + t.Fatalf("cs.LoadCerts() failed, err: %v", err) } clientOptions := &ClientOptions{ RootOptions: RootCertificateOptions{ - RootCACerts: clientTrustPool, + RootCACerts: cs.ClientTrust1, }, ServerNameOverride: expectedServerName, } @@ -836,122 +812,33 @@ func (s) TestAdvancedTLSOverrideServerName(t *testing.T) { } } -func (s) TestTLSClone(t *testing.T) { - expectedServerName := "server.name" - clientTrustPool, err := readTrustCert(testdata.Path("client_trust_cert_1.pem")) - if err != nil { - t.Fatalf("Client is unable to load trust certs. Error: %v", err) - } - clientOptions := &ClientOptions{ - RootOptions: RootCertificateOptions{ - RootCACerts: clientTrustPool, - }, - ServerNameOverride: expectedServerName, - } - c, err := NewClientCreds(clientOptions) - if err != nil { - t.Fatalf("Failed to create new client: %v", err) - } - cc := c.Clone() - if cc.Info().ServerName != expectedServerName { - t.Fatalf("cc.Info().ServerName = %v, want %v", cc.Info().ServerName, expectedServerName) - } - cc.OverrideServerName("") - if c.Info().ServerName != expectedServerName { - t.Fatalf("Change in clone should not affect the original, "+ - "c.Info().ServerName = %v, want %v", c.Info().ServerName, expectedServerName) - } - -} - -func (s) TestAppendH2ToNextProtos(t *testing.T) { - tests := []struct { - name string - ps []string - want []string - }{ - { - name: "empty", - ps: nil, - want: []string{"h2"}, - }, - { - name: "only h2", - ps: []string{"h2"}, - want: []string{"h2"}, - }, - { - name: "with h2", - ps: []string{"alpn", "h2"}, - want: []string{"alpn", "h2"}, - }, - { - name: "no h2", - ps: []string{"alpn"}, - want: []string{"alpn", "h2"}, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := appendH2ToNextProtos(tt.ps); !reflect.DeepEqual(got, tt.want) { - t.Errorf("appendH2ToNextProtos() = %v, want %v", got, tt.want) - } - }) - } -} - -type nonSyscallConn struct { - net.Conn -} - -func (s) TestWrapSyscallConn(t *testing.T) { - sc := &syscallConn{} - nsc := &nonSyscallConn{} - - wrapConn := WrapSyscallConn(sc, nsc) - if _, ok := wrapConn.(syscall.Conn); !ok { - t.Errorf("returned conn (type %T) doesn't implement syscall.Conn, want implement", - wrapConn) - } -} - func (s) TestGetCertificatesSNI(t *testing.T) { - // Load server certificates for setting the serverGetCert callback function. - serverCert1, err := tls.LoadX509KeyPair(testdata.Path("server_cert_1.pem"), testdata.Path("server_key_1.pem")) - if err != nil { - t.Fatalf("tls.LoadX509KeyPair(server_cert_1.pem, server_key_1.pem) failed: %v", err) + cs := &testutils.CertStore{} + if err := cs.LoadCerts(); err != nil { + t.Fatalf("cs.LoadCerts() failed, err: %v", err) } - serverCert2, err := tls.LoadX509KeyPair(testdata.Path("server_cert_2.pem"), testdata.Path("server_key_2.pem")) - if err != nil { - t.Fatalf("tls.LoadX509KeyPair(server_cert_2.pem, server_key_2.pem) failed: %v", err) - } - serverCert3, err := tls.LoadX509KeyPair(testdata.Path("server_cert_3.pem"), testdata.Path("server_key_3.pem")) - if err != nil { - t.Fatalf("tls.LoadX509KeyPair(server_cert_3.pem, server_key_3.pem) failed: %v", err) - } - tests := []struct { desc string serverName string wantCert tls.Certificate }{ { - desc: "Select serverCert1", + desc: "Select ServerCert1", // "foo.bar.com" is the common name on server certificate server_cert_1.pem. serverName: "foo.bar.com", - wantCert: serverCert1, + wantCert: cs.ServerCert1, }, { - desc: "Select serverCert2", + desc: "Select ServerCert2", // "foo.bar.server2.com" is the common name on server certificate server_cert_2.pem. serverName: "foo.bar.server2.com", - wantCert: serverCert2, + wantCert: cs.ServerCert2, }, { desc: "Select serverCert3", // "google.com" is one of the DNS names on server certificate server_cert_3.pem. serverName: "google.com", - wantCert: serverCert3, + wantCert: cs.ServerPeer3, }, } for _, test := range tests { @@ -960,7 +847,7 @@ func (s) TestGetCertificatesSNI(t *testing.T) { serverOptions := &ServerOptions{ IdentityOptions: IdentityCertificateOptions{ GetIdentityCertificatesForServer: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) { - return []*tls.Certificate{&serverCert1, &serverCert2, &serverCert3}, nil + return []*tls.Certificate{&cs.ServerCert1, &cs.ServerCert2, &cs.ServerPeer3}, nil }, }, } diff --git a/security/advancedtls/internal/testutils/testutils.go b/security/advancedtls/internal/testutils/testutils.go new file mode 100644 index 00000000..665cc602 --- /dev/null +++ b/security/advancedtls/internal/testutils/testutils.go @@ -0,0 +1,100 @@ +/* + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package testutils contains helper functions for advancedtls. +package testutils + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "io/ioutil" + + "google.golang.org/grpc/security/advancedtls/testdata" +) + +// CertStore contains all the certificates used in the integration tests. +type CertStore struct { + // ClientCert1 is the certificate sent by client to prove its identity. + // It is trusted by ServerTrust1. + ClientCert1 tls.Certificate + // ClientCert2 is the certificate sent by client to prove its identity. + // It is trusted by ServerTrust2. + ClientCert2 tls.Certificate + // ServerCert1 is the certificate sent by server to prove its identity. + // It is trusted by ClientTrust1. + ServerCert1 tls.Certificate + // ServerCert2 is the certificate sent by server to prove its identity. + // It is trusted by ClientTrust2. + ServerCert2 tls.Certificate + // ServerPeer3 is the certificate sent by server to prove its identity. + ServerPeer3 tls.Certificate + // ClientTrust1 is the root certificate used on the client side. + ClientTrust1 *x509.CertPool + // ClientTrust2 is the root certificate used on the client side. + ClientTrust2 *x509.CertPool + // ServerTrust1 is the root certificate used on the server side. + ServerTrust1 *x509.CertPool + // ServerTrust2 is the root certificate used on the server side. + ServerTrust2 *x509.CertPool +} + +func readTrustCert(fileName string) (*x509.CertPool, error) { + trustData, err := ioutil.ReadFile(fileName) + if err != nil { + return nil, err + } + trustPool := x509.NewCertPool() + if !trustPool.AppendCertsFromPEM(trustData) { + return nil, fmt.Errorf("error loading trust certificates") + } + return trustPool, nil +} + +// LoadCerts function is used to load test certificates at the beginning of +// each integration test. +func (cs *CertStore) LoadCerts() error { + var err error + if cs.ClientCert1, err = tls.LoadX509KeyPair(testdata.Path("client_cert_1.pem"), testdata.Path("client_key_1.pem")); err != nil { + return err + } + if cs.ClientCert2, err = tls.LoadX509KeyPair(testdata.Path("client_cert_2.pem"), testdata.Path("client_key_2.pem")); err != nil { + return err + } + if cs.ServerCert1, err = tls.LoadX509KeyPair(testdata.Path("server_cert_1.pem"), testdata.Path("server_key_1.pem")); err != nil { + return err + } + if cs.ServerCert2, err = tls.LoadX509KeyPair(testdata.Path("server_cert_2.pem"), testdata.Path("server_key_2.pem")); err != nil { + return err + } + if cs.ServerPeer3, err = tls.LoadX509KeyPair(testdata.Path("server_cert_3.pem"), testdata.Path("server_key_3.pem")); err != nil { + return err + } + if cs.ClientTrust1, err = readTrustCert(testdata.Path("client_trust_cert_1.pem")); err != nil { + return err + } + if cs.ClientTrust2, err = readTrustCert(testdata.Path("client_trust_cert_2.pem")); err != nil { + return err + } + if cs.ServerTrust1, err = readTrustCert(testdata.Path("server_trust_cert_1.pem")); err != nil { + return err + } + if cs.ServerTrust2, err = readTrustCert(testdata.Path("server_trust_cert_2.pem")); err != nil { + return err + } + return nil +} diff --git a/security/advancedtls/pemfile_provider_test.go b/security/advancedtls/pemfile_provider_test.go index abc494bc..48e0bd2f 100644 --- a/security/advancedtls/pemfile_provider_test.go +++ b/security/advancedtls/pemfile_provider_test.go @@ -29,6 +29,7 @@ import ( "github.com/google/go-cmp/cmp" "google.golang.org/grpc/credentials/tls/certprovider" + "google.golang.org/grpc/security/advancedtls/internal/testutils" "google.golang.org/grpc/security/advancedtls/testdata" ) @@ -95,17 +96,17 @@ func (s) TestNewPEMFileProvider(t *testing.T) { // This test overwrites the credential reading function used by the watching // goroutine. It is tested under different stages: -// At stage 0, we force reading function to load clientPeer1 and serverTrust1, +// At stage 0, we force reading function to load ClientCert1 and ServerTrust1, // and see if the credentials are picked up by the watching go routine. // At stage 1, we force reading function to cause an error. The watching go // routine should log the error while leaving the credentials unchanged. -// At stage 2, we force reading function to load clientPeer2 and serverTrust2, +// At stage 2, we force reading function to load ClientCert2 and ServerTrust2, // and see if the new credentials are picked up. func (s) TestWatchingRoutineUpdates(t *testing.T) { // Load certificates. - cs := &certStore{} - if err := cs.loadCerts(); err != nil { - t.Fatalf("cs.loadCerts() failed: %v", err) + cs := &testutils.CertStore{} + if err := cs.LoadCerts(); err != nil { + t.Fatalf("cs.LoadCerts() failed, err: %v", err) } tests := []struct { desc string @@ -121,9 +122,9 @@ func (s) TestWatchingRoutineUpdates(t *testing.T) { KeyFile: "not_empty_key_file", TrustFile: "not_empty_trust_file", }, - wantKmStage0: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.clientPeer1}, Roots: cs.serverTrust1}, - wantKmStage1: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.clientPeer1}, Roots: cs.serverTrust1}, - wantKmStage2: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.clientPeer2}, Roots: cs.serverTrust2}, + wantKmStage0: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ClientCert1}, Roots: cs.ServerTrust1}, + wantKmStage1: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ClientCert1}, Roots: cs.ServerTrust1}, + wantKmStage2: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ClientCert2}, Roots: cs.ServerTrust2}, }, { desc: "use identity certs only", @@ -131,18 +132,18 @@ func (s) TestWatchingRoutineUpdates(t *testing.T) { CertFile: "not_empty_cert_file", KeyFile: "not_empty_key_file", }, - wantKmStage0: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.clientPeer1}}, - wantKmStage1: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.clientPeer1}}, - wantKmStage2: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.clientPeer2}}, + wantKmStage0: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ClientCert1}}, + wantKmStage1: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ClientCert1}}, + wantKmStage2: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ClientCert2}}, }, { desc: "use trust certs only", options: PEMFileProviderOptions{ TrustFile: "not_empty_trust_file", }, - wantKmStage0: certprovider.KeyMaterial{Roots: cs.serverTrust1}, - wantKmStage1: certprovider.KeyMaterial{Roots: cs.serverTrust1}, - wantKmStage2: certprovider.KeyMaterial{Roots: cs.serverTrust2}, + wantKmStage0: certprovider.KeyMaterial{Roots: cs.ServerTrust1}, + wantKmStage1: certprovider.KeyMaterial{Roots: cs.ServerTrust1}, + wantKmStage2: certprovider.KeyMaterial{Roots: cs.ServerTrust2}, }, } for _, test := range tests { @@ -155,11 +156,11 @@ func (s) TestWatchingRoutineUpdates(t *testing.T) { readKeyCertPairFunc = func(certFile, keyFile string) (tls.Certificate, error) { switch stage.read() { case 0: - return cs.clientPeer1, nil + return cs.ClientCert1, nil case 1: return tls.Certificate{}, fmt.Errorf("error occurred while reloading") case 2: - return cs.clientPeer2, nil + return cs.ClientCert2, nil default: return tls.Certificate{}, fmt.Errorf("test stage not supported") } @@ -171,11 +172,11 @@ func (s) TestWatchingRoutineUpdates(t *testing.T) { readTrustCertFunc = func(trustFile string) (*x509.CertPool, error) { switch stage.read() { case 0: - return cs.serverTrust1, nil + return cs.ServerTrust1, nil case 1: return nil, fmt.Errorf("error occurred while reloading") case 2: - return cs.serverTrust2, nil + return cs.ServerTrust2, nil default: return nil, fmt.Errorf("test stage not supported") }