From f313ade84ce22334b1411af48361ab7fcf1c91df Mon Sep 17 00:00:00 2001 From: ZhenLian Date: Wed, 22 Apr 2020 13:44:18 -0700 Subject: [PATCH] advancedtls: add fine-grained verification levels in XXXOptions (#3454) --- security/advancedtls/advancedtls.go | 314 ++++++++++-------- .../advancedtls_integration_test.go | 168 +++++++--- security/advancedtls/advancedtls_test.go | 143 ++++++-- 3 files changed, 421 insertions(+), 204 deletions(-) diff --git a/security/advancedtls/advancedtls.go b/security/advancedtls/advancedtls.go index eec489c0..529ecc76 100644 --- a/security/advancedtls/advancedtls.go +++ b/security/advancedtls/advancedtls.go @@ -17,8 +17,8 @@ */ // Package advancedtls is a utility library containing functions to construct -// credentials.TransportCredentials that can perform credential reloading and custom -// server authorization. +// credentials.TransportCredentials that can perform credential reloading and +// custom verification check. package advancedtls import ( @@ -33,98 +33,155 @@ import ( "google.golang.org/grpc/credentials" ) -// VerificationFuncParams contains the parameters available to users when implementing CustomVerificationFunc. +// VerificationFuncParams contains parameters available to users when +// implementing CustomVerificationFunc. +// The fields in this struct are read-only. type VerificationFuncParams struct { - ServerName string - RawCerts [][]byte + // The target server name that the client connects to when establishing the + // connection. This field is only meaningful for client side. On server side, + // this field would be an empty string. + ServerName string + // The raw certificates sent from peer. + RawCerts [][]byte + // The verification chain obtained by checking peer RawCerts against the + // trust certificate bundle(s), if applicable. VerifiedChains [][]*x509.Certificate } -// VerificationResults contains the information about results of CustomVerificationFunc. -// VerificationResults is an empty struct for now. It may be extended in the future to include more information. +// VerificationResults contains the information about results of +// CustomVerificationFunc. +// VerificationResults is an empty struct for now. It may be extended in the +// future to include more information. type VerificationResults struct{} -// CustomVerificationFunc is the function defined by users to perform custom server authorization. -// CustomVerificationFunc returns nil if the authorization fails; otherwise returns an empty struct. +// CustomVerificationFunc is the function defined by users to perform custom +// verification check. +// CustomVerificationFunc returns nil if the authorization fails; otherwise +// returns an empty struct. type CustomVerificationFunc func(params *VerificationFuncParams) (*VerificationResults, error) -// GetRootCAsParams contains the parameters available to users when implementing GetRootCAs. +// GetRootCAsParams contains the parameters available to users when +// implementing GetRootCAs. type GetRootCAsParams struct { RawConn net.Conn RawCerts [][]byte } // GetRootCAsResults contains the results of GetRootCAs. -// If users want to reload the root trust certificate, it is required to return the proper TrustCerts in GetRootCAs. +// If users want to reload the root trust certificate, it is required to return +// the proper TrustCerts in GetRootCAs. type GetRootCAsResults struct { TrustCerts *x509.CertPool } -// RootCertificateOptions contains a field and a function for obtaining root trust certificates. -// It is used by both ClientOptions and ServerOptions. Note that RootCertificateOptions is required -// to be correctly set on client side; on server side, it is only required when mutual TLS is -// enabled(RequireClientCert in ServerOptions is true). +// RootCertificateOptions contains a field and a function for obtaining root +// trust certificates. +// It is used by both ClientOptions and ServerOptions. type RootCertificateOptions struct { - // If field RootCACerts is set, field GetRootCAs will be ignored. RootCACerts will be used - // every time when verifying the peer certificates, without performing root certificate reloading. + // If field RootCACerts is set, field GetRootCAs will be ignored. RootCACerts + // will be used every time when verifying the peer certificates, without + // performing root certificate reloading. RootCACerts *x509.CertPool - // If GetRootCAs is set and RootCACerts is nil, GetRootCAs will be invoked every time - // asked to check certificates sent from the server when a new connection is established. + // If GetRootCAs is set and RootCACerts is nil, GetRootCAs will be invoked + // every time asked to check certificates sent from the server when a new + // connection is established. // This is known as root CA certificate reloading. GetRootCAs func(params *GetRootCAsParams) (*GetRootCAsResults, error) } -// ClientOptions contains all the fields and functions needed to be filled by the client. +// VerificationType is the enum type that represents different levels of +// verification users could set, both on client side and on server side. +type VerificationType int + +const ( + // CertAndHostVerification indicates doing both certificate signature check + // and hostname check. + CertAndHostVerification VerificationType = iota + // CertVerification indicates doing certificate signature check only. Setting + // this field without proper custom verification check would leave the + // application susceptible to the MITM attack. + CertVerification + // SkipVerification indicates skipping both certificate signature check and + // hostname check. If setting this field, proper custom verification needs to + // be implemented in order to complete the authentication. Setting this field + // with a nil custom verification would raise an error. + SkipVerification +) + +// ClientOptions contains all the fields and functions needed to be filled by +// the client. // General rules for certificate setting on client side: -// Certificates or GetClientCertificate indicates the certificates sent from the client to the -// server to prove client's identities. The rules for setting these two fields are: +// Certificates or GetClientCertificate indicates the certificates sent from +// the client to the server to prove client's identities. The rules for setting +// these two fields are: // If requiring mutual authentication on server side: -// Either Certificates or GetClientCertificate must be set; the other will be ignored +// Either Certificates or GetClientCertificate must be set; the other will +// be ignored. // Otherwise: -// Nothing needed(the two fields will be ignored) +// Nothing needed(the two fields will be ignored). type ClientOptions struct { - // If field Certificates is set, field GetClientCertificate will be ignored. The client will use - // Certificates every time when asked for a certificate, without performing certificate reloading. + // If field Certificates is set, field GetClientCertificate will be ignored. + // The client will use Certificates every time when asked for a certificate, + // without performing certificate reloading. Certificates []tls.Certificate - // If GetClientCertificate is set and Certificates is nil, the client will invoke this - // function every time asked to present certificates to the server when a new connection is - // established. This is known as peer certificate reloading. + // If GetClientCertificate is set and Certificates is nil, the client will + // invoke this function every time asked to present certificates to the + // server when a new connection is established. This is known as peer + // certificate reloading. GetClientCertificate func(*tls.CertificateRequestInfo) (*tls.Certificate, error) - // VerifyPeer is a custom server authorization checking after certificate signature check. - // If this is set, we will replace the hostname check with this customized authorization check. - // If this is nil, we fall back to typical hostname check. + // VerifyPeer is a custom verification check after certificate signature + // check. + // If this is set, we will perform this customized check after doing the + // normal check(s) indicated by setting VType. VerifyPeer CustomVerificationFunc // ServerNameOverride is for testing only. If set to a non-empty string, - // it will override the virtual host name of authority (e.g. :authority header field) in requests. + // it will override the virtual host name of authority (e.g. :authority + // header field) in requests. ServerNameOverride string + // RootCertificateOptions is REQUIRED to be correctly set on client side. RootCertificateOptions + // VType is the verification type on the client side. + VType VerificationType } -// ServerOptions contains all the fields and functions needed to be filled by the client. +// ServerOptions contains all the fields and functions needed to be filled by +// the client. // General rules for certificate setting on server side: -// Certificates or GetClientCertificate indicates the certificates sent from the server to -// the client to prove server's identities. The rules for setting these two fields are: -// Either Certificates or GetCertificate must be set; the other will be ignored +// Certificates or GetClientCertificate indicates the certificates sent from +// the server to the client to prove server's identities. The rules for setting +// these two fields are: +// Either Certificates or GetCertificate must be set; the other will be ignored. type ServerOptions struct { - // If field Certificates is set, field GetClientCertificate will be ignored. The server will use - // Certificates every time when asked for a certificate, without performing certificate reloading. + // If field Certificates is set, field GetClientCertificate will be ignored. + // The server will use Certificates every time when asked for a certificate, + // without performing certificate reloading. Certificates []tls.Certificate - // If GetClientCertificate is set and Certificates is nil, the server will invoke this - // function every time asked to present certificates to the client when a new connection is - // established. This is known as peer certificate reloading. + // If GetClientCertificate is set and Certificates is nil, the server will + // invoke this function every time asked to present certificates to the + // client when a new connection is established. This is known as peer + // certificate reloading. GetCertificate func(*tls.ClientHelloInfo) (*tls.Certificate, error) + // VerifyPeer is a custom verification check after certificate signature + // check. + // If this is set, we will perform this customized check after doing the + // normal check(s) indicated by setting VType. + VerifyPeer CustomVerificationFunc + // RootCertificateOptions is only required when mutual TLS is + // enabled(RequireClientCert is true). RootCertificateOptions // If the server want the client to send certificates. RequireClientCert bool + // VType is the verification type on the server side. + VType VerificationType } func (o *ClientOptions) config() (*tls.Config, error) { - if o.RootCACerts == nil && o.GetRootCAs == nil && o.VerifyPeer == nil { + if o.VType == SkipVerification && o.VerifyPeer == nil { return nil, fmt.Errorf( - "client needs to provide root CA certs, or a custom verification function") + "client needs to provide custom verification mechanism if choose to skip default verification") } - // We have to set InsecureSkipVerify to true to skip the default checks and use the - // verification function we built from buildVerifyFunc. + // We have to set InsecureSkipVerify to true to skip the default checks and + // use the verification function we built from buildVerifyFunc. config := &tls.Config{ ServerName: o.ServerNameOverride, Certificates: o.Certificates, @@ -139,21 +196,16 @@ func (o *ServerOptions) config() (*tls.Config, error) { if o.Certificates == nil && o.GetCertificate == nil { return nil, fmt.Errorf("either Certificates or GetCertificate must be specified") } - if o.RequireClientCert && o.GetRootCAs == nil && o.RootCACerts == nil { - return nil, fmt.Errorf("server needs to provide root CA certs if requiring client cert") + if o.RequireClientCert && o.VType == SkipVerification && o.VerifyPeer == nil { + return nil, fmt.Errorf( + "server needs to provide custom verification mechanism if choose to skip default verification, but require client certificate(s)") } clientAuth := tls.NoClientCert if o.RequireClientCert { - // We fall back to normal config settings if users don't need to reload root certificates. - // If using RequireAndVerifyClientCert, the underlying stack would use the default - // checking and ignore the verification function we built from buildVerifyFunc. - // If using RequireAnyClientCert, the code would skip all the checks and use the - // function from buildVerifyFunc. - if o.RootCACerts != nil { - clientAuth = tls.RequireAndVerifyClientCert - } else { - clientAuth = tls.RequireAnyClientCert - } + // We have to set clientAuth to RequireAnyClientCert to force underlying + // TLS package to use the verification function we built from + // buildVerifyFunc. + clientAuth = tls.RequireAnyClientCert } config := &tls.Config{ ClientAuth: clientAuth, @@ -166,12 +218,14 @@ func (o *ServerOptions) config() (*tls.Config, error) { return config, nil } -// advancedTLSCreds is the credentials required for authenticating a connection using TLS. +// advancedTLSCreds is the credentials required for authenticating a connection +// using TLS. type advancedTLSCreds struct { config *tls.Config verifyFunc CustomVerificationFunc getRootCAs func(params *GetRootCAsParams) (*GetRootCAsResults, error) isClient bool + vType VerificationType } func (c advancedTLSCreds) Info() credentials.ProtocolInfo { @@ -218,10 +272,7 @@ func (c *advancedTLSCreds) ClientHandshake(ctx context.Context, authority string func (c *advancedTLSCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { cfg := cloneTLSConfig(c.config) - // We build server side verification function only when root cert reloading is needed. - if c.getRootCAs != nil { - cfg.VerifyPeerCertificate = buildVerifyFunc(c, "", rawConn) - } + cfg.VerifyPeerCertificate = buildVerifyFunc(c, "", rawConn) conn := tls.Server(rawConn, cfg) if err := conn.Handshake(); err != nil { conn.Close() @@ -250,91 +301,82 @@ func (c *advancedTLSCreds) OverrideServerName(serverNameOverride string) error { return nil } -// The function buildVerifyFunc is used when users want root cert reloading, and possibly custom -// server authorization check. -// We have to build our own verification function here because current tls module: -// 1. does not have a good support on root cert reloading -// 2. will ignore basic certificate check when setting InsecureSkipVerify to true +// The function buildVerifyFunc is used when users want root cert reloading, +// and possibly custom verification check. +// We have to build our own verification function here because current +// tls module: +// 1. does not have a good support on root cert reloading. +// 2. will ignore basic certificate check when setting InsecureSkipVerify +// to true. func buildVerifyFunc(c *advancedTLSCreds, serverName string, rawConn net.Conn) func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { return func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { - // If users didn't specify either rootCAs or getRootCAs on client side, - // as we see some use cases such as https://github.com/grpc/grpc/pull/20530, - // instead of failing, we just don't validate the server cert and let - // application decide via VerifyPeer - if c.isClient && c.config.RootCAs == nil && c.getRootCAs == nil { - if c.verifyFunc != nil { - _, err := c.verifyFunc(&VerificationFuncParams{ - ServerName: serverName, - RawCerts: rawCerts, - VerifiedChains: verifiedChains, + chains := verifiedChains + if c.vType == CertAndHostVerification || c.vType == CertVerification { + // perform possible trust credential reloading and certificate check + rootCAs := c.config.RootCAs + if !c.isClient { + rootCAs = c.config.ClientCAs + } + // Reload root CA certs. + if rootCAs == nil && c.getRootCAs != nil { + results, err := c.getRootCAs(&GetRootCAsParams{ + RawConn: rawConn, + RawCerts: rawCerts, }) + if err != nil { + return err + } + rootCAs = results.TrustCerts + } + // Verify peers' certificates against RootCAs and get verifiedChains. + certs := make([]*x509.Certificate, len(rawCerts)) + for i, asn1Data := range rawCerts { + cert, err := x509.ParseCertificate(asn1Data) + if err != nil { + return err + } + certs[i] = cert + } + keyUsages := []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth} + if !c.isClient { + keyUsages = []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth} + } + opts := x509.VerifyOptions{ + Roots: rootCAs, + CurrentTime: time.Now(), + Intermediates: x509.NewCertPool(), + KeyUsages: keyUsages, + } + for _, cert := range certs[1:] { + opts.Intermediates.AddCert(cert) + } + // Perform default hostname check if specified. + if c.isClient && c.vType == CertAndHostVerification && serverName != "" { + opts.DNSName = serverName + } + var err error + chains, err = certs[0].Verify(opts) + if err != nil { return err } } - var rootCAs *x509.CertPool - if c.isClient { - rootCAs = c.config.RootCAs - } else { - rootCAs = c.config.ClientCAs - } - // reload root CA certs - if rootCAs == nil && c.getRootCAs != nil { - results, err := c.getRootCAs(&GetRootCAsParams{ - RawConn: rawConn, - RawCerts: rawCerts, + // Perform custom verification check if specified. + if c.verifyFunc != nil { + _, err := c.verifyFunc(&VerificationFuncParams{ + ServerName: serverName, + RawCerts: rawCerts, + VerifiedChains: chains, }) - if err != nil { - return err - } - rootCAs = results.TrustCerts - } - // verify peers' certificates against RootCAs and get verifiedChains - certs := make([]*x509.Certificate, len(rawCerts)) - for i, asn1Data := range rawCerts { - cert, err := x509.ParseCertificate(asn1Data) - if err != nil { - return err - } - certs[i] = cert - } - opts := x509.VerifyOptions{ - Roots: rootCAs, - CurrentTime: time.Now(), - Intermediates: x509.NewCertPool(), - } - if !c.isClient { - opts.KeyUsages = []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth} - } else { - opts.KeyUsages = []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth} - } - for _, cert := range certs[1:] { - opts.Intermediates.AddCert(cert) - } - // We use default hostname check if users don't specify verifyFunc function - if c.isClient && c.verifyFunc == nil && serverName != "" { - opts.DNSName = serverName - } - verifiedChains, err := certs[0].Verify(opts) - if err != nil { return err } - if c.isClient && c.verifyFunc != nil { - if c.verifyFunc != nil { - _, err := c.verifyFunc(&VerificationFuncParams{ - ServerName: serverName, - RawCerts: rawCerts, - VerifiedChains: verifiedChains, - }) - return err - } - } return nil } } -// NewClientCreds uses ClientOptions to construct a TransportCredentials based on TLS. +// NewClientCreds uses ClientOptions to construct a TransportCredentials based +// on TLS. func NewClientCreds(o *ClientOptions) (credentials.TransportCredentials, error) { conf, err := o.config() if err != nil { @@ -345,12 +387,14 @@ func NewClientCreds(o *ClientOptions) (credentials.TransportCredentials, error) isClient: true, getRootCAs: o.GetRootCAs, verifyFunc: o.VerifyPeer, + vType: o.VType, } tc.config.NextProtos = appendH2ToNextProtos(tc.config.NextProtos) return tc, nil } -// NewServerCreds uses ServerOptions to construct a TransportCredentials based on TLS. +// NewServerCreds uses ServerOptions to construct a TransportCredentials based +// on TLS. func NewServerCreds(o *ServerOptions) (credentials.TransportCredentials, error) { conf, err := o.config() if err != nil { @@ -360,6 +404,8 @@ func NewServerCreds(o *ServerOptions) (credentials.TransportCredentials, error) config: conf, isClient: false, getRootCAs: o.GetRootCAs, + verifyFunc: o.VerifyPeer, + vType: o.VType, } tc.config.NextProtos = appendH2ToNextProtos(tc.config.NextProtos) return tc, nil diff --git a/security/advancedtls/advancedtls_integration_test.go b/security/advancedtls/advancedtls_integration_test.go index 0a7efe43..c259e114 100644 --- a/security/advancedtls/advancedtls_integration_test.go +++ b/security/advancedtls/advancedtls_integration_test.go @@ -39,9 +39,11 @@ var ( port = ":50051" ) -// stageInfo contains a stage number indicating the current phase of each integration test, and a mutex. -// Based on the stage number of current test, we will use different certificates and server authorization -// functions to check if our tests behave as expected. +// stageInfo contains a stage number indicating the current phase of each +// integration test, and a mutex. +// Based on the stage number of current test, we will use different +// certificates and custom verification functions to check if our tests behave +// as expected. type stageInfo struct { mutex sync.Mutex stage int @@ -67,13 +69,17 @@ func (s *stageInfo) reset() { // certStore contains all the certificates used in the integration tests. type certStore struct { - // clientPeer1 is the certificate sent by client to prove its identity. It is trusted by serverTrust1. + // clientPeer1 is the certificate sent by client to prove its identity. + // It is trusted by serverTrust1. clientPeer1 tls.Certificate - // clientPeer2 is the certificate sent by client to prove its identity. It is trusted by serverTrust2. + // clientPeer2 is the certificate sent by client to prove its identity. + // It is trusted by serverTrust2. clientPeer2 tls.Certificate - // serverPeer1 is the certificate sent by server to prove its identity. It is trusted by clientTrust1. + // serverPeer1 is the certificate sent by server to prove its identity. + // It is trusted by clientTrust1. serverPeer1 tls.Certificate - // serverPeer2 is the certificate sent by server to prove its identity. It is trusted by clientTrust2. + // serverPeer2 is the certificate sent by server to prove its identity. + // It is trusted by clientTrust2. serverPeer2 tls.Certificate clientTrust1 *x509.CertPool clientTrust2 *x509.CertPool @@ -81,7 +87,8 @@ type certStore struct { serverTrust2 *x509.CertPool } -// loadCerts function is used to load test certificates at the beginning of each integration test. +// loadCerts function is used to load test certificates at the beginning of +// each integration test. func (cs *certStore) loadCerts() error { var err error cs.clientPeer1, err = tls.LoadX509KeyPair(testdata.Path("client_cert_1.pem"), @@ -144,7 +151,8 @@ func callAndVerify(msg string, client pb.GreeterClient, shouldFail bool) error { func callAndVerifyWithClientConn(connCtx context.Context, msg string, creds credentials.TransportCredentials, shouldFail bool) (*grpc.ClientConn, pb.GreeterClient, error) { var conn *grpc.ClientConn var err error - // If we want the test to fail, we establish a non-blocking connection to avoid it hangs and killed by the context. + // If we want the test to fail, we establish a non-blocking connection to + // avoid it hangs and killed by the context. if shouldFail { conn, err = grpc.DialContext(connCtx, address, grpc.WithTransportCredentials(creds)) if err != nil { @@ -166,10 +174,13 @@ func callAndVerifyWithClientConn(connCtx context.Context, msg string, creds cred // The advanced TLS features are tested in different stages. // At stage 0, we establish a good connection between client and server. -// At stage 1, we change one factor(it could be we change the server's certificate, or server authorization function, etc), -// and test if the following connections would be dropped. -// At stage 2, we re-establish the connection by changing the counterpart of the factor we modified in stage 1. -// (could be change the client's trust certificate, or change server authorization function, etc) +// At stage 1, we change one factor(it could be we change the server's +// certificate, or custom verification function, etc), and test if the +// following connections would be dropped. +// At stage 2, we re-establish the connection by changing the counterpart of +// the factor we modified in stage 1. +// (could be change the client's trust certificate, or change custom +// verification function, etc) func TestEnd2End(t *testing.T) { cs := &certStore{} err := cs.loadCerts() @@ -184,17 +195,25 @@ func TestEnd2End(t *testing.T) { clientRoot *x509.CertPool clientGetRoot func(params *GetRootCAsParams) (*GetRootCAsResults, error) clientVerifyFunc CustomVerificationFunc + clientVType VerificationType serverCert []tls.Certificate serverGetCert func(*tls.ClientHelloInfo) (*tls.Certificate, error) serverRoot *x509.CertPool serverGetRoot func(params *GetRootCAsParams) (*GetRootCAsResults, error) + serverVerifyFunc CustomVerificationFunc + serverVType VerificationType }{ // Test Scenarios: - // At initialization(stage = 0), client will be initialized with cert clientPeer1 and clientTrust1, server with serverPeer1 and serverTrust1. - // The mutual authentication works at the beginning, since clientPeer1 is trusted by serverTrust1, and serverPeer1 by clientTrust1. - // At stage 1, client changes clientPeer1 to clientPeer2. Since clientPeer2 is not trusted by serverTrust1, following rpc calls are expected - // to fail, while the previous rpc calls are still good because those are already authenticated. - // At stage 2, the server changes serverTrust1 to serverTrust2, and we should see it again accepts the connection, since clientPeer2 is trusted + // At initialization(stage = 0), client will be initialized with cert + // clientPeer1 and clientTrust1, server with serverPeer1 and serverTrust1. + // The mutual authentication works at the beginning, since clientPeer1 is + // trusted by serverTrust1, and serverPeer1 by clientTrust1. + // At stage 1, client changes clientPeer1 to clientPeer2. Since clientPeer2 + // is not trusted by serverTrust1, following rpc calls are expected to + // fail, while the previous rpc calls are still good because those are + // already authenticated. + // At stage 2, the server changes serverTrust1 to serverTrust2, and we + // should see it again accepts the connection, since clientPeer2 is trusted // by serverTrust2. { desc: "TestClientPeerCertReloadServerTrustCertReload", @@ -212,6 +231,7 @@ func TestEnd2End(t *testing.T) { clientVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) { return &VerificationResults{}, nil }, + clientVType: CertVerification, serverCert: []tls.Certificate{cs.serverPeer1}, serverGetCert: nil, serverRoot: nil, @@ -223,13 +243,22 @@ func TestEnd2End(t *testing.T) { return &GetRootCAsResults{TrustCerts: cs.serverTrust2}, nil } }, + serverVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) { + return &VerificationResults{}, nil + }, + serverVType: CertVerification, }, // Test Scenarios: - // At initialization(stage = 0), client will be initialized with cert clientPeer1 and clientTrust1, server with serverPeer1 and serverTrust1. - // The mutual authentication works at the beginning, since clientPeer1 is trusted by serverTrust1, and serverPeer1 by clientTrust1. - // At stage 1, server changes serverPeer1 to serverPeer2. Since serverPeer2 is not trusted by clientTrust1, following rpc calls are expected - // to fail, while the previous rpc calls are still good because those are already authenticated. - // At stage 2, the client changes clientTrust1 to clientTrust2, and we should see it again accepts the connection, since serverPeer2 is trusted + // At initialization(stage = 0), client will be initialized with cert + // clientPeer1 and clientTrust1, server with serverPeer1 and serverTrust1. + // The mutual authentication works at the beginning, since clientPeer1 is + // trusted by serverTrust1, and serverPeer1 by clientTrust1. + // At stage 1, server changes serverPeer1 to serverPeer2. Since serverPeer2 + // is not trusted by clientTrust1, following rpc calls are expected to + // fail, while the previous rpc calls are still good because those are + // already authenticated. + // At stage 2, the client changes clientTrust1 to clientTrust2, and we + // should see it again accepts the connection, since serverPeer2 is trusted // by clientTrust2. { desc: "TestServerPeerCertReloadClientTrustCertReload", @@ -247,7 +276,8 @@ func TestEnd2End(t *testing.T) { clientVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) { return &VerificationResults{}, nil }, - serverCert: nil, + clientVType: CertVerification, + serverCert: nil, serverGetCert: func(*tls.ClientHelloInfo) (*tls.Certificate, error) { switch stage.read() { case 0: @@ -258,17 +288,26 @@ func TestEnd2End(t *testing.T) { }, serverRoot: cs.serverTrust1, serverGetRoot: nil, + serverVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) { + return &VerificationResults{}, nil + }, + serverVType: CertVerification, }, // Test Scenarios: - // At initialization(stage = 0), client will be initialized with cert clientPeer1 and clientTrust1, server with serverPeer1 and serverTrust1. - // The mutual authentication works at the beginning, since clientPeer1 trusted by serverTrust1, serverPeer1 by clientTrust1, and also the - // custom server authorization check allows the CommonName on serverPeer1. - // At stage 1, server changes serverPeer1 to serverPeer2, and client changes clientTrust1 to clientTrust2. Although serverPeer2 is trusted by - // clientTrust2, our authorization check only accepts serverPeer1, and hence the following calls should fail. Previous connections should + // At initialization(stage = 0), client will be initialized with cert + // clientPeer1 and clientTrust1, server with serverPeer1 and serverTrust1. + // The mutual authentication works at the beginning, since clientPeer1 + // trusted by serverTrust1, serverPeer1 by clientTrust1, and also the + // custom verification check allows the CommonName on serverPeer1. + // At stage 1, server changes serverPeer1 to serverPeer2, and client + // changes clientTrust1 to clientTrust2. Although serverPeer2 is trusted by + // clientTrust2, our authorization check only accepts serverPeer1, and + // hence the following calls should fail. Previous connections should // not be affected. - // At stage 2, the client changes authorization check to only accept serverPeer2. Now we should see the connection becomes normal again. + // At stage 2, the client changes authorization check to only accept + // serverPeer2. Now we should see the connection becomes normal again. { - desc: "TestClientCustomServerAuthz", + desc: "TestClientCustomVerification", clientCert: []tls.Certificate{cs.clientPeer1}, clientGetCert: nil, clientGetRoot: func(params *GetRootCAsParams) (*GetRootCAsResults, error) { @@ -306,7 +345,8 @@ func TestEnd2End(t *testing.T) { } return nil, fmt.Errorf("custom authz check fails") }, - serverCert: nil, + clientVType: CertVerification, + serverCert: nil, serverGetCert: func(*tls.ClientHelloInfo) (*tls.Certificate, error) { switch stage.read() { case 0: @@ -317,6 +357,47 @@ func TestEnd2End(t *testing.T) { }, serverRoot: cs.serverTrust1, serverGetRoot: nil, + serverVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) { + return &VerificationResults{}, nil + }, + serverVType: CertVerification, + }, + // Test Scenarios: + // At initialization(stage = 0), client will be initialized with cert + // clientPeer1 and clientTrust1, server with serverPeer1 and serverTrust1. + // The mutual authentication works at the beginning, since clientPeer1 + // trusted by serverTrust1, serverPeer1 by clientTrust1, and also the + // custom verification check on server side allows all connections. + // At stage 1, server disallows the the connections by setting custom + // verification check. The following calls should fail. Previous + // connections should not be affected. + // At stage 2, server allows all the connections again and the + // authentications should go back to normal. + { + desc: "TestServerCustomVerification", + clientCert: []tls.Certificate{cs.clientPeer1}, + clientGetCert: nil, + clientGetRoot: nil, + clientRoot: cs.clientTrust1, + clientVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) { + return &VerificationResults{}, nil + }, + clientVType: CertVerification, + serverCert: []tls.Certificate{cs.serverPeer1}, + serverGetCert: nil, + serverRoot: cs.serverTrust1, + serverGetRoot: nil, + serverVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) { + switch stage.read() { + case 0, 2: + return &VerificationResults{}, nil + case 1: + return nil, fmt.Errorf("custom authz check fails") + default: + return nil, fmt.Errorf("custom authz check fails") + } + }, + serverVType: CertVerification, }, } { test := test @@ -330,6 +411,8 @@ func TestEnd2End(t *testing.T) { GetRootCAs: test.serverGetRoot, }, RequireClientCert: true, + VerifyPeer: test.serverVerifyFunc, + VType: test.serverVType, } serverTLSCreds, err := NewServerCreds(serverOptions) if err != nil { @@ -356,49 +439,50 @@ func TestEnd2End(t *testing.T) { RootCACerts: test.clientRoot, GetRootCAs: test.clientGetRoot, }, + VType: test.clientVType, } clientTLSCreds, err := NewClientCreds(clientOptions) if err != nil { t.Fatalf("clientTLSCreds failed to create") } - // ------------------------Scenario 1----------------------------------------- + // ------------------------Scenario 1------------------------------------ // stage = 0, initial connection should succeed ctx1, cancel1 := context.WithTimeout(context.Background(), 5*time.Second) defer cancel1() conn, greetClient, err := callAndVerifyWithClientConn(ctx1, "rpc call 1", clientTLSCreds, false) - defer conn.Close() if err != nil { t.Fatal(err) } - // --------------------------------------------------------------------------- + defer conn.Close() + // ---------------------------------------------------------------------- stage.increase() - // ------------------------Scenario 2----------------------------------------- + // ------------------------Scenario 2------------------------------------ // stage = 1, previous connection should still succeed err = callAndVerify("rpc call 2", greetClient, false) if err != nil { t.Fatal(err) } - // ------------------------Scenario 3----------------------------------------- + // ------------------------Scenario 3------------------------------------ // stage = 1, new connection should fail ctx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second) defer cancel2() conn2, greetClient, err := callAndVerifyWithClientConn(ctx2, "rpc call 3", clientTLSCreds, true) - defer conn2.Close() if err != nil { t.Fatal(err) } - //// --------------------------------------------------------------------------- + defer conn2.Close() + //// -------------------------------------------------------------------- stage.increase() - // ------------------------Scenario 4----------------------------------------- + // ------------------------Scenario 4------------------------------------ // stage = 2, new connection should succeed ctx3, cancel3 := context.WithTimeout(context.Background(), 5*time.Second) defer cancel3() conn3, greetClient, err := callAndVerifyWithClientConn(ctx3, "rpc call 4", clientTLSCreds, false) - defer conn3.Close() if err != nil { t.Fatal(err) } - // --------------------------------------------------------------------------- + defer conn3.Close() + // ---------------------------------------------------------------------- stage.reset() }) } diff --git a/security/advancedtls/advancedtls_test.go b/security/advancedtls/advancedtls_test.go index 509befa0..6c1909c4 100644 --- a/security/advancedtls/advancedtls_test.go +++ b/security/advancedtls/advancedtls_test.go @@ -77,6 +77,7 @@ func TestClientServerHandshake(t *testing.T) { clientRoot *x509.CertPool clientGetRoot func(params *GetRootCAsParams) (*GetRootCAsResults, error) clientVerifyFunc CustomVerificationFunc + clientVType VerificationType clientExpectCreateError bool clientExpectHandshakeError bool serverMutualTLS bool @@ -84,13 +85,17 @@ func TestClientServerHandshake(t *testing.T) { serverGetCert func(*tls.ClientHelloInfo) (*tls.Certificate, error) serverRoot *x509.CertPool serverGetRoot func(params *GetRootCAsParams) (*GetRootCAsResults, error) + serverVerifyFunc CustomVerificationFunc + serverVType VerificationType serverExpectError bool }{ // Client: nil setting // Server: only set serverCert with mutual TLS off // Expected Behavior: server side failure - // Reason: if either clientCert or clientGetClientCert is not set and - // verifyFunc is not set, we will fail directly + // Reason: if clientRoot, clientGetRoot and verifyFunc is not set, client + // side doesn't provide any verification mechanism. We don't allow this + // even setting vType to SkipVerification. Clients should at least provide + // their own verification logic. { "Client_no_trust_cert_Server_peer_cert", nil, @@ -98,6 +103,7 @@ func TestClientServerHandshake(t *testing.T) { nil, nil, nil, + SkipVerification, true, false, false, @@ -105,6 +111,8 @@ func TestClientServerHandshake(t *testing.T) { nil, nil, nil, + nil, + CertAndHostVerification, true, }, // Client: nil setting except verifyFuncGood @@ -119,6 +127,7 @@ func TestClientServerHandshake(t *testing.T) { nil, nil, verifyFuncGood, + SkipVerification, false, false, false, @@ -126,13 +135,16 @@ func TestClientServerHandshake(t *testing.T) { nil, nil, nil, + nil, + CertAndHostVerification, false, }, // Client: only set clientRoot // Server: only set serverCert with mutual TLS off // Expected Behavior: server side failure and client handshake failure - // Reason: not setting advanced TLS features will fall back to normal check, and will hence fail - // on default host name check. All the default hostname checks will fail in this test suites. + // Reason: client side sets vType to CertAndHostVerification, and will do + // default hostname check. All the default hostname checks will fail in + // this test suites. { "Client_root_cert_Server_peer_cert", nil, @@ -140,6 +152,7 @@ func TestClientServerHandshake(t *testing.T) { clientTrustPool, nil, nil, + CertAndHostVerification, false, true, false, @@ -147,13 +160,16 @@ func TestClientServerHandshake(t *testing.T) { nil, nil, nil, + nil, + CertAndHostVerification, true, }, // Client: only set clientGetRoot // Server: only set serverCert with mutual TLS off // Expected Behavior: server side failure and client handshake failure - // Reason: setting root reloading function without custom verifyFunc will also fail, - // since it will also fall back to default host name check + // Reason: client side sets vType to CertAndHostVerification, and will do + // default hostname check. All the default hostname checks will fail in + // this test suites. { "Client_reload_root_Server_peer_cert", nil, @@ -161,6 +177,7 @@ func TestClientServerHandshake(t *testing.T) { nil, getRootCAsForClient, nil, + CertAndHostVerification, false, true, false, @@ -168,6 +185,8 @@ func TestClientServerHandshake(t *testing.T) { nil, nil, nil, + nil, + CertAndHostVerification, true, }, // Client: set clientGetRoot and clientVerifyFunc @@ -180,6 +199,7 @@ func TestClientServerHandshake(t *testing.T) { nil, getRootCAsForClient, verifyFuncGood, + CertVerification, false, false, false, @@ -187,6 +207,8 @@ func TestClientServerHandshake(t *testing.T) { nil, nil, nil, + nil, + CertAndHostVerification, false, }, // Client: set clientGetRoot and bad clientVerifyFunc function @@ -200,6 +222,7 @@ func TestClientServerHandshake(t *testing.T) { nil, getRootCAsForClient, verifyFuncBad, + CertVerification, false, true, false, @@ -207,6 +230,8 @@ func TestClientServerHandshake(t *testing.T) { nil, nil, nil, + nil, + CertVerification, true, }, // Client: set clientGetRoot and clientVerifyFunc @@ -220,6 +245,7 @@ func TestClientServerHandshake(t *testing.T) { nil, getRootCAsForClient, verifyFuncGood, + CertVerification, false, false, false, @@ -227,26 +253,8 @@ func TestClientServerHandshake(t *testing.T) { nil, nil, nil, - true, - }, - // Client: set clientGetRoot and clientVerifyFunc - // Server: only set serverCert with mutual TLS on - // Expected Behavior: server side failure - // Reason: server side must either set serverRoot or serverGetRoot when using mutual TLS - { - "Client_reload_root_verifyFuncGood_Server_peer_cert_no_root_cert_mutualTLS", - nil, - nil, - nil, - getRootCAsForClient, - verifyFuncGood, - false, - false, - true, - []tls.Certificate{serverPeerCert}, - nil, - nil, nil, + CertVerification, true, }, // Client: set clientGetRoot, clientVerifyFunc and clientCert @@ -259,6 +267,7 @@ func TestClientServerHandshake(t *testing.T) { nil, getRootCAsForClient, verifyFuncGood, + CertVerification, false, false, true, @@ -266,9 +275,37 @@ func TestClientServerHandshake(t *testing.T) { nil, serverTrustPool, nil, + nil, + CertVerification, false, }, // Client: set clientGetRoot, clientVerifyFunc and clientCert + // Server: set serverCert, but not setting any of serverRoot, serverGetRoot + // or serverVerifyFunc, with mutual TLS on + // Expected Behavior: server side failure + // Reason: server side needs to provide any verification mechanism when + // mTLS in on, even setting vType to SkipVerification. Servers should at + // least provide their own verification logic. + { + "Client_peer_cert_reload_root_verifyFuncGood_Server_no_verification_mutualTLS", + []tls.Certificate{clientPeerCert}, + nil, + nil, + getRootCAsForClient, + verifyFuncGood, + CertVerification, + false, + true, + true, + []tls.Certificate{serverPeerCert}, + nil, + nil, + nil, + nil, + SkipVerification, + true, + }, + // Client: set clientGetRoot, clientVerifyFunc and clientCert // Server: set serverGetRoot and serverCert with mutual TLS on // Expected Behavior: success { @@ -278,6 +315,7 @@ func TestClientServerHandshake(t *testing.T) { nil, getRootCAsForClient, verifyFuncGood, + CertVerification, false, false, true, @@ -285,10 +323,13 @@ func TestClientServerHandshake(t *testing.T) { nil, nil, getRootCAsForServer, + nil, + CertVerification, false, }, // Client: set clientGetRoot, clientVerifyFunc and clientCert - // Server: set serverGetRoot returning error and serverCert with mutual TLS on + // Server: set serverGetRoot returning error and serverCert with mutual + // TLS on // Expected Behavior: server side failure // Reason: server side reloading returns failure { @@ -298,6 +339,7 @@ func TestClientServerHandshake(t *testing.T) { nil, getRootCAsForClient, verifyFuncGood, + CertVerification, false, false, true, @@ -305,6 +347,8 @@ func TestClientServerHandshake(t *testing.T) { nil, nil, getRootCAsForServerBad, + nil, + CertVerification, true, }, // Client: set clientGetRoot, clientVerifyFunc and clientGetClientCert @@ -319,6 +363,7 @@ func TestClientServerHandshake(t *testing.T) { nil, getRootCAsForClient, verifyFuncGood, + CertVerification, false, false, true, @@ -328,9 +373,12 @@ func TestClientServerHandshake(t *testing.T) { }, nil, getRootCAsForServer, + verifyFuncGood, + CertVerification, false, }, - // Client: set everything but with the wrong peer cert not trusted by server + // Client: set everything but with the wrong peer cert not trusted by + // server // Server: set serverGetRoot and serverGetCert with mutual TLS on // Expected Behavior: server side returns failure because of // certificate mismatch @@ -343,6 +391,7 @@ func TestClientServerHandshake(t *testing.T) { nil, getRootCAsForClient, verifyFuncGood, + CertVerification, false, false, true, @@ -352,6 +401,8 @@ func TestClientServerHandshake(t *testing.T) { }, nil, getRootCAsForServer, + verifyFuncGood, + CertVerification, true, }, // Client: set everything but with the wrong trust cert not trusting server @@ -367,6 +418,7 @@ func TestClientServerHandshake(t *testing.T) { nil, getRootCAsForServer, verifyFuncGood, + CertVerification, false, true, true, @@ -376,10 +428,13 @@ func TestClientServerHandshake(t *testing.T) { }, nil, getRootCAsForServer, + verifyFuncGood, + CertVerification, true, }, // Client: set clientGetRoot, clientVerifyFunc and clientCert - // Server: set everything but with the wrong peer cert not trusted by client + // Server: set everything but with the wrong peer cert not trusted by + // client // Expected Behavior: server side and client side return failure due to // certificate mismatch and handshake failure { @@ -391,6 +446,7 @@ func TestClientServerHandshake(t *testing.T) { nil, getRootCAsForClient, verifyFuncGood, + CertVerification, false, false, true, @@ -400,6 +456,8 @@ func TestClientServerHandshake(t *testing.T) { }, nil, getRootCAsForServer, + verifyFuncGood, + CertVerification, true, }, // Client: set clientGetRoot, clientVerifyFunc and clientCert @@ -415,6 +473,7 @@ func TestClientServerHandshake(t *testing.T) { nil, getRootCAsForClient, verifyFuncGood, + CertVerification, false, true, true, @@ -424,6 +483,31 @@ func TestClientServerHandshake(t *testing.T) { }, nil, getRootCAsForClient, + verifyFuncGood, + CertVerification, + true, + }, + // Client: set clientGetRoot, clientVerifyFunc and clientCert + // Server: set serverGetRoot and serverCert, but with bad verifyFunc + // Expected Behavior: server side and client side return failure due to + // server custom check fails + { + "Client_peer_cert_reload_root_verifyFuncGood_Server_bad_custom_verification_mutualTLS", + []tls.Certificate{clientPeerCert}, + nil, + nil, + getRootCAsForClient, + verifyFuncGood, + CertVerification, + false, + true, + true, + []tls.Certificate{serverPeerCert}, + nil, + nil, + getRootCAsForServer, + verifyFuncBad, + CertVerification, true, }, } { @@ -443,6 +527,8 @@ func TestClientServerHandshake(t *testing.T) { GetRootCAs: test.serverGetRoot, }, RequireClientCert: test.serverMutualTLS, + VerifyPeer: test.serverVerifyFunc, + VType: test.serverVType, } go func(done chan credentials.AuthInfo, lis net.Listener, serverOptions *ServerOptions) { serverRawConn, err := lis.Accept() @@ -480,6 +566,7 @@ func TestClientServerHandshake(t *testing.T) { RootCACerts: test.clientRoot, GetRootCAs: test.clientGetRoot, }, + VType: test.clientVType, } clientTLS, newClientErr := NewClientCreds(clientOptions) if newClientErr != nil && test.clientExpectCreateError {