advancedtls: add end to end tests (#3318)
This commit is contained in:
402
security/advancedtls/advancedtls_integration_test.go
Normal file
402
security/advancedtls/advancedtls_integration_test.go
Normal file
@ -0,0 +1,402 @@
|
||||
/*
|
||||
*
|
||||
* Copyright 2020 gRPC authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package advancedtls
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
pb "google.golang.org/grpc/examples/helloworld/helloworld"
|
||||
"google.golang.org/grpc/security/advancedtls/testdata"
|
||||
)
|
||||
|
||||
var (
|
||||
address = "localhost:50051"
|
||||
port = ":50051"
|
||||
)
|
||||
|
||||
// stageInfo contains a stage number indicating the current phase of each integration test, and a mutex.
|
||||
// Based on the stage number of current test, we will use different certificates and server authorization
|
||||
// functions to check if our tests behave as expected.
|
||||
type stageInfo struct {
|
||||
mutex sync.Mutex
|
||||
stage int
|
||||
}
|
||||
|
||||
func (s *stageInfo) increase() {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
s.stage = s.stage + 1
|
||||
}
|
||||
|
||||
func (s *stageInfo) read() int {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
return s.stage
|
||||
}
|
||||
|
||||
func (s *stageInfo) reset() {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
s.stage = 0
|
||||
}
|
||||
|
||||
// certStore contains all the certificates used in the integration tests.
|
||||
type certStore struct {
|
||||
// clientPeer1 is the certificate sent by client to prove its identity. It is trusted by serverTrust1.
|
||||
clientPeer1 tls.Certificate
|
||||
// clientPeer2 is the certificate sent by client to prove its identity. It is trusted by serverTrust2.
|
||||
clientPeer2 tls.Certificate
|
||||
// serverPeer1 is the certificate sent by server to prove its identity. It is trusted by clientTrust1.
|
||||
serverPeer1 tls.Certificate
|
||||
// serverPeer2 is the certificate sent by server to prove its identity. It is trusted by clientTrust2.
|
||||
serverPeer2 tls.Certificate
|
||||
clientTrust1 *x509.CertPool
|
||||
clientTrust2 *x509.CertPool
|
||||
serverTrust1 *x509.CertPool
|
||||
serverTrust2 *x509.CertPool
|
||||
}
|
||||
|
||||
// loadCerts function is used to load test certificates at the beginning of each integration test.
|
||||
func (cs *certStore) loadCerts() error {
|
||||
var err error
|
||||
cs.clientPeer1, err = tls.LoadX509KeyPair(testdata.Path("client_cert_1.pem"),
|
||||
testdata.Path("client_key_1.pem"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cs.clientPeer2, err = tls.LoadX509KeyPair(testdata.Path("client_cert_2.pem"),
|
||||
testdata.Path("client_key_2.pem"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cs.serverPeer1, err = tls.LoadX509KeyPair(testdata.Path("server_cert_1.pem"),
|
||||
testdata.Path("server_key_1.pem"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cs.serverPeer2, err = tls.LoadX509KeyPair(testdata.Path("server_cert_2.pem"),
|
||||
testdata.Path("server_key_2.pem"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cs.clientTrust1, err = readTrustCert(testdata.Path("client_trust_cert_1.pem"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cs.clientTrust2, err = readTrustCert(testdata.Path("client_trust_cert_2.pem"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cs.serverTrust1, err = readTrustCert(testdata.Path("server_trust_cert_1.pem"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cs.serverTrust2, err = readTrustCert(testdata.Path("server_trust_cert_2.pem"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// serverImpl is used to implement pb.GreeterServer.
|
||||
type serverImpl struct{}
|
||||
|
||||
// SayHello is a simple implementation of pb.GreeterServer.
|
||||
func (s *serverImpl) SayHello(ctx context.Context, in *pb.HelloRequest) (*pb.HelloReply, error) {
|
||||
return &pb.HelloReply{Message: "Hello " + in.Name}, nil
|
||||
}
|
||||
|
||||
func callAndVerify(msg string, client pb.GreeterClient, shouldFail bool) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
_, err := client.SayHello(ctx, &pb.HelloRequest{Name: msg})
|
||||
if want, got := shouldFail == true, err != nil; got != want {
|
||||
return fmt.Errorf("want and got mismatch, want shouldFail=%v, got fail=%v, rpc error: %v", want, got, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func callAndVerifyWithClientConn(connCtx context.Context, msg string, creds credentials.TransportCredentials, shouldFail bool) (*grpc.ClientConn, pb.GreeterClient, error) {
|
||||
var conn *grpc.ClientConn
|
||||
var err error
|
||||
// If we want the test to fail, we establish a non-blocking connection to avoid it hangs and killed by the context.
|
||||
if shouldFail {
|
||||
conn, err = grpc.DialContext(connCtx, address, grpc.WithTransportCredentials(creds))
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("client failed to connect to %s. Error: %v", address, err)
|
||||
}
|
||||
} else {
|
||||
conn, err = grpc.DialContext(connCtx, address, grpc.WithTransportCredentials(creds), grpc.WithBlock())
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("client failed to connect to %s. Error: %v", address, err)
|
||||
}
|
||||
}
|
||||
greetClient := pb.NewGreeterClient(conn)
|
||||
err = callAndVerify(msg, greetClient, shouldFail)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return conn, greetClient, nil
|
||||
}
|
||||
|
||||
// The advanced TLS features are tested in different stages.
|
||||
// At stage 0, we establish a good connection between client and server.
|
||||
// At stage 1, we change one factor(it could be we change the server's certificate, or server authorization function, etc),
|
||||
// and test if the following connections would be dropped.
|
||||
// At stage 2, we re-establish the connection by changing the counterpart of the factor we modified in stage 1.
|
||||
// (could be change the client's trust certificate, or change server authorization function, etc)
|
||||
func TestEnd2End(t *testing.T) {
|
||||
cs := &certStore{}
|
||||
cs.loadCerts()
|
||||
stage := &stageInfo{}
|
||||
for _, test := range []struct {
|
||||
desc string
|
||||
clientCert []tls.Certificate
|
||||
clientGetCert func(*tls.CertificateRequestInfo) (*tls.Certificate, error)
|
||||
clientRoot *x509.CertPool
|
||||
clientGetRoot func(params *GetRootCAsParams) (*GetRootCAsResults, error)
|
||||
clientVerifyFunc CustomVerificationFunc
|
||||
serverCert []tls.Certificate
|
||||
serverGetCert func(*tls.ClientHelloInfo) (*tls.Certificate, error)
|
||||
serverRoot *x509.CertPool
|
||||
serverGetRoot func(params *GetRootCAsParams) (*GetRootCAsResults, error)
|
||||
}{
|
||||
// Test Scenarios:
|
||||
// At initialization(stage = 0), client will be initialized with cert clientPeer1 and clientTrust1, server with serverPeer1 and serverTrust1.
|
||||
// The mutual authentication works at the beginning, since clientPeer1 is trusted by serverTrust1, and serverPeer1 by clientTrust1.
|
||||
// At stage 1, client changes clientPeer1 to clientPeer2. Since clientPeer2 is not trusted by serverTrust1, following rpc calls are expected
|
||||
// to fail, while the previous rpc calls are still good because those are already authenticated.
|
||||
// At stage 2, the server changes serverTrust1 to serverTrust2, and we should see it again accepts the connection, since clientPeer2 is trusted
|
||||
// by serverTrust2.
|
||||
{
|
||||
desc: "TestClientPeerCertReloadServerTrustCertReload",
|
||||
clientCert: nil,
|
||||
clientGetCert: func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
||||
switch stage.read() {
|
||||
case 0:
|
||||
return &cs.clientPeer1, nil
|
||||
default:
|
||||
return &cs.clientPeer2, nil
|
||||
}
|
||||
},
|
||||
clientGetRoot: nil,
|
||||
clientRoot: cs.clientTrust1,
|
||||
clientVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
|
||||
return &VerificationResults{}, nil
|
||||
},
|
||||
serverCert: []tls.Certificate{cs.serverPeer1},
|
||||
serverGetCert: nil,
|
||||
serverRoot: nil,
|
||||
serverGetRoot: func(params *GetRootCAsParams) (*GetRootCAsResults, error) {
|
||||
switch stage.read() {
|
||||
case 0, 1:
|
||||
return &GetRootCAsResults{TrustCerts: cs.serverTrust1}, nil
|
||||
default:
|
||||
return &GetRootCAsResults{TrustCerts: cs.serverTrust2}, nil
|
||||
}
|
||||
},
|
||||
},
|
||||
// Test Scenarios:
|
||||
// At initialization(stage = 0), client will be initialized with cert clientPeer1 and clientTrust1, server with serverPeer1 and serverTrust1.
|
||||
// The mutual authentication works at the beginning, since clientPeer1 is trusted by serverTrust1, and serverPeer1 by clientTrust1.
|
||||
// At stage 1, server changes serverPeer1 to serverPeer2. Since serverPeer2 is not trusted by clientTrust1, following rpc calls are expected
|
||||
// to fail, while the previous rpc calls are still good because those are already authenticated.
|
||||
// At stage 2, the client changes clientTrust1 to clientTrust2, and we should see it again accepts the connection, since serverPeer2 is trusted
|
||||
// by clientTrust2.
|
||||
{
|
||||
desc: "TestServerPeerCertReloadClientTrustCertReload",
|
||||
clientCert: []tls.Certificate{cs.clientPeer1},
|
||||
clientGetCert: nil,
|
||||
clientGetRoot: func(params *GetRootCAsParams) (*GetRootCAsResults, error) {
|
||||
switch stage.read() {
|
||||
case 0, 1:
|
||||
return &GetRootCAsResults{TrustCerts: cs.clientTrust1}, nil
|
||||
default:
|
||||
return &GetRootCAsResults{TrustCerts: cs.clientTrust2}, nil
|
||||
}
|
||||
},
|
||||
clientRoot: nil,
|
||||
clientVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
|
||||
return &VerificationResults{}, nil
|
||||
},
|
||||
serverCert: nil,
|
||||
serverGetCert: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
switch stage.read() {
|
||||
case 0:
|
||||
return &cs.serverPeer1, nil
|
||||
default:
|
||||
return &cs.serverPeer2, nil
|
||||
}
|
||||
},
|
||||
serverRoot: cs.serverTrust1,
|
||||
serverGetRoot: nil,
|
||||
},
|
||||
// Test Scenarios:
|
||||
// At initialization(stage = 0), client will be initialized with cert clientPeer1 and clientTrust1, server with serverPeer1 and serverTrust1.
|
||||
// The mutual authentication works at the beginning, since clientPeer1 trusted by serverTrust1, serverPeer1 by clientTrust1, and also the
|
||||
// custom server authorization check allows the CommonName on serverPeer1.
|
||||
// At stage 1, server changes serverPeer1 to serverPeer2, and client changes clientTrust1 to clientTrust2. Although serverPeer2 is trusted by
|
||||
// clientTrust2, our authorization check only accepts serverPeer1, and hence the following calls should fail. Previous connections should
|
||||
// not be affected.
|
||||
// At stage 2, the client changes authorization check to only accept serverPeer2. Now we should see the connection becomes normal again.
|
||||
{
|
||||
desc: "TestClientCustomServerAuthz",
|
||||
clientCert: []tls.Certificate{cs.clientPeer1},
|
||||
clientGetCert: nil,
|
||||
clientGetRoot: func(params *GetRootCAsParams) (*GetRootCAsResults, error) {
|
||||
switch stage.read() {
|
||||
case 0:
|
||||
return &GetRootCAsResults{TrustCerts: cs.clientTrust1}, nil
|
||||
default:
|
||||
return &GetRootCAsResults{TrustCerts: cs.clientTrust2}, nil
|
||||
}
|
||||
},
|
||||
clientRoot: nil,
|
||||
clientVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
|
||||
if len(params.RawCerts) == 0 {
|
||||
return nil, fmt.Errorf("no peer certs")
|
||||
}
|
||||
cert, err := x509.ParseCertificate(params.RawCerts[0])
|
||||
if err != nil || cert == nil {
|
||||
return nil, fmt.Errorf("failed to parse certificate: " + err.Error())
|
||||
}
|
||||
authzCheck := false
|
||||
switch stage.read() {
|
||||
case 0, 1:
|
||||
// foo.bar.com is the common name on serverPeer1
|
||||
if cert.Subject.CommonName == "foo.bar.com" {
|
||||
authzCheck = true
|
||||
}
|
||||
default:
|
||||
// foo.bar.server2.com is the common name on serverPeer2
|
||||
if cert.Subject.CommonName == "foo.bar.server2.com" {
|
||||
authzCheck = true
|
||||
}
|
||||
}
|
||||
if authzCheck {
|
||||
return &VerificationResults{}, nil
|
||||
}
|
||||
return nil, fmt.Errorf("custom authz check fails")
|
||||
},
|
||||
serverCert: nil,
|
||||
serverGetCert: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
switch stage.read() {
|
||||
case 0:
|
||||
return &cs.serverPeer1, nil
|
||||
default:
|
||||
return &cs.serverPeer2, nil
|
||||
}
|
||||
},
|
||||
serverRoot: cs.serverTrust1,
|
||||
serverGetRoot: nil,
|
||||
},
|
||||
} {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
// Start a server using ServerOptions in another goroutine.
|
||||
serverOptions := &ServerOptions{
|
||||
Certificates: test.serverCert,
|
||||
GetCertificate: test.serverGetCert,
|
||||
RootCertificateOptions: RootCertificateOptions{
|
||||
RootCACerts: test.serverRoot,
|
||||
GetRootCAs: test.serverGetRoot,
|
||||
},
|
||||
RequireClientCert: true,
|
||||
}
|
||||
serverTLSCreds, err := NewServerCreds(serverOptions)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create server creds: %v", err)
|
||||
}
|
||||
s := grpc.NewServer(grpc.Creds(serverTLSCreds))
|
||||
defer s.Stop()
|
||||
go func(s *grpc.Server) {
|
||||
lis, err := net.Listen("tcp", port)
|
||||
// defer lis.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to listen: %v", err)
|
||||
}
|
||||
pb.RegisterGreeterServer(s, &serverImpl{})
|
||||
if err := s.Serve(lis); err != nil {
|
||||
t.Fatalf("failed to serve: %v", err)
|
||||
}
|
||||
}(s)
|
||||
clientOptions := &ClientOptions{
|
||||
Certificates: test.clientCert,
|
||||
GetClientCertificate: test.clientGetCert,
|
||||
VerifyPeer: test.clientVerifyFunc,
|
||||
RootCertificateOptions: RootCertificateOptions{
|
||||
RootCACerts: test.clientRoot,
|
||||
GetRootCAs: test.clientGetRoot,
|
||||
},
|
||||
}
|
||||
clientTLSCreds, err := NewClientCreds(clientOptions)
|
||||
if err != nil {
|
||||
t.Fatalf("clientTLSCreds failed to create")
|
||||
}
|
||||
// ------------------------Scenario 1-----------------------------------------
|
||||
// stage = 0, initial connection should succeed
|
||||
ctx1, cancel1 := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel1()
|
||||
conn, greetClient, err := callAndVerifyWithClientConn(ctx1, "rpc call 1", clientTLSCreds, false)
|
||||
defer conn.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// ---------------------------------------------------------------------------
|
||||
stage.increase()
|
||||
// ------------------------Scenario 2-----------------------------------------
|
||||
// stage = 1, previous connection should still succeed
|
||||
err = callAndVerify("rpc call 2", greetClient, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// ------------------------Scenario 3-----------------------------------------
|
||||
// stage = 1, new connection should fail
|
||||
ctx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel2()
|
||||
conn2, greetClient, err := callAndVerifyWithClientConn(ctx2, "rpc call 3", clientTLSCreds, true)
|
||||
defer conn2.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
//// ---------------------------------------------------------------------------
|
||||
stage.increase()
|
||||
// ------------------------Scenario 4-----------------------------------------
|
||||
// stage = 2, new connection should succeed
|
||||
ctx3, cancel3 := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel3()
|
||||
conn3, greetClient, err := callAndVerifyWithClientConn(ctx3, "rpc call 4", clientTLSCreds, false)
|
||||
defer conn3.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// ---------------------------------------------------------------------------
|
||||
stage.reset()
|
||||
})
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user