interop: use credentials.NewTLS() when possible (#4390)
This commit is contained in:

committed by
GitHub

parent
a95a5c3bac
commit
aa59641d5d
@ -20,7 +20,10 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
"flag"
|
"flag"
|
||||||
|
"io/ioutil"
|
||||||
"net"
|
"net"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
@ -57,7 +60,7 @@ var (
|
|||||||
serverHost = flag.String("server_host", "localhost", "The server host name")
|
serverHost = flag.String("server_host", "localhost", "The server host name")
|
||||||
serverPort = flag.Int("server_port", 10000, "The server port number")
|
serverPort = flag.Int("server_port", 10000, "The server port number")
|
||||||
serviceConfigJSON = flag.String("service_config_json", "", "Disables service config lookups and sets the provided string as the default service config.")
|
serviceConfigJSON = flag.String("service_config_json", "", "Disables service config lookups and sets the provided string as the default service config.")
|
||||||
tlsServerName = flag.String("server_host_override", "", "The server name use to verify the hostname returned by TLS handshake if it is not empty. Otherwise, --server_host is used.")
|
tlsServerName = flag.String("server_host_override", "", "The server name used to verify the hostname returned by TLS handshake if it is not empty. Otherwise, --server_host is used.")
|
||||||
testCase = flag.String("test_case", "large_unary",
|
testCase = flag.String("test_case", "large_unary",
|
||||||
`Configure different test cases. Valid options are:
|
`Configure different test cases. Valid options are:
|
||||||
empty_unary : empty (zero bytes) request and response;
|
empty_unary : empty (zero bytes) request and response;
|
||||||
@ -135,22 +138,25 @@ func main() {
|
|||||||
var opts []grpc.DialOption
|
var opts []grpc.DialOption
|
||||||
switch credsChosen {
|
switch credsChosen {
|
||||||
case credsTLS:
|
case credsTLS:
|
||||||
var sn string
|
var roots *x509.CertPool
|
||||||
if *tlsServerName != "" {
|
|
||||||
sn = *tlsServerName
|
|
||||||
}
|
|
||||||
var creds credentials.TransportCredentials
|
|
||||||
if *testCA {
|
if *testCA {
|
||||||
var err error
|
|
||||||
if *caFile == "" {
|
if *caFile == "" {
|
||||||
*caFile = testdata.Path("ca.pem")
|
*caFile = testdata.Path("ca.pem")
|
||||||
}
|
}
|
||||||
creds, err = credentials.NewClientTLSFromFile(*caFile, sn)
|
b, err := ioutil.ReadFile(*caFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatalf("Failed to create TLS credentials %v", err)
|
logger.Fatalf("Failed to read root certificate file %q: %v", *caFile, err)
|
||||||
}
|
}
|
||||||
|
roots = x509.NewCertPool()
|
||||||
|
if !roots.AppendCertsFromPEM(b) {
|
||||||
|
logger.Fatalf("Failed to append certificates: %s", string(b))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var creds credentials.TransportCredentials
|
||||||
|
if *tlsServerName != "" {
|
||||||
|
creds = credentials.NewClientTLSFromCert(roots, *tlsServerName)
|
||||||
} else {
|
} else {
|
||||||
creds = credentials.NewClientTLSFromCert(nil, sn)
|
creds = credentials.NewTLS(&tls.Config{RootCAs: roots})
|
||||||
}
|
}
|
||||||
opts = append(opts, grpc.WithTransportCredentials(creds))
|
opts = append(opts, grpc.WithTransportCredentials(creds))
|
||||||
case credsALTS:
|
case credsALTS:
|
||||||
|
Reference in New Issue
Block a user