remove dialing work from TransportAuthenticator
This commit is contained in:
@ -73,10 +73,9 @@ type Credentials interface {
|
|||||||
// TransportAuthenticator defines the common interface all supported transport
|
// TransportAuthenticator defines the common interface all supported transport
|
||||||
// authentication protocols (e.g., TLS, SSL) must implement.
|
// authentication protocols (e.g., TLS, SSL) must implement.
|
||||||
type TransportAuthenticator interface {
|
type TransportAuthenticator interface {
|
||||||
// Dial connects to the given network address using dialer and then
|
// Handshake does the authentication handshake specified by the corresponding
|
||||||
// does the authentication handshake specified by the corresponding
|
// authentication protocol on the given rawConn.
|
||||||
// authentication protocol.
|
Handshake(addr string, rawConn net.Conn, timeout time.Duration) (net.Conn, error)
|
||||||
Dial(dialer func(string, time.Duration) (net.Conn, error), addr string, timeout time.Duration) (net.Conn, error)
|
|
||||||
// NewListener creates a listener which accepts connections with requested
|
// NewListener creates a listener which accepts connections with requested
|
||||||
// authentication handshake.
|
// authentication handshake.
|
||||||
NewListener(lis net.Listener) net.Listener
|
NewListener(lis net.Listener) net.Listener
|
||||||
@ -101,7 +100,7 @@ func (timeoutError) Error() string { return "credentials: Dial timed out" }
|
|||||||
func (timeoutError) Timeout() bool { return true }
|
func (timeoutError) Timeout() bool { return true }
|
||||||
func (timeoutError) Temporary() bool { return true }
|
func (timeoutError) Temporary() bool { return true }
|
||||||
|
|
||||||
func (c *tlsCreds) Dial(dialer func(addr string, timeout time.Duration) (net.Conn, error), addr string, timeout time.Duration) (net.Conn, error) {
|
func (c *tlsCreds) Handshake(addr string, rawConn net.Conn, timeout time.Duration) (_ net.Conn, err error) {
|
||||||
// borrow some code from tls.DialWithDialer
|
// borrow some code from tls.DialWithDialer
|
||||||
var errChannel chan error
|
var errChannel chan error
|
||||||
if timeout != 0 {
|
if timeout != 0 {
|
||||||
@ -110,10 +109,6 @@ func (c *tlsCreds) Dial(dialer func(addr string, timeout time.Duration) (net.Con
|
|||||||
errChannel <- timeoutError{}
|
errChannel <- timeoutError{}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
rawConn, err := dialer(addr, timeout)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if c.config.ServerName == "" {
|
if c.config.ServerName == "" {
|
||||||
colonPos := strings.LastIndex(addr, ":")
|
colonPos := strings.LastIndex(addr, ":")
|
||||||
if colonPos == -1 {
|
if colonPos == -1 {
|
||||||
|
@ -104,11 +104,13 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
|
|||||||
return net.DialTimeout("tcp", addr, timeout)
|
return net.DialTimeout("tcp", addr, timeout)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
var (
|
|
||||||
connErr error
|
|
||||||
conn net.Conn
|
|
||||||
)
|
|
||||||
scheme := "http"
|
scheme := "http"
|
||||||
|
startT := time.Now()
|
||||||
|
timeout := opts.Timeout
|
||||||
|
conn, connErr := opts.Dialer(addr, timeout)
|
||||||
|
if connErr != nil {
|
||||||
|
return nil, ConnectionErrorf("transport: %v", connErr)
|
||||||
|
}
|
||||||
for _, c := range opts.AuthOptions {
|
for _, c := range opts.AuthOptions {
|
||||||
if ccreds, ok := c.(credentials.TransportAuthenticator); ok {
|
if ccreds, ok := c.(credentials.TransportAuthenticator); ok {
|
||||||
scheme = "https"
|
scheme = "https"
|
||||||
@ -116,13 +118,13 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
|
|||||||
// multiple ones provided. Revisit this if it is not appropriate. Probably
|
// multiple ones provided. Revisit this if it is not appropriate. Probably
|
||||||
// place the ClientTransport construction into a separate function to make
|
// place the ClientTransport construction into a separate function to make
|
||||||
// things clear.
|
// things clear.
|
||||||
conn, connErr = ccreds.Dial(opts.Dialer, addr, opts.Timeout)
|
if timeout > 0 {
|
||||||
|
timeout = opts.Timeout - time.Since(startT)
|
||||||
|
}
|
||||||
|
conn, connErr = ccreds.Handshake(addr, conn, timeout)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if scheme == "http" {
|
|
||||||
conn, connErr = opts.Dialer(addr, opts.Timeout)
|
|
||||||
}
|
|
||||||
if connErr != nil {
|
if connErr != nil {
|
||||||
return nil, ConnectionErrorf("transport: %v", connErr)
|
return nil, ConnectionErrorf("transport: %v", connErr)
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user