Add Clone() and OverrideServerName() to TransportCredentials
This commit is contained in:
@ -109,6 +109,12 @@ type TransportCredentials interface {
|
|||||||
ServerHandshake(net.Conn) (net.Conn, AuthInfo, error)
|
ServerHandshake(net.Conn) (net.Conn, AuthInfo, error)
|
||||||
// Info provides the ProtocolInfo of this TransportCredentials.
|
// Info provides the ProtocolInfo of this TransportCredentials.
|
||||||
Info() ProtocolInfo
|
Info() ProtocolInfo
|
||||||
|
// Clone makes a copy of this TransportCredentials.
|
||||||
|
Clone() TransportCredentials
|
||||||
|
// OverrideServerName overrides the server name used to verify the hostname on the returned certificates from the server.
|
||||||
|
// gRPC internals also use it to override the virtual hosting name if it is set.
|
||||||
|
// It must be called before dialing. Currently, this is only used by grpclb.
|
||||||
|
OverrideServerName(string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// TLSInfo contains the auth information for a TLS authenticated connection.
|
// TLSInfo contains the auth information for a TLS authenticated connection.
|
||||||
@ -136,16 +142,6 @@ func (c tlsCreds) Info() ProtocolInfo {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRequestMetadata returns nil, nil since TLS credentials does not have
|
|
||||||
// metadata.
|
|
||||||
func (c *tlsCreds) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *tlsCreds) RequireTransportSecurity() bool {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *tlsCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (_ net.Conn, _ AuthInfo, err error) {
|
func (c *tlsCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (_ net.Conn, _ AuthInfo, err error) {
|
||||||
// use local cfg to avoid clobbering ServerName if using multiple endpoints
|
// use local cfg to avoid clobbering ServerName if using multiple endpoints
|
||||||
cfg := cloneTLSConfig(c.config)
|
cfg := cloneTLSConfig(c.config)
|
||||||
@ -182,6 +178,15 @@ func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error)
|
|||||||
return conn, TLSInfo{conn.ConnectionState()}, nil
|
return conn, TLSInfo{conn.ConnectionState()}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *tlsCreds) Clone() TransportCredentials {
|
||||||
|
return NewTLS(c.config)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *tlsCreds) OverrideServerName(serverNameOverride string) error {
|
||||||
|
c.config.ServerName = serverNameOverride
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// NewTLS uses c to construct a TransportCredentials based on TLS.
|
// NewTLS uses c to construct a TransportCredentials based on TLS.
|
||||||
func NewTLS(c *tls.Config) TransportCredentials {
|
func NewTLS(c *tls.Config) TransportCredentials {
|
||||||
tc := &tlsCreds{cloneTLSConfig(c)}
|
tc := &tlsCreds{cloneTLSConfig(c)}
|
||||||
@ -190,16 +195,16 @@ func NewTLS(c *tls.Config) TransportCredentials {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewClientTLSFromCert constructs a TLS from the input certificate for client.
|
// NewClientTLSFromCert constructs a TLS from the input certificate for client.
|
||||||
// serverNameOverwrite is for testing only. If set to a non empty string,
|
// serverNameOverride is for testing only. If set to a non empty string,
|
||||||
// it will overwrite the virtual host name of authority (e.g. :authority header field) in requests.
|
// it will override the virtual host name of authority (e.g. :authority header field) in requests.
|
||||||
func NewClientTLSFromCert(cp *x509.CertPool, serverNameOverwrite string) TransportCredentials {
|
func NewClientTLSFromCert(cp *x509.CertPool, serverNameOverride string) TransportCredentials {
|
||||||
return NewTLS(&tls.Config{ServerName: serverNameOverwrite, RootCAs: cp})
|
return NewTLS(&tls.Config{ServerName: serverNameOverride, RootCAs: cp})
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClientTLSFromFile constructs a TLS from the input certificate file for client.
|
// NewClientTLSFromFile constructs a TLS from the input certificate file for client.
|
||||||
// serverNameOverwrite is for testing only. If set to a non empty string,
|
// serverNameOverride is for testing only. If set to a non empty string,
|
||||||
// it will overwrite the virtual host name of authority (e.g. :authority header field) in requests.
|
// it will override the virtual host name of authority (e.g. :authority header field) in requests.
|
||||||
func NewClientTLSFromFile(certFile, serverNameOverwrite string) (TransportCredentials, error) {
|
func NewClientTLSFromFile(certFile, serverNameOverride string) (TransportCredentials, error) {
|
||||||
b, err := ioutil.ReadFile(certFile)
|
b, err := ioutil.ReadFile(certFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -208,7 +213,7 @@ func NewClientTLSFromFile(certFile, serverNameOverwrite string) (TransportCreden
|
|||||||
if !cp.AppendCertsFromPEM(b) {
|
if !cp.AppendCertsFromPEM(b) {
|
||||||
return nil, fmt.Errorf("credentials: failed to append certificates")
|
return nil, fmt.Errorf("credentials: failed to append certificates")
|
||||||
}
|
}
|
||||||
return NewTLS(&tls.Config{ServerName: serverNameOverwrite, RootCAs: cp}), nil
|
return NewTLS(&tls.Config{ServerName: serverNameOverride, RootCAs: cp}), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewServerTLSFromCert constructs a TLS from the input certificate for server.
|
// NewServerTLSFromCert constructs a TLS from the input certificate for server.
|
||||||
|
61
credentials/credentials_test.go
Normal file
61
credentials/credentials_test.go
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
/*
|
||||||
|
*
|
||||||
|
* Copyright 2016, Google Inc.
|
||||||
|
* All rights reserved.
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are
|
||||||
|
* met:
|
||||||
|
*
|
||||||
|
* * Redistributions of source code must retain the above copyright
|
||||||
|
* notice, this list of conditions and the following disclaimer.
|
||||||
|
* * Redistributions in binary form must reproduce the above
|
||||||
|
* copyright notice, this list of conditions and the following disclaimer
|
||||||
|
* in the documentation and/or other materials provided with the
|
||||||
|
* distribution.
|
||||||
|
* * Neither the name of Google Inc. nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||||
|
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||||
|
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||||
|
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||||
|
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||||
|
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||||
|
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||||
|
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||||
|
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||||
|
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
package credentials
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTLSOverrideServerName(t *testing.T) {
|
||||||
|
expectedServerName := "server.name"
|
||||||
|
c := NewTLS(nil)
|
||||||
|
c.OverrideServerName(expectedServerName)
|
||||||
|
if c.Info().ServerName != expectedServerName {
|
||||||
|
t.Fatalf("c.Info().ServerName = %v, want %v", c.Info().ServerName, expectedServerName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTLSClone(t *testing.T) {
|
||||||
|
expectedServerName := "server.name"
|
||||||
|
c := NewTLS(nil)
|
||||||
|
c.OverrideServerName(expectedServerName)
|
||||||
|
cc := c.Clone()
|
||||||
|
if cc.Info().ServerName != expectedServerName {
|
||||||
|
t.Fatalf("cc.Info().ServerName = %v, want %v", cc.Info().ServerName, expectedServerName)
|
||||||
|
}
|
||||||
|
cc.OverrideServerName("")
|
||||||
|
if c.Info().ServerName != expectedServerName {
|
||||||
|
t.Fatalf("Change in clone should not affect the original, c.Info().ServerName = %v, want %v", c.Info().ServerName, expectedServerName)
|
||||||
|
}
|
||||||
|
}
|
@ -2450,6 +2450,12 @@ func (c clientAlwaysFailCred) ServerHandshake(rawConn net.Conn) (net.Conn, crede
|
|||||||
func (c clientAlwaysFailCred) Info() credentials.ProtocolInfo {
|
func (c clientAlwaysFailCred) Info() credentials.ProtocolInfo {
|
||||||
return credentials.ProtocolInfo{}
|
return credentials.ProtocolInfo{}
|
||||||
}
|
}
|
||||||
|
func (c clientAlwaysFailCred) Clone() credentials.TransportCredentials {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (c clientAlwaysFailCred) OverrideServerName(s string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func TestDialWithBlockErrorOnBadCertificates(t *testing.T) {
|
func TestDialWithBlockErrorOnBadCertificates(t *testing.T) {
|
||||||
te := newTest(t, env{name: "bad-cred", network: "tcp", security: "clientAlwaysFailCred", balancer: true})
|
te := newTest(t, env{name: "bad-cred", network: "tcp", security: "clientAlwaysFailCred", balancer: true})
|
||||||
@ -2520,6 +2526,12 @@ func (c *clientTimeoutCreds) ServerHandshake(rawConn net.Conn) (net.Conn, creden
|
|||||||
func (c *clientTimeoutCreds) Info() credentials.ProtocolInfo {
|
func (c *clientTimeoutCreds) Info() credentials.ProtocolInfo {
|
||||||
return credentials.ProtocolInfo{}
|
return credentials.ProtocolInfo{}
|
||||||
}
|
}
|
||||||
|
func (c *clientTimeoutCreds) Clone() credentials.TransportCredentials {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (c *clientTimeoutCreds) OverrideServerName(s string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func TestNonFailFastRPCSucceedOnTimeoutCreds(t *testing.T) {
|
func TestNonFailFastRPCSucceedOnTimeoutCreds(t *testing.T) {
|
||||||
te := newTest(t, env{name: "timeout-cred", network: "tcp", security: "clientTimeoutCreds", balancer: false})
|
te := newTest(t, env{name: "timeout-cred", network: "tcp", security: "clientTimeoutCreds", balancer: false})
|
||||||
@ -2556,6 +2568,12 @@ func (c *serverDispatchCred) ServerHandshake(rawConn net.Conn) (net.Conn, creden
|
|||||||
func (c *serverDispatchCred) Info() credentials.ProtocolInfo {
|
func (c *serverDispatchCred) Info() credentials.ProtocolInfo {
|
||||||
return credentials.ProtocolInfo{}
|
return credentials.ProtocolInfo{}
|
||||||
}
|
}
|
||||||
|
func (c *serverDispatchCred) Clone() credentials.TransportCredentials {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (c *serverDispatchCred) OverrideServerName(s string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
func (c *serverDispatchCred) getRawConn() net.Conn {
|
func (c *serverDispatchCred) getRawConn() net.Conn {
|
||||||
<-c.ready
|
<-c.ready
|
||||||
return c.rawConn
|
return c.rawConn
|
||||||
|
Reference in New Issue
Block a user