Overwrite authority if creds servername is specified
This commit is contained in:
@ -322,11 +322,16 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
|
||||
if ok {
|
||||
go cc.lbWatcher()
|
||||
}
|
||||
colonPos := strings.LastIndex(target, ":")
|
||||
if colonPos == -1 {
|
||||
colonPos = len(target)
|
||||
creds := cc.dopts.copts.TransportCredentials
|
||||
if creds != nil && creds.Info().ServerName != "" {
|
||||
cc.authority = creds.Info().ServerName
|
||||
} else {
|
||||
colonPos := strings.LastIndex(target, ":")
|
||||
if colonPos == -1 {
|
||||
colonPos = len(target)
|
||||
}
|
||||
cc.authority = target[:colonPos]
|
||||
}
|
||||
cc.authority = target[:colonPos]
|
||||
return cc, nil
|
||||
}
|
||||
|
||||
|
@ -65,7 +65,23 @@ func TestTLSDialTimeout(t *testing.T) {
|
||||
conn.Close()
|
||||
}
|
||||
if err != ErrClientConnTimeout {
|
||||
t.Fatalf("grpc.Dial(_, _) = %v, %v, want %v", conn, err, ErrClientConnTimeout)
|
||||
t.Fatalf("Dial(_, _) = %v, %v, want %v", conn, err, ErrClientConnTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTLSServerNameOverwrite(t *testing.T) {
|
||||
overwriteServerName := "over.write.server.name"
|
||||
creds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", overwriteServerName)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create credentials %v", err)
|
||||
}
|
||||
conn, err := Dial("Non-Existent.Server:80", WithTransportCredentials(creds), WithTimeout(time.Millisecond))
|
||||
if err != nil {
|
||||
t.Fatalf("Dial(_, _) = _, %v, want _, <nil>", err)
|
||||
}
|
||||
conn.Close()
|
||||
if conn.authority != overwriteServerName {
|
||||
t.Fatalf("%v.authority = %v, want %v", conn, conn.authority, overwriteServerName)
|
||||
}
|
||||
}
|
||||
|
||||
@ -73,7 +89,7 @@ func TestDialContextCancel(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
if _, err := DialContext(ctx, "Non-Existent.Server:80", WithBlock(), WithInsecure()); err != context.Canceled {
|
||||
t.Fatalf("grpc.DialContext(%v, _) = _, %v, want _, %v", ctx, err, context.Canceled)
|
||||
t.Fatalf("DialContext(%v, _) = _, %v, want _, %v", ctx, err, context.Canceled)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -72,7 +72,7 @@ type PerRPCCredentials interface {
|
||||
}
|
||||
|
||||
// ProtocolInfo provides information regarding the gRPC wire protocol version,
|
||||
// security protocol, security protocol version in use, etc.
|
||||
// security protocol, security protocol version in use, server name, etc.
|
||||
type ProtocolInfo struct {
|
||||
// ProtocolVersion is the gRPC wire protocol version.
|
||||
ProtocolVersion string
|
||||
@ -80,6 +80,8 @@ type ProtocolInfo struct {
|
||||
SecurityProtocol string
|
||||
// SecurityVersion is the security protocol version.
|
||||
SecurityVersion string
|
||||
// ServerName is the user-configured server name.
|
||||
ServerName string
|
||||
}
|
||||
|
||||
// AuthInfo defines the common interface for the auth information the users are interested in.
|
||||
@ -130,6 +132,7 @@ func (c tlsCreds) Info() ProtocolInfo {
|
||||
return ProtocolInfo{
|
||||
SecurityProtocol: "tls",
|
||||
SecurityVersion: "1.2",
|
||||
ServerName: c.config.ServerName,
|
||||
}
|
||||
}
|
||||
|
||||
@ -187,12 +190,16 @@ func NewTLS(c *tls.Config) TransportCredentials {
|
||||
}
|
||||
|
||||
// NewClientTLSFromCert constructs a TLS from the input certificate for client.
|
||||
func NewClientTLSFromCert(cp *x509.CertPool, serverName string) TransportCredentials {
|
||||
return NewTLS(&tls.Config{ServerName: serverName, RootCAs: cp})
|
||||
// serverNameOverwrite is for testing only. If set to a non empty string,
|
||||
// it will overwrite the virtual host name of authority (e.g. :authority header field) in requests.
|
||||
func NewClientTLSFromCert(cp *x509.CertPool, serverNameOverwrite string) TransportCredentials {
|
||||
return NewTLS(&tls.Config{ServerName: serverNameOverwrite, RootCAs: cp})
|
||||
}
|
||||
|
||||
// NewClientTLSFromFile constructs a TLS from the input certificate file for client.
|
||||
func NewClientTLSFromFile(certFile, serverName string) (TransportCredentials, error) {
|
||||
// serverNameOverwrite is for testing only. If set to a non empty string,
|
||||
// it will overwrite the virtual host name of authority (e.g. :authority header field) in requests.
|
||||
func NewClientTLSFromFile(certFile, serverNameOverwrite string) (TransportCredentials, error) {
|
||||
b, err := ioutil.ReadFile(certFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -201,7 +208,7 @@ func NewClientTLSFromFile(certFile, serverName string) (TransportCredentials, er
|
||||
if !cp.AppendCertsFromPEM(b) {
|
||||
return nil, fmt.Errorf("credentials: failed to append certificates")
|
||||
}
|
||||
return NewTLS(&tls.Config{ServerName: serverName, RootCAs: cp}), nil
|
||||
return NewTLS(&tls.Config{ServerName: serverNameOverwrite, RootCAs: cp}), nil
|
||||
}
|
||||
|
||||
// NewServerTLSFromCert constructs a TLS from the input certificate for server.
|
||||
|
Reference in New Issue
Block a user