Merge pull request #167 from iamqizhao/master
Refactor tlsCreds to allow users pass in a tls.Config.
This commit is contained in:
@ -86,21 +86,10 @@ type TransportAuthenticator interface {
|
|||||||
Credentials
|
Credentials
|
||||||
}
|
}
|
||||||
|
|
||||||
// tlsCreds is the credentials required for authenticating a connection.
|
// tlsCreds is the credentials required for authenticating a connection using TLS.
|
||||||
type tlsCreds struct {
|
type tlsCreds struct {
|
||||||
// serverName is used to verify the hostname on the returned
|
// TLS configuration
|
||||||
// certificates. It is also included in the client's handshake
|
config tls.Config
|
||||||
// to support virtual hosting. This is optional. If it is not
|
|
||||||
// set gRPC internals will use the dialing address instead.
|
|
||||||
serverName string
|
|
||||||
// rootCAs defines the set of root certificate authorities
|
|
||||||
// that clients use when verifying server certificates.
|
|
||||||
// If rootCAs is nil, tls uses the host's root CA set.
|
|
||||||
rootCAs *x509.CertPool
|
|
||||||
// certificates contains one or more certificate chains
|
|
||||||
// to present to the other side of the connection.
|
|
||||||
// Server configurations must include at least one certificate.
|
|
||||||
certificates []tls.Certificate
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRequestMetadata returns nil, nil since TLS credentials does not have
|
// GetRequestMetadata returns nil, nil since TLS credentials does not have
|
||||||
@ -110,18 +99,13 @@ func (c *tlsCreds) GetRequestMetadata(ctx context.Context) (map[string]string, e
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *tlsCreds) DialWithDialer(dialer *net.Dialer, network, addr string) (_ net.Conn, err error) {
|
func (c *tlsCreds) DialWithDialer(dialer *net.Dialer, network, addr string) (_ net.Conn, err error) {
|
||||||
name := c.serverName
|
if c.config.ServerName == "" {
|
||||||
if name == "" {
|
c.config.ServerName, _, err = net.SplitHostPort(addr)
|
||||||
name, _, err = net.SplitHostPort(addr)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("credentials: failed to parse server address %v", err)
|
return nil, fmt.Errorf("credentials: failed to parse server address %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return tls.DialWithDialer(dialer, "tcp", addr, &tls.Config{
|
return tls.DialWithDialer(dialer, "tcp", addr, &c.config)
|
||||||
RootCAs: c.rootCAs,
|
|
||||||
NextProtos: alpnProtoStr,
|
|
||||||
ServerName: name,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Dial connects to addr and performs TLS handshake.
|
// Dial connects to addr and performs TLS handshake.
|
||||||
@ -129,21 +113,21 @@ func (c *tlsCreds) Dial(network, addr string) (_ net.Conn, err error) {
|
|||||||
return c.DialWithDialer(new(net.Dialer), network, addr)
|
return c.DialWithDialer(new(net.Dialer), network, addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewListener creates a net.Listener with a TLS configuration constructed
|
// NewListener creates a net.Listener using the information in tlsCreds.
|
||||||
// from the information in tlsCreds.
|
|
||||||
func (c *tlsCreds) NewListener(lis net.Listener) net.Listener {
|
func (c *tlsCreds) NewListener(lis net.Listener) net.Listener {
|
||||||
return tls.NewListener(lis, &tls.Config{
|
return tls.NewListener(lis, &c.config)
|
||||||
Certificates: c.certificates,
|
}
|
||||||
NextProtos: alpnProtoStr,
|
|
||||||
})
|
// NewTLS uses c to construct a TransportAuthenticator based on TLS.
|
||||||
|
func NewTLS(c *tls.Config) TransportAuthenticator {
|
||||||
|
tc := &tlsCreds{*c}
|
||||||
|
tc.config.NextProtos = alpnProtoStr
|
||||||
|
return tc
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClientTLSFromCert constructs a TLS from the input certificate for client.
|
// NewClientTLSFromCert constructs a TLS from the input certificate for client.
|
||||||
func NewClientTLSFromCert(cp *x509.CertPool, serverName string) TransportAuthenticator {
|
func NewClientTLSFromCert(cp *x509.CertPool, serverName string) TransportAuthenticator {
|
||||||
return &tlsCreds{
|
return NewTLS(&tls.Config{ServerName: serverName, RootCAs: cp})
|
||||||
serverName: serverName,
|
|
||||||
rootCAs: cp,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClientTLSFromFile constructs a TLS from the input certificate file for client.
|
// NewClientTLSFromFile constructs a TLS from the input certificate file for client.
|
||||||
@ -156,17 +140,12 @@ func NewClientTLSFromFile(certFile, serverName string) (TransportAuthenticator,
|
|||||||
if !cp.AppendCertsFromPEM(b) {
|
if !cp.AppendCertsFromPEM(b) {
|
||||||
return nil, fmt.Errorf("credentials: failed to append certificates")
|
return nil, fmt.Errorf("credentials: failed to append certificates")
|
||||||
}
|
}
|
||||||
return &tlsCreds{
|
return NewTLS(&tls.Config{ServerName: serverName, RootCAs: cp}), nil
|
||||||
serverName: serverName,
|
|
||||||
rootCAs: cp,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewServerTLSFromCert constructs a TLS from the input certificate for server.
|
// NewServerTLSFromCert constructs a TLS from the input certificate for server.
|
||||||
func NewServerTLSFromCert(cert *tls.Certificate) TransportAuthenticator {
|
func NewServerTLSFromCert(cert *tls.Certificate) TransportAuthenticator {
|
||||||
return &tlsCreds{
|
return NewTLS(&tls.Config{Certificates: []tls.Certificate{*cert}})
|
||||||
certificates: []tls.Certificate{*cert},
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewServerTLSFromFile constructs a TLS from the input certificate file and key
|
// NewServerTLSFromFile constructs a TLS from the input certificate file and key
|
||||||
@ -176,9 +155,7 @@ func NewServerTLSFromFile(certFile, keyFile string) (TransportAuthenticator, err
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &tlsCreds{
|
return NewTLS(&tls.Config{Certificates: []tls.Certificate{cert}}), nil
|
||||||
certificates: []tls.Certificate{cert},
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TokenSource supplies credentials from an oauth2.TokenSource.
|
// TokenSource supplies credentials from an oauth2.TokenSource.
|
||||||
|
@ -479,13 +479,13 @@ func (t *http2Client) getStream(f http2.Frame) (*Stream, bool) {
|
|||||||
// Window updates will deliver to the controller for sending when
|
// Window updates will deliver to the controller for sending when
|
||||||
// the cumulative quota exceeds the corresponding threshold.
|
// the cumulative quota exceeds the corresponding threshold.
|
||||||
func (t *http2Client) updateWindow(s *Stream, n uint32) {
|
func (t *http2Client) updateWindow(s *Stream, n uint32) {
|
||||||
swu, cwu := s.fc.onRead(n)
|
swu, cwu := s.fc.onRead(n)
|
||||||
if swu > 0 {
|
if swu > 0 {
|
||||||
t.controlBuf.put(&windowUpdate{s.id, swu})
|
t.controlBuf.put(&windowUpdate{s.id, swu})
|
||||||
}
|
}
|
||||||
if cwu > 0 {
|
if cwu > 0 {
|
||||||
t.controlBuf.put(&windowUpdate{0, cwu})
|
t.controlBuf.put(&windowUpdate{0, cwu})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *http2Client) handleData(f *http2.DataFrame) {
|
func (t *http2Client) handleData(f *http2.DataFrame) {
|
||||||
|
Reference in New Issue
Block a user