advancedtls: add fields for root and identity providers in API (#3863)
* add provider in advancedtls API for pem file reloading
This commit is contained in:
@ -34,6 +34,7 @@ import (
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/credentials/tls/certprovider"
|
||||
"google.golang.org/grpc/internal/grpctest"
|
||||
"google.golang.org/grpc/security/advancedtls/testdata"
|
||||
)
|
||||
@ -46,14 +47,274 @@ func Test(t *testing.T) {
|
||||
grpctest.RunSubTests(t, s{})
|
||||
}
|
||||
|
||||
func (s) TestClientServerHandshake(t *testing.T) {
|
||||
// ------------------Load Client Trust Cert and Peer Cert-------------------
|
||||
clientTrustPool, err := readTrustCert(testdata.Path("client_trust_cert_1.pem"))
|
||||
type provType int
|
||||
|
||||
const (
|
||||
provTypeRoot provType = iota
|
||||
provTypeIdentity
|
||||
)
|
||||
|
||||
type fakeProvider struct {
|
||||
pt provType
|
||||
isClient bool
|
||||
wantMultiCert bool
|
||||
wantError bool
|
||||
}
|
||||
|
||||
func (f fakeProvider) KeyMaterial(ctx context.Context) (*certprovider.KeyMaterial, error) {
|
||||
if f.wantError {
|
||||
return nil, fmt.Errorf("bad fakeProvider")
|
||||
}
|
||||
cs := &certStore{}
|
||||
err := cs.loadCerts()
|
||||
if err != nil {
|
||||
t.Fatalf("Client is unable to load trust certs. Error: %v", err)
|
||||
return nil, fmt.Errorf("failed to load certs: %v", err)
|
||||
}
|
||||
if f.pt == provTypeRoot && f.isClient {
|
||||
return &certprovider.KeyMaterial{Roots: cs.clientTrust1}, nil
|
||||
}
|
||||
if f.pt == provTypeRoot && !f.isClient {
|
||||
return &certprovider.KeyMaterial{Roots: cs.serverTrust1}, nil
|
||||
}
|
||||
if f.pt == provTypeIdentity && f.isClient {
|
||||
if f.wantMultiCert {
|
||||
return &certprovider.KeyMaterial{Certs: []tls.Certificate{cs.clientPeer1, cs.clientPeer2}}, nil
|
||||
}
|
||||
return &certprovider.KeyMaterial{Certs: []tls.Certificate{cs.clientPeer1}}, nil
|
||||
}
|
||||
if f.wantMultiCert {
|
||||
return &certprovider.KeyMaterial{Certs: []tls.Certificate{cs.serverPeer1, cs.serverPeer2}}, nil
|
||||
}
|
||||
return &certprovider.KeyMaterial{Certs: []tls.Certificate{cs.serverPeer1}}, nil
|
||||
}
|
||||
|
||||
func (f fakeProvider) Close() {}
|
||||
|
||||
func (s) TestClientOptionsConfigErrorCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
desc string
|
||||
clientVType VerificationType
|
||||
IdentityOptions IdentityCertificateOptions
|
||||
RootOptions RootCertificateOptions
|
||||
}{
|
||||
{
|
||||
desc: "Skip default verification and provide no root credentials",
|
||||
clientVType: SkipVerification,
|
||||
},
|
||||
{
|
||||
desc: "More than one fields in RootCertificateOptions is specified",
|
||||
clientVType: CertVerification,
|
||||
RootOptions: RootCertificateOptions{
|
||||
RootCACerts: x509.NewCertPool(),
|
||||
RootProvider: fakeProvider{},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "More than one fields in IdentityCertificateOptions is specified",
|
||||
clientVType: CertVerification,
|
||||
IdentityOptions: IdentityCertificateOptions{
|
||||
GetIdentityCertificatesForClient: func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
||||
return nil, nil
|
||||
},
|
||||
IdentityProvider: fakeProvider{pt: provTypeIdentity},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "Specify GetIdentityCertificatesForServer",
|
||||
IdentityOptions: IdentityCertificateOptions{
|
||||
GetIdentityCertificatesForServer: func(*tls.ClientHelloInfo) ([]*tls.Certificate, error) {
|
||||
return nil, nil
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
clientOptions := &ClientOptions{
|
||||
VType: test.clientVType,
|
||||
IdentityOptions: test.IdentityOptions,
|
||||
RootOptions: test.RootOptions,
|
||||
}
|
||||
_, err := clientOptions.config()
|
||||
if err == nil {
|
||||
t.Fatalf("ClientOptions{%v}.config() returns no err, wantErr != nil", clientOptions)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s) TestClientOptionsConfigSuccessCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
desc string
|
||||
clientVType VerificationType
|
||||
IdentityOptions IdentityCertificateOptions
|
||||
RootOptions RootCertificateOptions
|
||||
}{
|
||||
{
|
||||
desc: "Use system default if no fields in RootCertificateOptions is specified",
|
||||
clientVType: CertVerification,
|
||||
},
|
||||
{
|
||||
desc: "Good case with mutual TLS",
|
||||
clientVType: CertVerification,
|
||||
RootOptions: RootCertificateOptions{
|
||||
RootProvider: fakeProvider{},
|
||||
},
|
||||
IdentityOptions: IdentityCertificateOptions{
|
||||
IdentityProvider: fakeProvider{pt: provTypeIdentity},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
clientOptions := &ClientOptions{
|
||||
VType: test.clientVType,
|
||||
IdentityOptions: test.IdentityOptions,
|
||||
RootOptions: test.RootOptions,
|
||||
}
|
||||
clientConfig, err := clientOptions.config()
|
||||
if err != nil {
|
||||
t.Fatalf("ClientOptions{%v}.config() = %v, wantErr == nil", clientOptions, err)
|
||||
}
|
||||
// Verify that the system-provided certificates would be used
|
||||
// when no verification method was set in clientOptions.
|
||||
if clientOptions.RootOptions.RootCACerts == nil &&
|
||||
clientOptions.RootOptions.GetRootCertificates == nil && clientOptions.RootOptions.RootProvider == nil {
|
||||
if clientConfig.RootCAs == nil {
|
||||
t.Fatalf("Failed to assign system-provided certificates on the client side.")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s) TestServerOptionsConfigErrorCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
desc string
|
||||
requireClientCert bool
|
||||
serverVType VerificationType
|
||||
IdentityOptions IdentityCertificateOptions
|
||||
RootOptions RootCertificateOptions
|
||||
}{
|
||||
{
|
||||
desc: "Skip default verification and provide no root credentials",
|
||||
requireClientCert: true,
|
||||
serverVType: SkipVerification,
|
||||
},
|
||||
{
|
||||
desc: "More than one fields in RootCertificateOptions is specified",
|
||||
requireClientCert: true,
|
||||
serverVType: CertVerification,
|
||||
RootOptions: RootCertificateOptions{
|
||||
RootCACerts: x509.NewCertPool(),
|
||||
GetRootCertificates: func(*GetRootCAsParams) (*GetRootCAsResults, error) {
|
||||
return nil, nil
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "More than one fields in IdentityCertificateOptions is specified",
|
||||
serverVType: CertVerification,
|
||||
IdentityOptions: IdentityCertificateOptions{
|
||||
Certificates: []tls.Certificate{},
|
||||
IdentityProvider: fakeProvider{pt: provTypeIdentity},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "no field in IdentityCertificateOptions is specified",
|
||||
serverVType: CertVerification,
|
||||
},
|
||||
{
|
||||
desc: "Specify GetIdentityCertificatesForClient",
|
||||
IdentityOptions: IdentityCertificateOptions{
|
||||
GetIdentityCertificatesForClient: func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
||||
return nil, nil
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
serverOptions := &ServerOptions{
|
||||
VType: test.serverVType,
|
||||
RequireClientCert: test.requireClientCert,
|
||||
IdentityOptions: test.IdentityOptions,
|
||||
RootOptions: test.RootOptions,
|
||||
}
|
||||
_, err := serverOptions.config()
|
||||
if err == nil {
|
||||
t.Fatalf("ServerOptions{%v}.config() returns no err, wantErr != nil", serverOptions)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s) TestServerOptionsConfigSuccessCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
desc string
|
||||
requireClientCert bool
|
||||
serverVType VerificationType
|
||||
IdentityOptions IdentityCertificateOptions
|
||||
RootOptions RootCertificateOptions
|
||||
}{
|
||||
{
|
||||
desc: "Use system default if no fields in RootCertificateOptions is specified",
|
||||
requireClientCert: true,
|
||||
serverVType: CertVerification,
|
||||
IdentityOptions: IdentityCertificateOptions{
|
||||
Certificates: []tls.Certificate{},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "Good case with mutual TLS",
|
||||
requireClientCert: true,
|
||||
serverVType: CertVerification,
|
||||
RootOptions: RootCertificateOptions{
|
||||
RootProvider: fakeProvider{},
|
||||
},
|
||||
IdentityOptions: IdentityCertificateOptions{
|
||||
GetIdentityCertificatesForServer: func(*tls.ClientHelloInfo) ([]*tls.Certificate, error) {
|
||||
return nil, nil
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
serverOptions := &ServerOptions{
|
||||
VType: test.serverVType,
|
||||
RequireClientCert: test.requireClientCert,
|
||||
IdentityOptions: test.IdentityOptions,
|
||||
RootOptions: test.RootOptions,
|
||||
}
|
||||
serverConfig, err := serverOptions.config()
|
||||
if err != nil {
|
||||
t.Fatalf("ServerOptions{%v}.config() = %v, wantErr == nil", serverOptions, err)
|
||||
}
|
||||
// Verify that the system-provided certificates would be used
|
||||
// when no verification method was set in serverOptions.
|
||||
if serverOptions.RootOptions.RootCACerts == nil &&
|
||||
serverOptions.RootOptions.GetRootCertificates == nil && serverOptions.RootOptions.RootProvider == nil {
|
||||
if serverConfig.ClientCAs == nil {
|
||||
t.Fatalf("Failed to assign system-provided certificates on the server side.")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s) TestClientServerHandshake(t *testing.T) {
|
||||
cs := &certStore{}
|
||||
err := cs.loadCerts()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load certs: %v", err)
|
||||
}
|
||||
getRootCAsForClient := func(params *GetRootCAsParams) (*GetRootCAsResults, error) {
|
||||
return &GetRootCAsResults{TrustCerts: clientTrustPool}, nil
|
||||
return &GetRootCAsResults{TrustCerts: cs.clientTrust1}, nil
|
||||
}
|
||||
clientVerifyFuncGood := func(params *VerificationFuncParams) (*VerificationResults, error) {
|
||||
if params.ServerName == "" {
|
||||
@ -69,18 +330,8 @@ func (s) TestClientServerHandshake(t *testing.T) {
|
||||
verifyFuncBad := func(params *VerificationFuncParams) (*VerificationResults, error) {
|
||||
return nil, fmt.Errorf("custom verification function failed")
|
||||
}
|
||||
clientPeerCert, err := tls.LoadX509KeyPair(testdata.Path("client_cert_1.pem"),
|
||||
testdata.Path("client_key_1.pem"))
|
||||
if err != nil {
|
||||
t.Fatalf("Client is unable to parse peer certificates. Error: %v", err)
|
||||
}
|
||||
// ------------------Load Server Trust Cert and Peer Cert-------------------
|
||||
serverTrustPool, err := readTrustCert(testdata.Path("server_trust_cert_1.pem"))
|
||||
if err != nil {
|
||||
t.Fatalf("Server is unable to load trust certs. Error: %v", err)
|
||||
}
|
||||
getRootCAsForServer := func(params *GetRootCAsParams) (*GetRootCAsResults, error) {
|
||||
return &GetRootCAsResults{TrustCerts: serverTrustPool}, nil
|
||||
return &GetRootCAsResults{TrustCerts: cs.serverTrust1}, nil
|
||||
}
|
||||
serverVerifyFunc := func(params *VerificationFuncParams) (*VerificationResults, error) {
|
||||
if params.ServerName != "" {
|
||||
@ -93,11 +344,6 @@ func (s) TestClientServerHandshake(t *testing.T) {
|
||||
|
||||
return &VerificationResults{}, nil
|
||||
}
|
||||
serverPeerCert, err := tls.LoadX509KeyPair(testdata.Path("server_cert_1.pem"),
|
||||
testdata.Path("server_key_1.pem"))
|
||||
if err != nil {
|
||||
t.Fatalf("Server is unable to parse peer certificates. Error: %v", err)
|
||||
}
|
||||
getRootCAsForServerBad := func(params *GetRootCAsParams) (*GetRootCAsResults, error) {
|
||||
return nil, fmt.Errorf("bad root certificate reloading")
|
||||
}
|
||||
@ -109,7 +355,8 @@ func (s) TestClientServerHandshake(t *testing.T) {
|
||||
clientGetRoot func(params *GetRootCAsParams) (*GetRootCAsResults, error)
|
||||
clientVerifyFunc CustomVerificationFunc
|
||||
clientVType VerificationType
|
||||
clientExpectCreateError bool
|
||||
clientRootProvider certprovider.Provider
|
||||
clientIdentityProvider certprovider.Provider
|
||||
clientExpectHandshakeError bool
|
||||
serverMutualTLS bool
|
||||
serverCert []tls.Certificate
|
||||
@ -118,23 +365,10 @@ func (s) TestClientServerHandshake(t *testing.T) {
|
||||
serverGetRoot func(params *GetRootCAsParams) (*GetRootCAsResults, error)
|
||||
serverVerifyFunc CustomVerificationFunc
|
||||
serverVType VerificationType
|
||||
serverRootProvider certprovider.Provider
|
||||
serverIdentityProvider certprovider.Provider
|
||||
serverExpectError bool
|
||||
}{
|
||||
// Client: nil setting
|
||||
// Server: only set serverCert with mutual TLS off
|
||||
// Expected Behavior: server side failure
|
||||
// Reason: if clientRoot, clientGetRoot and verifyFunc is not set, client
|
||||
// side doesn't provide any verification mechanism. We don't allow this
|
||||
// even setting vType to SkipVerification. Clients should at least provide
|
||||
// their own verification logic.
|
||||
{
|
||||
desc: "Client has no trust cert; server sends peer cert",
|
||||
clientVType: SkipVerification,
|
||||
clientExpectCreateError: true,
|
||||
serverCert: []tls.Certificate{serverPeerCert},
|
||||
serverVType: CertAndHostVerification,
|
||||
serverExpectError: true,
|
||||
},
|
||||
// Client: nil setting except verifyFuncGood
|
||||
// Server: only set serverCert with mutual TLS off
|
||||
// Expected Behavior: success
|
||||
@ -144,7 +378,7 @@ func (s) TestClientServerHandshake(t *testing.T) {
|
||||
desc: "Client has no trust cert with verifyFuncGood; server sends peer cert",
|
||||
clientVerifyFunc: clientVerifyFuncGood,
|
||||
clientVType: SkipVerification,
|
||||
serverCert: []tls.Certificate{serverPeerCert},
|
||||
serverCert: []tls.Certificate{cs.serverPeer1},
|
||||
serverVType: CertAndHostVerification,
|
||||
},
|
||||
// Client: only set clientRoot
|
||||
@ -155,10 +389,10 @@ func (s) TestClientServerHandshake(t *testing.T) {
|
||||
// this test suites.
|
||||
{
|
||||
desc: "Client has root cert; server sends peer cert",
|
||||
clientRoot: clientTrustPool,
|
||||
clientRoot: cs.clientTrust1,
|
||||
clientVType: CertAndHostVerification,
|
||||
clientExpectHandshakeError: true,
|
||||
serverCert: []tls.Certificate{serverPeerCert},
|
||||
serverCert: []tls.Certificate{cs.serverPeer1},
|
||||
serverVType: CertAndHostVerification,
|
||||
serverExpectError: true,
|
||||
},
|
||||
@ -173,7 +407,7 @@ func (s) TestClientServerHandshake(t *testing.T) {
|
||||
clientGetRoot: getRootCAsForClient,
|
||||
clientVType: CertAndHostVerification,
|
||||
clientExpectHandshakeError: true,
|
||||
serverCert: []tls.Certificate{serverPeerCert},
|
||||
serverCert: []tls.Certificate{cs.serverPeer1},
|
||||
serverVType: CertAndHostVerification,
|
||||
serverExpectError: true,
|
||||
},
|
||||
@ -185,7 +419,7 @@ func (s) TestClientServerHandshake(t *testing.T) {
|
||||
clientGetRoot: getRootCAsForClient,
|
||||
clientVerifyFunc: clientVerifyFuncGood,
|
||||
clientVType: CertVerification,
|
||||
serverCert: []tls.Certificate{serverPeerCert},
|
||||
serverCert: []tls.Certificate{cs.serverPeer1},
|
||||
serverVType: CertAndHostVerification,
|
||||
},
|
||||
// Client: set clientGetRoot and bad clientVerifyFunc function
|
||||
@ -198,66 +432,35 @@ func (s) TestClientServerHandshake(t *testing.T) {
|
||||
clientVerifyFunc: verifyFuncBad,
|
||||
clientVType: CertVerification,
|
||||
clientExpectHandshakeError: true,
|
||||
serverCert: []tls.Certificate{serverPeerCert},
|
||||
serverCert: []tls.Certificate{cs.serverPeer1},
|
||||
serverVType: CertVerification,
|
||||
serverExpectError: true,
|
||||
},
|
||||
// Client: set clientGetRoot and clientVerifyFunc
|
||||
// Server: nil setting
|
||||
// Expected Behavior: server side failure
|
||||
// Reason: server side must either set serverCert or serverGetCert
|
||||
{
|
||||
desc: "Client sets reload root function with verifyFuncGood; server sets nil",
|
||||
clientGetRoot: getRootCAsForClient,
|
||||
clientVerifyFunc: clientVerifyFuncGood,
|
||||
clientVType: CertVerification,
|
||||
serverVType: CertVerification,
|
||||
serverExpectError: true,
|
||||
},
|
||||
// Client: set clientGetRoot, clientVerifyFunc and clientCert
|
||||
// Server: set serverRoot and serverCert with mutual TLS on
|
||||
// Expected Behavior: success
|
||||
{
|
||||
desc: "Client sets peer cert, reload root function with verifyFuncGood; server sets peer cert and root cert; mutualTLS",
|
||||
clientCert: []tls.Certificate{clientPeerCert},
|
||||
clientCert: []tls.Certificate{cs.clientPeer1},
|
||||
clientGetRoot: getRootCAsForClient,
|
||||
clientVerifyFunc: clientVerifyFuncGood,
|
||||
clientVType: CertVerification,
|
||||
serverMutualTLS: true,
|
||||
serverCert: []tls.Certificate{serverPeerCert},
|
||||
serverRoot: serverTrustPool,
|
||||
serverCert: []tls.Certificate{cs.serverPeer1},
|
||||
serverRoot: cs.serverTrust1,
|
||||
serverVType: CertVerification,
|
||||
},
|
||||
// Client: set clientGetRoot, clientVerifyFunc and clientCert
|
||||
// Server: set serverCert, but not setting any of serverRoot, serverGetRoot
|
||||
// or serverVerifyFunc, with mutual TLS on
|
||||
// Expected Behavior: server side failure
|
||||
// Reason: server side needs to provide any verification mechanism when
|
||||
// mTLS in on, even setting vType to SkipVerification. Servers should at
|
||||
// least provide their own verification logic.
|
||||
{
|
||||
desc: "Client sets peer cert, reload root function with verifyFuncGood; server sets no verification; mutualTLS",
|
||||
clientCert: []tls.Certificate{clientPeerCert},
|
||||
clientGetRoot: getRootCAsForClient,
|
||||
clientVerifyFunc: clientVerifyFuncGood,
|
||||
clientVType: CertVerification,
|
||||
clientExpectHandshakeError: true,
|
||||
serverMutualTLS: true,
|
||||
serverCert: []tls.Certificate{serverPeerCert},
|
||||
serverVType: SkipVerification,
|
||||
serverExpectError: true,
|
||||
},
|
||||
// Client: set clientGetRoot, clientVerifyFunc and clientCert
|
||||
// Server: set serverGetRoot and serverCert with mutual TLS on
|
||||
// Expected Behavior: success
|
||||
{
|
||||
desc: "Client sets peer cert, reload root function with verifyFuncGood; Server sets peer cert, reload root function; mutualTLS",
|
||||
clientCert: []tls.Certificate{clientPeerCert},
|
||||
clientCert: []tls.Certificate{cs.clientPeer1},
|
||||
clientGetRoot: getRootCAsForClient,
|
||||
clientVerifyFunc: clientVerifyFuncGood,
|
||||
clientVType: CertVerification,
|
||||
serverMutualTLS: true,
|
||||
serverCert: []tls.Certificate{serverPeerCert},
|
||||
serverCert: []tls.Certificate{cs.serverPeer1},
|
||||
serverGetRoot: getRootCAsForServer,
|
||||
serverVType: CertVerification,
|
||||
},
|
||||
@ -268,12 +471,12 @@ func (s) TestClientServerHandshake(t *testing.T) {
|
||||
// Reason: server side reloading returns failure
|
||||
{
|
||||
desc: "Client sets peer cert, reload root function with verifyFuncGood; Server sets peer cert, bad reload root function; mutualTLS",
|
||||
clientCert: []tls.Certificate{clientPeerCert},
|
||||
clientCert: []tls.Certificate{cs.clientPeer1},
|
||||
clientGetRoot: getRootCAsForClient,
|
||||
clientVerifyFunc: clientVerifyFuncGood,
|
||||
clientVType: CertVerification,
|
||||
serverMutualTLS: true,
|
||||
serverCert: []tls.Certificate{serverPeerCert},
|
||||
serverCert: []tls.Certificate{cs.serverPeer1},
|
||||
serverGetRoot: getRootCAsForServerBad,
|
||||
serverVType: CertVerification,
|
||||
serverExpectError: true,
|
||||
@ -284,14 +487,14 @@ func (s) TestClientServerHandshake(t *testing.T) {
|
||||
{
|
||||
desc: "Client sets reload peer/root function with verifyFuncGood; Server sets reload peer/root function with verifyFuncGood; mutualTLS",
|
||||
clientGetCert: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
||||
return &clientPeerCert, nil
|
||||
return &cs.clientPeer1, nil
|
||||
},
|
||||
clientGetRoot: getRootCAsForClient,
|
||||
clientVerifyFunc: clientVerifyFuncGood,
|
||||
clientVType: CertVerification,
|
||||
serverMutualTLS: true,
|
||||
serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
|
||||
return []*tls.Certificate{&serverPeerCert}, nil
|
||||
return []*tls.Certificate{&cs.serverPeer1}, nil
|
||||
},
|
||||
serverGetRoot: getRootCAsForServer,
|
||||
serverVerifyFunc: serverVerifyFunc,
|
||||
@ -305,14 +508,14 @@ func (s) TestClientServerHandshake(t *testing.T) {
|
||||
{
|
||||
desc: "Client sends wrong peer cert; Server sets reload peer/root function with verifyFuncGood; mutualTLS",
|
||||
clientGetCert: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
||||
return &serverPeerCert, nil
|
||||
return &cs.serverPeer1, nil
|
||||
},
|
||||
clientGetRoot: getRootCAsForClient,
|
||||
clientVerifyFunc: clientVerifyFuncGood,
|
||||
clientVType: CertVerification,
|
||||
serverMutualTLS: true,
|
||||
serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
|
||||
return []*tls.Certificate{&serverPeerCert}, nil
|
||||
return []*tls.Certificate{&cs.serverPeer1}, nil
|
||||
},
|
||||
serverGetRoot: getRootCAsForServer,
|
||||
serverVerifyFunc: serverVerifyFunc,
|
||||
@ -326,7 +529,7 @@ func (s) TestClientServerHandshake(t *testing.T) {
|
||||
{
|
||||
desc: "Client has wrong trust cert; Server sets reload peer/root function with verifyFuncGood; mutualTLS",
|
||||
clientGetCert: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
||||
return &clientPeerCert, nil
|
||||
return &cs.clientPeer1, nil
|
||||
},
|
||||
clientGetRoot: getRootCAsForServer,
|
||||
clientVerifyFunc: clientVerifyFuncGood,
|
||||
@ -334,7 +537,7 @@ func (s) TestClientServerHandshake(t *testing.T) {
|
||||
clientExpectHandshakeError: true,
|
||||
serverMutualTLS: true,
|
||||
serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
|
||||
return []*tls.Certificate{&serverPeerCert}, nil
|
||||
return []*tls.Certificate{&cs.serverPeer1}, nil
|
||||
},
|
||||
serverGetRoot: getRootCAsForServer,
|
||||
serverVerifyFunc: serverVerifyFunc,
|
||||
@ -349,14 +552,14 @@ func (s) TestClientServerHandshake(t *testing.T) {
|
||||
{
|
||||
desc: "Client sets reload peer/root function with verifyFuncGood; Server sends wrong peer cert; mutualTLS",
|
||||
clientGetCert: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
||||
return &clientPeerCert, nil
|
||||
return &cs.clientPeer1, nil
|
||||
},
|
||||
clientGetRoot: getRootCAsForClient,
|
||||
clientVerifyFunc: clientVerifyFuncGood,
|
||||
clientVType: CertVerification,
|
||||
serverMutualTLS: true,
|
||||
serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
|
||||
return []*tls.Certificate{&clientPeerCert}, nil
|
||||
return []*tls.Certificate{&cs.clientPeer1}, nil
|
||||
},
|
||||
serverGetRoot: getRootCAsForServer,
|
||||
serverVerifyFunc: serverVerifyFunc,
|
||||
@ -370,7 +573,7 @@ func (s) TestClientServerHandshake(t *testing.T) {
|
||||
{
|
||||
desc: "Client sets reload peer/root function with verifyFuncGood; Server has wrong trust cert; mutualTLS",
|
||||
clientGetCert: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
||||
return &clientPeerCert, nil
|
||||
return &cs.clientPeer1, nil
|
||||
},
|
||||
clientGetRoot: getRootCAsForClient,
|
||||
clientVerifyFunc: clientVerifyFuncGood,
|
||||
@ -378,7 +581,7 @@ func (s) TestClientServerHandshake(t *testing.T) {
|
||||
clientExpectHandshakeError: true,
|
||||
serverMutualTLS: true,
|
||||
serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
|
||||
return []*tls.Certificate{&serverPeerCert}, nil
|
||||
return []*tls.Certificate{&cs.serverPeer1}, nil
|
||||
},
|
||||
serverGetRoot: getRootCAsForClient,
|
||||
serverVerifyFunc: serverVerifyFunc,
|
||||
@ -391,18 +594,92 @@ func (s) TestClientServerHandshake(t *testing.T) {
|
||||
// server custom check fails
|
||||
{
|
||||
desc: "Client sets peer cert, reload root function with verifyFuncGood; Server sets bad custom check; mutualTLS",
|
||||
clientCert: []tls.Certificate{clientPeerCert},
|
||||
clientCert: []tls.Certificate{cs.clientPeer1},
|
||||
clientGetRoot: getRootCAsForClient,
|
||||
clientVerifyFunc: clientVerifyFuncGood,
|
||||
clientVType: CertVerification,
|
||||
clientExpectHandshakeError: true,
|
||||
serverMutualTLS: true,
|
||||
serverCert: []tls.Certificate{serverPeerCert},
|
||||
serverCert: []tls.Certificate{cs.serverPeer1},
|
||||
serverGetRoot: getRootCAsForServer,
|
||||
serverVerifyFunc: verifyFuncBad,
|
||||
serverVType: CertVerification,
|
||||
serverExpectError: true,
|
||||
},
|
||||
// Client: set a clientIdentityProvider which will get multiple cert chains
|
||||
// Server: set serverIdentityProvider and serverRootProvider with mutual TLS on
|
||||
// Expected Behavior: server side failure due to multiple cert chains in
|
||||
// clientIdentityProvider
|
||||
{
|
||||
desc: "Client sets multiple certs in clientIdentityProvider; Server sets root and identity provider; mutualTLS",
|
||||
clientIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: true, wantMultiCert: true},
|
||||
clientRootProvider: fakeProvider{isClient: true},
|
||||
clientVerifyFunc: clientVerifyFuncGood,
|
||||
clientVType: CertVerification,
|
||||
serverMutualTLS: true,
|
||||
serverIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: false},
|
||||
serverRootProvider: fakeProvider{isClient: false},
|
||||
serverVType: CertVerification,
|
||||
serverExpectError: true,
|
||||
},
|
||||
// Client: set a bad clientIdentityProvider
|
||||
// Server: set serverIdentityProvider and serverRootProvider with mutual TLS on
|
||||
// Expected Behavior: server side failure due to bad clientIdentityProvider
|
||||
{
|
||||
desc: "Client sets bad clientIdentityProvider; Server sets root and identity provider; mutualTLS",
|
||||
clientIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: true, wantError: true},
|
||||
clientRootProvider: fakeProvider{isClient: true},
|
||||
clientVerifyFunc: clientVerifyFuncGood,
|
||||
clientVType: CertVerification,
|
||||
serverMutualTLS: true,
|
||||
serverIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: false},
|
||||
serverRootProvider: fakeProvider{isClient: false},
|
||||
serverVType: CertVerification,
|
||||
serverExpectError: true,
|
||||
},
|
||||
// Client: set clientIdentityProvider and clientRootProvider
|
||||
// Server: set bad serverRootProvider with mutual TLS on
|
||||
// Expected Behavior: server side failure due to bad serverRootProvider
|
||||
{
|
||||
desc: "Client sets root and identity provider; Server sets bad root provider; mutualTLS",
|
||||
clientIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: true},
|
||||
clientRootProvider: fakeProvider{isClient: true},
|
||||
clientVerifyFunc: clientVerifyFuncGood,
|
||||
clientVType: CertVerification,
|
||||
serverMutualTLS: true,
|
||||
serverIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: false},
|
||||
serverRootProvider: fakeProvider{isClient: false, wantError: true},
|
||||
serverVType: CertVerification,
|
||||
serverExpectError: true,
|
||||
},
|
||||
// Client: set clientIdentityProvider and clientRootProvider
|
||||
// Server: set serverIdentityProvider and serverRootProvider with mutual TLS on
|
||||
// Expected Behavior: success
|
||||
{
|
||||
desc: "Client sets root and identity provider; Server sets root and identity provider; mutualTLS",
|
||||
clientIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: true},
|
||||
clientRootProvider: fakeProvider{isClient: true},
|
||||
clientVerifyFunc: clientVerifyFuncGood,
|
||||
clientVType: CertVerification,
|
||||
serverMutualTLS: true,
|
||||
serverIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: false},
|
||||
serverRootProvider: fakeProvider{isClient: false},
|
||||
serverVType: CertVerification,
|
||||
},
|
||||
// Client: set clientIdentityProvider and clientRootProvider
|
||||
// Server: set serverIdentityProvider getting multiple cert chains and serverRootProvider with mutual TLS on
|
||||
// Expected Behavior: success, because server side has SNI
|
||||
{
|
||||
desc: "Client sets root and identity provider; Server sets multiple certs in serverIdentityProvider; mutualTLS",
|
||||
clientIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: true},
|
||||
clientRootProvider: fakeProvider{isClient: true},
|
||||
clientVerifyFunc: clientVerifyFuncGood,
|
||||
clientVType: CertVerification,
|
||||
serverMutualTLS: true,
|
||||
serverIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: false, wantMultiCert: true},
|
||||
serverRootProvider: fakeProvider{isClient: false},
|
||||
serverVType: CertVerification,
|
||||
},
|
||||
} {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
@ -413,11 +690,15 @@ func (s) TestClientServerHandshake(t *testing.T) {
|
||||
}
|
||||
// Start a server using ServerOptions in another goroutine.
|
||||
serverOptions := &ServerOptions{
|
||||
Certificates: test.serverCert,
|
||||
GetCertificates: test.serverGetCert,
|
||||
RootCertificateOptions: RootCertificateOptions{
|
||||
RootCACerts: test.serverRoot,
|
||||
GetRootCAs: test.serverGetRoot,
|
||||
IdentityOptions: IdentityCertificateOptions{
|
||||
Certificates: test.serverCert,
|
||||
GetIdentityCertificatesForServer: test.serverGetCert,
|
||||
IdentityProvider: test.serverIdentityProvider,
|
||||
},
|
||||
RootOptions: RootCertificateOptions{
|
||||
RootCACerts: test.serverRoot,
|
||||
GetRootCertificates: test.serverGetRoot,
|
||||
RootProvider: test.serverRootProvider,
|
||||
},
|
||||
RequireClientCert: test.serverMutualTLS,
|
||||
VerifyPeer: test.serverVerifyFunc,
|
||||
@ -452,23 +733,22 @@ func (s) TestClientServerHandshake(t *testing.T) {
|
||||
}
|
||||
defer conn.Close()
|
||||
clientOptions := &ClientOptions{
|
||||
Certificates: test.clientCert,
|
||||
GetClientCertificate: test.clientGetCert,
|
||||
VerifyPeer: test.clientVerifyFunc,
|
||||
RootCertificateOptions: RootCertificateOptions{
|
||||
RootCACerts: test.clientRoot,
|
||||
GetRootCAs: test.clientGetRoot,
|
||||
IdentityOptions: IdentityCertificateOptions{
|
||||
Certificates: test.clientCert,
|
||||
GetIdentityCertificatesForClient: test.clientGetCert,
|
||||
IdentityProvider: test.clientIdentityProvider,
|
||||
},
|
||||
VerifyPeer: test.clientVerifyFunc,
|
||||
RootOptions: RootCertificateOptions{
|
||||
RootCACerts: test.clientRoot,
|
||||
GetRootCertificates: test.clientGetRoot,
|
||||
RootProvider: test.clientRootProvider,
|
||||
},
|
||||
VType: test.clientVType,
|
||||
}
|
||||
clientTLS, newClientErr := NewClientCreds(clientOptions)
|
||||
if newClientErr != nil && test.clientExpectCreateError {
|
||||
return
|
||||
}
|
||||
if newClientErr != nil && !test.clientExpectCreateError ||
|
||||
newClientErr == nil && test.clientExpectCreateError {
|
||||
t.Fatalf("Expect error: %v, but err is %v",
|
||||
test.clientExpectCreateError, newClientErr)
|
||||
clientTLS, err := NewClientCreds(clientOptions)
|
||||
if err != nil {
|
||||
t.Fatalf("NewClientCreds failed: %v", err)
|
||||
}
|
||||
_, clientAuthInfo, handshakeErr := clientTLS.ClientHandshake(context.Background(),
|
||||
lisAddr, conn)
|
||||
@ -541,7 +821,7 @@ func (s) TestAdvancedTLSOverrideServerName(t *testing.T) {
|
||||
t.Fatalf("Client is unable to load trust certs. Error: %v", err)
|
||||
}
|
||||
clientOptions := &ClientOptions{
|
||||
RootCertificateOptions: RootCertificateOptions{
|
||||
RootOptions: RootCertificateOptions{
|
||||
RootCACerts: clientTrustPool,
|
||||
},
|
||||
ServerNameOverride: expectedServerName,
|
||||
@ -563,7 +843,7 @@ func (s) TestTLSClone(t *testing.T) {
|
||||
t.Fatalf("Client is unable to load trust certs. Error: %v", err)
|
||||
}
|
||||
clientOptions := &ClientOptions{
|
||||
RootCertificateOptions: RootCertificateOptions{
|
||||
RootOptions: RootCertificateOptions{
|
||||
RootCACerts: clientTrustPool,
|
||||
},
|
||||
ServerNameOverride: expectedServerName,
|
||||
@ -635,62 +915,6 @@ func (s) TestWrapSyscallConn(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func (s) TestOptionsConfig(t *testing.T) {
|
||||
serverPeerCert, err := tls.LoadX509KeyPair(testdata.Path("server_cert_1.pem"),
|
||||
testdata.Path("server_key_1.pem"))
|
||||
if err != nil {
|
||||
t.Fatalf("Server is unable to parse peer certificates. Error: %v", err)
|
||||
}
|
||||
tests := []struct {
|
||||
desc string
|
||||
clientVType VerificationType
|
||||
serverMutualTLS bool
|
||||
serverCert []tls.Certificate
|
||||
serverVType VerificationType
|
||||
}{
|
||||
{
|
||||
desc: "Client uses system-provided RootCAs; server uses system-provided ClientCAs",
|
||||
clientVType: CertVerification,
|
||||
serverMutualTLS: true,
|
||||
serverCert: []tls.Certificate{serverPeerCert},
|
||||
serverVType: CertAndHostVerification,
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
serverOptions := &ServerOptions{
|
||||
Certificates: test.serverCert,
|
||||
RequireClientCert: test.serverMutualTLS,
|
||||
VType: test.serverVType,
|
||||
}
|
||||
serverConfig, err := serverOptions.config()
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to generate serverConfig. Error: %v", err)
|
||||
}
|
||||
// Verify that the system-provided certificates would be used
|
||||
// when no verification method was set in serverOptions.
|
||||
if serverOptions.RootCACerts == nil && serverOptions.GetRootCAs == nil &&
|
||||
serverOptions.RequireClientCert && serverConfig.ClientCAs == nil {
|
||||
t.Fatalf("Failed to assign system-provided certificates on the server side.")
|
||||
}
|
||||
clientOptions := &ClientOptions{
|
||||
VType: test.clientVType,
|
||||
}
|
||||
clientConfig, err := clientOptions.config()
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to generate clientConfig. Error: %v", err)
|
||||
}
|
||||
// Verify that the system-provided certificates would be used
|
||||
// when no verification method was set in clientOptions.
|
||||
if clientOptions.RootCACerts == nil && clientOptions.GetRootCAs == nil &&
|
||||
clientConfig.RootCAs == nil {
|
||||
t.Fatalf("Failed to assign system-provided certificates on the client side.")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s) TestGetCertificatesSNI(t *testing.T) {
|
||||
// Load server certificates for setting the serverGetCert callback function.
|
||||
serverCert1, err := tls.LoadX509KeyPair(testdata.Path("server_cert_1.pem"), testdata.Path("server_key_1.pem"))
|
||||
@ -734,8 +958,10 @@ func (s) TestGetCertificatesSNI(t *testing.T) {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
serverOptions := &ServerOptions{
|
||||
GetCertificates: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
|
||||
return []*tls.Certificate{&serverCert1, &serverCert2, &serverCert3}, nil
|
||||
IdentityOptions: IdentityCertificateOptions{
|
||||
GetIdentityCertificatesForServer: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
|
||||
return []*tls.Certificate{&serverCert1, &serverCert2, &serverCert3}, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
serverConfig, err := serverOptions.config()
|
||||
|
Reference in New Issue
Block a user