Merge pull request #167 from iamqizhao/master
Refactor tlsCreds to allow users pass in a tls.Config.
This commit is contained in:
@ -86,21 +86,10 @@ type TransportAuthenticator interface {
|
||||
Credentials
|
||||
}
|
||||
|
||||
// tlsCreds is the credentials required for authenticating a connection.
|
||||
// tlsCreds is the credentials required for authenticating a connection using TLS.
|
||||
type tlsCreds struct {
|
||||
// serverName is used to verify the hostname on the returned
|
||||
// certificates. It is also included in the client's handshake
|
||||
// to support virtual hosting. This is optional. If it is not
|
||||
// set gRPC internals will use the dialing address instead.
|
||||
serverName string
|
||||
// rootCAs defines the set of root certificate authorities
|
||||
// that clients use when verifying server certificates.
|
||||
// If rootCAs is nil, tls uses the host's root CA set.
|
||||
rootCAs *x509.CertPool
|
||||
// certificates contains one or more certificate chains
|
||||
// to present to the other side of the connection.
|
||||
// Server configurations must include at least one certificate.
|
||||
certificates []tls.Certificate
|
||||
// TLS configuration
|
||||
config tls.Config
|
||||
}
|
||||
|
||||
// GetRequestMetadata returns nil, nil since TLS credentials does not have
|
||||
@ -110,18 +99,13 @@ func (c *tlsCreds) GetRequestMetadata(ctx context.Context) (map[string]string, e
|
||||
}
|
||||
|
||||
func (c *tlsCreds) DialWithDialer(dialer *net.Dialer, network, addr string) (_ net.Conn, err error) {
|
||||
name := c.serverName
|
||||
if name == "" {
|
||||
name, _, err = net.SplitHostPort(addr)
|
||||
if c.config.ServerName == "" {
|
||||
c.config.ServerName, _, err = net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("credentials: failed to parse server address %v", err)
|
||||
}
|
||||
}
|
||||
return tls.DialWithDialer(dialer, "tcp", addr, &tls.Config{
|
||||
RootCAs: c.rootCAs,
|
||||
NextProtos: alpnProtoStr,
|
||||
ServerName: name,
|
||||
})
|
||||
return tls.DialWithDialer(dialer, "tcp", addr, &c.config)
|
||||
}
|
||||
|
||||
// Dial connects to addr and performs TLS handshake.
|
||||
@ -129,21 +113,21 @@ func (c *tlsCreds) Dial(network, addr string) (_ net.Conn, err error) {
|
||||
return c.DialWithDialer(new(net.Dialer), network, addr)
|
||||
}
|
||||
|
||||
// NewListener creates a net.Listener with a TLS configuration constructed
|
||||
// from the information in tlsCreds.
|
||||
// NewListener creates a net.Listener using the information in tlsCreds.
|
||||
func (c *tlsCreds) NewListener(lis net.Listener) net.Listener {
|
||||
return tls.NewListener(lis, &tls.Config{
|
||||
Certificates: c.certificates,
|
||||
NextProtos: alpnProtoStr,
|
||||
})
|
||||
return tls.NewListener(lis, &c.config)
|
||||
}
|
||||
|
||||
// NewTLS uses c to construct a TransportAuthenticator based on TLS.
|
||||
func NewTLS(c *tls.Config) TransportAuthenticator {
|
||||
tc := &tlsCreds{*c}
|
||||
tc.config.NextProtos = alpnProtoStr
|
||||
return tc
|
||||
}
|
||||
|
||||
// NewClientTLSFromCert constructs a TLS from the input certificate for client.
|
||||
func NewClientTLSFromCert(cp *x509.CertPool, serverName string) TransportAuthenticator {
|
||||
return &tlsCreds{
|
||||
serverName: serverName,
|
||||
rootCAs: cp,
|
||||
}
|
||||
return NewTLS(&tls.Config{ServerName: serverName, RootCAs: cp})
|
||||
}
|
||||
|
||||
// NewClientTLSFromFile constructs a TLS from the input certificate file for client.
|
||||
@ -156,17 +140,12 @@ func NewClientTLSFromFile(certFile, serverName string) (TransportAuthenticator,
|
||||
if !cp.AppendCertsFromPEM(b) {
|
||||
return nil, fmt.Errorf("credentials: failed to append certificates")
|
||||
}
|
||||
return &tlsCreds{
|
||||
serverName: serverName,
|
||||
rootCAs: cp,
|
||||
}, nil
|
||||
return NewTLS(&tls.Config{ServerName: serverName, RootCAs: cp}), nil
|
||||
}
|
||||
|
||||
// NewServerTLSFromCert constructs a TLS from the input certificate for server.
|
||||
func NewServerTLSFromCert(cert *tls.Certificate) TransportAuthenticator {
|
||||
return &tlsCreds{
|
||||
certificates: []tls.Certificate{*cert},
|
||||
}
|
||||
return NewTLS(&tls.Config{Certificates: []tls.Certificate{*cert}})
|
||||
}
|
||||
|
||||
// NewServerTLSFromFile constructs a TLS from the input certificate file and key
|
||||
@ -176,9 +155,7 @@ func NewServerTLSFromFile(certFile, keyFile string) (TransportAuthenticator, err
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &tlsCreds{
|
||||
certificates: []tls.Certificate{cert},
|
||||
}, nil
|
||||
return NewTLS(&tls.Config{Certificates: []tls.Certificate{cert}}), nil
|
||||
}
|
||||
|
||||
// TokenSource supplies credentials from an oauth2.TokenSource.
|
||||
|
Reference in New Issue
Block a user