From b830b5f361072689b94d99055e8ffc51821eabf3 Mon Sep 17 00:00:00 2001 From: cindyxue <32377977+cindyxue@users.noreply.github.com> Date: Thu, 6 Aug 2020 11:10:47 -0700 Subject: [PATCH] advancedtls: fixed SNI testing and put SNI functions back in advancedtls.go (#3774) * Fixed sni unit test --- security/advancedtls/advancedtls.go | 27 ++++++ security/advancedtls/advancedtls_test.go | 72 +++++++++++++++ security/advancedtls/sni.go | 51 ----------- security/advancedtls/sni_test_disabled.go | 106 ---------------------- 4 files changed, 99 insertions(+), 157 deletions(-) delete mode 100644 security/advancedtls/sni.go delete mode 100644 security/advancedtls/sni_test_disabled.go diff --git a/security/advancedtls/advancedtls.go b/security/advancedtls/advancedtls.go index 40a9fd5f..c21ff753 100644 --- a/security/advancedtls/advancedtls.go +++ b/security/advancedtls/advancedtls.go @@ -506,3 +506,30 @@ func cloneTLSConfig(cfg *tls.Config) *tls.Config { } return cfg.Clone() } + +// buildGetCertificates returns the certificate that matches the SNI field +// for the given ClientHelloInfo, defaulting to the first element of o.GetCertificates. +func buildGetCertificates(clientHello *tls.ClientHelloInfo, o *ServerOptions) (*tls.Certificate, error) { + if o.GetCertificates == nil { + return nil, fmt.Errorf("function GetCertificates must be specified") + } + certificates, err := o.GetCertificates(clientHello) + if err != nil { + return nil, err + } + if len(certificates) == 0 { + return nil, fmt.Errorf("no certificates configured") + } + // If users pass in only one certificate, return that certificate. + if len(certificates) == 1 { + return certificates[0], nil + } + // Choose the SNI certificate using SupportsCertificate. + for _, cert := range certificates { + if err := clientHello.SupportsCertificate(cert); err == nil { + return cert, nil + } + } + // If nothing matches, return the first certificate. + return certificates[0], nil +} diff --git a/security/advancedtls/advancedtls_test.go b/security/advancedtls/advancedtls_test.go index 4e79d0d8..8800a51a 100644 --- a/security/advancedtls/advancedtls_test.go +++ b/security/advancedtls/advancedtls_test.go @@ -26,11 +26,13 @@ import ( "errors" "fmt" "io/ioutil" + "math/big" "net" "reflect" "syscall" "testing" + "github.com/google/go-cmp/cmp" "google.golang.org/grpc/credentials" "google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/security/advancedtls/testdata" @@ -688,3 +690,73 @@ func (s) TestOptionsConfig(t *testing.T) { }) } } + +func (s) TestGetCertificatesSNI(t *testing.T) { + // Load server certificates for setting the serverGetCert callback function. + serverCert1, err := tls.LoadX509KeyPair(testdata.Path("server_cert_1.pem"), testdata.Path("server_key_1.pem")) + if err != nil { + t.Fatalf("tls.LoadX509KeyPair(server_cert_1.pem, server_key_1.pem) failed: %v", err) + } + serverCert2, err := tls.LoadX509KeyPair(testdata.Path("server_cert_2.pem"), testdata.Path("server_key_2.pem")) + if err != nil { + t.Fatalf("tls.LoadX509KeyPair(server_cert_2.pem, server_key_2.pem) failed: %v", err) + } + serverCert3, err := tls.LoadX509KeyPair(testdata.Path("server_cert_3.pem"), testdata.Path("server_key_3.pem")) + if err != nil { + t.Fatalf("tls.LoadX509KeyPair(server_cert_3.pem, server_key_3.pem) failed: %v", err) + } + + tests := []struct { + desc string + serverName string + wantCert tls.Certificate + }{ + { + desc: "Select serverCert1", + // "foo.bar.com" is the common name on server certificate server_cert_1.pem. + serverName: "foo.bar.com", + wantCert: serverCert1, + }, + { + desc: "Select serverCert2", + // "foo.bar.server2.com" is the common name on server certificate server_cert_2.pem. + serverName: "foo.bar.server2.com", + wantCert: serverCert2, + }, + { + desc: "Select serverCert3", + // "google.com" is one of the DNS names on server certificate server_cert_3.pem. + serverName: "google.com", + wantCert: serverCert3, + }, + } + for _, test := range tests { + test := test + t.Run(test.desc, func(t *testing.T) { + serverOptions := &ServerOptions{ + GetCertificates: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) { + return []*tls.Certificate{&serverCert1, &serverCert2, &serverCert3}, nil + }, + } + serverConfig, err := serverOptions.config() + if err != nil { + t.Fatalf("serverOptions.config() failed: %v", err) + } + pointFormatUncompressed := uint8(0) + clientHello := &tls.ClientHelloInfo{ + CipherSuites: []uint16{tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA}, + ServerName: test.serverName, + SupportedCurves: []tls.CurveID{tls.CurveP256}, + SupportedPoints: []uint8{pointFormatUncompressed}, + SupportedVersions: []uint16{tls.VersionTLS10}, + } + gotCertificate, err := serverConfig.GetCertificate(clientHello) + if err != nil { + t.Fatalf("serverConfig.GetCertificate(clientHello) failed: %v", err) + } + if !cmp.Equal(*gotCertificate, test.wantCert, cmp.AllowUnexported(big.Int{})) { + t.Errorf("GetCertificates() = %v, want %v", *gotCertificate, test.wantCert) + } + }) + } +} diff --git a/security/advancedtls/sni.go b/security/advancedtls/sni.go deleted file mode 100644 index 00e551fa..00000000 --- a/security/advancedtls/sni.go +++ /dev/null @@ -1,51 +0,0 @@ -/* - * - * Copyright 2020 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package advancedtls - -import ( - "crypto/tls" - "fmt" -) - -// buildGetCertificates returns the certificate that matches the SNI field -// for the given ClientHelloInfo, defaulting to the first element of o.GetCertificates. -func buildGetCertificates(clientHello *tls.ClientHelloInfo, o *ServerOptions) (*tls.Certificate, error) { - if o.GetCertificates == nil { - return nil, fmt.Errorf("function GetCertificates must be specified") - } - certificates, err := o.GetCertificates(clientHello) - if err != nil { - return nil, err - } - if len(certificates) == 0 { - return nil, fmt.Errorf("no certificates configured") - } - // If users pass in only one certificate, return that certificate. - if len(certificates) == 1 { - return certificates[0], nil - } - // Choose the SNI certificate using SupportsCertificate. - for _, cert := range certificates { - if err := clientHello.SupportsCertificate(cert); err == nil { - return cert, nil - } - } - // If nothing matches, return the first certificate. - return certificates[0], nil -} diff --git a/security/advancedtls/sni_test_disabled.go b/security/advancedtls/sni_test_disabled.go deleted file mode 100644 index 3e9e19c1..00000000 --- a/security/advancedtls/sni_test_disabled.go +++ /dev/null @@ -1,106 +0,0 @@ -/* - * - * Copyright 2019 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package advancedtls - -import ( - "crypto/tls" - "testing" - - "github.com/google/go-cmp/cmp" - "google.golang.org/grpc/security/advancedtls/testdata" -) - -// TestGetCertificatesSNI tests SNI logic for go1.10 and above. -func TestGetCertificatesSNI(t *testing.T) { - // Load server certificates for setting the serverGetCert callback function. - serverCert1, err := tls.LoadX509KeyPair(testdata.Path("server_cert_1.pem"), testdata.Path("server_key_1.pem")) - if err != nil { - t.Fatalf("tls.LoadX509KeyPair(server_cert_1.pem, server_key_1.pem) failed: %v", err) - } - serverCert2, err := tls.LoadX509KeyPair(testdata.Path("server_cert_2.pem"), testdata.Path("server_key_2.pem")) - if err != nil { - t.Fatalf("tls.LoadX509KeyPair(server_cert_2.pem, server_key_2.pem) failed: %v", err) - } - serverCert3, err := tls.LoadX509KeyPair(testdata.Path("server_cert_3.pem"), testdata.Path("server_key_3.pem")) - if err != nil { - t.Fatalf("tls.LoadX509KeyPair(server_cert_3.pem, server_key_3.pem) failed: %v", err) - } - - tests := []struct { - desc string - serverGetCert func(*tls.ClientHelloInfo) ([]*tls.Certificate, error) - serverName string - wantCert tls.Certificate - }{ - { - desc: "Select serverCert1", - serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) { - return []*tls.Certificate{&serverCert1, &serverCert2, &serverCert3}, nil - }, - // "foo.bar.com" is the common name on server certificate server_cert_1.pem. - serverName: "foo.bar.com", - wantCert: serverCert1, - }, - { - desc: "Select serverCert2", - serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) { - return []*tls.Certificate{&serverCert1, &serverCert2, &serverCert3}, nil - }, - // "foo.bar.server2.com" is the common name on server certificate server_cert_2.pem. - serverName: "foo.bar.server2.com", - wantCert: serverCert2, - }, - { - desc: "Select serverCert3", - serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) { - return []*tls.Certificate{&serverCert1, &serverCert2, &serverCert3}, nil - }, - // "google.com" is one of the DNS names on server certificate server_cert_3.pem. - serverName: "google.com", - wantCert: serverCert3, - }, - } - for _, test := range tests { - test := test - t.Run(test.desc, func(t *testing.T) { - serverOptions := &ServerOptions{ - GetCertificates: test.serverGetCert, - } - serverConfig, err := serverOptions.config() - if err != nil { - t.Fatalf("serverOptions.config() failed: %v", err) - } - pointFormatUncompressed := uint8(0) - clientHello := &tls.ClientHelloInfo{ - CipherSuites: []uint16{tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA}, - ServerName: test.serverName, - SupportedCurves: []tls.CurveID{tls.CurveP256}, - SupportedPoints: []uint8{pointFormatUncompressed}, - SupportedVersions: []uint16{tls.VersionTLS10}, - } - gotCertificate, err := serverConfig.GetCertificate(clientHello) - if err != nil { - t.Fatalf("serverConfig.GetCertificate(clientHello) failed: %v", err) - } - if !cmp.Equal(gotCertificate, test.wantCert, cmp.AllowUnexported(tls.Certificate{})) { - t.Errorf("GetCertificates() = %v, want %v", gotCertificate, test.wantCert) - } - }) - } -}