diff --git a/clientconn.go b/clientconn.go index 00a3b8d8..b2ff4292 100644 --- a/clientconn.go +++ b/clientconn.go @@ -37,6 +37,7 @@ import ( "errors" "log" "net" + "strings" "sync" "time" @@ -96,11 +97,11 @@ func WithTimeout(d time.Duration) DialOption { } } -// WithNetwork returns a DialOption that specifies the network on which -// the connection will be established. -func WithNetwork(network string) DialOption { +// WithDialer returns a DialOption that defines a function which takes an +// address and turns it into a net.Conn. +func WithDialer(f func(addr string, timeout time.Duration) (net.Conn, error)) DialOption { return func(o *dialOptions) { - o.copts.Network = network + o.copts.Dialer = f } } @@ -117,24 +118,11 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) { for _, opt := range opts { opt(&cc.dopts) } - // Validate the network type - switch cc.dopts.copts.Network { - case "": - cc.dopts.copts.Network = "tcp" // Set the default - case "tcp", "tcp4", "tcp6", "unix": - default: - return nil, net.UnknownNetworkError(cc.dopts.copts.Network) - } - cc.authority = target - // Format target for tcp. - if cc.dopts.copts.Network != "unix" { - // format target for tcp. - var err error - cc.authority, _, err = net.SplitHostPort(target) - if err != nil { - return nil, err - } + colonPos := strings.LastIndex(target, ":") + if colonPos == -1 { + colonPos = len(target) } + cc.authority = target[:colonPos] if cc.dopts.codec == nil { // Set the default codec. cc.dopts.codec = protoCodec{} diff --git a/credentials/credentials.go b/credentials/credentials.go index fae0a302..aeab6737 100644 --- a/credentials/credentials.go +++ b/credentials/credentials.go @@ -43,6 +43,8 @@ import ( "fmt" "io/ioutil" "net" + "strings" + "time" "golang.org/x/net/context" "golang.org/x/oauth2" @@ -71,15 +73,10 @@ type Credentials interface { // TransportAuthenticator defines the common interface all supported transport // authentication protocols (e.g., TLS, SSL) must implement. type TransportAuthenticator interface { - // Dial connects to the given network address using net.Dial and then + // Dial connects to the given network address using dialer and then // does the authentication handshake specified by the corresponding // authentication protocol. - Dial(network, addr string) (net.Conn, error) - // DialWithDialer connects to the given network address using - // dialer.Dial does the authentication handshake specified by the - // corresponding authentication protocol. Any timeout or deadline - // given in the dialer apply to connection and handshake as a whole. - DialWithDialer(dialer *net.Dialer, network, addr string) (net.Conn, error) + Dial(dialer func(string, time.Duration) (net.Conn, error), addr string, timeout time.Duration) (net.Conn, error) // NewListener creates a listener which accepts connections with requested // authentication handshake. NewListener(lis net.Listener) net.Listener @@ -98,19 +95,46 @@ func (c *tlsCreds) GetRequestMetadata(ctx context.Context) (map[string]string, e return nil, nil } -func (c *tlsCreds) DialWithDialer(dialer *net.Dialer, network, addr string) (_ net.Conn, err error) { - if c.config.ServerName == "" { - c.config.ServerName, _, err = net.SplitHostPort(addr) - if err != nil { - return nil, fmt.Errorf("credentials: failed to parse server address %v", err) - } - } - return tls.DialWithDialer(dialer, network, addr, &c.config) -} +type timeoutError struct{} -// Dial connects to addr and performs TLS handshake. -func (c *tlsCreds) Dial(network, addr string) (_ net.Conn, err error) { - return c.DialWithDialer(new(net.Dialer), network, addr) +func (timeoutError) Error() string { return "credentials: Dial timed out" } +func (timeoutError) Timeout() bool { return true } +func (timeoutError) Temporary() bool { return true } + +func (c *tlsCreds) Dial(dialer func(addr string, timeout time.Duration) (net.Conn, error), addr string, timeout time.Duration) (net.Conn, error) { + // borrow some code from tls.DialWithDialer + var errChannel chan error + if timeout != 0 { + errChannel = make(chan error, 2) + time.AfterFunc(timeout, func() { + errChannel <- timeoutError{} + }) + } + rawConn, err := dialer(addr, timeout) + if err != nil { + return nil, err + } + if c.config.ServerName == "" { + colonPos := strings.LastIndex(addr, ":") + if colonPos == -1 { + colonPos = len(addr) + } + c.config.ServerName = addr[:colonPos] + } + conn := tls.Client(rawConn, &c.config) + if timeout == 0 { + err = conn.Handshake() + } else { + go func() { + errChannel <- conn.Handshake() + }() + err = <-errChannel + } + if err != nil { + rawConn.Close() + return nil, err + } + return conn, nil } // NewListener creates a net.Listener using the information in tlsCreds. diff --git a/test/end2end_test.go b/test/end2end_test.go index 9e8abf5b..3d41c610 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -266,16 +266,21 @@ func TestReconnectTimeout(t *testing.T) { } } +func unixDialer(addr string, timeout time.Duration) (net.Conn, error) { + return net.DialTimeout("unix", addr, timeout) +} + type env struct { network string // The type of network such as tcp, unix, etc. + dialer func(addr string, timeout time.Duration) (net.Conn, error) security string // The security protocol such as TLS, SSH, etc. } func listTestEnv() []env { if runtime.GOOS == "windows" { - return []env{env{"tcp", ""}, env{"tcp", "tls"}} + return []env{env{"tcp", nil, ""}, env{"tcp", nil, "tls"}} } - return []env{env{"tcp", ""}, env{"tcp", "tls"}, env{"unix", ""}, env{"unix", "tls"}} + return []env{env{"tcp", nil, ""}, env{"tcp", nil, "tls"}, env{"unix", unixDialer, ""}, env{"unix", unixDialer, "tls"}} } func setUp(maxStream uint32, e env) (s *grpc.Server, cc *grpc.ClientConn) { @@ -315,9 +320,9 @@ func setUp(maxStream uint32, e env) (s *grpc.Server, cc *grpc.ClientConn) { if err != nil { log.Fatalf("Failed to create credentials %v", err) } - cc, err = grpc.Dial(addr, grpc.WithTransportCredentials(creds), grpc.WithNetwork(e.network)) + cc, err = grpc.Dial(addr, grpc.WithTransportCredentials(creds), grpc.WithDialer(e.dialer)) } else { - cc, err = grpc.Dial(addr, grpc.WithNetwork(e.network)) + cc, err = grpc.Dial(addr, grpc.WithDialer(e.dialer)) } if err != nil { log.Fatalf("Dial(%q) = %v", addr, err) diff --git a/transport/http2_client.go b/transport/http2_client.go index 136debfb..710908dd 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -98,6 +98,12 @@ type http2Client struct { // and starts to receive messages on it. Non-nil error returns if construction // fails. func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err error) { + if opts.Dialer == nil { + // Set the default Dialer. + opts.Dialer = func(addr string, timeout time.Duration) (net.Conn, error) { + return net.DialTimeout("tcp", addr, timeout) + } + } var ( connErr error conn net.Conn @@ -110,12 +116,12 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e // multiple ones provided. Revisit this if it is not appropriate. Probably // place the ClientTransport construction into a separate function to make // things clear. - conn, connErr = ccreds.DialWithDialer(&net.Dialer{Timeout: opts.Timeout}, opts.Network, addr) + conn, connErr = ccreds.Dial(opts.Dialer, addr, opts.Timeout) break } } if scheme == "http" { - conn, connErr = net.DialTimeout(opts.Network, addr, opts.Timeout) + conn, connErr = opts.Dialer(addr, opts.Timeout) } if connErr != nil { return nil, ConnectionErrorf("transport: %v", connErr) diff --git a/transport/transport.go b/transport/transport.go index ebd36291..5dfd89f0 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -315,9 +315,7 @@ func NewServerTransport(protocol string, conn net.Conn, maxStreams uint32) (Serv // ConnectOptions covers all relevant options for dialing a server. type ConnectOptions struct { - // Network indicates the type of network where the connection is established. - // Known networks are "tcp", "tcp4", "tcp6", "unix" - Network string + Dialer func(string, time.Duration) (net.Conn, error) AuthOptions []credentials.Credentials Timeout time.Duration }