From eedec2c1c3fe8a87d6c966dba773c3c6bb984b12 Mon Sep 17 00:00:00 2001 From: ZhenLian Date: Wed, 29 Apr 2020 11:00:02 -0700 Subject: [PATCH] advancedtls: add leaf cert in verify params (#3571) --- security/advancedtls/advancedtls.go | 7 +++ security/advancedtls/advancedtls_test.go | 58 ++++++++++++++++-------- 2 files changed, 46 insertions(+), 19 deletions(-) diff --git a/security/advancedtls/advancedtls.go b/security/advancedtls/advancedtls.go index 529ecc76..db78a566 100644 --- a/security/advancedtls/advancedtls.go +++ b/security/advancedtls/advancedtls.go @@ -46,6 +46,10 @@ type VerificationFuncParams struct { // The verification chain obtained by checking peer RawCerts against the // trust certificate bundle(s), if applicable. VerifiedChains [][]*x509.Certificate + // The leaf certificate sent from peer, if choosing to verify the peer + // certificate(s) and that verification passed. This field would be nil if + // either user chose not to verify or the verification failed. + Leaf *x509.Certificate } // VerificationResults contains the information about results of @@ -313,6 +317,7 @@ func buildVerifyFunc(c *advancedTLSCreds, rawConn net.Conn) func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { return func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { chains := verifiedChains + var leafCert *x509.Certificate if c.vType == CertAndHostVerification || c.vType == CertVerification { // perform possible trust credential reloading and certificate check rootCAs := c.config.RootCAs @@ -361,6 +366,7 @@ func buildVerifyFunc(c *advancedTLSCreds, if err != nil { return err } + leafCert = certs[0] } // Perform custom verification check if specified. if c.verifyFunc != nil { @@ -368,6 +374,7 @@ func buildVerifyFunc(c *advancedTLSCreds, ServerName: serverName, RawCerts: rawCerts, VerifiedChains: chains, + Leaf: leafCert, }) return err } diff --git a/security/advancedtls/advancedtls_test.go b/security/advancedtls/advancedtls_test.go index 8dea2e49..ab6ba590 100644 --- a/security/advancedtls/advancedtls_test.go +++ b/security/advancedtls/advancedtls_test.go @@ -23,6 +23,7 @@ import ( "crypto/tls" "crypto/x509" "encoding/pem" + "errors" "fmt" "io/ioutil" "net" @@ -43,7 +44,15 @@ func TestClientServerHandshake(t *testing.T) { getRootCAsForClient := func(params *GetRootCAsParams) (*GetRootCAsResults, error) { return &GetRootCAsResults{TrustCerts: clientTrustPool}, nil } - verifyFuncGood := func(params *VerificationFuncParams) (*VerificationResults, error) { + clientVerifyFuncGood := func(params *VerificationFuncParams) (*VerificationResults, error) { + if params.ServerName == "" { + return nil, errors.New("client side server name should have a value") + } + // "foo.bar.com" is the common name on server certificate server_cert_1.pem. + if len(params.VerifiedChains) > 0 && (params.Leaf == nil || params.Leaf.Subject.CommonName != "foo.bar.com") { + return nil, errors.New("client side params parsing error") + } + return &VerificationResults{}, nil } verifyFuncBad := func(params *VerificationFuncParams) (*VerificationResults, error) { @@ -62,6 +71,17 @@ func TestClientServerHandshake(t *testing.T) { getRootCAsForServer := func(params *GetRootCAsParams) (*GetRootCAsResults, error) { return &GetRootCAsResults{TrustCerts: serverTrustPool}, nil } + serverVerifyFunc := func(params *VerificationFuncParams) (*VerificationResults, error) { + if params.ServerName != "" { + return nil, errors.New("server side server name should not have a value") + } + // "foo.bar.hoo.com" is the common name on client certificate client_cert_1.pem. + if len(params.VerifiedChains) > 0 && (params.Leaf == nil || params.Leaf.Subject.CommonName != "foo.bar.hoo.com") { + return nil, errors.New("server side params parsing error") + } + + return &VerificationResults{}, nil + } serverPeerCert, err := tls.LoadX509KeyPair(testdata.Path("server_cert_1.pem"), testdata.Path("server_key_1.pem")) if err != nil { @@ -111,7 +131,7 @@ func TestClientServerHandshake(t *testing.T) { // if either clientCert or clientGetCert is not set { desc: "Client has no trust cert with verifyFuncGood; server sends peer cert", - clientVerifyFunc: verifyFuncGood, + clientVerifyFunc: clientVerifyFuncGood, clientVType: SkipVerification, serverCert: []tls.Certificate{serverPeerCert}, serverVType: CertAndHostVerification, @@ -152,7 +172,7 @@ func TestClientServerHandshake(t *testing.T) { { desc: "Client sets reload root function with verifyFuncGood; server sends peer cert", clientGetRoot: getRootCAsForClient, - clientVerifyFunc: verifyFuncGood, + clientVerifyFunc: clientVerifyFuncGood, clientVType: CertVerification, serverCert: []tls.Certificate{serverPeerCert}, serverVType: CertAndHostVerification, @@ -178,7 +198,7 @@ func TestClientServerHandshake(t *testing.T) { { desc: "Client sets reload root function with verifyFuncGood; server sets nil", clientGetRoot: getRootCAsForClient, - clientVerifyFunc: verifyFuncGood, + clientVerifyFunc: clientVerifyFuncGood, clientVType: CertVerification, serverVType: CertVerification, serverExpectError: true, @@ -190,7 +210,7 @@ func TestClientServerHandshake(t *testing.T) { desc: "Client sets peer cert, reload root function with verifyFuncGood; server sets peer cert and root cert; mutualTLS", clientCert: []tls.Certificate{clientPeerCert}, clientGetRoot: getRootCAsForClient, - clientVerifyFunc: verifyFuncGood, + clientVerifyFunc: clientVerifyFuncGood, clientVType: CertVerification, serverMutualTLS: true, serverCert: []tls.Certificate{serverPeerCert}, @@ -208,7 +228,7 @@ func TestClientServerHandshake(t *testing.T) { desc: "Client sets peer cert, reload root function with verifyFuncGood; server sets no verification; mutualTLS", clientCert: []tls.Certificate{clientPeerCert}, clientGetRoot: getRootCAsForClient, - clientVerifyFunc: verifyFuncGood, + clientVerifyFunc: clientVerifyFuncGood, clientVType: CertVerification, clientExpectHandshakeError: true, serverMutualTLS: true, @@ -223,7 +243,7 @@ func TestClientServerHandshake(t *testing.T) { desc: "Client sets peer cert, reload root function with verifyFuncGood; Server sets peer cert, reload root function; mutualTLS", clientCert: []tls.Certificate{clientPeerCert}, clientGetRoot: getRootCAsForClient, - clientVerifyFunc: verifyFuncGood, + clientVerifyFunc: clientVerifyFuncGood, clientVType: CertVerification, serverMutualTLS: true, serverCert: []tls.Certificate{serverPeerCert}, @@ -239,7 +259,7 @@ func TestClientServerHandshake(t *testing.T) { desc: "Client sets peer cert, reload root function with verifyFuncGood; Server sets peer cert, bad reload root function; mutualTLS", clientCert: []tls.Certificate{clientPeerCert}, clientGetRoot: getRootCAsForClient, - clientVerifyFunc: verifyFuncGood, + clientVerifyFunc: clientVerifyFuncGood, clientVType: CertVerification, serverMutualTLS: true, serverCert: []tls.Certificate{serverPeerCert}, @@ -256,14 +276,14 @@ func TestClientServerHandshake(t *testing.T) { return &clientPeerCert, nil }, clientGetRoot: getRootCAsForClient, - clientVerifyFunc: verifyFuncGood, + clientVerifyFunc: clientVerifyFuncGood, clientVType: CertVerification, serverMutualTLS: true, serverGetCert: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { return &serverPeerCert, nil }, serverGetRoot: getRootCAsForServer, - serverVerifyFunc: verifyFuncGood, + serverVerifyFunc: serverVerifyFunc, serverVType: CertVerification, }, // Client: set everything but with the wrong peer cert not trusted by @@ -277,14 +297,14 @@ func TestClientServerHandshake(t *testing.T) { return &serverPeerCert, nil }, clientGetRoot: getRootCAsForClient, - clientVerifyFunc: verifyFuncGood, + clientVerifyFunc: clientVerifyFuncGood, clientVType: CertVerification, serverMutualTLS: true, serverGetCert: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { return &serverPeerCert, nil }, serverGetRoot: getRootCAsForServer, - serverVerifyFunc: verifyFuncGood, + serverVerifyFunc: serverVerifyFunc, serverVType: CertVerification, serverExpectError: true, }, @@ -298,7 +318,7 @@ func TestClientServerHandshake(t *testing.T) { return &clientPeerCert, nil }, clientGetRoot: getRootCAsForServer, - clientVerifyFunc: verifyFuncGood, + clientVerifyFunc: clientVerifyFuncGood, clientVType: CertVerification, clientExpectHandshakeError: true, serverMutualTLS: true, @@ -306,7 +326,7 @@ func TestClientServerHandshake(t *testing.T) { return &serverPeerCert, nil }, serverGetRoot: getRootCAsForServer, - serverVerifyFunc: verifyFuncGood, + serverVerifyFunc: serverVerifyFunc, serverVType: CertVerification, serverExpectError: true, }, @@ -321,14 +341,14 @@ func TestClientServerHandshake(t *testing.T) { return &clientPeerCert, nil }, clientGetRoot: getRootCAsForClient, - clientVerifyFunc: verifyFuncGood, + clientVerifyFunc: clientVerifyFuncGood, clientVType: CertVerification, serverMutualTLS: true, serverGetCert: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { return &clientPeerCert, nil }, serverGetRoot: getRootCAsForServer, - serverVerifyFunc: verifyFuncGood, + serverVerifyFunc: serverVerifyFunc, serverVType: CertVerification, serverExpectError: true, }, @@ -342,7 +362,7 @@ func TestClientServerHandshake(t *testing.T) { return &clientPeerCert, nil }, clientGetRoot: getRootCAsForClient, - clientVerifyFunc: verifyFuncGood, + clientVerifyFunc: clientVerifyFuncGood, clientVType: CertVerification, clientExpectHandshakeError: true, serverMutualTLS: true, @@ -350,7 +370,7 @@ func TestClientServerHandshake(t *testing.T) { return &serverPeerCert, nil }, serverGetRoot: getRootCAsForClient, - serverVerifyFunc: verifyFuncGood, + serverVerifyFunc: serverVerifyFunc, serverVType: CertVerification, serverExpectError: true, }, @@ -362,7 +382,7 @@ func TestClientServerHandshake(t *testing.T) { desc: "Client sets peer cert, reload root function with verifyFuncGood; Server sets bad custom check; mutualTLS", clientCert: []tls.Certificate{clientPeerCert}, clientGetRoot: getRootCAsForClient, - clientVerifyFunc: verifyFuncGood, + clientVerifyFunc: clientVerifyFuncGood, clientVType: CertVerification, clientExpectHandshakeError: true, serverMutualTLS: true,