remove WithNetwork and add WithDialer to have more flexibility on dialing

This commit is contained in:
iamqizhao
2015-04-21 16:19:29 -07:00
parent c0ee2e6ba1
commit 2cf2d0871b
5 changed files with 70 additions and 49 deletions

View File

@ -37,6 +37,7 @@ import (
"errors"
"log"
"net"
"strings"
"sync"
"time"
@ -96,11 +97,11 @@ func WithTimeout(d time.Duration) DialOption {
}
}
// WithNetwork returns a DialOption that specifies the network on which
// the connection will be established.
func WithNetwork(network string) DialOption {
// WithDialer returns a DialOption that defines a function which takes an
// address and turns it into a net.Conn.
func WithDialer(f func(addr string, timeout time.Duration) (net.Conn, error)) DialOption {
return func(o *dialOptions) {
o.copts.Network = network
o.copts.Dialer = f
}
}
@ -117,24 +118,11 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) {
for _, opt := range opts {
opt(&cc.dopts)
}
// Validate the network type
switch cc.dopts.copts.Network {
case "":
cc.dopts.copts.Network = "tcp" // Set the default
case "tcp", "tcp4", "tcp6", "unix":
default:
return nil, net.UnknownNetworkError(cc.dopts.copts.Network)
}
cc.authority = target
// Format target for tcp.
if cc.dopts.copts.Network != "unix" {
// format target for tcp.
var err error
cc.authority, _, err = net.SplitHostPort(target)
if err != nil {
return nil, err
}
colonPos := strings.LastIndex(target, ":")
if colonPos == -1 {
colonPos = len(target)
}
cc.authority = target[:colonPos]
if cc.dopts.codec == nil {
// Set the default codec.
cc.dopts.codec = protoCodec{}

View File

@ -43,6 +43,8 @@ import (
"fmt"
"io/ioutil"
"net"
"strings"
"time"
"golang.org/x/net/context"
"golang.org/x/oauth2"
@ -71,15 +73,10 @@ type Credentials interface {
// TransportAuthenticator defines the common interface all supported transport
// authentication protocols (e.g., TLS, SSL) must implement.
type TransportAuthenticator interface {
// Dial connects to the given network address using net.Dial and then
// Dial connects to the given network address using dialer and then
// does the authentication handshake specified by the corresponding
// authentication protocol.
Dial(network, addr string) (net.Conn, error)
// DialWithDialer connects to the given network address using
// dialer.Dial does the authentication handshake specified by the
// corresponding authentication protocol. Any timeout or deadline
// given in the dialer apply to connection and handshake as a whole.
DialWithDialer(dialer *net.Dialer, network, addr string) (net.Conn, error)
Dial(dialer func(string, time.Duration) (net.Conn, error), addr string, timeout time.Duration) (net.Conn, error)
// NewListener creates a listener which accepts connections with requested
// authentication handshake.
NewListener(lis net.Listener) net.Listener
@ -98,19 +95,46 @@ func (c *tlsCreds) GetRequestMetadata(ctx context.Context) (map[string]string, e
return nil, nil
}
func (c *tlsCreds) DialWithDialer(dialer *net.Dialer, network, addr string) (_ net.Conn, err error) {
if c.config.ServerName == "" {
c.config.ServerName, _, err = net.SplitHostPort(addr)
if err != nil {
return nil, fmt.Errorf("credentials: failed to parse server address %v", err)
}
}
return tls.DialWithDialer(dialer, network, addr, &c.config)
}
type timeoutError struct{}
// Dial connects to addr and performs TLS handshake.
func (c *tlsCreds) Dial(network, addr string) (_ net.Conn, err error) {
return c.DialWithDialer(new(net.Dialer), network, addr)
func (timeoutError) Error() string { return "credentials: Dial timed out" }
func (timeoutError) Timeout() bool { return true }
func (timeoutError) Temporary() bool { return true }
func (c *tlsCreds) Dial(dialer func(addr string, timeout time.Duration) (net.Conn, error), addr string, timeout time.Duration) (net.Conn, error) {
// borrow some code from tls.DialWithDialer
var errChannel chan error
if timeout != 0 {
errChannel = make(chan error, 2)
time.AfterFunc(timeout, func() {
errChannel <- timeoutError{}
})
}
rawConn, err := dialer(addr, timeout)
if err != nil {
return nil, err
}
if c.config.ServerName == "" {
colonPos := strings.LastIndex(addr, ":")
if colonPos == -1 {
colonPos = len(addr)
}
c.config.ServerName = addr[:colonPos]
}
conn := tls.Client(rawConn, &c.config)
if timeout == 0 {
err = conn.Handshake()
} else {
go func() {
errChannel <- conn.Handshake()
}()
err = <-errChannel
}
if err != nil {
rawConn.Close()
return nil, err
}
return conn, nil
}
// NewListener creates a net.Listener using the information in tlsCreds.

View File

@ -266,16 +266,21 @@ func TestReconnectTimeout(t *testing.T) {
}
}
func unixDialer(addr string, timeout time.Duration) (net.Conn, error) {
return net.DialTimeout("unix", addr, timeout)
}
type env struct {
network string // The type of network such as tcp, unix, etc.
dialer func(addr string, timeout time.Duration) (net.Conn, error)
security string // The security protocol such as TLS, SSH, etc.
}
func listTestEnv() []env {
if runtime.GOOS == "windows" {
return []env{env{"tcp", ""}, env{"tcp", "tls"}}
return []env{env{"tcp", nil, ""}, env{"tcp", nil, "tls"}}
}
return []env{env{"tcp", ""}, env{"tcp", "tls"}, env{"unix", ""}, env{"unix", "tls"}}
return []env{env{"tcp", nil, ""}, env{"tcp", nil, "tls"}, env{"unix", unixDialer, ""}, env{"unix", unixDialer, "tls"}}
}
func setUp(maxStream uint32, e env) (s *grpc.Server, cc *grpc.ClientConn) {
@ -315,9 +320,9 @@ func setUp(maxStream uint32, e env) (s *grpc.Server, cc *grpc.ClientConn) {
if err != nil {
log.Fatalf("Failed to create credentials %v", err)
}
cc, err = grpc.Dial(addr, grpc.WithTransportCredentials(creds), grpc.WithNetwork(e.network))
cc, err = grpc.Dial(addr, grpc.WithTransportCredentials(creds), grpc.WithDialer(e.dialer))
} else {
cc, err = grpc.Dial(addr, grpc.WithNetwork(e.network))
cc, err = grpc.Dial(addr, grpc.WithDialer(e.dialer))
}
if err != nil {
log.Fatalf("Dial(%q) = %v", addr, err)

View File

@ -98,6 +98,12 @@ type http2Client struct {
// and starts to receive messages on it. Non-nil error returns if construction
// fails.
func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err error) {
if opts.Dialer == nil {
// Set the default Dialer.
opts.Dialer = func(addr string, timeout time.Duration) (net.Conn, error) {
return net.DialTimeout("tcp", addr, timeout)
}
}
var (
connErr error
conn net.Conn
@ -110,12 +116,12 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
// multiple ones provided. Revisit this if it is not appropriate. Probably
// place the ClientTransport construction into a separate function to make
// things clear.
conn, connErr = ccreds.DialWithDialer(&net.Dialer{Timeout: opts.Timeout}, opts.Network, addr)
conn, connErr = ccreds.Dial(opts.Dialer, addr, opts.Timeout)
break
}
}
if scheme == "http" {
conn, connErr = net.DialTimeout(opts.Network, addr, opts.Timeout)
conn, connErr = opts.Dialer(addr, opts.Timeout)
}
if connErr != nil {
return nil, ConnectionErrorf("transport: %v", connErr)

View File

@ -315,9 +315,7 @@ func NewServerTransport(protocol string, conn net.Conn, maxStreams uint32) (Serv
// ConnectOptions covers all relevant options for dialing a server.
type ConnectOptions struct {
// Network indicates the type of network where the connection is established.
// Known networks are "tcp", "tcp4", "tcp6", "unix"
Network string
Dialer func(string, time.Duration) (net.Conn, error)
AuthOptions []credentials.Credentials
Timeout time.Duration
}