advancedtls: fixed SNI testing and put SNI functions back in advancedtls.go (#3774)
* Fixed sni unit test
This commit is contained in:
@ -506,3 +506,30 @@ func cloneTLSConfig(cfg *tls.Config) *tls.Config {
|
|||||||
}
|
}
|
||||||
return cfg.Clone()
|
return cfg.Clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// buildGetCertificates returns the certificate that matches the SNI field
|
||||||
|
// for the given ClientHelloInfo, defaulting to the first element of o.GetCertificates.
|
||||||
|
func buildGetCertificates(clientHello *tls.ClientHelloInfo, o *ServerOptions) (*tls.Certificate, error) {
|
||||||
|
if o.GetCertificates == nil {
|
||||||
|
return nil, fmt.Errorf("function GetCertificates must be specified")
|
||||||
|
}
|
||||||
|
certificates, err := o.GetCertificates(clientHello)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(certificates) == 0 {
|
||||||
|
return nil, fmt.Errorf("no certificates configured")
|
||||||
|
}
|
||||||
|
// If users pass in only one certificate, return that certificate.
|
||||||
|
if len(certificates) == 1 {
|
||||||
|
return certificates[0], nil
|
||||||
|
}
|
||||||
|
// Choose the SNI certificate using SupportsCertificate.
|
||||||
|
for _, cert := range certificates {
|
||||||
|
if err := clientHello.SupportsCertificate(cert); err == nil {
|
||||||
|
return cert, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// If nothing matches, return the first certificate.
|
||||||
|
return certificates[0], nil
|
||||||
|
}
|
||||||
|
@ -26,11 +26,13 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
|
"math/big"
|
||||||
"net"
|
"net"
|
||||||
"reflect"
|
"reflect"
|
||||||
"syscall"
|
"syscall"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
"google.golang.org/grpc/credentials"
|
"google.golang.org/grpc/credentials"
|
||||||
"google.golang.org/grpc/internal/grpctest"
|
"google.golang.org/grpc/internal/grpctest"
|
||||||
"google.golang.org/grpc/security/advancedtls/testdata"
|
"google.golang.org/grpc/security/advancedtls/testdata"
|
||||||
@ -688,3 +690,73 @@ func (s) TestOptionsConfig(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s) TestGetCertificatesSNI(t *testing.T) {
|
||||||
|
// Load server certificates for setting the serverGetCert callback function.
|
||||||
|
serverCert1, err := tls.LoadX509KeyPair(testdata.Path("server_cert_1.pem"), testdata.Path("server_key_1.pem"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("tls.LoadX509KeyPair(server_cert_1.pem, server_key_1.pem) failed: %v", err)
|
||||||
|
}
|
||||||
|
serverCert2, err := tls.LoadX509KeyPair(testdata.Path("server_cert_2.pem"), testdata.Path("server_key_2.pem"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("tls.LoadX509KeyPair(server_cert_2.pem, server_key_2.pem) failed: %v", err)
|
||||||
|
}
|
||||||
|
serverCert3, err := tls.LoadX509KeyPair(testdata.Path("server_cert_3.pem"), testdata.Path("server_key_3.pem"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("tls.LoadX509KeyPair(server_cert_3.pem, server_key_3.pem) failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
desc string
|
||||||
|
serverName string
|
||||||
|
wantCert tls.Certificate
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
desc: "Select serverCert1",
|
||||||
|
// "foo.bar.com" is the common name on server certificate server_cert_1.pem.
|
||||||
|
serverName: "foo.bar.com",
|
||||||
|
wantCert: serverCert1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "Select serverCert2",
|
||||||
|
// "foo.bar.server2.com" is the common name on server certificate server_cert_2.pem.
|
||||||
|
serverName: "foo.bar.server2.com",
|
||||||
|
wantCert: serverCert2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "Select serverCert3",
|
||||||
|
// "google.com" is one of the DNS names on server certificate server_cert_3.pem.
|
||||||
|
serverName: "google.com",
|
||||||
|
wantCert: serverCert3,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, test := range tests {
|
||||||
|
test := test
|
||||||
|
t.Run(test.desc, func(t *testing.T) {
|
||||||
|
serverOptions := &ServerOptions{
|
||||||
|
GetCertificates: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
|
||||||
|
return []*tls.Certificate{&serverCert1, &serverCert2, &serverCert3}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
serverConfig, err := serverOptions.config()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("serverOptions.config() failed: %v", err)
|
||||||
|
}
|
||||||
|
pointFormatUncompressed := uint8(0)
|
||||||
|
clientHello := &tls.ClientHelloInfo{
|
||||||
|
CipherSuites: []uint16{tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA},
|
||||||
|
ServerName: test.serverName,
|
||||||
|
SupportedCurves: []tls.CurveID{tls.CurveP256},
|
||||||
|
SupportedPoints: []uint8{pointFormatUncompressed},
|
||||||
|
SupportedVersions: []uint16{tls.VersionTLS10},
|
||||||
|
}
|
||||||
|
gotCertificate, err := serverConfig.GetCertificate(clientHello)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("serverConfig.GetCertificate(clientHello) failed: %v", err)
|
||||||
|
}
|
||||||
|
if !cmp.Equal(*gotCertificate, test.wantCert, cmp.AllowUnexported(big.Int{})) {
|
||||||
|
t.Errorf("GetCertificates() = %v, want %v", *gotCertificate, test.wantCert)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -1,51 +0,0 @@
|
|||||||
/*
|
|
||||||
*
|
|
||||||
* Copyright 2020 gRPC authors.
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
|
|
||||||
package advancedtls
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/tls"
|
|
||||||
"fmt"
|
|
||||||
)
|
|
||||||
|
|
||||||
// buildGetCertificates returns the certificate that matches the SNI field
|
|
||||||
// for the given ClientHelloInfo, defaulting to the first element of o.GetCertificates.
|
|
||||||
func buildGetCertificates(clientHello *tls.ClientHelloInfo, o *ServerOptions) (*tls.Certificate, error) {
|
|
||||||
if o.GetCertificates == nil {
|
|
||||||
return nil, fmt.Errorf("function GetCertificates must be specified")
|
|
||||||
}
|
|
||||||
certificates, err := o.GetCertificates(clientHello)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if len(certificates) == 0 {
|
|
||||||
return nil, fmt.Errorf("no certificates configured")
|
|
||||||
}
|
|
||||||
// If users pass in only one certificate, return that certificate.
|
|
||||||
if len(certificates) == 1 {
|
|
||||||
return certificates[0], nil
|
|
||||||
}
|
|
||||||
// Choose the SNI certificate using SupportsCertificate.
|
|
||||||
for _, cert := range certificates {
|
|
||||||
if err := clientHello.SupportsCertificate(cert); err == nil {
|
|
||||||
return cert, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// If nothing matches, return the first certificate.
|
|
||||||
return certificates[0], nil
|
|
||||||
}
|
|
@ -1,106 +0,0 @@
|
|||||||
/*
|
|
||||||
*
|
|
||||||
* Copyright 2019 gRPC authors.
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
|
|
||||||
package advancedtls
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/tls"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
|
||||||
"google.golang.org/grpc/security/advancedtls/testdata"
|
|
||||||
)
|
|
||||||
|
|
||||||
// TestGetCertificatesSNI tests SNI logic for go1.10 and above.
|
|
||||||
func TestGetCertificatesSNI(t *testing.T) {
|
|
||||||
// Load server certificates for setting the serverGetCert callback function.
|
|
||||||
serverCert1, err := tls.LoadX509KeyPair(testdata.Path("server_cert_1.pem"), testdata.Path("server_key_1.pem"))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("tls.LoadX509KeyPair(server_cert_1.pem, server_key_1.pem) failed: %v", err)
|
|
||||||
}
|
|
||||||
serverCert2, err := tls.LoadX509KeyPair(testdata.Path("server_cert_2.pem"), testdata.Path("server_key_2.pem"))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("tls.LoadX509KeyPair(server_cert_2.pem, server_key_2.pem) failed: %v", err)
|
|
||||||
}
|
|
||||||
serverCert3, err := tls.LoadX509KeyPair(testdata.Path("server_cert_3.pem"), testdata.Path("server_key_3.pem"))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("tls.LoadX509KeyPair(server_cert_3.pem, server_key_3.pem) failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
desc string
|
|
||||||
serverGetCert func(*tls.ClientHelloInfo) ([]*tls.Certificate, error)
|
|
||||||
serverName string
|
|
||||||
wantCert tls.Certificate
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
desc: "Select serverCert1",
|
|
||||||
serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
|
|
||||||
return []*tls.Certificate{&serverCert1, &serverCert2, &serverCert3}, nil
|
|
||||||
},
|
|
||||||
// "foo.bar.com" is the common name on server certificate server_cert_1.pem.
|
|
||||||
serverName: "foo.bar.com",
|
|
||||||
wantCert: serverCert1,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "Select serverCert2",
|
|
||||||
serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
|
|
||||||
return []*tls.Certificate{&serverCert1, &serverCert2, &serverCert3}, nil
|
|
||||||
},
|
|
||||||
// "foo.bar.server2.com" is the common name on server certificate server_cert_2.pem.
|
|
||||||
serverName: "foo.bar.server2.com",
|
|
||||||
wantCert: serverCert2,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "Select serverCert3",
|
|
||||||
serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
|
|
||||||
return []*tls.Certificate{&serverCert1, &serverCert2, &serverCert3}, nil
|
|
||||||
},
|
|
||||||
// "google.com" is one of the DNS names on server certificate server_cert_3.pem.
|
|
||||||
serverName: "google.com",
|
|
||||||
wantCert: serverCert3,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, test := range tests {
|
|
||||||
test := test
|
|
||||||
t.Run(test.desc, func(t *testing.T) {
|
|
||||||
serverOptions := &ServerOptions{
|
|
||||||
GetCertificates: test.serverGetCert,
|
|
||||||
}
|
|
||||||
serverConfig, err := serverOptions.config()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("serverOptions.config() failed: %v", err)
|
|
||||||
}
|
|
||||||
pointFormatUncompressed := uint8(0)
|
|
||||||
clientHello := &tls.ClientHelloInfo{
|
|
||||||
CipherSuites: []uint16{tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA},
|
|
||||||
ServerName: test.serverName,
|
|
||||||
SupportedCurves: []tls.CurveID{tls.CurveP256},
|
|
||||||
SupportedPoints: []uint8{pointFormatUncompressed},
|
|
||||||
SupportedVersions: []uint16{tls.VersionTLS10},
|
|
||||||
}
|
|
||||||
gotCertificate, err := serverConfig.GetCertificate(clientHello)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("serverConfig.GetCertificate(clientHello) failed: %v", err)
|
|
||||||
}
|
|
||||||
if !cmp.Equal(gotCertificate, test.wantCert, cmp.AllowUnexported(tls.Certificate{})) {
|
|
||||||
t.Errorf("GetCertificates() = %v, want %v", gotCertificate, test.wantCert)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
Reference in New Issue
Block a user