advancedtls: add fine-grained verification levels in XXXOptions (#3454)

This commit is contained in:
ZhenLian
2020-04-22 13:44:18 -07:00
committed by GitHub
parent 843b06d549
commit f313ade84c
3 changed files with 421 additions and 204 deletions

View File

@ -17,8 +17,8 @@
*/ */
// Package advancedtls is a utility library containing functions to construct // Package advancedtls is a utility library containing functions to construct
// credentials.TransportCredentials that can perform credential reloading and custom // credentials.TransportCredentials that can perform credential reloading and
// server authorization. // custom verification check.
package advancedtls package advancedtls
import ( import (
@ -33,98 +33,155 @@ import (
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
) )
// VerificationFuncParams contains the parameters available to users when implementing CustomVerificationFunc. // VerificationFuncParams contains parameters available to users when
// implementing CustomVerificationFunc.
// The fields in this struct are read-only.
type VerificationFuncParams struct { type VerificationFuncParams struct {
ServerName string // The target server name that the client connects to when establishing the
RawCerts [][]byte // connection. This field is only meaningful for client side. On server side,
// this field would be an empty string.
ServerName string
// The raw certificates sent from peer.
RawCerts [][]byte
// The verification chain obtained by checking peer RawCerts against the
// trust certificate bundle(s), if applicable.
VerifiedChains [][]*x509.Certificate VerifiedChains [][]*x509.Certificate
} }
// VerificationResults contains the information about results of CustomVerificationFunc. // VerificationResults contains the information about results of
// VerificationResults is an empty struct for now. It may be extended in the future to include more information. // CustomVerificationFunc.
// VerificationResults is an empty struct for now. It may be extended in the
// future to include more information.
type VerificationResults struct{} type VerificationResults struct{}
// CustomVerificationFunc is the function defined by users to perform custom server authorization. // CustomVerificationFunc is the function defined by users to perform custom
// CustomVerificationFunc returns nil if the authorization fails; otherwise returns an empty struct. // verification check.
// CustomVerificationFunc returns nil if the authorization fails; otherwise
// returns an empty struct.
type CustomVerificationFunc func(params *VerificationFuncParams) (*VerificationResults, error) type CustomVerificationFunc func(params *VerificationFuncParams) (*VerificationResults, error)
// GetRootCAsParams contains the parameters available to users when implementing GetRootCAs. // GetRootCAsParams contains the parameters available to users when
// implementing GetRootCAs.
type GetRootCAsParams struct { type GetRootCAsParams struct {
RawConn net.Conn RawConn net.Conn
RawCerts [][]byte RawCerts [][]byte
} }
// GetRootCAsResults contains the results of GetRootCAs. // GetRootCAsResults contains the results of GetRootCAs.
// If users want to reload the root trust certificate, it is required to return the proper TrustCerts in GetRootCAs. // If users want to reload the root trust certificate, it is required to return
// the proper TrustCerts in GetRootCAs.
type GetRootCAsResults struct { type GetRootCAsResults struct {
TrustCerts *x509.CertPool TrustCerts *x509.CertPool
} }
// RootCertificateOptions contains a field and a function for obtaining root trust certificates. // RootCertificateOptions contains a field and a function for obtaining root
// It is used by both ClientOptions and ServerOptions. Note that RootCertificateOptions is required // trust certificates.
// to be correctly set on client side; on server side, it is only required when mutual TLS is // It is used by both ClientOptions and ServerOptions.
// enabled(RequireClientCert in ServerOptions is true).
type RootCertificateOptions struct { type RootCertificateOptions struct {
// If field RootCACerts is set, field GetRootCAs will be ignored. RootCACerts will be used // If field RootCACerts is set, field GetRootCAs will be ignored. RootCACerts
// every time when verifying the peer certificates, without performing root certificate reloading. // will be used every time when verifying the peer certificates, without
// performing root certificate reloading.
RootCACerts *x509.CertPool RootCACerts *x509.CertPool
// If GetRootCAs is set and RootCACerts is nil, GetRootCAs will be invoked every time // If GetRootCAs is set and RootCACerts is nil, GetRootCAs will be invoked
// asked to check certificates sent from the server when a new connection is established. // every time asked to check certificates sent from the server when a new
// connection is established.
// This is known as root CA certificate reloading. // This is known as root CA certificate reloading.
GetRootCAs func(params *GetRootCAsParams) (*GetRootCAsResults, error) GetRootCAs func(params *GetRootCAsParams) (*GetRootCAsResults, error)
} }
// ClientOptions contains all the fields and functions needed to be filled by the client. // VerificationType is the enum type that represents different levels of
// verification users could set, both on client side and on server side.
type VerificationType int
const (
// CertAndHostVerification indicates doing both certificate signature check
// and hostname check.
CertAndHostVerification VerificationType = iota
// CertVerification indicates doing certificate signature check only. Setting
// this field without proper custom verification check would leave the
// application susceptible to the MITM attack.
CertVerification
// SkipVerification indicates skipping both certificate signature check and
// hostname check. If setting this field, proper custom verification needs to
// be implemented in order to complete the authentication. Setting this field
// with a nil custom verification would raise an error.
SkipVerification
)
// ClientOptions contains all the fields and functions needed to be filled by
// the client.
// General rules for certificate setting on client side: // General rules for certificate setting on client side:
// Certificates or GetClientCertificate indicates the certificates sent from the client to the // Certificates or GetClientCertificate indicates the certificates sent from
// server to prove client's identities. The rules for setting these two fields are: // the client to the server to prove client's identities. The rules for setting
// these two fields are:
// If requiring mutual authentication on server side: // If requiring mutual authentication on server side:
// Either Certificates or GetClientCertificate must be set; the other will be ignored // Either Certificates or GetClientCertificate must be set; the other will
// be ignored.
// Otherwise: // Otherwise:
// Nothing needed(the two fields will be ignored) // Nothing needed(the two fields will be ignored).
type ClientOptions struct { type ClientOptions struct {
// If field Certificates is set, field GetClientCertificate will be ignored. The client will use // If field Certificates is set, field GetClientCertificate will be ignored.
// Certificates every time when asked for a certificate, without performing certificate reloading. // The client will use Certificates every time when asked for a certificate,
// without performing certificate reloading.
Certificates []tls.Certificate Certificates []tls.Certificate
// If GetClientCertificate is set and Certificates is nil, the client will invoke this // If GetClientCertificate is set and Certificates is nil, the client will
// function every time asked to present certificates to the server when a new connection is // invoke this function every time asked to present certificates to the
// established. This is known as peer certificate reloading. // server when a new connection is established. This is known as peer
// certificate reloading.
GetClientCertificate func(*tls.CertificateRequestInfo) (*tls.Certificate, error) GetClientCertificate func(*tls.CertificateRequestInfo) (*tls.Certificate, error)
// VerifyPeer is a custom server authorization checking after certificate signature check. // VerifyPeer is a custom verification check after certificate signature
// If this is set, we will replace the hostname check with this customized authorization check. // check.
// If this is nil, we fall back to typical hostname check. // If this is set, we will perform this customized check after doing the
// normal check(s) indicated by setting VType.
VerifyPeer CustomVerificationFunc VerifyPeer CustomVerificationFunc
// ServerNameOverride is for testing only. If set to a non-empty string, // ServerNameOverride is for testing only. If set to a non-empty string,
// it will override the virtual host name of authority (e.g. :authority header field) in requests. // it will override the virtual host name of authority (e.g. :authority
// header field) in requests.
ServerNameOverride string ServerNameOverride string
// RootCertificateOptions is REQUIRED to be correctly set on client side.
RootCertificateOptions RootCertificateOptions
// VType is the verification type on the client side.
VType VerificationType
} }
// ServerOptions contains all the fields and functions needed to be filled by the client. // ServerOptions contains all the fields and functions needed to be filled by
// the client.
// General rules for certificate setting on server side: // General rules for certificate setting on server side:
// Certificates or GetClientCertificate indicates the certificates sent from the server to // Certificates or GetClientCertificate indicates the certificates sent from
// the client to prove server's identities. The rules for setting these two fields are: // the server to the client to prove server's identities. The rules for setting
// Either Certificates or GetCertificate must be set; the other will be ignored // these two fields are:
// Either Certificates or GetCertificate must be set; the other will be ignored.
type ServerOptions struct { type ServerOptions struct {
// If field Certificates is set, field GetClientCertificate will be ignored. The server will use // If field Certificates is set, field GetClientCertificate will be ignored.
// Certificates every time when asked for a certificate, without performing certificate reloading. // The server will use Certificates every time when asked for a certificate,
// without performing certificate reloading.
Certificates []tls.Certificate Certificates []tls.Certificate
// If GetClientCertificate is set and Certificates is nil, the server will invoke this // If GetClientCertificate is set and Certificates is nil, the server will
// function every time asked to present certificates to the client when a new connection is // invoke this function every time asked to present certificates to the
// established. This is known as peer certificate reloading. // client when a new connection is established. This is known as peer
// certificate reloading.
GetCertificate func(*tls.ClientHelloInfo) (*tls.Certificate, error) GetCertificate func(*tls.ClientHelloInfo) (*tls.Certificate, error)
// VerifyPeer is a custom verification check after certificate signature
// check.
// If this is set, we will perform this customized check after doing the
// normal check(s) indicated by setting VType.
VerifyPeer CustomVerificationFunc
// RootCertificateOptions is only required when mutual TLS is
// enabled(RequireClientCert is true).
RootCertificateOptions RootCertificateOptions
// If the server want the client to send certificates. // If the server want the client to send certificates.
RequireClientCert bool RequireClientCert bool
// VType is the verification type on the server side.
VType VerificationType
} }
func (o *ClientOptions) config() (*tls.Config, error) { func (o *ClientOptions) config() (*tls.Config, error) {
if o.RootCACerts == nil && o.GetRootCAs == nil && o.VerifyPeer == nil { if o.VType == SkipVerification && o.VerifyPeer == nil {
return nil, fmt.Errorf( return nil, fmt.Errorf(
"client needs to provide root CA certs, or a custom verification function") "client needs to provide custom verification mechanism if choose to skip default verification")
} }
// We have to set InsecureSkipVerify to true to skip the default checks and use the // We have to set InsecureSkipVerify to true to skip the default checks and
// verification function we built from buildVerifyFunc. // use the verification function we built from buildVerifyFunc.
config := &tls.Config{ config := &tls.Config{
ServerName: o.ServerNameOverride, ServerName: o.ServerNameOverride,
Certificates: o.Certificates, Certificates: o.Certificates,
@ -139,21 +196,16 @@ func (o *ServerOptions) config() (*tls.Config, error) {
if o.Certificates == nil && o.GetCertificate == nil { if o.Certificates == nil && o.GetCertificate == nil {
return nil, fmt.Errorf("either Certificates or GetCertificate must be specified") return nil, fmt.Errorf("either Certificates or GetCertificate must be specified")
} }
if o.RequireClientCert && o.GetRootCAs == nil && o.RootCACerts == nil { if o.RequireClientCert && o.VType == SkipVerification && o.VerifyPeer == nil {
return nil, fmt.Errorf("server needs to provide root CA certs if requiring client cert") return nil, fmt.Errorf(
"server needs to provide custom verification mechanism if choose to skip default verification, but require client certificate(s)")
} }
clientAuth := tls.NoClientCert clientAuth := tls.NoClientCert
if o.RequireClientCert { if o.RequireClientCert {
// We fall back to normal config settings if users don't need to reload root certificates. // We have to set clientAuth to RequireAnyClientCert to force underlying
// If using RequireAndVerifyClientCert, the underlying stack would use the default // TLS package to use the verification function we built from
// checking and ignore the verification function we built from buildVerifyFunc. // buildVerifyFunc.
// If using RequireAnyClientCert, the code would skip all the checks and use the clientAuth = tls.RequireAnyClientCert
// function from buildVerifyFunc.
if o.RootCACerts != nil {
clientAuth = tls.RequireAndVerifyClientCert
} else {
clientAuth = tls.RequireAnyClientCert
}
} }
config := &tls.Config{ config := &tls.Config{
ClientAuth: clientAuth, ClientAuth: clientAuth,
@ -166,12 +218,14 @@ func (o *ServerOptions) config() (*tls.Config, error) {
return config, nil return config, nil
} }
// advancedTLSCreds is the credentials required for authenticating a connection using TLS. // advancedTLSCreds is the credentials required for authenticating a connection
// using TLS.
type advancedTLSCreds struct { type advancedTLSCreds struct {
config *tls.Config config *tls.Config
verifyFunc CustomVerificationFunc verifyFunc CustomVerificationFunc
getRootCAs func(params *GetRootCAsParams) (*GetRootCAsResults, error) getRootCAs func(params *GetRootCAsParams) (*GetRootCAsResults, error)
isClient bool isClient bool
vType VerificationType
} }
func (c advancedTLSCreds) Info() credentials.ProtocolInfo { func (c advancedTLSCreds) Info() credentials.ProtocolInfo {
@ -218,10 +272,7 @@ func (c *advancedTLSCreds) ClientHandshake(ctx context.Context, authority string
func (c *advancedTLSCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { func (c *advancedTLSCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
cfg := cloneTLSConfig(c.config) cfg := cloneTLSConfig(c.config)
// We build server side verification function only when root cert reloading is needed. cfg.VerifyPeerCertificate = buildVerifyFunc(c, "", rawConn)
if c.getRootCAs != nil {
cfg.VerifyPeerCertificate = buildVerifyFunc(c, "", rawConn)
}
conn := tls.Server(rawConn, cfg) conn := tls.Server(rawConn, cfg)
if err := conn.Handshake(); err != nil { if err := conn.Handshake(); err != nil {
conn.Close() conn.Close()
@ -250,91 +301,82 @@ func (c *advancedTLSCreds) OverrideServerName(serverNameOverride string) error {
return nil return nil
} }
// The function buildVerifyFunc is used when users want root cert reloading, and possibly custom // The function buildVerifyFunc is used when users want root cert reloading,
// server authorization check. // and possibly custom verification check.
// We have to build our own verification function here because current tls module: // We have to build our own verification function here because current
// 1. does not have a good support on root cert reloading // tls module:
// 2. will ignore basic certificate check when setting InsecureSkipVerify to true // 1. does not have a good support on root cert reloading.
// 2. will ignore basic certificate check when setting InsecureSkipVerify
// to true.
func buildVerifyFunc(c *advancedTLSCreds, func buildVerifyFunc(c *advancedTLSCreds,
serverName string, serverName string,
rawConn net.Conn) func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { rawConn net.Conn) func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
return func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { return func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
// If users didn't specify either rootCAs or getRootCAs on client side, chains := verifiedChains
// as we see some use cases such as https://github.com/grpc/grpc/pull/20530, if c.vType == CertAndHostVerification || c.vType == CertVerification {
// instead of failing, we just don't validate the server cert and let // perform possible trust credential reloading and certificate check
// application decide via VerifyPeer rootCAs := c.config.RootCAs
if c.isClient && c.config.RootCAs == nil && c.getRootCAs == nil { if !c.isClient {
if c.verifyFunc != nil { rootCAs = c.config.ClientCAs
_, err := c.verifyFunc(&VerificationFuncParams{ }
ServerName: serverName, // Reload root CA certs.
RawCerts: rawCerts, if rootCAs == nil && c.getRootCAs != nil {
VerifiedChains: verifiedChains, results, err := c.getRootCAs(&GetRootCAsParams{
RawConn: rawConn,
RawCerts: rawCerts,
}) })
if err != nil {
return err
}
rootCAs = results.TrustCerts
}
// Verify peers' certificates against RootCAs and get verifiedChains.
certs := make([]*x509.Certificate, len(rawCerts))
for i, asn1Data := range rawCerts {
cert, err := x509.ParseCertificate(asn1Data)
if err != nil {
return err
}
certs[i] = cert
}
keyUsages := []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}
if !c.isClient {
keyUsages = []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}
}
opts := x509.VerifyOptions{
Roots: rootCAs,
CurrentTime: time.Now(),
Intermediates: x509.NewCertPool(),
KeyUsages: keyUsages,
}
for _, cert := range certs[1:] {
opts.Intermediates.AddCert(cert)
}
// Perform default hostname check if specified.
if c.isClient && c.vType == CertAndHostVerification && serverName != "" {
opts.DNSName = serverName
}
var err error
chains, err = certs[0].Verify(opts)
if err != nil {
return err return err
} }
} }
var rootCAs *x509.CertPool // Perform custom verification check if specified.
if c.isClient { if c.verifyFunc != nil {
rootCAs = c.config.RootCAs _, err := c.verifyFunc(&VerificationFuncParams{
} else { ServerName: serverName,
rootCAs = c.config.ClientCAs RawCerts: rawCerts,
} VerifiedChains: chains,
// reload root CA certs
if rootCAs == nil && c.getRootCAs != nil {
results, err := c.getRootCAs(&GetRootCAsParams{
RawConn: rawConn,
RawCerts: rawCerts,
}) })
if err != nil {
return err
}
rootCAs = results.TrustCerts
}
// verify peers' certificates against RootCAs and get verifiedChains
certs := make([]*x509.Certificate, len(rawCerts))
for i, asn1Data := range rawCerts {
cert, err := x509.ParseCertificate(asn1Data)
if err != nil {
return err
}
certs[i] = cert
}
opts := x509.VerifyOptions{
Roots: rootCAs,
CurrentTime: time.Now(),
Intermediates: x509.NewCertPool(),
}
if !c.isClient {
opts.KeyUsages = []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}
} else {
opts.KeyUsages = []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}
}
for _, cert := range certs[1:] {
opts.Intermediates.AddCert(cert)
}
// We use default hostname check if users don't specify verifyFunc function
if c.isClient && c.verifyFunc == nil && serverName != "" {
opts.DNSName = serverName
}
verifiedChains, err := certs[0].Verify(opts)
if err != nil {
return err return err
} }
if c.isClient && c.verifyFunc != nil {
if c.verifyFunc != nil {
_, err := c.verifyFunc(&VerificationFuncParams{
ServerName: serverName,
RawCerts: rawCerts,
VerifiedChains: verifiedChains,
})
return err
}
}
return nil return nil
} }
} }
// NewClientCreds uses ClientOptions to construct a TransportCredentials based on TLS. // NewClientCreds uses ClientOptions to construct a TransportCredentials based
// on TLS.
func NewClientCreds(o *ClientOptions) (credentials.TransportCredentials, error) { func NewClientCreds(o *ClientOptions) (credentials.TransportCredentials, error) {
conf, err := o.config() conf, err := o.config()
if err != nil { if err != nil {
@ -345,12 +387,14 @@ func NewClientCreds(o *ClientOptions) (credentials.TransportCredentials, error)
isClient: true, isClient: true,
getRootCAs: o.GetRootCAs, getRootCAs: o.GetRootCAs,
verifyFunc: o.VerifyPeer, verifyFunc: o.VerifyPeer,
vType: o.VType,
} }
tc.config.NextProtos = appendH2ToNextProtos(tc.config.NextProtos) tc.config.NextProtos = appendH2ToNextProtos(tc.config.NextProtos)
return tc, nil return tc, nil
} }
// NewServerCreds uses ServerOptions to construct a TransportCredentials based on TLS. // NewServerCreds uses ServerOptions to construct a TransportCredentials based
// on TLS.
func NewServerCreds(o *ServerOptions) (credentials.TransportCredentials, error) { func NewServerCreds(o *ServerOptions) (credentials.TransportCredentials, error) {
conf, err := o.config() conf, err := o.config()
if err != nil { if err != nil {
@ -360,6 +404,8 @@ func NewServerCreds(o *ServerOptions) (credentials.TransportCredentials, error)
config: conf, config: conf,
isClient: false, isClient: false,
getRootCAs: o.GetRootCAs, getRootCAs: o.GetRootCAs,
verifyFunc: o.VerifyPeer,
vType: o.VType,
} }
tc.config.NextProtos = appendH2ToNextProtos(tc.config.NextProtos) tc.config.NextProtos = appendH2ToNextProtos(tc.config.NextProtos)
return tc, nil return tc, nil

View File

@ -39,9 +39,11 @@ var (
port = ":50051" port = ":50051"
) )
// stageInfo contains a stage number indicating the current phase of each integration test, and a mutex. // stageInfo contains a stage number indicating the current phase of each
// Based on the stage number of current test, we will use different certificates and server authorization // integration test, and a mutex.
// functions to check if our tests behave as expected. // Based on the stage number of current test, we will use different
// certificates and custom verification functions to check if our tests behave
// as expected.
type stageInfo struct { type stageInfo struct {
mutex sync.Mutex mutex sync.Mutex
stage int stage int
@ -67,13 +69,17 @@ func (s *stageInfo) reset() {
// certStore contains all the certificates used in the integration tests. // certStore contains all the certificates used in the integration tests.
type certStore struct { type certStore struct {
// clientPeer1 is the certificate sent by client to prove its identity. It is trusted by serverTrust1. // clientPeer1 is the certificate sent by client to prove its identity.
// It is trusted by serverTrust1.
clientPeer1 tls.Certificate clientPeer1 tls.Certificate
// clientPeer2 is the certificate sent by client to prove its identity. It is trusted by serverTrust2. // clientPeer2 is the certificate sent by client to prove its identity.
// It is trusted by serverTrust2.
clientPeer2 tls.Certificate clientPeer2 tls.Certificate
// serverPeer1 is the certificate sent by server to prove its identity. It is trusted by clientTrust1. // serverPeer1 is the certificate sent by server to prove its identity.
// It is trusted by clientTrust1.
serverPeer1 tls.Certificate serverPeer1 tls.Certificate
// serverPeer2 is the certificate sent by server to prove its identity. It is trusted by clientTrust2. // serverPeer2 is the certificate sent by server to prove its identity.
// It is trusted by clientTrust2.
serverPeer2 tls.Certificate serverPeer2 tls.Certificate
clientTrust1 *x509.CertPool clientTrust1 *x509.CertPool
clientTrust2 *x509.CertPool clientTrust2 *x509.CertPool
@ -81,7 +87,8 @@ type certStore struct {
serverTrust2 *x509.CertPool serverTrust2 *x509.CertPool
} }
// loadCerts function is used to load test certificates at the beginning of each integration test. // loadCerts function is used to load test certificates at the beginning of
// each integration test.
func (cs *certStore) loadCerts() error { func (cs *certStore) loadCerts() error {
var err error var err error
cs.clientPeer1, err = tls.LoadX509KeyPair(testdata.Path("client_cert_1.pem"), cs.clientPeer1, err = tls.LoadX509KeyPair(testdata.Path("client_cert_1.pem"),
@ -144,7 +151,8 @@ func callAndVerify(msg string, client pb.GreeterClient, shouldFail bool) error {
func callAndVerifyWithClientConn(connCtx context.Context, msg string, creds credentials.TransportCredentials, shouldFail bool) (*grpc.ClientConn, pb.GreeterClient, error) { func callAndVerifyWithClientConn(connCtx context.Context, msg string, creds credentials.TransportCredentials, shouldFail bool) (*grpc.ClientConn, pb.GreeterClient, error) {
var conn *grpc.ClientConn var conn *grpc.ClientConn
var err error var err error
// If we want the test to fail, we establish a non-blocking connection to avoid it hangs and killed by the context. // If we want the test to fail, we establish a non-blocking connection to
// avoid it hangs and killed by the context.
if shouldFail { if shouldFail {
conn, err = grpc.DialContext(connCtx, address, grpc.WithTransportCredentials(creds)) conn, err = grpc.DialContext(connCtx, address, grpc.WithTransportCredentials(creds))
if err != nil { if err != nil {
@ -166,10 +174,13 @@ func callAndVerifyWithClientConn(connCtx context.Context, msg string, creds cred
// The advanced TLS features are tested in different stages. // The advanced TLS features are tested in different stages.
// At stage 0, we establish a good connection between client and server. // At stage 0, we establish a good connection between client and server.
// At stage 1, we change one factor(it could be we change the server's certificate, or server authorization function, etc), // At stage 1, we change one factor(it could be we change the server's
// and test if the following connections would be dropped. // certificate, or custom verification function, etc), and test if the
// At stage 2, we re-establish the connection by changing the counterpart of the factor we modified in stage 1. // following connections would be dropped.
// (could be change the client's trust certificate, or change server authorization function, etc) // At stage 2, we re-establish the connection by changing the counterpart of
// the factor we modified in stage 1.
// (could be change the client's trust certificate, or change custom
// verification function, etc)
func TestEnd2End(t *testing.T) { func TestEnd2End(t *testing.T) {
cs := &certStore{} cs := &certStore{}
err := cs.loadCerts() err := cs.loadCerts()
@ -184,17 +195,25 @@ func TestEnd2End(t *testing.T) {
clientRoot *x509.CertPool clientRoot *x509.CertPool
clientGetRoot func(params *GetRootCAsParams) (*GetRootCAsResults, error) clientGetRoot func(params *GetRootCAsParams) (*GetRootCAsResults, error)
clientVerifyFunc CustomVerificationFunc clientVerifyFunc CustomVerificationFunc
clientVType VerificationType
serverCert []tls.Certificate serverCert []tls.Certificate
serverGetCert func(*tls.ClientHelloInfo) (*tls.Certificate, error) serverGetCert func(*tls.ClientHelloInfo) (*tls.Certificate, error)
serverRoot *x509.CertPool serverRoot *x509.CertPool
serverGetRoot func(params *GetRootCAsParams) (*GetRootCAsResults, error) serverGetRoot func(params *GetRootCAsParams) (*GetRootCAsResults, error)
serverVerifyFunc CustomVerificationFunc
serverVType VerificationType
}{ }{
// Test Scenarios: // Test Scenarios:
// At initialization(stage = 0), client will be initialized with cert clientPeer1 and clientTrust1, server with serverPeer1 and serverTrust1. // At initialization(stage = 0), client will be initialized with cert
// The mutual authentication works at the beginning, since clientPeer1 is trusted by serverTrust1, and serverPeer1 by clientTrust1. // clientPeer1 and clientTrust1, server with serverPeer1 and serverTrust1.
// At stage 1, client changes clientPeer1 to clientPeer2. Since clientPeer2 is not trusted by serverTrust1, following rpc calls are expected // The mutual authentication works at the beginning, since clientPeer1 is
// to fail, while the previous rpc calls are still good because those are already authenticated. // trusted by serverTrust1, and serverPeer1 by clientTrust1.
// At stage 2, the server changes serverTrust1 to serverTrust2, and we should see it again accepts the connection, since clientPeer2 is trusted // At stage 1, client changes clientPeer1 to clientPeer2. Since clientPeer2
// is not trusted by serverTrust1, following rpc calls are expected to
// fail, while the previous rpc calls are still good because those are
// already authenticated.
// At stage 2, the server changes serverTrust1 to serverTrust2, and we
// should see it again accepts the connection, since clientPeer2 is trusted
// by serverTrust2. // by serverTrust2.
{ {
desc: "TestClientPeerCertReloadServerTrustCertReload", desc: "TestClientPeerCertReloadServerTrustCertReload",
@ -212,6 +231,7 @@ func TestEnd2End(t *testing.T) {
clientVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) { clientVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
return &VerificationResults{}, nil return &VerificationResults{}, nil
}, },
clientVType: CertVerification,
serverCert: []tls.Certificate{cs.serverPeer1}, serverCert: []tls.Certificate{cs.serverPeer1},
serverGetCert: nil, serverGetCert: nil,
serverRoot: nil, serverRoot: nil,
@ -223,13 +243,22 @@ func TestEnd2End(t *testing.T) {
return &GetRootCAsResults{TrustCerts: cs.serverTrust2}, nil return &GetRootCAsResults{TrustCerts: cs.serverTrust2}, nil
} }
}, },
serverVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
return &VerificationResults{}, nil
},
serverVType: CertVerification,
}, },
// Test Scenarios: // Test Scenarios:
// At initialization(stage = 0), client will be initialized with cert clientPeer1 and clientTrust1, server with serverPeer1 and serverTrust1. // At initialization(stage = 0), client will be initialized with cert
// The mutual authentication works at the beginning, since clientPeer1 is trusted by serverTrust1, and serverPeer1 by clientTrust1. // clientPeer1 and clientTrust1, server with serverPeer1 and serverTrust1.
// At stage 1, server changes serverPeer1 to serverPeer2. Since serverPeer2 is not trusted by clientTrust1, following rpc calls are expected // The mutual authentication works at the beginning, since clientPeer1 is
// to fail, while the previous rpc calls are still good because those are already authenticated. // trusted by serverTrust1, and serverPeer1 by clientTrust1.
// At stage 2, the client changes clientTrust1 to clientTrust2, and we should see it again accepts the connection, since serverPeer2 is trusted // At stage 1, server changes serverPeer1 to serverPeer2. Since serverPeer2
// is not trusted by clientTrust1, following rpc calls are expected to
// fail, while the previous rpc calls are still good because those are
// already authenticated.
// At stage 2, the client changes clientTrust1 to clientTrust2, and we
// should see it again accepts the connection, since serverPeer2 is trusted
// by clientTrust2. // by clientTrust2.
{ {
desc: "TestServerPeerCertReloadClientTrustCertReload", desc: "TestServerPeerCertReloadClientTrustCertReload",
@ -247,7 +276,8 @@ func TestEnd2End(t *testing.T) {
clientVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) { clientVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
return &VerificationResults{}, nil return &VerificationResults{}, nil
}, },
serverCert: nil, clientVType: CertVerification,
serverCert: nil,
serverGetCert: func(*tls.ClientHelloInfo) (*tls.Certificate, error) { serverGetCert: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
switch stage.read() { switch stage.read() {
case 0: case 0:
@ -258,17 +288,26 @@ func TestEnd2End(t *testing.T) {
}, },
serverRoot: cs.serverTrust1, serverRoot: cs.serverTrust1,
serverGetRoot: nil, serverGetRoot: nil,
serverVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
return &VerificationResults{}, nil
},
serverVType: CertVerification,
}, },
// Test Scenarios: // Test Scenarios:
// At initialization(stage = 0), client will be initialized with cert clientPeer1 and clientTrust1, server with serverPeer1 and serverTrust1. // At initialization(stage = 0), client will be initialized with cert
// The mutual authentication works at the beginning, since clientPeer1 trusted by serverTrust1, serverPeer1 by clientTrust1, and also the // clientPeer1 and clientTrust1, server with serverPeer1 and serverTrust1.
// custom server authorization check allows the CommonName on serverPeer1. // The mutual authentication works at the beginning, since clientPeer1
// At stage 1, server changes serverPeer1 to serverPeer2, and client changes clientTrust1 to clientTrust2. Although serverPeer2 is trusted by // trusted by serverTrust1, serverPeer1 by clientTrust1, and also the
// clientTrust2, our authorization check only accepts serverPeer1, and hence the following calls should fail. Previous connections should // custom verification check allows the CommonName on serverPeer1.
// At stage 1, server changes serverPeer1 to serverPeer2, and client
// changes clientTrust1 to clientTrust2. Although serverPeer2 is trusted by
// clientTrust2, our authorization check only accepts serverPeer1, and
// hence the following calls should fail. Previous connections should
// not be affected. // not be affected.
// At stage 2, the client changes authorization check to only accept serverPeer2. Now we should see the connection becomes normal again. // At stage 2, the client changes authorization check to only accept
// serverPeer2. Now we should see the connection becomes normal again.
{ {
desc: "TestClientCustomServerAuthz", desc: "TestClientCustomVerification",
clientCert: []tls.Certificate{cs.clientPeer1}, clientCert: []tls.Certificate{cs.clientPeer1},
clientGetCert: nil, clientGetCert: nil,
clientGetRoot: func(params *GetRootCAsParams) (*GetRootCAsResults, error) { clientGetRoot: func(params *GetRootCAsParams) (*GetRootCAsResults, error) {
@ -306,7 +345,8 @@ func TestEnd2End(t *testing.T) {
} }
return nil, fmt.Errorf("custom authz check fails") return nil, fmt.Errorf("custom authz check fails")
}, },
serverCert: nil, clientVType: CertVerification,
serverCert: nil,
serverGetCert: func(*tls.ClientHelloInfo) (*tls.Certificate, error) { serverGetCert: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
switch stage.read() { switch stage.read() {
case 0: case 0:
@ -317,6 +357,47 @@ func TestEnd2End(t *testing.T) {
}, },
serverRoot: cs.serverTrust1, serverRoot: cs.serverTrust1,
serverGetRoot: nil, serverGetRoot: nil,
serverVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
return &VerificationResults{}, nil
},
serverVType: CertVerification,
},
// Test Scenarios:
// At initialization(stage = 0), client will be initialized with cert
// clientPeer1 and clientTrust1, server with serverPeer1 and serverTrust1.
// The mutual authentication works at the beginning, since clientPeer1
// trusted by serverTrust1, serverPeer1 by clientTrust1, and also the
// custom verification check on server side allows all connections.
// At stage 1, server disallows the the connections by setting custom
// verification check. The following calls should fail. Previous
// connections should not be affected.
// At stage 2, server allows all the connections again and the
// authentications should go back to normal.
{
desc: "TestServerCustomVerification",
clientCert: []tls.Certificate{cs.clientPeer1},
clientGetCert: nil,
clientGetRoot: nil,
clientRoot: cs.clientTrust1,
clientVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
return &VerificationResults{}, nil
},
clientVType: CertVerification,
serverCert: []tls.Certificate{cs.serverPeer1},
serverGetCert: nil,
serverRoot: cs.serverTrust1,
serverGetRoot: nil,
serverVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
switch stage.read() {
case 0, 2:
return &VerificationResults{}, nil
case 1:
return nil, fmt.Errorf("custom authz check fails")
default:
return nil, fmt.Errorf("custom authz check fails")
}
},
serverVType: CertVerification,
}, },
} { } {
test := test test := test
@ -330,6 +411,8 @@ func TestEnd2End(t *testing.T) {
GetRootCAs: test.serverGetRoot, GetRootCAs: test.serverGetRoot,
}, },
RequireClientCert: true, RequireClientCert: true,
VerifyPeer: test.serverVerifyFunc,
VType: test.serverVType,
} }
serverTLSCreds, err := NewServerCreds(serverOptions) serverTLSCreds, err := NewServerCreds(serverOptions)
if err != nil { if err != nil {
@ -356,49 +439,50 @@ func TestEnd2End(t *testing.T) {
RootCACerts: test.clientRoot, RootCACerts: test.clientRoot,
GetRootCAs: test.clientGetRoot, GetRootCAs: test.clientGetRoot,
}, },
VType: test.clientVType,
} }
clientTLSCreds, err := NewClientCreds(clientOptions) clientTLSCreds, err := NewClientCreds(clientOptions)
if err != nil { if err != nil {
t.Fatalf("clientTLSCreds failed to create") t.Fatalf("clientTLSCreds failed to create")
} }
// ------------------------Scenario 1----------------------------------------- // ------------------------Scenario 1------------------------------------
// stage = 0, initial connection should succeed // stage = 0, initial connection should succeed
ctx1, cancel1 := context.WithTimeout(context.Background(), 5*time.Second) ctx1, cancel1 := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel1() defer cancel1()
conn, greetClient, err := callAndVerifyWithClientConn(ctx1, "rpc call 1", clientTLSCreds, false) conn, greetClient, err := callAndVerifyWithClientConn(ctx1, "rpc call 1", clientTLSCreds, false)
defer conn.Close()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
// --------------------------------------------------------------------------- defer conn.Close()
// ----------------------------------------------------------------------
stage.increase() stage.increase()
// ------------------------Scenario 2----------------------------------------- // ------------------------Scenario 2------------------------------------
// stage = 1, previous connection should still succeed // stage = 1, previous connection should still succeed
err = callAndVerify("rpc call 2", greetClient, false) err = callAndVerify("rpc call 2", greetClient, false)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
// ------------------------Scenario 3----------------------------------------- // ------------------------Scenario 3------------------------------------
// stage = 1, new connection should fail // stage = 1, new connection should fail
ctx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second) ctx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel2() defer cancel2()
conn2, greetClient, err := callAndVerifyWithClientConn(ctx2, "rpc call 3", clientTLSCreds, true) conn2, greetClient, err := callAndVerifyWithClientConn(ctx2, "rpc call 3", clientTLSCreds, true)
defer conn2.Close()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
//// --------------------------------------------------------------------------- defer conn2.Close()
//// --------------------------------------------------------------------
stage.increase() stage.increase()
// ------------------------Scenario 4----------------------------------------- // ------------------------Scenario 4------------------------------------
// stage = 2, new connection should succeed // stage = 2, new connection should succeed
ctx3, cancel3 := context.WithTimeout(context.Background(), 5*time.Second) ctx3, cancel3 := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel3() defer cancel3()
conn3, greetClient, err := callAndVerifyWithClientConn(ctx3, "rpc call 4", clientTLSCreds, false) conn3, greetClient, err := callAndVerifyWithClientConn(ctx3, "rpc call 4", clientTLSCreds, false)
defer conn3.Close()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
// --------------------------------------------------------------------------- defer conn3.Close()
// ----------------------------------------------------------------------
stage.reset() stage.reset()
}) })
} }

View File

@ -77,6 +77,7 @@ func TestClientServerHandshake(t *testing.T) {
clientRoot *x509.CertPool clientRoot *x509.CertPool
clientGetRoot func(params *GetRootCAsParams) (*GetRootCAsResults, error) clientGetRoot func(params *GetRootCAsParams) (*GetRootCAsResults, error)
clientVerifyFunc CustomVerificationFunc clientVerifyFunc CustomVerificationFunc
clientVType VerificationType
clientExpectCreateError bool clientExpectCreateError bool
clientExpectHandshakeError bool clientExpectHandshakeError bool
serverMutualTLS bool serverMutualTLS bool
@ -84,13 +85,17 @@ func TestClientServerHandshake(t *testing.T) {
serverGetCert func(*tls.ClientHelloInfo) (*tls.Certificate, error) serverGetCert func(*tls.ClientHelloInfo) (*tls.Certificate, error)
serverRoot *x509.CertPool serverRoot *x509.CertPool
serverGetRoot func(params *GetRootCAsParams) (*GetRootCAsResults, error) serverGetRoot func(params *GetRootCAsParams) (*GetRootCAsResults, error)
serverVerifyFunc CustomVerificationFunc
serverVType VerificationType
serverExpectError bool serverExpectError bool
}{ }{
// Client: nil setting // Client: nil setting
// Server: only set serverCert with mutual TLS off // Server: only set serverCert with mutual TLS off
// Expected Behavior: server side failure // Expected Behavior: server side failure
// Reason: if either clientCert or clientGetClientCert is not set and // Reason: if clientRoot, clientGetRoot and verifyFunc is not set, client
// verifyFunc is not set, we will fail directly // side doesn't provide any verification mechanism. We don't allow this
// even setting vType to SkipVerification. Clients should at least provide
// their own verification logic.
{ {
"Client_no_trust_cert_Server_peer_cert", "Client_no_trust_cert_Server_peer_cert",
nil, nil,
@ -98,6 +103,7 @@ func TestClientServerHandshake(t *testing.T) {
nil, nil,
nil, nil,
nil, nil,
SkipVerification,
true, true,
false, false,
false, false,
@ -105,6 +111,8 @@ func TestClientServerHandshake(t *testing.T) {
nil, nil,
nil, nil,
nil, nil,
nil,
CertAndHostVerification,
true, true,
}, },
// Client: nil setting except verifyFuncGood // Client: nil setting except verifyFuncGood
@ -119,6 +127,7 @@ func TestClientServerHandshake(t *testing.T) {
nil, nil,
nil, nil,
verifyFuncGood, verifyFuncGood,
SkipVerification,
false, false,
false, false,
false, false,
@ -126,13 +135,16 @@ func TestClientServerHandshake(t *testing.T) {
nil, nil,
nil, nil,
nil, nil,
nil,
CertAndHostVerification,
false, false,
}, },
// Client: only set clientRoot // Client: only set clientRoot
// Server: only set serverCert with mutual TLS off // Server: only set serverCert with mutual TLS off
// Expected Behavior: server side failure and client handshake failure // Expected Behavior: server side failure and client handshake failure
// Reason: not setting advanced TLS features will fall back to normal check, and will hence fail // Reason: client side sets vType to CertAndHostVerification, and will do
// on default host name check. All the default hostname checks will fail in this test suites. // default hostname check. All the default hostname checks will fail in
// this test suites.
{ {
"Client_root_cert_Server_peer_cert", "Client_root_cert_Server_peer_cert",
nil, nil,
@ -140,6 +152,7 @@ func TestClientServerHandshake(t *testing.T) {
clientTrustPool, clientTrustPool,
nil, nil,
nil, nil,
CertAndHostVerification,
false, false,
true, true,
false, false,
@ -147,13 +160,16 @@ func TestClientServerHandshake(t *testing.T) {
nil, nil,
nil, nil,
nil, nil,
nil,
CertAndHostVerification,
true, true,
}, },
// Client: only set clientGetRoot // Client: only set clientGetRoot
// Server: only set serverCert with mutual TLS off // Server: only set serverCert with mutual TLS off
// Expected Behavior: server side failure and client handshake failure // Expected Behavior: server side failure and client handshake failure
// Reason: setting root reloading function without custom verifyFunc will also fail, // Reason: client side sets vType to CertAndHostVerification, and will do
// since it will also fall back to default host name check // default hostname check. All the default hostname checks will fail in
// this test suites.
{ {
"Client_reload_root_Server_peer_cert", "Client_reload_root_Server_peer_cert",
nil, nil,
@ -161,6 +177,7 @@ func TestClientServerHandshake(t *testing.T) {
nil, nil,
getRootCAsForClient, getRootCAsForClient,
nil, nil,
CertAndHostVerification,
false, false,
true, true,
false, false,
@ -168,6 +185,8 @@ func TestClientServerHandshake(t *testing.T) {
nil, nil,
nil, nil,
nil, nil,
nil,
CertAndHostVerification,
true, true,
}, },
// Client: set clientGetRoot and clientVerifyFunc // Client: set clientGetRoot and clientVerifyFunc
@ -180,6 +199,7 @@ func TestClientServerHandshake(t *testing.T) {
nil, nil,
getRootCAsForClient, getRootCAsForClient,
verifyFuncGood, verifyFuncGood,
CertVerification,
false, false,
false, false,
false, false,
@ -187,6 +207,8 @@ func TestClientServerHandshake(t *testing.T) {
nil, nil,
nil, nil,
nil, nil,
nil,
CertAndHostVerification,
false, false,
}, },
// Client: set clientGetRoot and bad clientVerifyFunc function // Client: set clientGetRoot and bad clientVerifyFunc function
@ -200,6 +222,7 @@ func TestClientServerHandshake(t *testing.T) {
nil, nil,
getRootCAsForClient, getRootCAsForClient,
verifyFuncBad, verifyFuncBad,
CertVerification,
false, false,
true, true,
false, false,
@ -207,6 +230,8 @@ func TestClientServerHandshake(t *testing.T) {
nil, nil,
nil, nil,
nil, nil,
nil,
CertVerification,
true, true,
}, },
// Client: set clientGetRoot and clientVerifyFunc // Client: set clientGetRoot and clientVerifyFunc
@ -220,6 +245,7 @@ func TestClientServerHandshake(t *testing.T) {
nil, nil,
getRootCAsForClient, getRootCAsForClient,
verifyFuncGood, verifyFuncGood,
CertVerification,
false, false,
false, false,
false, false,
@ -227,26 +253,8 @@ func TestClientServerHandshake(t *testing.T) {
nil, nil,
nil, nil,
nil, nil,
true,
},
// Client: set clientGetRoot and clientVerifyFunc
// Server: only set serverCert with mutual TLS on
// Expected Behavior: server side failure
// Reason: server side must either set serverRoot or serverGetRoot when using mutual TLS
{
"Client_reload_root_verifyFuncGood_Server_peer_cert_no_root_cert_mutualTLS",
nil,
nil,
nil,
getRootCAsForClient,
verifyFuncGood,
false,
false,
true,
[]tls.Certificate{serverPeerCert},
nil,
nil,
nil, nil,
CertVerification,
true, true,
}, },
// Client: set clientGetRoot, clientVerifyFunc and clientCert // Client: set clientGetRoot, clientVerifyFunc and clientCert
@ -259,6 +267,7 @@ func TestClientServerHandshake(t *testing.T) {
nil, nil,
getRootCAsForClient, getRootCAsForClient,
verifyFuncGood, verifyFuncGood,
CertVerification,
false, false,
false, false,
true, true,
@ -266,9 +275,37 @@ func TestClientServerHandshake(t *testing.T) {
nil, nil,
serverTrustPool, serverTrustPool,
nil, nil,
nil,
CertVerification,
false, false,
}, },
// Client: set clientGetRoot, clientVerifyFunc and clientCert // Client: set clientGetRoot, clientVerifyFunc and clientCert
// Server: set serverCert, but not setting any of serverRoot, serverGetRoot
// or serverVerifyFunc, with mutual TLS on
// Expected Behavior: server side failure
// Reason: server side needs to provide any verification mechanism when
// mTLS in on, even setting vType to SkipVerification. Servers should at
// least provide their own verification logic.
{
"Client_peer_cert_reload_root_verifyFuncGood_Server_no_verification_mutualTLS",
[]tls.Certificate{clientPeerCert},
nil,
nil,
getRootCAsForClient,
verifyFuncGood,
CertVerification,
false,
true,
true,
[]tls.Certificate{serverPeerCert},
nil,
nil,
nil,
nil,
SkipVerification,
true,
},
// Client: set clientGetRoot, clientVerifyFunc and clientCert
// Server: set serverGetRoot and serverCert with mutual TLS on // Server: set serverGetRoot and serverCert with mutual TLS on
// Expected Behavior: success // Expected Behavior: success
{ {
@ -278,6 +315,7 @@ func TestClientServerHandshake(t *testing.T) {
nil, nil,
getRootCAsForClient, getRootCAsForClient,
verifyFuncGood, verifyFuncGood,
CertVerification,
false, false,
false, false,
true, true,
@ -285,10 +323,13 @@ func TestClientServerHandshake(t *testing.T) {
nil, nil,
nil, nil,
getRootCAsForServer, getRootCAsForServer,
nil,
CertVerification,
false, false,
}, },
// Client: set clientGetRoot, clientVerifyFunc and clientCert // Client: set clientGetRoot, clientVerifyFunc and clientCert
// Server: set serverGetRoot returning error and serverCert with mutual TLS on // Server: set serverGetRoot returning error and serverCert with mutual
// TLS on
// Expected Behavior: server side failure // Expected Behavior: server side failure
// Reason: server side reloading returns failure // Reason: server side reloading returns failure
{ {
@ -298,6 +339,7 @@ func TestClientServerHandshake(t *testing.T) {
nil, nil,
getRootCAsForClient, getRootCAsForClient,
verifyFuncGood, verifyFuncGood,
CertVerification,
false, false,
false, false,
true, true,
@ -305,6 +347,8 @@ func TestClientServerHandshake(t *testing.T) {
nil, nil,
nil, nil,
getRootCAsForServerBad, getRootCAsForServerBad,
nil,
CertVerification,
true, true,
}, },
// Client: set clientGetRoot, clientVerifyFunc and clientGetClientCert // Client: set clientGetRoot, clientVerifyFunc and clientGetClientCert
@ -319,6 +363,7 @@ func TestClientServerHandshake(t *testing.T) {
nil, nil,
getRootCAsForClient, getRootCAsForClient,
verifyFuncGood, verifyFuncGood,
CertVerification,
false, false,
false, false,
true, true,
@ -328,9 +373,12 @@ func TestClientServerHandshake(t *testing.T) {
}, },
nil, nil,
getRootCAsForServer, getRootCAsForServer,
verifyFuncGood,
CertVerification,
false, false,
}, },
// Client: set everything but with the wrong peer cert not trusted by server // Client: set everything but with the wrong peer cert not trusted by
// server
// Server: set serverGetRoot and serverGetCert with mutual TLS on // Server: set serverGetRoot and serverGetCert with mutual TLS on
// Expected Behavior: server side returns failure because of // Expected Behavior: server side returns failure because of
// certificate mismatch // certificate mismatch
@ -343,6 +391,7 @@ func TestClientServerHandshake(t *testing.T) {
nil, nil,
getRootCAsForClient, getRootCAsForClient,
verifyFuncGood, verifyFuncGood,
CertVerification,
false, false,
false, false,
true, true,
@ -352,6 +401,8 @@ func TestClientServerHandshake(t *testing.T) {
}, },
nil, nil,
getRootCAsForServer, getRootCAsForServer,
verifyFuncGood,
CertVerification,
true, true,
}, },
// Client: set everything but with the wrong trust cert not trusting server // Client: set everything but with the wrong trust cert not trusting server
@ -367,6 +418,7 @@ func TestClientServerHandshake(t *testing.T) {
nil, nil,
getRootCAsForServer, getRootCAsForServer,
verifyFuncGood, verifyFuncGood,
CertVerification,
false, false,
true, true,
true, true,
@ -376,10 +428,13 @@ func TestClientServerHandshake(t *testing.T) {
}, },
nil, nil,
getRootCAsForServer, getRootCAsForServer,
verifyFuncGood,
CertVerification,
true, true,
}, },
// Client: set clientGetRoot, clientVerifyFunc and clientCert // Client: set clientGetRoot, clientVerifyFunc and clientCert
// Server: set everything but with the wrong peer cert not trusted by client // Server: set everything but with the wrong peer cert not trusted by
// client
// Expected Behavior: server side and client side return failure due to // Expected Behavior: server side and client side return failure due to
// certificate mismatch and handshake failure // certificate mismatch and handshake failure
{ {
@ -391,6 +446,7 @@ func TestClientServerHandshake(t *testing.T) {
nil, nil,
getRootCAsForClient, getRootCAsForClient,
verifyFuncGood, verifyFuncGood,
CertVerification,
false, false,
false, false,
true, true,
@ -400,6 +456,8 @@ func TestClientServerHandshake(t *testing.T) {
}, },
nil, nil,
getRootCAsForServer, getRootCAsForServer,
verifyFuncGood,
CertVerification,
true, true,
}, },
// Client: set clientGetRoot, clientVerifyFunc and clientCert // Client: set clientGetRoot, clientVerifyFunc and clientCert
@ -415,6 +473,7 @@ func TestClientServerHandshake(t *testing.T) {
nil, nil,
getRootCAsForClient, getRootCAsForClient,
verifyFuncGood, verifyFuncGood,
CertVerification,
false, false,
true, true,
true, true,
@ -424,6 +483,31 @@ func TestClientServerHandshake(t *testing.T) {
}, },
nil, nil,
getRootCAsForClient, getRootCAsForClient,
verifyFuncGood,
CertVerification,
true,
},
// Client: set clientGetRoot, clientVerifyFunc and clientCert
// Server: set serverGetRoot and serverCert, but with bad verifyFunc
// Expected Behavior: server side and client side return failure due to
// server custom check fails
{
"Client_peer_cert_reload_root_verifyFuncGood_Server_bad_custom_verification_mutualTLS",
[]tls.Certificate{clientPeerCert},
nil,
nil,
getRootCAsForClient,
verifyFuncGood,
CertVerification,
false,
true,
true,
[]tls.Certificate{serverPeerCert},
nil,
nil,
getRootCAsForServer,
verifyFuncBad,
CertVerification,
true, true,
}, },
} { } {
@ -443,6 +527,8 @@ func TestClientServerHandshake(t *testing.T) {
GetRootCAs: test.serverGetRoot, GetRootCAs: test.serverGetRoot,
}, },
RequireClientCert: test.serverMutualTLS, RequireClientCert: test.serverMutualTLS,
VerifyPeer: test.serverVerifyFunc,
VType: test.serverVType,
} }
go func(done chan credentials.AuthInfo, lis net.Listener, serverOptions *ServerOptions) { go func(done chan credentials.AuthInfo, lis net.Listener, serverOptions *ServerOptions) {
serverRawConn, err := lis.Accept() serverRawConn, err := lis.Accept()
@ -480,6 +566,7 @@ func TestClientServerHandshake(t *testing.T) {
RootCACerts: test.clientRoot, RootCACerts: test.clientRoot,
GetRootCAs: test.clientGetRoot, GetRootCAs: test.clientGetRoot,
}, },
VType: test.clientVType,
} }
clientTLS, newClientErr := NewClientCreds(clientOptions) clientTLS, newClientErr := NewClientCreds(clientOptions)
if newClientErr != nil && test.clientExpectCreateError { if newClientErr != nil && test.clientExpectCreateError {