Overwrite authority if creds servername is specified

This commit is contained in:
Menghan Li
2016-09-02 16:54:17 -07:00
parent 52f6504dc2
commit a00cbfeab5
3 changed files with 39 additions and 11 deletions

View File

@ -322,11 +322,16 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
if ok {
go cc.lbWatcher()
}
colonPos := strings.LastIndex(target, ":")
if colonPos == -1 {
colonPos = len(target)
creds := cc.dopts.copts.TransportCredentials
if creds != nil && creds.Info().ServerName != "" {
cc.authority = creds.Info().ServerName
} else {
colonPos := strings.LastIndex(target, ":")
if colonPos == -1 {
colonPos = len(target)
}
cc.authority = target[:colonPos]
}
cc.authority = target[:colonPos]
return cc, nil
}

View File

@ -65,7 +65,23 @@ func TestTLSDialTimeout(t *testing.T) {
conn.Close()
}
if err != ErrClientConnTimeout {
t.Fatalf("grpc.Dial(_, _) = %v, %v, want %v", conn, err, ErrClientConnTimeout)
t.Fatalf("Dial(_, _) = %v, %v, want %v", conn, err, ErrClientConnTimeout)
}
}
func TestTLSServerNameOverwrite(t *testing.T) {
overwriteServerName := "over.write.server.name"
creds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", overwriteServerName)
if err != nil {
t.Fatalf("Failed to create credentials %v", err)
}
conn, err := Dial("Non-Existent.Server:80", WithTransportCredentials(creds), WithTimeout(time.Millisecond))
if err != nil {
t.Fatalf("Dial(_, _) = _, %v, want _, <nil>", err)
}
conn.Close()
if conn.authority != overwriteServerName {
t.Fatalf("%v.authority = %v, want %v", conn, conn.authority, overwriteServerName)
}
}
@ -73,7 +89,7 @@ func TestDialContextCancel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
if _, err := DialContext(ctx, "Non-Existent.Server:80", WithBlock(), WithInsecure()); err != context.Canceled {
t.Fatalf("grpc.DialContext(%v, _) = _, %v, want _, %v", ctx, err, context.Canceled)
t.Fatalf("DialContext(%v, _) = _, %v, want _, %v", ctx, err, context.Canceled)
}
}

View File

@ -72,7 +72,7 @@ type PerRPCCredentials interface {
}
// ProtocolInfo provides information regarding the gRPC wire protocol version,
// security protocol, security protocol version in use, etc.
// security protocol, security protocol version in use, server name, etc.
type ProtocolInfo struct {
// ProtocolVersion is the gRPC wire protocol version.
ProtocolVersion string
@ -80,6 +80,8 @@ type ProtocolInfo struct {
SecurityProtocol string
// SecurityVersion is the security protocol version.
SecurityVersion string
// ServerName is the user-configured server name.
ServerName string
}
// AuthInfo defines the common interface for the auth information the users are interested in.
@ -130,6 +132,7 @@ func (c tlsCreds) Info() ProtocolInfo {
return ProtocolInfo{
SecurityProtocol: "tls",
SecurityVersion: "1.2",
ServerName: c.config.ServerName,
}
}
@ -187,12 +190,16 @@ func NewTLS(c *tls.Config) TransportCredentials {
}
// NewClientTLSFromCert constructs a TLS from the input certificate for client.
func NewClientTLSFromCert(cp *x509.CertPool, serverName string) TransportCredentials {
return NewTLS(&tls.Config{ServerName: serverName, RootCAs: cp})
// serverNameOverwrite is for testing only. If set to a non empty string,
// it will overwrite the virtual host name of authority (e.g. :authority header field) in requests.
func NewClientTLSFromCert(cp *x509.CertPool, serverNameOverwrite string) TransportCredentials {
return NewTLS(&tls.Config{ServerName: serverNameOverwrite, RootCAs: cp})
}
// NewClientTLSFromFile constructs a TLS from the input certificate file for client.
func NewClientTLSFromFile(certFile, serverName string) (TransportCredentials, error) {
// serverNameOverwrite is for testing only. If set to a non empty string,
// it will overwrite the virtual host name of authority (e.g. :authority header field) in requests.
func NewClientTLSFromFile(certFile, serverNameOverwrite string) (TransportCredentials, error) {
b, err := ioutil.ReadFile(certFile)
if err != nil {
return nil, err
@ -201,7 +208,7 @@ func NewClientTLSFromFile(certFile, serverName string) (TransportCredentials, er
if !cp.AppendCertsFromPEM(b) {
return nil, fmt.Errorf("credentials: failed to append certificates")
}
return NewTLS(&tls.Config{ServerName: serverName, RootCAs: cp}), nil
return NewTLS(&tls.Config{ServerName: serverNameOverwrite, RootCAs: cp}), nil
}
// NewServerTLSFromCert constructs a TLS from the input certificate for server.