remove WithNetwork and add WithDialer to have more flexibility on dialing
This commit is contained in:
@ -37,6 +37,7 @@ import (
|
||||
"errors"
|
||||
"log"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@ -96,11 +97,11 @@ func WithTimeout(d time.Duration) DialOption {
|
||||
}
|
||||
}
|
||||
|
||||
// WithNetwork returns a DialOption that specifies the network on which
|
||||
// the connection will be established.
|
||||
func WithNetwork(network string) DialOption {
|
||||
// WithDialer returns a DialOption that defines a function which takes an
|
||||
// address and turns it into a net.Conn.
|
||||
func WithDialer(f func(addr string, timeout time.Duration) (net.Conn, error)) DialOption {
|
||||
return func(o *dialOptions) {
|
||||
o.copts.Network = network
|
||||
o.copts.Dialer = f
|
||||
}
|
||||
}
|
||||
|
||||
@ -117,24 +118,11 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) {
|
||||
for _, opt := range opts {
|
||||
opt(&cc.dopts)
|
||||
}
|
||||
// Validate the network type
|
||||
switch cc.dopts.copts.Network {
|
||||
case "":
|
||||
cc.dopts.copts.Network = "tcp" // Set the default
|
||||
case "tcp", "tcp4", "tcp6", "unix":
|
||||
default:
|
||||
return nil, net.UnknownNetworkError(cc.dopts.copts.Network)
|
||||
}
|
||||
cc.authority = target
|
||||
// Format target for tcp.
|
||||
if cc.dopts.copts.Network != "unix" {
|
||||
// format target for tcp.
|
||||
var err error
|
||||
cc.authority, _, err = net.SplitHostPort(target)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
colonPos := strings.LastIndex(target, ":")
|
||||
if colonPos == -1 {
|
||||
colonPos = len(target)
|
||||
}
|
||||
cc.authority = target[:colonPos]
|
||||
if cc.dopts.codec == nil {
|
||||
// Set the default codec.
|
||||
cc.dopts.codec = protoCodec{}
|
||||
|
@ -43,6 +43,8 @@ import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
"golang.org/x/oauth2"
|
||||
@ -71,15 +73,10 @@ type Credentials interface {
|
||||
// TransportAuthenticator defines the common interface all supported transport
|
||||
// authentication protocols (e.g., TLS, SSL) must implement.
|
||||
type TransportAuthenticator interface {
|
||||
// Dial connects to the given network address using net.Dial and then
|
||||
// Dial connects to the given network address using dialer and then
|
||||
// does the authentication handshake specified by the corresponding
|
||||
// authentication protocol.
|
||||
Dial(network, addr string) (net.Conn, error)
|
||||
// DialWithDialer connects to the given network address using
|
||||
// dialer.Dial does the authentication handshake specified by the
|
||||
// corresponding authentication protocol. Any timeout or deadline
|
||||
// given in the dialer apply to connection and handshake as a whole.
|
||||
DialWithDialer(dialer *net.Dialer, network, addr string) (net.Conn, error)
|
||||
Dial(dialer func(string, time.Duration) (net.Conn, error), addr string, timeout time.Duration) (net.Conn, error)
|
||||
// NewListener creates a listener which accepts connections with requested
|
||||
// authentication handshake.
|
||||
NewListener(lis net.Listener) net.Listener
|
||||
@ -98,19 +95,46 @@ func (c *tlsCreds) GetRequestMetadata(ctx context.Context) (map[string]string, e
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (c *tlsCreds) DialWithDialer(dialer *net.Dialer, network, addr string) (_ net.Conn, err error) {
|
||||
if c.config.ServerName == "" {
|
||||
c.config.ServerName, _, err = net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("credentials: failed to parse server address %v", err)
|
||||
}
|
||||
}
|
||||
return tls.DialWithDialer(dialer, network, addr, &c.config)
|
||||
}
|
||||
type timeoutError struct{}
|
||||
|
||||
// Dial connects to addr and performs TLS handshake.
|
||||
func (c *tlsCreds) Dial(network, addr string) (_ net.Conn, err error) {
|
||||
return c.DialWithDialer(new(net.Dialer), network, addr)
|
||||
func (timeoutError) Error() string { return "credentials: Dial timed out" }
|
||||
func (timeoutError) Timeout() bool { return true }
|
||||
func (timeoutError) Temporary() bool { return true }
|
||||
|
||||
func (c *tlsCreds) Dial(dialer func(addr string, timeout time.Duration) (net.Conn, error), addr string, timeout time.Duration) (net.Conn, error) {
|
||||
// borrow some code from tls.DialWithDialer
|
||||
var errChannel chan error
|
||||
if timeout != 0 {
|
||||
errChannel = make(chan error, 2)
|
||||
time.AfterFunc(timeout, func() {
|
||||
errChannel <- timeoutError{}
|
||||
})
|
||||
}
|
||||
rawConn, err := dialer(addr, timeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if c.config.ServerName == "" {
|
||||
colonPos := strings.LastIndex(addr, ":")
|
||||
if colonPos == -1 {
|
||||
colonPos = len(addr)
|
||||
}
|
||||
c.config.ServerName = addr[:colonPos]
|
||||
}
|
||||
conn := tls.Client(rawConn, &c.config)
|
||||
if timeout == 0 {
|
||||
err = conn.Handshake()
|
||||
} else {
|
||||
go func() {
|
||||
errChannel <- conn.Handshake()
|
||||
}()
|
||||
err = <-errChannel
|
||||
}
|
||||
if err != nil {
|
||||
rawConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// NewListener creates a net.Listener using the information in tlsCreds.
|
||||
|
@ -266,16 +266,21 @@ func TestReconnectTimeout(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func unixDialer(addr string, timeout time.Duration) (net.Conn, error) {
|
||||
return net.DialTimeout("unix", addr, timeout)
|
||||
}
|
||||
|
||||
type env struct {
|
||||
network string // The type of network such as tcp, unix, etc.
|
||||
dialer func(addr string, timeout time.Duration) (net.Conn, error)
|
||||
security string // The security protocol such as TLS, SSH, etc.
|
||||
}
|
||||
|
||||
func listTestEnv() []env {
|
||||
if runtime.GOOS == "windows" {
|
||||
return []env{env{"tcp", ""}, env{"tcp", "tls"}}
|
||||
return []env{env{"tcp", nil, ""}, env{"tcp", nil, "tls"}}
|
||||
}
|
||||
return []env{env{"tcp", ""}, env{"tcp", "tls"}, env{"unix", ""}, env{"unix", "tls"}}
|
||||
return []env{env{"tcp", nil, ""}, env{"tcp", nil, "tls"}, env{"unix", unixDialer, ""}, env{"unix", unixDialer, "tls"}}
|
||||
}
|
||||
|
||||
func setUp(maxStream uint32, e env) (s *grpc.Server, cc *grpc.ClientConn) {
|
||||
@ -315,9 +320,9 @@ func setUp(maxStream uint32, e env) (s *grpc.Server, cc *grpc.ClientConn) {
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to create credentials %v", err)
|
||||
}
|
||||
cc, err = grpc.Dial(addr, grpc.WithTransportCredentials(creds), grpc.WithNetwork(e.network))
|
||||
cc, err = grpc.Dial(addr, grpc.WithTransportCredentials(creds), grpc.WithDialer(e.dialer))
|
||||
} else {
|
||||
cc, err = grpc.Dial(addr, grpc.WithNetwork(e.network))
|
||||
cc, err = grpc.Dial(addr, grpc.WithDialer(e.dialer))
|
||||
}
|
||||
if err != nil {
|
||||
log.Fatalf("Dial(%q) = %v", addr, err)
|
||||
|
@ -98,6 +98,12 @@ type http2Client struct {
|
||||
// and starts to receive messages on it. Non-nil error returns if construction
|
||||
// fails.
|
||||
func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err error) {
|
||||
if opts.Dialer == nil {
|
||||
// Set the default Dialer.
|
||||
opts.Dialer = func(addr string, timeout time.Duration) (net.Conn, error) {
|
||||
return net.DialTimeout("tcp", addr, timeout)
|
||||
}
|
||||
}
|
||||
var (
|
||||
connErr error
|
||||
conn net.Conn
|
||||
@ -110,12 +116,12 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
|
||||
// multiple ones provided. Revisit this if it is not appropriate. Probably
|
||||
// place the ClientTransport construction into a separate function to make
|
||||
// things clear.
|
||||
conn, connErr = ccreds.DialWithDialer(&net.Dialer{Timeout: opts.Timeout}, opts.Network, addr)
|
||||
conn, connErr = ccreds.Dial(opts.Dialer, addr, opts.Timeout)
|
||||
break
|
||||
}
|
||||
}
|
||||
if scheme == "http" {
|
||||
conn, connErr = net.DialTimeout(opts.Network, addr, opts.Timeout)
|
||||
conn, connErr = opts.Dialer(addr, opts.Timeout)
|
||||
}
|
||||
if connErr != nil {
|
||||
return nil, ConnectionErrorf("transport: %v", connErr)
|
||||
|
@ -315,9 +315,7 @@ func NewServerTransport(protocol string, conn net.Conn, maxStreams uint32) (Serv
|
||||
|
||||
// ConnectOptions covers all relevant options for dialing a server.
|
||||
type ConnectOptions struct {
|
||||
// Network indicates the type of network where the connection is established.
|
||||
// Known networks are "tcp", "tcp4", "tcp6", "unix"
|
||||
Network string
|
||||
Dialer func(string, time.Duration) (net.Conn, error)
|
||||
AuthOptions []credentials.Credentials
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
Reference in New Issue
Block a user