advancedtls: fixed SNI testing and put SNI functions back in advancedtls.go (#3774)
* Fixed sni unit test
This commit is contained in:
@ -26,11 +26,13 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"math/big"
|
||||
"net"
|
||||
"reflect"
|
||||
"syscall"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/internal/grpctest"
|
||||
"google.golang.org/grpc/security/advancedtls/testdata"
|
||||
@ -688,3 +690,73 @@ func (s) TestOptionsConfig(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s) TestGetCertificatesSNI(t *testing.T) {
|
||||
// Load server certificates for setting the serverGetCert callback function.
|
||||
serverCert1, err := tls.LoadX509KeyPair(testdata.Path("server_cert_1.pem"), testdata.Path("server_key_1.pem"))
|
||||
if err != nil {
|
||||
t.Fatalf("tls.LoadX509KeyPair(server_cert_1.pem, server_key_1.pem) failed: %v", err)
|
||||
}
|
||||
serverCert2, err := tls.LoadX509KeyPair(testdata.Path("server_cert_2.pem"), testdata.Path("server_key_2.pem"))
|
||||
if err != nil {
|
||||
t.Fatalf("tls.LoadX509KeyPair(server_cert_2.pem, server_key_2.pem) failed: %v", err)
|
||||
}
|
||||
serverCert3, err := tls.LoadX509KeyPair(testdata.Path("server_cert_3.pem"), testdata.Path("server_key_3.pem"))
|
||||
if err != nil {
|
||||
t.Fatalf("tls.LoadX509KeyPair(server_cert_3.pem, server_key_3.pem) failed: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
desc string
|
||||
serverName string
|
||||
wantCert tls.Certificate
|
||||
}{
|
||||
{
|
||||
desc: "Select serverCert1",
|
||||
// "foo.bar.com" is the common name on server certificate server_cert_1.pem.
|
||||
serverName: "foo.bar.com",
|
||||
wantCert: serverCert1,
|
||||
},
|
||||
{
|
||||
desc: "Select serverCert2",
|
||||
// "foo.bar.server2.com" is the common name on server certificate server_cert_2.pem.
|
||||
serverName: "foo.bar.server2.com",
|
||||
wantCert: serverCert2,
|
||||
},
|
||||
{
|
||||
desc: "Select serverCert3",
|
||||
// "google.com" is one of the DNS names on server certificate server_cert_3.pem.
|
||||
serverName: "google.com",
|
||||
wantCert: serverCert3,
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
serverOptions := &ServerOptions{
|
||||
GetCertificates: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
|
||||
return []*tls.Certificate{&serverCert1, &serverCert2, &serverCert3}, nil
|
||||
},
|
||||
}
|
||||
serverConfig, err := serverOptions.config()
|
||||
if err != nil {
|
||||
t.Fatalf("serverOptions.config() failed: %v", err)
|
||||
}
|
||||
pointFormatUncompressed := uint8(0)
|
||||
clientHello := &tls.ClientHelloInfo{
|
||||
CipherSuites: []uint16{tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA},
|
||||
ServerName: test.serverName,
|
||||
SupportedCurves: []tls.CurveID{tls.CurveP256},
|
||||
SupportedPoints: []uint8{pointFormatUncompressed},
|
||||
SupportedVersions: []uint16{tls.VersionTLS10},
|
||||
}
|
||||
gotCertificate, err := serverConfig.GetCertificate(clientHello)
|
||||
if err != nil {
|
||||
t.Fatalf("serverConfig.GetCertificate(clientHello) failed: %v", err)
|
||||
}
|
||||
if !cmp.Equal(*gotCertificate, test.wantCert, cmp.AllowUnexported(big.Int{})) {
|
||||
t.Errorf("GetCertificates() = %v, want %v", *gotCertificate, test.wantCert)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user