interop: use credentials.NewTLS() when possible (#4390)

This commit is contained in:
Easwar Swaminathan
2021-05-12 10:17:13 -07:00
committed by GitHub
parent a95a5c3bac
commit aa59641d5d

View File

@ -20,7 +20,10 @@
package main package main
import ( import (
"crypto/tls"
"crypto/x509"
"flag" "flag"
"io/ioutil"
"net" "net"
"strconv" "strconv"
@ -57,7 +60,7 @@ var (
serverHost = flag.String("server_host", "localhost", "The server host name") serverHost = flag.String("server_host", "localhost", "The server host name")
serverPort = flag.Int("server_port", 10000, "The server port number") serverPort = flag.Int("server_port", 10000, "The server port number")
serviceConfigJSON = flag.String("service_config_json", "", "Disables service config lookups and sets the provided string as the default service config.") serviceConfigJSON = flag.String("service_config_json", "", "Disables service config lookups and sets the provided string as the default service config.")
tlsServerName = flag.String("server_host_override", "", "The server name use to verify the hostname returned by TLS handshake if it is not empty. Otherwise, --server_host is used.") tlsServerName = flag.String("server_host_override", "", "The server name used to verify the hostname returned by TLS handshake if it is not empty. Otherwise, --server_host is used.")
testCase = flag.String("test_case", "large_unary", testCase = flag.String("test_case", "large_unary",
`Configure different test cases. Valid options are: `Configure different test cases. Valid options are:
empty_unary : empty (zero bytes) request and response; empty_unary : empty (zero bytes) request and response;
@ -135,22 +138,25 @@ func main() {
var opts []grpc.DialOption var opts []grpc.DialOption
switch credsChosen { switch credsChosen {
case credsTLS: case credsTLS:
var sn string var roots *x509.CertPool
if *tlsServerName != "" {
sn = *tlsServerName
}
var creds credentials.TransportCredentials
if *testCA { if *testCA {
var err error
if *caFile == "" { if *caFile == "" {
*caFile = testdata.Path("ca.pem") *caFile = testdata.Path("ca.pem")
} }
creds, err = credentials.NewClientTLSFromFile(*caFile, sn) b, err := ioutil.ReadFile(*caFile)
if err != nil { if err != nil {
logger.Fatalf("Failed to create TLS credentials %v", err) logger.Fatalf("Failed to read root certificate file %q: %v", *caFile, err)
} }
roots = x509.NewCertPool()
if !roots.AppendCertsFromPEM(b) {
logger.Fatalf("Failed to append certificates: %s", string(b))
}
}
var creds credentials.TransportCredentials
if *tlsServerName != "" {
creds = credentials.NewClientTLSFromCert(roots, *tlsServerName)
} else { } else {
creds = credentials.NewClientTLSFromCert(nil, sn) creds = credentials.NewTLS(&tls.Config{RootCAs: roots})
} }
opts = append(opts, grpc.WithTransportCredentials(creds)) opts = append(opts, grpc.WithTransportCredentials(creds))
case credsALTS: case credsALTS: