advancedtls: clean up test files and shared code (#3897)
* advancedtls: clean up test files and shared code
This commit is contained in:
@ -28,7 +28,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"reflect"
|
"reflect"
|
||||||
"syscall"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"google.golang.org/grpc/credentials"
|
"google.golang.org/grpc/credentials"
|
||||||
@ -374,7 +373,7 @@ func (c advancedTLSCreds) Info() credentials.ProtocolInfo {
|
|||||||
|
|
||||||
func (c *advancedTLSCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
func (c *advancedTLSCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
||||||
// Use local cfg to avoid clobbering ServerName if using multiple endpoints.
|
// Use local cfg to avoid clobbering ServerName if using multiple endpoints.
|
||||||
cfg := cloneTLSConfig(c.config)
|
cfg := credinternal.CloneTLSConfig(c.config)
|
||||||
// We return the full authority name to users if ServerName is empty without
|
// We return the full authority name to users if ServerName is empty without
|
||||||
// stripping the trailing port.
|
// stripping the trailing port.
|
||||||
if cfg.ServerName == "" {
|
if cfg.ServerName == "" {
|
||||||
@ -404,11 +403,11 @@ func (c *advancedTLSCreds) ClientHandshake(ctx context.Context, authority string
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
info.SPIFFEID = credinternal.SPIFFEIDFromState(conn.ConnectionState())
|
info.SPIFFEID = credinternal.SPIFFEIDFromState(conn.ConnectionState())
|
||||||
return WrapSyscallConn(rawConn, conn), info, nil
|
return credinternal.WrapSyscallConn(rawConn, conn), info, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *advancedTLSCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
func (c *advancedTLSCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
||||||
cfg := cloneTLSConfig(c.config)
|
cfg := credinternal.CloneTLSConfig(c.config)
|
||||||
cfg.VerifyPeerCertificate = buildVerifyFunc(c, "", rawConn)
|
cfg.VerifyPeerCertificate = buildVerifyFunc(c, "", rawConn)
|
||||||
conn := tls.Server(rawConn, cfg)
|
conn := tls.Server(rawConn, cfg)
|
||||||
if err := conn.Handshake(); err != nil {
|
if err := conn.Handshake(); err != nil {
|
||||||
@ -422,12 +421,12 @@ func (c *advancedTLSCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credenti
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
info.SPIFFEID = credinternal.SPIFFEIDFromState(conn.ConnectionState())
|
info.SPIFFEID = credinternal.SPIFFEIDFromState(conn.ConnectionState())
|
||||||
return WrapSyscallConn(rawConn, conn), info, nil
|
return credinternal.WrapSyscallConn(rawConn, conn), info, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *advancedTLSCreds) Clone() credentials.TransportCredentials {
|
func (c *advancedTLSCreds) Clone() credentials.TransportCredentials {
|
||||||
return &advancedTLSCreds{
|
return &advancedTLSCreds{
|
||||||
config: cloneTLSConfig(c.config),
|
config: credinternal.CloneTLSConfig(c.config),
|
||||||
verifyFunc: c.verifyFunc,
|
verifyFunc: c.verifyFunc,
|
||||||
getRootCAs: c.getRootCAs,
|
getRootCAs: c.getRootCAs,
|
||||||
isClient: c.isClient,
|
isClient: c.isClient,
|
||||||
@ -530,7 +529,7 @@ func NewClientCreds(o *ClientOptions) (credentials.TransportCredentials, error)
|
|||||||
verifyFunc: o.VerifyPeer,
|
verifyFunc: o.VerifyPeer,
|
||||||
vType: o.VType,
|
vType: o.VType,
|
||||||
}
|
}
|
||||||
tc.config.NextProtos = appendH2ToNextProtos(tc.config.NextProtos)
|
tc.config.NextProtos = credinternal.AppendH2ToNextProtos(tc.config.NextProtos)
|
||||||
return tc, nil
|
return tc, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -548,64 +547,6 @@ func NewServerCreds(o *ServerOptions) (credentials.TransportCredentials, error)
|
|||||||
verifyFunc: o.VerifyPeer,
|
verifyFunc: o.VerifyPeer,
|
||||||
vType: o.VType,
|
vType: o.VType,
|
||||||
}
|
}
|
||||||
tc.config.NextProtos = appendH2ToNextProtos(tc.config.NextProtos)
|
tc.config.NextProtos = credinternal.AppendH2ToNextProtos(tc.config.NextProtos)
|
||||||
return tc, nil
|
return tc, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(ZhenLian): The code below are duplicates with gRPC-Go under
|
|
||||||
// credentials/internal. Consider refactoring in the future.
|
|
||||||
const alpnProtoStrH2 = "h2"
|
|
||||||
|
|
||||||
func appendH2ToNextProtos(ps []string) []string {
|
|
||||||
for _, p := range ps {
|
|
||||||
if p == alpnProtoStrH2 {
|
|
||||||
return ps
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ret := make([]string, 0, len(ps)+1)
|
|
||||||
ret = append(ret, ps...)
|
|
||||||
return append(ret, alpnProtoStrH2)
|
|
||||||
}
|
|
||||||
|
|
||||||
// We give syscall.Conn a new name here since syscall.Conn and net.Conn used
|
|
||||||
// below have the same names.
|
|
||||||
type sysConn = syscall.Conn
|
|
||||||
|
|
||||||
// syscallConn keeps reference of rawConn to support syscall.Conn for channelz.
|
|
||||||
// SyscallConn() (the method in interface syscall.Conn) is explicitly
|
|
||||||
// implemented on this type,
|
|
||||||
//
|
|
||||||
// Interface syscall.Conn is implemented by most net.Conn implementations (e.g.
|
|
||||||
// TCPConn, UnixConn), but is not part of net.Conn interface. So wrapper conns
|
|
||||||
// that embed net.Conn don't implement syscall.Conn. (Side note: tls.Conn
|
|
||||||
// doesn't embed net.Conn, so even if syscall.Conn is part of net.Conn, it won't
|
|
||||||
// help here).
|
|
||||||
type syscallConn struct {
|
|
||||||
net.Conn
|
|
||||||
// sysConn is a type alias of syscall.Conn. It's necessary because the name
|
|
||||||
// `Conn` collides with `net.Conn`.
|
|
||||||
sysConn
|
|
||||||
}
|
|
||||||
|
|
||||||
// WrapSyscallConn tries to wrap rawConn and newConn into a net.Conn that
|
|
||||||
// implements syscall.Conn. rawConn will be used to support syscall, and newConn
|
|
||||||
// will be used for read/write.
|
|
||||||
//
|
|
||||||
// This function returns newConn if rawConn doesn't implement syscall.Conn.
|
|
||||||
func WrapSyscallConn(rawConn, newConn net.Conn) net.Conn {
|
|
||||||
sysConn, ok := rawConn.(syscall.Conn)
|
|
||||||
if !ok {
|
|
||||||
return newConn
|
|
||||||
}
|
|
||||||
return &syscallConn{
|
|
||||||
Conn: newConn,
|
|
||||||
sysConn: sysConn,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
|
|
||||||
if cfg == nil {
|
|
||||||
return &tls.Config{}
|
|
||||||
}
|
|
||||||
return cfg.Clone()
|
|
||||||
}
|
|
||||||
|
@ -31,7 +31,7 @@ import (
|
|||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/credentials"
|
"google.golang.org/grpc/credentials"
|
||||||
pb "google.golang.org/grpc/examples/helloworld/helloworld"
|
pb "google.golang.org/grpc/examples/helloworld/helloworld"
|
||||||
"google.golang.org/grpc/security/advancedtls/testdata"
|
"google.golang.org/grpc/security/advancedtls/internal/testutils"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -67,69 +67,6 @@ func (s *stageInfo) reset() {
|
|||||||
s.stage = 0
|
s.stage = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
// certStore contains all the certificates used in the integration tests.
|
|
||||||
type certStore struct {
|
|
||||||
// clientPeer1 is the certificate sent by client to prove its identity.
|
|
||||||
// It is trusted by serverTrust1.
|
|
||||||
clientPeer1 tls.Certificate
|
|
||||||
// clientPeer2 is the certificate sent by client to prove its identity.
|
|
||||||
// It is trusted by serverTrust2.
|
|
||||||
clientPeer2 tls.Certificate
|
|
||||||
// serverPeer1 is the certificate sent by server to prove its identity.
|
|
||||||
// It is trusted by clientTrust1.
|
|
||||||
serverPeer1 tls.Certificate
|
|
||||||
// serverPeer2 is the certificate sent by server to prove its identity.
|
|
||||||
// It is trusted by clientTrust2.
|
|
||||||
serverPeer2 tls.Certificate
|
|
||||||
clientTrust1 *x509.CertPool
|
|
||||||
clientTrust2 *x509.CertPool
|
|
||||||
serverTrust1 *x509.CertPool
|
|
||||||
serverTrust2 *x509.CertPool
|
|
||||||
}
|
|
||||||
|
|
||||||
// loadCerts function is used to load test certificates at the beginning of
|
|
||||||
// each integration test.
|
|
||||||
func (cs *certStore) loadCerts() error {
|
|
||||||
var err error
|
|
||||||
cs.clientPeer1, err = tls.LoadX509KeyPair(testdata.Path("client_cert_1.pem"),
|
|
||||||
testdata.Path("client_key_1.pem"))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
cs.clientPeer2, err = tls.LoadX509KeyPair(testdata.Path("client_cert_2.pem"),
|
|
||||||
testdata.Path("client_key_2.pem"))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
cs.serverPeer1, err = tls.LoadX509KeyPair(testdata.Path("server_cert_1.pem"),
|
|
||||||
testdata.Path("server_key_1.pem"))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
cs.serverPeer2, err = tls.LoadX509KeyPair(testdata.Path("server_cert_2.pem"),
|
|
||||||
testdata.Path("server_key_2.pem"))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
cs.clientTrust1, err = readTrustCert(testdata.Path("client_trust_cert_1.pem"))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
cs.clientTrust2, err = readTrustCert(testdata.Path("client_trust_cert_2.pem"))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
cs.serverTrust1, err = readTrustCert(testdata.Path("server_trust_cert_1.pem"))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
cs.serverTrust2, err = readTrustCert(testdata.Path("server_trust_cert_2.pem"))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type greeterServer struct {
|
type greeterServer struct {
|
||||||
pb.UnimplementedGreeterServer
|
pb.UnimplementedGreeterServer
|
||||||
}
|
}
|
||||||
@ -183,10 +120,9 @@ func callAndVerifyWithClientConn(connCtx context.Context, msg string, creds cred
|
|||||||
// (could be change the client's trust certificate, or change custom
|
// (could be change the client's trust certificate, or change custom
|
||||||
// verification function, etc)
|
// verification function, etc)
|
||||||
func (s) TestEnd2End(t *testing.T) {
|
func (s) TestEnd2End(t *testing.T) {
|
||||||
cs := &certStore{}
|
cs := &testutils.CertStore{}
|
||||||
err := cs.loadCerts()
|
if err := cs.LoadCerts(); err != nil {
|
||||||
if err != nil {
|
t.Fatalf("cs.LoadCerts() failed, err: %v", err)
|
||||||
t.Fatalf("failed to load certs: %v", err)
|
|
||||||
}
|
}
|
||||||
stage := &stageInfo{}
|
stage := &stageInfo{}
|
||||||
for _, test := range []struct {
|
for _, test := range []struct {
|
||||||
@ -206,38 +142,38 @@ func (s) TestEnd2End(t *testing.T) {
|
|||||||
}{
|
}{
|
||||||
// Test Scenarios:
|
// Test Scenarios:
|
||||||
// At initialization(stage = 0), client will be initialized with cert
|
// At initialization(stage = 0), client will be initialized with cert
|
||||||
// clientPeer1 and clientTrust1, server with serverPeer1 and serverTrust1.
|
// ClientCert1 and ClientTrust1, server with ServerCert1 and ServerTrust1.
|
||||||
// The mutual authentication works at the beginning, since clientPeer1 is
|
// The mutual authentication works at the beginning, since ClientCert1 is
|
||||||
// trusted by serverTrust1, and serverPeer1 by clientTrust1.
|
// trusted by ServerTrust1, and ServerCert1 by ClientTrust1.
|
||||||
// At stage 1, client changes clientPeer1 to clientPeer2. Since clientPeer2
|
// At stage 1, client changes ClientCert1 to ClientCert2. Since ClientCert2
|
||||||
// is not trusted by serverTrust1, following rpc calls are expected to
|
// is not trusted by ServerTrust1, following rpc calls are expected to
|
||||||
// fail, while the previous rpc calls are still good because those are
|
// fail, while the previous rpc calls are still good because those are
|
||||||
// already authenticated.
|
// already authenticated.
|
||||||
// At stage 2, the server changes serverTrust1 to serverTrust2, and we
|
// At stage 2, the server changes ServerTrust1 to ServerTrust2, and we
|
||||||
// should see it again accepts the connection, since clientPeer2 is trusted
|
// should see it again accepts the connection, since ClientCert2 is trusted
|
||||||
// by serverTrust2.
|
// by ServerTrust2.
|
||||||
{
|
{
|
||||||
desc: "TestClientPeerCertReloadServerTrustCertReload",
|
desc: "TestClientPeerCertReloadServerTrustCertReload",
|
||||||
clientGetCert: func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
clientGetCert: func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
||||||
switch stage.read() {
|
switch stage.read() {
|
||||||
case 0:
|
case 0:
|
||||||
return &cs.clientPeer1, nil
|
return &cs.ClientCert1, nil
|
||||||
default:
|
default:
|
||||||
return &cs.clientPeer2, nil
|
return &cs.ClientCert2, nil
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
clientRoot: cs.clientTrust1,
|
clientRoot: cs.ClientTrust1,
|
||||||
clientVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
|
clientVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
|
||||||
return &VerificationResults{}, nil
|
return &VerificationResults{}, nil
|
||||||
},
|
},
|
||||||
clientVType: CertVerification,
|
clientVType: CertVerification,
|
||||||
serverCert: []tls.Certificate{cs.serverPeer1},
|
serverCert: []tls.Certificate{cs.ServerCert1},
|
||||||
serverGetRoot: func(params *GetRootCAsParams) (*GetRootCAsResults, error) {
|
serverGetRoot: func(params *GetRootCAsParams) (*GetRootCAsResults, error) {
|
||||||
switch stage.read() {
|
switch stage.read() {
|
||||||
case 0, 1:
|
case 0, 1:
|
||||||
return &GetRootCAsResults{TrustCerts: cs.serverTrust1}, nil
|
return &GetRootCAsResults{TrustCerts: cs.ServerTrust1}, nil
|
||||||
default:
|
default:
|
||||||
return &GetRootCAsResults{TrustCerts: cs.serverTrust2}, nil
|
return &GetRootCAsResults{TrustCerts: cs.ServerTrust2}, nil
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
serverVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
|
serverVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
|
||||||
@ -247,25 +183,25 @@ func (s) TestEnd2End(t *testing.T) {
|
|||||||
},
|
},
|
||||||
// Test Scenarios:
|
// Test Scenarios:
|
||||||
// At initialization(stage = 0), client will be initialized with cert
|
// At initialization(stage = 0), client will be initialized with cert
|
||||||
// clientPeer1 and clientTrust1, server with serverPeer1 and serverTrust1.
|
// ClientCert1 and ClientTrust1, server with ServerCert1 and ServerTrust1.
|
||||||
// The mutual authentication works at the beginning, since clientPeer1 is
|
// The mutual authentication works at the beginning, since ClientCert1 is
|
||||||
// trusted by serverTrust1, and serverPeer1 by clientTrust1.
|
// trusted by ServerTrust1, and ServerCert1 by ClientTrust1.
|
||||||
// At stage 1, server changes serverPeer1 to serverPeer2. Since serverPeer2
|
// At stage 1, server changes ServerCert1 to ServerCert2. Since ServerCert2
|
||||||
// is not trusted by clientTrust1, following rpc calls are expected to
|
// is not trusted by ClientTrust1, following rpc calls are expected to
|
||||||
// fail, while the previous rpc calls are still good because those are
|
// fail, while the previous rpc calls are still good because those are
|
||||||
// already authenticated.
|
// already authenticated.
|
||||||
// At stage 2, the client changes clientTrust1 to clientTrust2, and we
|
// At stage 2, the client changes ClientTrust1 to ClientTrust2, and we
|
||||||
// should see it again accepts the connection, since serverPeer2 is trusted
|
// should see it again accepts the connection, since ServerCert2 is trusted
|
||||||
// by clientTrust2.
|
// by ClientTrust2.
|
||||||
{
|
{
|
||||||
desc: "TestServerPeerCertReloadClientTrustCertReload",
|
desc: "TestServerPeerCertReloadClientTrustCertReload",
|
||||||
clientCert: []tls.Certificate{cs.clientPeer1},
|
clientCert: []tls.Certificate{cs.ClientCert1},
|
||||||
clientGetRoot: func(params *GetRootCAsParams) (*GetRootCAsResults, error) {
|
clientGetRoot: func(params *GetRootCAsParams) (*GetRootCAsResults, error) {
|
||||||
switch stage.read() {
|
switch stage.read() {
|
||||||
case 0, 1:
|
case 0, 1:
|
||||||
return &GetRootCAsResults{TrustCerts: cs.clientTrust1}, nil
|
return &GetRootCAsResults{TrustCerts: cs.ClientTrust1}, nil
|
||||||
default:
|
default:
|
||||||
return &GetRootCAsResults{TrustCerts: cs.clientTrust2}, nil
|
return &GetRootCAsResults{TrustCerts: cs.ClientTrust2}, nil
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
clientVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
|
clientVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
|
||||||
@ -275,12 +211,12 @@ func (s) TestEnd2End(t *testing.T) {
|
|||||||
serverGetCert: func(*tls.ClientHelloInfo) ([]*tls.Certificate, error) {
|
serverGetCert: func(*tls.ClientHelloInfo) ([]*tls.Certificate, error) {
|
||||||
switch stage.read() {
|
switch stage.read() {
|
||||||
case 0:
|
case 0:
|
||||||
return []*tls.Certificate{&cs.serverPeer1}, nil
|
return []*tls.Certificate{&cs.ServerCert1}, nil
|
||||||
default:
|
default:
|
||||||
return []*tls.Certificate{&cs.serverPeer2}, nil
|
return []*tls.Certificate{&cs.ServerCert2}, nil
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
serverRoot: cs.serverTrust1,
|
serverRoot: cs.ServerTrust1,
|
||||||
serverVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
|
serverVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
|
||||||
return &VerificationResults{}, nil
|
return &VerificationResults{}, nil
|
||||||
},
|
},
|
||||||
@ -288,26 +224,26 @@ func (s) TestEnd2End(t *testing.T) {
|
|||||||
},
|
},
|
||||||
// Test Scenarios:
|
// Test Scenarios:
|
||||||
// At initialization(stage = 0), client will be initialized with cert
|
// At initialization(stage = 0), client will be initialized with cert
|
||||||
// clientPeer1 and clientTrust1, server with serverPeer1 and serverTrust1.
|
// ClientCert1 and ClientTrust1, server with ServerCert1 and ServerTrust1.
|
||||||
// The mutual authentication works at the beginning, since clientPeer1
|
// The mutual authentication works at the beginning, since ClientCert1
|
||||||
// trusted by serverTrust1, serverPeer1 by clientTrust1, and also the
|
// trusted by ServerTrust1, ServerCert1 by ClientTrust1, and also the
|
||||||
// custom verification check allows the CommonName on serverPeer1.
|
// custom verification check allows the CommonName on ServerCert1.
|
||||||
// At stage 1, server changes serverPeer1 to serverPeer2, and client
|
// At stage 1, server changes ServerCert1 to ServerCert2, and client
|
||||||
// changes clientTrust1 to clientTrust2. Although serverPeer2 is trusted by
|
// changes ClientTrust1 to ClientTrust2. Although ServerCert2 is trusted by
|
||||||
// clientTrust2, our authorization check only accepts serverPeer1, and
|
// ClientTrust2, our authorization check only accepts ServerCert1, and
|
||||||
// hence the following calls should fail. Previous connections should
|
// hence the following calls should fail. Previous connections should
|
||||||
// not be affected.
|
// not be affected.
|
||||||
// At stage 2, the client changes authorization check to only accept
|
// At stage 2, the client changes authorization check to only accept
|
||||||
// serverPeer2. Now we should see the connection becomes normal again.
|
// ServerCert2. Now we should see the connection becomes normal again.
|
||||||
{
|
{
|
||||||
desc: "TestClientCustomVerification",
|
desc: "TestClientCustomVerification",
|
||||||
clientCert: []tls.Certificate{cs.clientPeer1},
|
clientCert: []tls.Certificate{cs.ClientCert1},
|
||||||
clientGetRoot: func(params *GetRootCAsParams) (*GetRootCAsResults, error) {
|
clientGetRoot: func(params *GetRootCAsParams) (*GetRootCAsResults, error) {
|
||||||
switch stage.read() {
|
switch stage.read() {
|
||||||
case 0:
|
case 0:
|
||||||
return &GetRootCAsResults{TrustCerts: cs.clientTrust1}, nil
|
return &GetRootCAsResults{TrustCerts: cs.ClientTrust1}, nil
|
||||||
default:
|
default:
|
||||||
return &GetRootCAsResults{TrustCerts: cs.clientTrust2}, nil
|
return &GetRootCAsResults{TrustCerts: cs.ClientTrust2}, nil
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
clientVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
|
clientVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
|
||||||
@ -321,12 +257,12 @@ func (s) TestEnd2End(t *testing.T) {
|
|||||||
authzCheck := false
|
authzCheck := false
|
||||||
switch stage.read() {
|
switch stage.read() {
|
||||||
case 0, 1:
|
case 0, 1:
|
||||||
// foo.bar.com is the common name on serverPeer1
|
// foo.bar.com is the common name on ServerCert1
|
||||||
if cert.Subject.CommonName == "foo.bar.com" {
|
if cert.Subject.CommonName == "foo.bar.com" {
|
||||||
authzCheck = true
|
authzCheck = true
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
// foo.bar.server2.com is the common name on serverPeer2
|
// foo.bar.server2.com is the common name on ServerCert2
|
||||||
if cert.Subject.CommonName == "foo.bar.server2.com" {
|
if cert.Subject.CommonName == "foo.bar.server2.com" {
|
||||||
authzCheck = true
|
authzCheck = true
|
||||||
}
|
}
|
||||||
@ -340,12 +276,12 @@ func (s) TestEnd2End(t *testing.T) {
|
|||||||
serverGetCert: func(*tls.ClientHelloInfo) ([]*tls.Certificate, error) {
|
serverGetCert: func(*tls.ClientHelloInfo) ([]*tls.Certificate, error) {
|
||||||
switch stage.read() {
|
switch stage.read() {
|
||||||
case 0:
|
case 0:
|
||||||
return []*tls.Certificate{&cs.serverPeer1}, nil
|
return []*tls.Certificate{&cs.ServerCert1}, nil
|
||||||
default:
|
default:
|
||||||
return []*tls.Certificate{&cs.serverPeer2}, nil
|
return []*tls.Certificate{&cs.ServerCert2}, nil
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
serverRoot: cs.serverTrust1,
|
serverRoot: cs.ServerTrust1,
|
||||||
serverVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
|
serverVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
|
||||||
return &VerificationResults{}, nil
|
return &VerificationResults{}, nil
|
||||||
},
|
},
|
||||||
@ -353,9 +289,9 @@ func (s) TestEnd2End(t *testing.T) {
|
|||||||
},
|
},
|
||||||
// Test Scenarios:
|
// Test Scenarios:
|
||||||
// At initialization(stage = 0), client will be initialized with cert
|
// At initialization(stage = 0), client will be initialized with cert
|
||||||
// clientPeer1 and clientTrust1, server with serverPeer1 and serverTrust1.
|
// ClientCert1 and ClientTrust1, server with ServerCert1 and ServerTrust1.
|
||||||
// The mutual authentication works at the beginning, since clientPeer1
|
// The mutual authentication works at the beginning, since ClientCert1
|
||||||
// trusted by serverTrust1, serverPeer1 by clientTrust1, and also the
|
// trusted by ServerTrust1, ServerCert1 by ClientTrust1, and also the
|
||||||
// custom verification check on server side allows all connections.
|
// custom verification check on server side allows all connections.
|
||||||
// At stage 1, server disallows the the connections by setting custom
|
// At stage 1, server disallows the the connections by setting custom
|
||||||
// verification check. The following calls should fail. Previous
|
// verification check. The following calls should fail. Previous
|
||||||
@ -364,14 +300,14 @@ func (s) TestEnd2End(t *testing.T) {
|
|||||||
// authentications should go back to normal.
|
// authentications should go back to normal.
|
||||||
{
|
{
|
||||||
desc: "TestServerCustomVerification",
|
desc: "TestServerCustomVerification",
|
||||||
clientCert: []tls.Certificate{cs.clientPeer1},
|
clientCert: []tls.Certificate{cs.ClientCert1},
|
||||||
clientRoot: cs.clientTrust1,
|
clientRoot: cs.ClientTrust1,
|
||||||
clientVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
|
clientVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
|
||||||
return &VerificationResults{}, nil
|
return &VerificationResults{}, nil
|
||||||
},
|
},
|
||||||
clientVType: CertVerification,
|
clientVType: CertVerification,
|
||||||
serverCert: []tls.Certificate{cs.serverPeer1},
|
serverCert: []tls.Certificate{cs.ServerCert1},
|
||||||
serverRoot: cs.serverTrust1,
|
serverRoot: cs.ServerTrust1,
|
||||||
serverVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
|
serverVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
|
||||||
switch stage.read() {
|
switch stage.read() {
|
||||||
case 0, 2:
|
case 0, 2:
|
||||||
|
@ -22,21 +22,17 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/pem"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
|
||||||
"math/big"
|
"math/big"
|
||||||
"net"
|
"net"
|
||||||
"reflect"
|
|
||||||
"syscall"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
"google.golang.org/grpc/credentials"
|
"google.golang.org/grpc/credentials"
|
||||||
"google.golang.org/grpc/credentials/tls/certprovider"
|
"google.golang.org/grpc/credentials/tls/certprovider"
|
||||||
"google.golang.org/grpc/internal/grpctest"
|
"google.golang.org/grpc/internal/grpctest"
|
||||||
"google.golang.org/grpc/security/advancedtls/testdata"
|
"google.golang.org/grpc/security/advancedtls/internal/testutils"
|
||||||
)
|
)
|
||||||
|
|
||||||
type s struct {
|
type s struct {
|
||||||
@ -65,27 +61,26 @@ func (f fakeProvider) KeyMaterial(ctx context.Context) (*certprovider.KeyMateria
|
|||||||
if f.wantError {
|
if f.wantError {
|
||||||
return nil, fmt.Errorf("bad fakeProvider")
|
return nil, fmt.Errorf("bad fakeProvider")
|
||||||
}
|
}
|
||||||
cs := &certStore{}
|
cs := &testutils.CertStore{}
|
||||||
err := cs.loadCerts()
|
if err := cs.LoadCerts(); err != nil {
|
||||||
if err != nil {
|
return nil, fmt.Errorf("cs.LoadCerts() failed, err: %v", err)
|
||||||
return nil, fmt.Errorf("failed to load certs: %v", err)
|
|
||||||
}
|
}
|
||||||
if f.pt == provTypeRoot && f.isClient {
|
if f.pt == provTypeRoot && f.isClient {
|
||||||
return &certprovider.KeyMaterial{Roots: cs.clientTrust1}, nil
|
return &certprovider.KeyMaterial{Roots: cs.ClientTrust1}, nil
|
||||||
}
|
}
|
||||||
if f.pt == provTypeRoot && !f.isClient {
|
if f.pt == provTypeRoot && !f.isClient {
|
||||||
return &certprovider.KeyMaterial{Roots: cs.serverTrust1}, nil
|
return &certprovider.KeyMaterial{Roots: cs.ServerTrust1}, nil
|
||||||
}
|
}
|
||||||
if f.pt == provTypeIdentity && f.isClient {
|
if f.pt == provTypeIdentity && f.isClient {
|
||||||
if f.wantMultiCert {
|
if f.wantMultiCert {
|
||||||
return &certprovider.KeyMaterial{Certs: []tls.Certificate{cs.clientPeer1, cs.clientPeer2}}, nil
|
return &certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ClientCert1, cs.ClientCert2}}, nil
|
||||||
}
|
}
|
||||||
return &certprovider.KeyMaterial{Certs: []tls.Certificate{cs.clientPeer1}}, nil
|
return &certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ClientCert1}}, nil
|
||||||
}
|
}
|
||||||
if f.wantMultiCert {
|
if f.wantMultiCert {
|
||||||
return &certprovider.KeyMaterial{Certs: []tls.Certificate{cs.serverPeer1, cs.serverPeer2}}, nil
|
return &certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ServerCert1, cs.ServerCert2}}, nil
|
||||||
}
|
}
|
||||||
return &certprovider.KeyMaterial{Certs: []tls.Certificate{cs.serverPeer1}}, nil
|
return &certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ServerCert1}}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f fakeProvider) Close() {}
|
func (f fakeProvider) Close() {}
|
||||||
@ -308,13 +303,12 @@ func (s) TestServerOptionsConfigSuccessCases(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s) TestClientServerHandshake(t *testing.T) {
|
func (s) TestClientServerHandshake(t *testing.T) {
|
||||||
cs := &certStore{}
|
cs := &testutils.CertStore{}
|
||||||
err := cs.loadCerts()
|
if err := cs.LoadCerts(); err != nil {
|
||||||
if err != nil {
|
t.Fatalf("cs.LoadCerts() failed, err: %v", err)
|
||||||
t.Fatalf("Failed to load certs: %v", err)
|
|
||||||
}
|
}
|
||||||
getRootCAsForClient := func(params *GetRootCAsParams) (*GetRootCAsResults, error) {
|
getRootCAsForClient := func(params *GetRootCAsParams) (*GetRootCAsResults, error) {
|
||||||
return &GetRootCAsResults{TrustCerts: cs.clientTrust1}, nil
|
return &GetRootCAsResults{TrustCerts: cs.ClientTrust1}, nil
|
||||||
}
|
}
|
||||||
clientVerifyFuncGood := func(params *VerificationFuncParams) (*VerificationResults, error) {
|
clientVerifyFuncGood := func(params *VerificationFuncParams) (*VerificationResults, error) {
|
||||||
if params.ServerName == "" {
|
if params.ServerName == "" {
|
||||||
@ -331,7 +325,7 @@ func (s) TestClientServerHandshake(t *testing.T) {
|
|||||||
return nil, fmt.Errorf("custom verification function failed")
|
return nil, fmt.Errorf("custom verification function failed")
|
||||||
}
|
}
|
||||||
getRootCAsForServer := func(params *GetRootCAsParams) (*GetRootCAsResults, error) {
|
getRootCAsForServer := func(params *GetRootCAsParams) (*GetRootCAsResults, error) {
|
||||||
return &GetRootCAsResults{TrustCerts: cs.serverTrust1}, nil
|
return &GetRootCAsResults{TrustCerts: cs.ServerTrust1}, nil
|
||||||
}
|
}
|
||||||
serverVerifyFunc := func(params *VerificationFuncParams) (*VerificationResults, error) {
|
serverVerifyFunc := func(params *VerificationFuncParams) (*VerificationResults, error) {
|
||||||
if params.ServerName != "" {
|
if params.ServerName != "" {
|
||||||
@ -378,7 +372,7 @@ func (s) TestClientServerHandshake(t *testing.T) {
|
|||||||
desc: "Client has no trust cert with verifyFuncGood; server sends peer cert",
|
desc: "Client has no trust cert with verifyFuncGood; server sends peer cert",
|
||||||
clientVerifyFunc: clientVerifyFuncGood,
|
clientVerifyFunc: clientVerifyFuncGood,
|
||||||
clientVType: SkipVerification,
|
clientVType: SkipVerification,
|
||||||
serverCert: []tls.Certificate{cs.serverPeer1},
|
serverCert: []tls.Certificate{cs.ServerCert1},
|
||||||
serverVType: CertAndHostVerification,
|
serverVType: CertAndHostVerification,
|
||||||
},
|
},
|
||||||
// Client: only set clientRoot
|
// Client: only set clientRoot
|
||||||
@ -389,10 +383,10 @@ func (s) TestClientServerHandshake(t *testing.T) {
|
|||||||
// this test suites.
|
// this test suites.
|
||||||
{
|
{
|
||||||
desc: "Client has root cert; server sends peer cert",
|
desc: "Client has root cert; server sends peer cert",
|
||||||
clientRoot: cs.clientTrust1,
|
clientRoot: cs.ClientTrust1,
|
||||||
clientVType: CertAndHostVerification,
|
clientVType: CertAndHostVerification,
|
||||||
clientExpectHandshakeError: true,
|
clientExpectHandshakeError: true,
|
||||||
serverCert: []tls.Certificate{cs.serverPeer1},
|
serverCert: []tls.Certificate{cs.ServerCert1},
|
||||||
serverVType: CertAndHostVerification,
|
serverVType: CertAndHostVerification,
|
||||||
serverExpectError: true,
|
serverExpectError: true,
|
||||||
},
|
},
|
||||||
@ -407,7 +401,7 @@ func (s) TestClientServerHandshake(t *testing.T) {
|
|||||||
clientGetRoot: getRootCAsForClient,
|
clientGetRoot: getRootCAsForClient,
|
||||||
clientVType: CertAndHostVerification,
|
clientVType: CertAndHostVerification,
|
||||||
clientExpectHandshakeError: true,
|
clientExpectHandshakeError: true,
|
||||||
serverCert: []tls.Certificate{cs.serverPeer1},
|
serverCert: []tls.Certificate{cs.ServerCert1},
|
||||||
serverVType: CertAndHostVerification,
|
serverVType: CertAndHostVerification,
|
||||||
serverExpectError: true,
|
serverExpectError: true,
|
||||||
},
|
},
|
||||||
@ -419,7 +413,7 @@ func (s) TestClientServerHandshake(t *testing.T) {
|
|||||||
clientGetRoot: getRootCAsForClient,
|
clientGetRoot: getRootCAsForClient,
|
||||||
clientVerifyFunc: clientVerifyFuncGood,
|
clientVerifyFunc: clientVerifyFuncGood,
|
||||||
clientVType: CertVerification,
|
clientVType: CertVerification,
|
||||||
serverCert: []tls.Certificate{cs.serverPeer1},
|
serverCert: []tls.Certificate{cs.ServerCert1},
|
||||||
serverVType: CertAndHostVerification,
|
serverVType: CertAndHostVerification,
|
||||||
},
|
},
|
||||||
// Client: set clientGetRoot and bad clientVerifyFunc function
|
// Client: set clientGetRoot and bad clientVerifyFunc function
|
||||||
@ -432,7 +426,7 @@ func (s) TestClientServerHandshake(t *testing.T) {
|
|||||||
clientVerifyFunc: verifyFuncBad,
|
clientVerifyFunc: verifyFuncBad,
|
||||||
clientVType: CertVerification,
|
clientVType: CertVerification,
|
||||||
clientExpectHandshakeError: true,
|
clientExpectHandshakeError: true,
|
||||||
serverCert: []tls.Certificate{cs.serverPeer1},
|
serverCert: []tls.Certificate{cs.ServerCert1},
|
||||||
serverVType: CertVerification,
|
serverVType: CertVerification,
|
||||||
serverExpectError: true,
|
serverExpectError: true,
|
||||||
},
|
},
|
||||||
@ -441,13 +435,13 @@ func (s) TestClientServerHandshake(t *testing.T) {
|
|||||||
// Expected Behavior: success
|
// Expected Behavior: success
|
||||||
{
|
{
|
||||||
desc: "Client sets peer cert, reload root function with verifyFuncGood; server sets peer cert and root cert; mutualTLS",
|
desc: "Client sets peer cert, reload root function with verifyFuncGood; server sets peer cert and root cert; mutualTLS",
|
||||||
clientCert: []tls.Certificate{cs.clientPeer1},
|
clientCert: []tls.Certificate{cs.ClientCert1},
|
||||||
clientGetRoot: getRootCAsForClient,
|
clientGetRoot: getRootCAsForClient,
|
||||||
clientVerifyFunc: clientVerifyFuncGood,
|
clientVerifyFunc: clientVerifyFuncGood,
|
||||||
clientVType: CertVerification,
|
clientVType: CertVerification,
|
||||||
serverMutualTLS: true,
|
serverMutualTLS: true,
|
||||||
serverCert: []tls.Certificate{cs.serverPeer1},
|
serverCert: []tls.Certificate{cs.ServerCert1},
|
||||||
serverRoot: cs.serverTrust1,
|
serverRoot: cs.ServerTrust1,
|
||||||
serverVType: CertVerification,
|
serverVType: CertVerification,
|
||||||
},
|
},
|
||||||
// Client: set clientGetRoot, clientVerifyFunc and clientCert
|
// Client: set clientGetRoot, clientVerifyFunc and clientCert
|
||||||
@ -455,12 +449,12 @@ func (s) TestClientServerHandshake(t *testing.T) {
|
|||||||
// Expected Behavior: success
|
// Expected Behavior: success
|
||||||
{
|
{
|
||||||
desc: "Client sets peer cert, reload root function with verifyFuncGood; Server sets peer cert, reload root function; mutualTLS",
|
desc: "Client sets peer cert, reload root function with verifyFuncGood; Server sets peer cert, reload root function; mutualTLS",
|
||||||
clientCert: []tls.Certificate{cs.clientPeer1},
|
clientCert: []tls.Certificate{cs.ClientCert1},
|
||||||
clientGetRoot: getRootCAsForClient,
|
clientGetRoot: getRootCAsForClient,
|
||||||
clientVerifyFunc: clientVerifyFuncGood,
|
clientVerifyFunc: clientVerifyFuncGood,
|
||||||
clientVType: CertVerification,
|
clientVType: CertVerification,
|
||||||
serverMutualTLS: true,
|
serverMutualTLS: true,
|
||||||
serverCert: []tls.Certificate{cs.serverPeer1},
|
serverCert: []tls.Certificate{cs.ServerCert1},
|
||||||
serverGetRoot: getRootCAsForServer,
|
serverGetRoot: getRootCAsForServer,
|
||||||
serverVType: CertVerification,
|
serverVType: CertVerification,
|
||||||
},
|
},
|
||||||
@ -471,12 +465,12 @@ func (s) TestClientServerHandshake(t *testing.T) {
|
|||||||
// Reason: server side reloading returns failure
|
// Reason: server side reloading returns failure
|
||||||
{
|
{
|
||||||
desc: "Client sets peer cert, reload root function with verifyFuncGood; Server sets peer cert, bad reload root function; mutualTLS",
|
desc: "Client sets peer cert, reload root function with verifyFuncGood; Server sets peer cert, bad reload root function; mutualTLS",
|
||||||
clientCert: []tls.Certificate{cs.clientPeer1},
|
clientCert: []tls.Certificate{cs.ClientCert1},
|
||||||
clientGetRoot: getRootCAsForClient,
|
clientGetRoot: getRootCAsForClient,
|
||||||
clientVerifyFunc: clientVerifyFuncGood,
|
clientVerifyFunc: clientVerifyFuncGood,
|
||||||
clientVType: CertVerification,
|
clientVType: CertVerification,
|
||||||
serverMutualTLS: true,
|
serverMutualTLS: true,
|
||||||
serverCert: []tls.Certificate{cs.serverPeer1},
|
serverCert: []tls.Certificate{cs.ServerCert1},
|
||||||
serverGetRoot: getRootCAsForServerBad,
|
serverGetRoot: getRootCAsForServerBad,
|
||||||
serverVType: CertVerification,
|
serverVType: CertVerification,
|
||||||
serverExpectError: true,
|
serverExpectError: true,
|
||||||
@ -487,14 +481,14 @@ func (s) TestClientServerHandshake(t *testing.T) {
|
|||||||
{
|
{
|
||||||
desc: "Client sets reload peer/root function with verifyFuncGood; Server sets reload peer/root function with verifyFuncGood; mutualTLS",
|
desc: "Client sets reload peer/root function with verifyFuncGood; Server sets reload peer/root function with verifyFuncGood; mutualTLS",
|
||||||
clientGetCert: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
clientGetCert: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
||||||
return &cs.clientPeer1, nil
|
return &cs.ClientCert1, nil
|
||||||
},
|
},
|
||||||
clientGetRoot: getRootCAsForClient,
|
clientGetRoot: getRootCAsForClient,
|
||||||
clientVerifyFunc: clientVerifyFuncGood,
|
clientVerifyFunc: clientVerifyFuncGood,
|
||||||
clientVType: CertVerification,
|
clientVType: CertVerification,
|
||||||
serverMutualTLS: true,
|
serverMutualTLS: true,
|
||||||
serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
|
serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
|
||||||
return []*tls.Certificate{&cs.serverPeer1}, nil
|
return []*tls.Certificate{&cs.ServerCert1}, nil
|
||||||
},
|
},
|
||||||
serverGetRoot: getRootCAsForServer,
|
serverGetRoot: getRootCAsForServer,
|
||||||
serverVerifyFunc: serverVerifyFunc,
|
serverVerifyFunc: serverVerifyFunc,
|
||||||
@ -508,14 +502,14 @@ func (s) TestClientServerHandshake(t *testing.T) {
|
|||||||
{
|
{
|
||||||
desc: "Client sends wrong peer cert; Server sets reload peer/root function with verifyFuncGood; mutualTLS",
|
desc: "Client sends wrong peer cert; Server sets reload peer/root function with verifyFuncGood; mutualTLS",
|
||||||
clientGetCert: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
clientGetCert: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
||||||
return &cs.serverPeer1, nil
|
return &cs.ServerCert1, nil
|
||||||
},
|
},
|
||||||
clientGetRoot: getRootCAsForClient,
|
clientGetRoot: getRootCAsForClient,
|
||||||
clientVerifyFunc: clientVerifyFuncGood,
|
clientVerifyFunc: clientVerifyFuncGood,
|
||||||
clientVType: CertVerification,
|
clientVType: CertVerification,
|
||||||
serverMutualTLS: true,
|
serverMutualTLS: true,
|
||||||
serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
|
serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
|
||||||
return []*tls.Certificate{&cs.serverPeer1}, nil
|
return []*tls.Certificate{&cs.ServerCert1}, nil
|
||||||
},
|
},
|
||||||
serverGetRoot: getRootCAsForServer,
|
serverGetRoot: getRootCAsForServer,
|
||||||
serverVerifyFunc: serverVerifyFunc,
|
serverVerifyFunc: serverVerifyFunc,
|
||||||
@ -529,7 +523,7 @@ func (s) TestClientServerHandshake(t *testing.T) {
|
|||||||
{
|
{
|
||||||
desc: "Client has wrong trust cert; Server sets reload peer/root function with verifyFuncGood; mutualTLS",
|
desc: "Client has wrong trust cert; Server sets reload peer/root function with verifyFuncGood; mutualTLS",
|
||||||
clientGetCert: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
clientGetCert: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
||||||
return &cs.clientPeer1, nil
|
return &cs.ClientCert1, nil
|
||||||
},
|
},
|
||||||
clientGetRoot: getRootCAsForServer,
|
clientGetRoot: getRootCAsForServer,
|
||||||
clientVerifyFunc: clientVerifyFuncGood,
|
clientVerifyFunc: clientVerifyFuncGood,
|
||||||
@ -537,7 +531,7 @@ func (s) TestClientServerHandshake(t *testing.T) {
|
|||||||
clientExpectHandshakeError: true,
|
clientExpectHandshakeError: true,
|
||||||
serverMutualTLS: true,
|
serverMutualTLS: true,
|
||||||
serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
|
serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
|
||||||
return []*tls.Certificate{&cs.serverPeer1}, nil
|
return []*tls.Certificate{&cs.ServerCert1}, nil
|
||||||
},
|
},
|
||||||
serverGetRoot: getRootCAsForServer,
|
serverGetRoot: getRootCAsForServer,
|
||||||
serverVerifyFunc: serverVerifyFunc,
|
serverVerifyFunc: serverVerifyFunc,
|
||||||
@ -552,14 +546,14 @@ func (s) TestClientServerHandshake(t *testing.T) {
|
|||||||
{
|
{
|
||||||
desc: "Client sets reload peer/root function with verifyFuncGood; Server sends wrong peer cert; mutualTLS",
|
desc: "Client sets reload peer/root function with verifyFuncGood; Server sends wrong peer cert; mutualTLS",
|
||||||
clientGetCert: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
clientGetCert: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
||||||
return &cs.clientPeer1, nil
|
return &cs.ClientCert1, nil
|
||||||
},
|
},
|
||||||
clientGetRoot: getRootCAsForClient,
|
clientGetRoot: getRootCAsForClient,
|
||||||
clientVerifyFunc: clientVerifyFuncGood,
|
clientVerifyFunc: clientVerifyFuncGood,
|
||||||
clientVType: CertVerification,
|
clientVType: CertVerification,
|
||||||
serverMutualTLS: true,
|
serverMutualTLS: true,
|
||||||
serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
|
serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
|
||||||
return []*tls.Certificate{&cs.clientPeer1}, nil
|
return []*tls.Certificate{&cs.ClientCert1}, nil
|
||||||
},
|
},
|
||||||
serverGetRoot: getRootCAsForServer,
|
serverGetRoot: getRootCAsForServer,
|
||||||
serverVerifyFunc: serverVerifyFunc,
|
serverVerifyFunc: serverVerifyFunc,
|
||||||
@ -573,7 +567,7 @@ func (s) TestClientServerHandshake(t *testing.T) {
|
|||||||
{
|
{
|
||||||
desc: "Client sets reload peer/root function with verifyFuncGood; Server has wrong trust cert; mutualTLS",
|
desc: "Client sets reload peer/root function with verifyFuncGood; Server has wrong trust cert; mutualTLS",
|
||||||
clientGetCert: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
clientGetCert: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
||||||
return &cs.clientPeer1, nil
|
return &cs.ClientCert1, nil
|
||||||
},
|
},
|
||||||
clientGetRoot: getRootCAsForClient,
|
clientGetRoot: getRootCAsForClient,
|
||||||
clientVerifyFunc: clientVerifyFuncGood,
|
clientVerifyFunc: clientVerifyFuncGood,
|
||||||
@ -581,7 +575,7 @@ func (s) TestClientServerHandshake(t *testing.T) {
|
|||||||
clientExpectHandshakeError: true,
|
clientExpectHandshakeError: true,
|
||||||
serverMutualTLS: true,
|
serverMutualTLS: true,
|
||||||
serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
|
serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
|
||||||
return []*tls.Certificate{&cs.serverPeer1}, nil
|
return []*tls.Certificate{&cs.ServerCert1}, nil
|
||||||
},
|
},
|
||||||
serverGetRoot: getRootCAsForClient,
|
serverGetRoot: getRootCAsForClient,
|
||||||
serverVerifyFunc: serverVerifyFunc,
|
serverVerifyFunc: serverVerifyFunc,
|
||||||
@ -594,13 +588,13 @@ func (s) TestClientServerHandshake(t *testing.T) {
|
|||||||
// server custom check fails
|
// server custom check fails
|
||||||
{
|
{
|
||||||
desc: "Client sets peer cert, reload root function with verifyFuncGood; Server sets bad custom check; mutualTLS",
|
desc: "Client sets peer cert, reload root function with verifyFuncGood; Server sets bad custom check; mutualTLS",
|
||||||
clientCert: []tls.Certificate{cs.clientPeer1},
|
clientCert: []tls.Certificate{cs.ClientCert1},
|
||||||
clientGetRoot: getRootCAsForClient,
|
clientGetRoot: getRootCAsForClient,
|
||||||
clientVerifyFunc: clientVerifyFuncGood,
|
clientVerifyFunc: clientVerifyFuncGood,
|
||||||
clientVType: CertVerification,
|
clientVType: CertVerification,
|
||||||
clientExpectHandshakeError: true,
|
clientExpectHandshakeError: true,
|
||||||
serverMutualTLS: true,
|
serverMutualTLS: true,
|
||||||
serverCert: []tls.Certificate{cs.serverPeer1},
|
serverCert: []tls.Certificate{cs.ServerCert1},
|
||||||
serverGetRoot: getRootCAsForServer,
|
serverGetRoot: getRootCAsForServer,
|
||||||
serverVerifyFunc: verifyFuncBad,
|
serverVerifyFunc: verifyFuncBad,
|
||||||
serverVType: CertVerification,
|
serverVType: CertVerification,
|
||||||
@ -776,24 +770,6 @@ func (s) TestClientServerHandshake(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func readTrustCert(fileName string) (*x509.CertPool, error) {
|
|
||||||
trustData, err := ioutil.ReadFile(fileName)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
trustBlock, _ := pem.Decode(trustData)
|
|
||||||
if trustBlock == nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
trustCert, err := x509.ParseCertificate(trustBlock.Bytes)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
trustPool := x509.NewCertPool()
|
|
||||||
trustPool.AddCert(trustCert)
|
|
||||||
return trustPool, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func compare(a1, a2 credentials.AuthInfo) bool {
|
func compare(a1, a2 credentials.AuthInfo) bool {
|
||||||
if a1.AuthType() != a2.AuthType() {
|
if a1.AuthType() != a2.AuthType() {
|
||||||
return false
|
return false
|
||||||
@ -816,13 +792,13 @@ func compare(a1, a2 credentials.AuthInfo) bool {
|
|||||||
|
|
||||||
func (s) TestAdvancedTLSOverrideServerName(t *testing.T) {
|
func (s) TestAdvancedTLSOverrideServerName(t *testing.T) {
|
||||||
expectedServerName := "server.name"
|
expectedServerName := "server.name"
|
||||||
clientTrustPool, err := readTrustCert(testdata.Path("client_trust_cert_1.pem"))
|
cs := &testutils.CertStore{}
|
||||||
if err != nil {
|
if err := cs.LoadCerts(); err != nil {
|
||||||
t.Fatalf("Client is unable to load trust certs. Error: %v", err)
|
t.Fatalf("cs.LoadCerts() failed, err: %v", err)
|
||||||
}
|
}
|
||||||
clientOptions := &ClientOptions{
|
clientOptions := &ClientOptions{
|
||||||
RootOptions: RootCertificateOptions{
|
RootOptions: RootCertificateOptions{
|
||||||
RootCACerts: clientTrustPool,
|
RootCACerts: cs.ClientTrust1,
|
||||||
},
|
},
|
||||||
ServerNameOverride: expectedServerName,
|
ServerNameOverride: expectedServerName,
|
||||||
}
|
}
|
||||||
@ -836,122 +812,33 @@ func (s) TestAdvancedTLSOverrideServerName(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s) TestTLSClone(t *testing.T) {
|
|
||||||
expectedServerName := "server.name"
|
|
||||||
clientTrustPool, err := readTrustCert(testdata.Path("client_trust_cert_1.pem"))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Client is unable to load trust certs. Error: %v", err)
|
|
||||||
}
|
|
||||||
clientOptions := &ClientOptions{
|
|
||||||
RootOptions: RootCertificateOptions{
|
|
||||||
RootCACerts: clientTrustPool,
|
|
||||||
},
|
|
||||||
ServerNameOverride: expectedServerName,
|
|
||||||
}
|
|
||||||
c, err := NewClientCreds(clientOptions)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create new client: %v", err)
|
|
||||||
}
|
|
||||||
cc := c.Clone()
|
|
||||||
if cc.Info().ServerName != expectedServerName {
|
|
||||||
t.Fatalf("cc.Info().ServerName = %v, want %v", cc.Info().ServerName, expectedServerName)
|
|
||||||
}
|
|
||||||
cc.OverrideServerName("")
|
|
||||||
if c.Info().ServerName != expectedServerName {
|
|
||||||
t.Fatalf("Change in clone should not affect the original, "+
|
|
||||||
"c.Info().ServerName = %v, want %v", c.Info().ServerName, expectedServerName)
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s) TestAppendH2ToNextProtos(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
ps []string
|
|
||||||
want []string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "empty",
|
|
||||||
ps: nil,
|
|
||||||
want: []string{"h2"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "only h2",
|
|
||||||
ps: []string{"h2"},
|
|
||||||
want: []string{"h2"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "with h2",
|
|
||||||
ps: []string{"alpn", "h2"},
|
|
||||||
want: []string{"alpn", "h2"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "no h2",
|
|
||||||
ps: []string{"alpn"},
|
|
||||||
want: []string{"alpn", "h2"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
if got := appendH2ToNextProtos(tt.ps); !reflect.DeepEqual(got, tt.want) {
|
|
||||||
t.Errorf("appendH2ToNextProtos() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type nonSyscallConn struct {
|
|
||||||
net.Conn
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s) TestWrapSyscallConn(t *testing.T) {
|
|
||||||
sc := &syscallConn{}
|
|
||||||
nsc := &nonSyscallConn{}
|
|
||||||
|
|
||||||
wrapConn := WrapSyscallConn(sc, nsc)
|
|
||||||
if _, ok := wrapConn.(syscall.Conn); !ok {
|
|
||||||
t.Errorf("returned conn (type %T) doesn't implement syscall.Conn, want implement",
|
|
||||||
wrapConn)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s) TestGetCertificatesSNI(t *testing.T) {
|
func (s) TestGetCertificatesSNI(t *testing.T) {
|
||||||
// Load server certificates for setting the serverGetCert callback function.
|
cs := &testutils.CertStore{}
|
||||||
serverCert1, err := tls.LoadX509KeyPair(testdata.Path("server_cert_1.pem"), testdata.Path("server_key_1.pem"))
|
if err := cs.LoadCerts(); err != nil {
|
||||||
if err != nil {
|
t.Fatalf("cs.LoadCerts() failed, err: %v", err)
|
||||||
t.Fatalf("tls.LoadX509KeyPair(server_cert_1.pem, server_key_1.pem) failed: %v", err)
|
|
||||||
}
|
}
|
||||||
serverCert2, err := tls.LoadX509KeyPair(testdata.Path("server_cert_2.pem"), testdata.Path("server_key_2.pem"))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("tls.LoadX509KeyPair(server_cert_2.pem, server_key_2.pem) failed: %v", err)
|
|
||||||
}
|
|
||||||
serverCert3, err := tls.LoadX509KeyPair(testdata.Path("server_cert_3.pem"), testdata.Path("server_key_3.pem"))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("tls.LoadX509KeyPair(server_cert_3.pem, server_key_3.pem) failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
desc string
|
desc string
|
||||||
serverName string
|
serverName string
|
||||||
wantCert tls.Certificate
|
wantCert tls.Certificate
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
desc: "Select serverCert1",
|
desc: "Select ServerCert1",
|
||||||
// "foo.bar.com" is the common name on server certificate server_cert_1.pem.
|
// "foo.bar.com" is the common name on server certificate server_cert_1.pem.
|
||||||
serverName: "foo.bar.com",
|
serverName: "foo.bar.com",
|
||||||
wantCert: serverCert1,
|
wantCert: cs.ServerCert1,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
desc: "Select serverCert2",
|
desc: "Select ServerCert2",
|
||||||
// "foo.bar.server2.com" is the common name on server certificate server_cert_2.pem.
|
// "foo.bar.server2.com" is the common name on server certificate server_cert_2.pem.
|
||||||
serverName: "foo.bar.server2.com",
|
serverName: "foo.bar.server2.com",
|
||||||
wantCert: serverCert2,
|
wantCert: cs.ServerCert2,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
desc: "Select serverCert3",
|
desc: "Select serverCert3",
|
||||||
// "google.com" is one of the DNS names on server certificate server_cert_3.pem.
|
// "google.com" is one of the DNS names on server certificate server_cert_3.pem.
|
||||||
serverName: "google.com",
|
serverName: "google.com",
|
||||||
wantCert: serverCert3,
|
wantCert: cs.ServerPeer3,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
@ -960,7 +847,7 @@ func (s) TestGetCertificatesSNI(t *testing.T) {
|
|||||||
serverOptions := &ServerOptions{
|
serverOptions := &ServerOptions{
|
||||||
IdentityOptions: IdentityCertificateOptions{
|
IdentityOptions: IdentityCertificateOptions{
|
||||||
GetIdentityCertificatesForServer: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
|
GetIdentityCertificatesForServer: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
|
||||||
return []*tls.Certificate{&serverCert1, &serverCert2, &serverCert3}, nil
|
return []*tls.Certificate{&cs.ServerCert1, &cs.ServerCert2, &cs.ServerPeer3}, nil
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
100
security/advancedtls/internal/testutils/testutils.go
Normal file
100
security/advancedtls/internal/testutils/testutils.go
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2020 gRPC authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Package testutils contains helper functions for advancedtls.
|
||||||
|
package testutils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
|
||||||
|
"google.golang.org/grpc/security/advancedtls/testdata"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CertStore contains all the certificates used in the integration tests.
|
||||||
|
type CertStore struct {
|
||||||
|
// ClientCert1 is the certificate sent by client to prove its identity.
|
||||||
|
// It is trusted by ServerTrust1.
|
||||||
|
ClientCert1 tls.Certificate
|
||||||
|
// ClientCert2 is the certificate sent by client to prove its identity.
|
||||||
|
// It is trusted by ServerTrust2.
|
||||||
|
ClientCert2 tls.Certificate
|
||||||
|
// ServerCert1 is the certificate sent by server to prove its identity.
|
||||||
|
// It is trusted by ClientTrust1.
|
||||||
|
ServerCert1 tls.Certificate
|
||||||
|
// ServerCert2 is the certificate sent by server to prove its identity.
|
||||||
|
// It is trusted by ClientTrust2.
|
||||||
|
ServerCert2 tls.Certificate
|
||||||
|
// ServerPeer3 is the certificate sent by server to prove its identity.
|
||||||
|
ServerPeer3 tls.Certificate
|
||||||
|
// ClientTrust1 is the root certificate used on the client side.
|
||||||
|
ClientTrust1 *x509.CertPool
|
||||||
|
// ClientTrust2 is the root certificate used on the client side.
|
||||||
|
ClientTrust2 *x509.CertPool
|
||||||
|
// ServerTrust1 is the root certificate used on the server side.
|
||||||
|
ServerTrust1 *x509.CertPool
|
||||||
|
// ServerTrust2 is the root certificate used on the server side.
|
||||||
|
ServerTrust2 *x509.CertPool
|
||||||
|
}
|
||||||
|
|
||||||
|
func readTrustCert(fileName string) (*x509.CertPool, error) {
|
||||||
|
trustData, err := ioutil.ReadFile(fileName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
trustPool := x509.NewCertPool()
|
||||||
|
if !trustPool.AppendCertsFromPEM(trustData) {
|
||||||
|
return nil, fmt.Errorf("error loading trust certificates")
|
||||||
|
}
|
||||||
|
return trustPool, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadCerts function is used to load test certificates at the beginning of
|
||||||
|
// each integration test.
|
||||||
|
func (cs *CertStore) LoadCerts() error {
|
||||||
|
var err error
|
||||||
|
if cs.ClientCert1, err = tls.LoadX509KeyPair(testdata.Path("client_cert_1.pem"), testdata.Path("client_key_1.pem")); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if cs.ClientCert2, err = tls.LoadX509KeyPair(testdata.Path("client_cert_2.pem"), testdata.Path("client_key_2.pem")); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if cs.ServerCert1, err = tls.LoadX509KeyPair(testdata.Path("server_cert_1.pem"), testdata.Path("server_key_1.pem")); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if cs.ServerCert2, err = tls.LoadX509KeyPair(testdata.Path("server_cert_2.pem"), testdata.Path("server_key_2.pem")); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if cs.ServerPeer3, err = tls.LoadX509KeyPair(testdata.Path("server_cert_3.pem"), testdata.Path("server_key_3.pem")); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if cs.ClientTrust1, err = readTrustCert(testdata.Path("client_trust_cert_1.pem")); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if cs.ClientTrust2, err = readTrustCert(testdata.Path("client_trust_cert_2.pem")); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if cs.ServerTrust1, err = readTrustCert(testdata.Path("server_trust_cert_1.pem")); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if cs.ServerTrust2, err = readTrustCert(testdata.Path("server_trust_cert_2.pem")); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
@ -29,6 +29,7 @@ import (
|
|||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
"google.golang.org/grpc/credentials/tls/certprovider"
|
"google.golang.org/grpc/credentials/tls/certprovider"
|
||||||
|
"google.golang.org/grpc/security/advancedtls/internal/testutils"
|
||||||
"google.golang.org/grpc/security/advancedtls/testdata"
|
"google.golang.org/grpc/security/advancedtls/testdata"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -95,17 +96,17 @@ func (s) TestNewPEMFileProvider(t *testing.T) {
|
|||||||
|
|
||||||
// This test overwrites the credential reading function used by the watching
|
// This test overwrites the credential reading function used by the watching
|
||||||
// goroutine. It is tested under different stages:
|
// goroutine. It is tested under different stages:
|
||||||
// At stage 0, we force reading function to load clientPeer1 and serverTrust1,
|
// At stage 0, we force reading function to load ClientCert1 and ServerTrust1,
|
||||||
// and see if the credentials are picked up by the watching go routine.
|
// and see if the credentials are picked up by the watching go routine.
|
||||||
// At stage 1, we force reading function to cause an error. The watching go
|
// At stage 1, we force reading function to cause an error. The watching go
|
||||||
// routine should log the error while leaving the credentials unchanged.
|
// routine should log the error while leaving the credentials unchanged.
|
||||||
// At stage 2, we force reading function to load clientPeer2 and serverTrust2,
|
// At stage 2, we force reading function to load ClientCert2 and ServerTrust2,
|
||||||
// and see if the new credentials are picked up.
|
// and see if the new credentials are picked up.
|
||||||
func (s) TestWatchingRoutineUpdates(t *testing.T) {
|
func (s) TestWatchingRoutineUpdates(t *testing.T) {
|
||||||
// Load certificates.
|
// Load certificates.
|
||||||
cs := &certStore{}
|
cs := &testutils.CertStore{}
|
||||||
if err := cs.loadCerts(); err != nil {
|
if err := cs.LoadCerts(); err != nil {
|
||||||
t.Fatalf("cs.loadCerts() failed: %v", err)
|
t.Fatalf("cs.LoadCerts() failed, err: %v", err)
|
||||||
}
|
}
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
desc string
|
desc string
|
||||||
@ -121,9 +122,9 @@ func (s) TestWatchingRoutineUpdates(t *testing.T) {
|
|||||||
KeyFile: "not_empty_key_file",
|
KeyFile: "not_empty_key_file",
|
||||||
TrustFile: "not_empty_trust_file",
|
TrustFile: "not_empty_trust_file",
|
||||||
},
|
},
|
||||||
wantKmStage0: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.clientPeer1}, Roots: cs.serverTrust1},
|
wantKmStage0: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ClientCert1}, Roots: cs.ServerTrust1},
|
||||||
wantKmStage1: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.clientPeer1}, Roots: cs.serverTrust1},
|
wantKmStage1: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ClientCert1}, Roots: cs.ServerTrust1},
|
||||||
wantKmStage2: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.clientPeer2}, Roots: cs.serverTrust2},
|
wantKmStage2: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ClientCert2}, Roots: cs.ServerTrust2},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
desc: "use identity certs only",
|
desc: "use identity certs only",
|
||||||
@ -131,18 +132,18 @@ func (s) TestWatchingRoutineUpdates(t *testing.T) {
|
|||||||
CertFile: "not_empty_cert_file",
|
CertFile: "not_empty_cert_file",
|
||||||
KeyFile: "not_empty_key_file",
|
KeyFile: "not_empty_key_file",
|
||||||
},
|
},
|
||||||
wantKmStage0: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.clientPeer1}},
|
wantKmStage0: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ClientCert1}},
|
||||||
wantKmStage1: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.clientPeer1}},
|
wantKmStage1: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ClientCert1}},
|
||||||
wantKmStage2: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.clientPeer2}},
|
wantKmStage2: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ClientCert2}},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
desc: "use trust certs only",
|
desc: "use trust certs only",
|
||||||
options: PEMFileProviderOptions{
|
options: PEMFileProviderOptions{
|
||||||
TrustFile: "not_empty_trust_file",
|
TrustFile: "not_empty_trust_file",
|
||||||
},
|
},
|
||||||
wantKmStage0: certprovider.KeyMaterial{Roots: cs.serverTrust1},
|
wantKmStage0: certprovider.KeyMaterial{Roots: cs.ServerTrust1},
|
||||||
wantKmStage1: certprovider.KeyMaterial{Roots: cs.serverTrust1},
|
wantKmStage1: certprovider.KeyMaterial{Roots: cs.ServerTrust1},
|
||||||
wantKmStage2: certprovider.KeyMaterial{Roots: cs.serverTrust2},
|
wantKmStage2: certprovider.KeyMaterial{Roots: cs.ServerTrust2},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
@ -155,11 +156,11 @@ func (s) TestWatchingRoutineUpdates(t *testing.T) {
|
|||||||
readKeyCertPairFunc = func(certFile, keyFile string) (tls.Certificate, error) {
|
readKeyCertPairFunc = func(certFile, keyFile string) (tls.Certificate, error) {
|
||||||
switch stage.read() {
|
switch stage.read() {
|
||||||
case 0:
|
case 0:
|
||||||
return cs.clientPeer1, nil
|
return cs.ClientCert1, nil
|
||||||
case 1:
|
case 1:
|
||||||
return tls.Certificate{}, fmt.Errorf("error occurred while reloading")
|
return tls.Certificate{}, fmt.Errorf("error occurred while reloading")
|
||||||
case 2:
|
case 2:
|
||||||
return cs.clientPeer2, nil
|
return cs.ClientCert2, nil
|
||||||
default:
|
default:
|
||||||
return tls.Certificate{}, fmt.Errorf("test stage not supported")
|
return tls.Certificate{}, fmt.Errorf("test stage not supported")
|
||||||
}
|
}
|
||||||
@ -171,11 +172,11 @@ func (s) TestWatchingRoutineUpdates(t *testing.T) {
|
|||||||
readTrustCertFunc = func(trustFile string) (*x509.CertPool, error) {
|
readTrustCertFunc = func(trustFile string) (*x509.CertPool, error) {
|
||||||
switch stage.read() {
|
switch stage.read() {
|
||||||
case 0:
|
case 0:
|
||||||
return cs.serverTrust1, nil
|
return cs.ServerTrust1, nil
|
||||||
case 1:
|
case 1:
|
||||||
return nil, fmt.Errorf("error occurred while reloading")
|
return nil, fmt.Errorf("error occurred while reloading")
|
||||||
case 2:
|
case 2:
|
||||||
return cs.serverTrust2, nil
|
return cs.ServerTrust2, nil
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("test stage not supported")
|
return nil, fmt.Errorf("test stage not supported")
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user