diff --git a/credentials/credentials.go b/credentials/credentials.go index 576cf62e..0d3e3e72 100644 --- a/credentials/credentials.go +++ b/credentials/credentials.go @@ -86,21 +86,10 @@ type TransportAuthenticator interface { Credentials } -// tlsCreds is the credentials required for authenticating a connection. +// tlsCreds is the credentials required for authenticating a connection using TLS. type tlsCreds struct { - // serverName is used to verify the hostname on the returned - // certificates. It is also included in the client's handshake - // to support virtual hosting. This is optional. If it is not - // set gRPC internals will use the dialing address instead. - serverName string - // rootCAs defines the set of root certificate authorities - // that clients use when verifying server certificates. - // If rootCAs is nil, tls uses the host's root CA set. - rootCAs *x509.CertPool - // certificates contains one or more certificate chains - // to present to the other side of the connection. - // Server configurations must include at least one certificate. - certificates []tls.Certificate + // TLS configuration + config tls.Config } // GetRequestMetadata returns nil, nil since TLS credentials does not have @@ -110,18 +99,13 @@ func (c *tlsCreds) GetRequestMetadata(ctx context.Context) (map[string]string, e } func (c *tlsCreds) DialWithDialer(dialer *net.Dialer, network, addr string) (_ net.Conn, err error) { - name := c.serverName - if name == "" { - name, _, err = net.SplitHostPort(addr) + if c.config.ServerName == "" { + c.config.ServerName, _, err = net.SplitHostPort(addr) if err != nil { return nil, fmt.Errorf("credentials: failed to parse server address %v", err) } } - return tls.DialWithDialer(dialer, "tcp", addr, &tls.Config{ - RootCAs: c.rootCAs, - NextProtos: alpnProtoStr, - ServerName: name, - }) + return tls.DialWithDialer(dialer, "tcp", addr, &c.config) } // Dial connects to addr and performs TLS handshake. @@ -129,21 +113,21 @@ func (c *tlsCreds) Dial(network, addr string) (_ net.Conn, err error) { return c.DialWithDialer(new(net.Dialer), network, addr) } -// NewListener creates a net.Listener with a TLS configuration constructed -// from the information in tlsCreds. +// NewListener creates a net.Listener using the information in tlsCreds. func (c *tlsCreds) NewListener(lis net.Listener) net.Listener { - return tls.NewListener(lis, &tls.Config{ - Certificates: c.certificates, - NextProtos: alpnProtoStr, - }) + return tls.NewListener(lis, &c.config) +} + +// NewTLS uses c to construct a TransportAuthenticator based on TLS. +func NewTLS(c *tls.Config) TransportAuthenticator { + tc := &tlsCreds{*c} + tc.config.NextProtos = alpnProtoStr + return tc } // NewClientTLSFromCert constructs a TLS from the input certificate for client. func NewClientTLSFromCert(cp *x509.CertPool, serverName string) TransportAuthenticator { - return &tlsCreds{ - serverName: serverName, - rootCAs: cp, - } + return NewTLS(&tls.Config{ServerName: serverName, RootCAs: cp}) } // NewClientTLSFromFile constructs a TLS from the input certificate file for client. @@ -156,17 +140,12 @@ func NewClientTLSFromFile(certFile, serverName string) (TransportAuthenticator, if !cp.AppendCertsFromPEM(b) { return nil, fmt.Errorf("credentials: failed to append certificates") } - return &tlsCreds{ - serverName: serverName, - rootCAs: cp, - }, nil + return NewTLS(&tls.Config{ServerName: serverName, RootCAs: cp}), nil } // NewServerTLSFromCert constructs a TLS from the input certificate for server. func NewServerTLSFromCert(cert *tls.Certificate) TransportAuthenticator { - return &tlsCreds{ - certificates: []tls.Certificate{*cert}, - } + return NewTLS(&tls.Config{Certificates: []tls.Certificate{*cert}}) } // NewServerTLSFromFile constructs a TLS from the input certificate file and key @@ -176,9 +155,7 @@ func NewServerTLSFromFile(certFile, keyFile string) (TransportAuthenticator, err if err != nil { return nil, err } - return &tlsCreds{ - certificates: []tls.Certificate{cert}, - }, nil + return NewTLS(&tls.Config{Certificates: []tls.Certificate{cert}}), nil } // TokenSource supplies credentials from an oauth2.TokenSource. diff --git a/transport/http2_client.go b/transport/http2_client.go index d01fb892..baeee8f6 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -479,13 +479,13 @@ func (t *http2Client) getStream(f http2.Frame) (*Stream, bool) { // Window updates will deliver to the controller for sending when // the cumulative quota exceeds the corresponding threshold. func (t *http2Client) updateWindow(s *Stream, n uint32) { - swu, cwu := s.fc.onRead(n) - if swu > 0 { - t.controlBuf.put(&windowUpdate{s.id, swu}) - } - if cwu > 0 { - t.controlBuf.put(&windowUpdate{0, cwu}) - } + swu, cwu := s.fc.onRead(n) + if swu > 0 { + t.controlBuf.put(&windowUpdate{s.id, swu}) + } + if cwu > 0 { + t.controlBuf.put(&windowUpdate{0, cwu}) + } } func (t *http2Client) handleData(f *http2.DataFrame) {