Merge pull request #167 from iamqizhao/master

Refactor tlsCreds to allow users pass in a tls.Config.
This commit is contained in:
Qi Zhao
2015-04-15 15:49:56 -07:00
2 changed files with 26 additions and 49 deletions

View File

@ -86,21 +86,10 @@ type TransportAuthenticator interface {
Credentials Credentials
} }
// tlsCreds is the credentials required for authenticating a connection. // tlsCreds is the credentials required for authenticating a connection using TLS.
type tlsCreds struct { type tlsCreds struct {
// serverName is used to verify the hostname on the returned // TLS configuration
// certificates. It is also included in the client's handshake config tls.Config
// to support virtual hosting. This is optional. If it is not
// set gRPC internals will use the dialing address instead.
serverName string
// rootCAs defines the set of root certificate authorities
// that clients use when verifying server certificates.
// If rootCAs is nil, tls uses the host's root CA set.
rootCAs *x509.CertPool
// certificates contains one or more certificate chains
// to present to the other side of the connection.
// Server configurations must include at least one certificate.
certificates []tls.Certificate
} }
// GetRequestMetadata returns nil, nil since TLS credentials does not have // GetRequestMetadata returns nil, nil since TLS credentials does not have
@ -110,18 +99,13 @@ func (c *tlsCreds) GetRequestMetadata(ctx context.Context) (map[string]string, e
} }
func (c *tlsCreds) DialWithDialer(dialer *net.Dialer, network, addr string) (_ net.Conn, err error) { func (c *tlsCreds) DialWithDialer(dialer *net.Dialer, network, addr string) (_ net.Conn, err error) {
name := c.serverName if c.config.ServerName == "" {
if name == "" { c.config.ServerName, _, err = net.SplitHostPort(addr)
name, _, err = net.SplitHostPort(addr)
if err != nil { if err != nil {
return nil, fmt.Errorf("credentials: failed to parse server address %v", err) return nil, fmt.Errorf("credentials: failed to parse server address %v", err)
} }
} }
return tls.DialWithDialer(dialer, "tcp", addr, &tls.Config{ return tls.DialWithDialer(dialer, "tcp", addr, &c.config)
RootCAs: c.rootCAs,
NextProtos: alpnProtoStr,
ServerName: name,
})
} }
// Dial connects to addr and performs TLS handshake. // Dial connects to addr and performs TLS handshake.
@ -129,21 +113,21 @@ func (c *tlsCreds) Dial(network, addr string) (_ net.Conn, err error) {
return c.DialWithDialer(new(net.Dialer), network, addr) return c.DialWithDialer(new(net.Dialer), network, addr)
} }
// NewListener creates a net.Listener with a TLS configuration constructed // NewListener creates a net.Listener using the information in tlsCreds.
// from the information in tlsCreds.
func (c *tlsCreds) NewListener(lis net.Listener) net.Listener { func (c *tlsCreds) NewListener(lis net.Listener) net.Listener {
return tls.NewListener(lis, &tls.Config{ return tls.NewListener(lis, &c.config)
Certificates: c.certificates, }
NextProtos: alpnProtoStr,
}) // NewTLS uses c to construct a TransportAuthenticator based on TLS.
func NewTLS(c *tls.Config) TransportAuthenticator {
tc := &tlsCreds{*c}
tc.config.NextProtos = alpnProtoStr
return tc
} }
// NewClientTLSFromCert constructs a TLS from the input certificate for client. // NewClientTLSFromCert constructs a TLS from the input certificate for client.
func NewClientTLSFromCert(cp *x509.CertPool, serverName string) TransportAuthenticator { func NewClientTLSFromCert(cp *x509.CertPool, serverName string) TransportAuthenticator {
return &tlsCreds{ return NewTLS(&tls.Config{ServerName: serverName, RootCAs: cp})
serverName: serverName,
rootCAs: cp,
}
} }
// NewClientTLSFromFile constructs a TLS from the input certificate file for client. // NewClientTLSFromFile constructs a TLS from the input certificate file for client.
@ -156,17 +140,12 @@ func NewClientTLSFromFile(certFile, serverName string) (TransportAuthenticator,
if !cp.AppendCertsFromPEM(b) { if !cp.AppendCertsFromPEM(b) {
return nil, fmt.Errorf("credentials: failed to append certificates") return nil, fmt.Errorf("credentials: failed to append certificates")
} }
return &tlsCreds{ return NewTLS(&tls.Config{ServerName: serverName, RootCAs: cp}), nil
serverName: serverName,
rootCAs: cp,
}, nil
} }
// NewServerTLSFromCert constructs a TLS from the input certificate for server. // NewServerTLSFromCert constructs a TLS from the input certificate for server.
func NewServerTLSFromCert(cert *tls.Certificate) TransportAuthenticator { func NewServerTLSFromCert(cert *tls.Certificate) TransportAuthenticator {
return &tlsCreds{ return NewTLS(&tls.Config{Certificates: []tls.Certificate{*cert}})
certificates: []tls.Certificate{*cert},
}
} }
// NewServerTLSFromFile constructs a TLS from the input certificate file and key // NewServerTLSFromFile constructs a TLS from the input certificate file and key
@ -176,9 +155,7 @@ func NewServerTLSFromFile(certFile, keyFile string) (TransportAuthenticator, err
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &tlsCreds{ return NewTLS(&tls.Config{Certificates: []tls.Certificate{cert}}), nil
certificates: []tls.Certificate{cert},
}, nil
} }
// TokenSource supplies credentials from an oauth2.TokenSource. // TokenSource supplies credentials from an oauth2.TokenSource.