From dfe197d91f04678d0db17dc87bc2643ffff0ccc0 Mon Sep 17 00:00:00 2001 From: iamqizhao Date: Tue, 21 Apr 2015 17:22:15 -0700 Subject: [PATCH] remove dialing work from TransportAuthenticator --- credentials/credentials.go | 13 ++++--------- transport/http2_client.go | 18 ++++++++++-------- 2 files changed, 14 insertions(+), 17 deletions(-) diff --git a/credentials/credentials.go b/credentials/credentials.go index aeab6737..7310bec8 100644 --- a/credentials/credentials.go +++ b/credentials/credentials.go @@ -73,10 +73,9 @@ type Credentials interface { // TransportAuthenticator defines the common interface all supported transport // authentication protocols (e.g., TLS, SSL) must implement. type TransportAuthenticator interface { - // Dial connects to the given network address using dialer and then - // does the authentication handshake specified by the corresponding - // authentication protocol. - Dial(dialer func(string, time.Duration) (net.Conn, error), addr string, timeout time.Duration) (net.Conn, error) + // Handshake does the authentication handshake specified by the corresponding + // authentication protocol on the given rawConn. + Handshake(addr string, rawConn net.Conn, timeout time.Duration) (net.Conn, error) // NewListener creates a listener which accepts connections with requested // authentication handshake. NewListener(lis net.Listener) net.Listener @@ -101,7 +100,7 @@ func (timeoutError) Error() string { return "credentials: Dial timed out" } func (timeoutError) Timeout() bool { return true } func (timeoutError) Temporary() bool { return true } -func (c *tlsCreds) Dial(dialer func(addr string, timeout time.Duration) (net.Conn, error), addr string, timeout time.Duration) (net.Conn, error) { +func (c *tlsCreds) Handshake(addr string, rawConn net.Conn, timeout time.Duration) (_ net.Conn, err error) { // borrow some code from tls.DialWithDialer var errChannel chan error if timeout != 0 { @@ -110,10 +109,6 @@ func (c *tlsCreds) Dial(dialer func(addr string, timeout time.Duration) (net.Con errChannel <- timeoutError{} }) } - rawConn, err := dialer(addr, timeout) - if err != nil { - return nil, err - } if c.config.ServerName == "" { colonPos := strings.LastIndex(addr, ":") if colonPos == -1 { diff --git a/transport/http2_client.go b/transport/http2_client.go index 710908dd..301c7091 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -104,11 +104,13 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e return net.DialTimeout("tcp", addr, timeout) } } - var ( - connErr error - conn net.Conn - ) scheme := "http" + startT := time.Now() + timeout := opts.Timeout + conn, connErr := opts.Dialer(addr, timeout) + if connErr != nil { + return nil, ConnectionErrorf("transport: %v", connErr) + } for _, c := range opts.AuthOptions { if ccreds, ok := c.(credentials.TransportAuthenticator); ok { scheme = "https" @@ -116,13 +118,13 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e // multiple ones provided. Revisit this if it is not appropriate. Probably // place the ClientTransport construction into a separate function to make // things clear. - conn, connErr = ccreds.Dial(opts.Dialer, addr, opts.Timeout) + if timeout > 0 { + timeout = opts.Timeout - time.Since(startT) + } + conn, connErr = ccreds.Handshake(addr, conn, timeout) break } } - if scheme == "http" { - conn, connErr = opts.Dialer(addr, opts.Timeout) - } if connErr != nil { return nil, ConnectionErrorf("transport: %v", connErr) }