remove WithNetwork and add WithDialer to have more flexibility on dialing
This commit is contained in:
@ -37,6 +37,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -96,11 +97,11 @@ func WithTimeout(d time.Duration) DialOption {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithNetwork returns a DialOption that specifies the network on which
|
// WithDialer returns a DialOption that defines a function which takes an
|
||||||
// the connection will be established.
|
// address and turns it into a net.Conn.
|
||||||
func WithNetwork(network string) DialOption {
|
func WithDialer(f func(addr string, timeout time.Duration) (net.Conn, error)) DialOption {
|
||||||
return func(o *dialOptions) {
|
return func(o *dialOptions) {
|
||||||
o.copts.Network = network
|
o.copts.Dialer = f
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -117,24 +118,11 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) {
|
|||||||
for _, opt := range opts {
|
for _, opt := range opts {
|
||||||
opt(&cc.dopts)
|
opt(&cc.dopts)
|
||||||
}
|
}
|
||||||
// Validate the network type
|
colonPos := strings.LastIndex(target, ":")
|
||||||
switch cc.dopts.copts.Network {
|
if colonPos == -1 {
|
||||||
case "":
|
colonPos = len(target)
|
||||||
cc.dopts.copts.Network = "tcp" // Set the default
|
|
||||||
case "tcp", "tcp4", "tcp6", "unix":
|
|
||||||
default:
|
|
||||||
return nil, net.UnknownNetworkError(cc.dopts.copts.Network)
|
|
||||||
}
|
|
||||||
cc.authority = target
|
|
||||||
// Format target for tcp.
|
|
||||||
if cc.dopts.copts.Network != "unix" {
|
|
||||||
// format target for tcp.
|
|
||||||
var err error
|
|
||||||
cc.authority, _, err = net.SplitHostPort(target)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
cc.authority = target[:colonPos]
|
||||||
if cc.dopts.codec == nil {
|
if cc.dopts.codec == nil {
|
||||||
// Set the default codec.
|
// Set the default codec.
|
||||||
cc.dopts.codec = protoCodec{}
|
cc.dopts.codec = protoCodec{}
|
||||||
|
@ -43,6 +43,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net"
|
"net"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"golang.org/x/net/context"
|
"golang.org/x/net/context"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
@ -71,15 +73,10 @@ type Credentials interface {
|
|||||||
// TransportAuthenticator defines the common interface all supported transport
|
// TransportAuthenticator defines the common interface all supported transport
|
||||||
// authentication protocols (e.g., TLS, SSL) must implement.
|
// authentication protocols (e.g., TLS, SSL) must implement.
|
||||||
type TransportAuthenticator interface {
|
type TransportAuthenticator interface {
|
||||||
// Dial connects to the given network address using net.Dial and then
|
// Dial connects to the given network address using dialer and then
|
||||||
// does the authentication handshake specified by the corresponding
|
// does the authentication handshake specified by the corresponding
|
||||||
// authentication protocol.
|
// authentication protocol.
|
||||||
Dial(network, addr string) (net.Conn, error)
|
Dial(dialer func(string, time.Duration) (net.Conn, error), addr string, timeout time.Duration) (net.Conn, error)
|
||||||
// DialWithDialer connects to the given network address using
|
|
||||||
// dialer.Dial does the authentication handshake specified by the
|
|
||||||
// corresponding authentication protocol. Any timeout or deadline
|
|
||||||
// given in the dialer apply to connection and handshake as a whole.
|
|
||||||
DialWithDialer(dialer *net.Dialer, network, addr string) (net.Conn, error)
|
|
||||||
// NewListener creates a listener which accepts connections with requested
|
// NewListener creates a listener which accepts connections with requested
|
||||||
// authentication handshake.
|
// authentication handshake.
|
||||||
NewListener(lis net.Listener) net.Listener
|
NewListener(lis net.Listener) net.Listener
|
||||||
@ -98,19 +95,46 @@ func (c *tlsCreds) GetRequestMetadata(ctx context.Context) (map[string]string, e
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *tlsCreds) DialWithDialer(dialer *net.Dialer, network, addr string) (_ net.Conn, err error) {
|
type timeoutError struct{}
|
||||||
if c.config.ServerName == "" {
|
|
||||||
c.config.ServerName, _, err = net.SplitHostPort(addr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("credentials: failed to parse server address %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return tls.DialWithDialer(dialer, network, addr, &c.config)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Dial connects to addr and performs TLS handshake.
|
func (timeoutError) Error() string { return "credentials: Dial timed out" }
|
||||||
func (c *tlsCreds) Dial(network, addr string) (_ net.Conn, err error) {
|
func (timeoutError) Timeout() bool { return true }
|
||||||
return c.DialWithDialer(new(net.Dialer), network, addr)
|
func (timeoutError) Temporary() bool { return true }
|
||||||
|
|
||||||
|
func (c *tlsCreds) Dial(dialer func(addr string, timeout time.Duration) (net.Conn, error), addr string, timeout time.Duration) (net.Conn, error) {
|
||||||
|
// borrow some code from tls.DialWithDialer
|
||||||
|
var errChannel chan error
|
||||||
|
if timeout != 0 {
|
||||||
|
errChannel = make(chan error, 2)
|
||||||
|
time.AfterFunc(timeout, func() {
|
||||||
|
errChannel <- timeoutError{}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
rawConn, err := dialer(addr, timeout)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if c.config.ServerName == "" {
|
||||||
|
colonPos := strings.LastIndex(addr, ":")
|
||||||
|
if colonPos == -1 {
|
||||||
|
colonPos = len(addr)
|
||||||
|
}
|
||||||
|
c.config.ServerName = addr[:colonPos]
|
||||||
|
}
|
||||||
|
conn := tls.Client(rawConn, &c.config)
|
||||||
|
if timeout == 0 {
|
||||||
|
err = conn.Handshake()
|
||||||
|
} else {
|
||||||
|
go func() {
|
||||||
|
errChannel <- conn.Handshake()
|
||||||
|
}()
|
||||||
|
err = <-errChannel
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
rawConn.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewListener creates a net.Listener using the information in tlsCreds.
|
// NewListener creates a net.Listener using the information in tlsCreds.
|
||||||
|
@ -266,16 +266,21 @@ func TestReconnectTimeout(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func unixDialer(addr string, timeout time.Duration) (net.Conn, error) {
|
||||||
|
return net.DialTimeout("unix", addr, timeout)
|
||||||
|
}
|
||||||
|
|
||||||
type env struct {
|
type env struct {
|
||||||
network string // The type of network such as tcp, unix, etc.
|
network string // The type of network such as tcp, unix, etc.
|
||||||
|
dialer func(addr string, timeout time.Duration) (net.Conn, error)
|
||||||
security string // The security protocol such as TLS, SSH, etc.
|
security string // The security protocol such as TLS, SSH, etc.
|
||||||
}
|
}
|
||||||
|
|
||||||
func listTestEnv() []env {
|
func listTestEnv() []env {
|
||||||
if runtime.GOOS == "windows" {
|
if runtime.GOOS == "windows" {
|
||||||
return []env{env{"tcp", ""}, env{"tcp", "tls"}}
|
return []env{env{"tcp", nil, ""}, env{"tcp", nil, "tls"}}
|
||||||
}
|
}
|
||||||
return []env{env{"tcp", ""}, env{"tcp", "tls"}, env{"unix", ""}, env{"unix", "tls"}}
|
return []env{env{"tcp", nil, ""}, env{"tcp", nil, "tls"}, env{"unix", unixDialer, ""}, env{"unix", unixDialer, "tls"}}
|
||||||
}
|
}
|
||||||
|
|
||||||
func setUp(maxStream uint32, e env) (s *grpc.Server, cc *grpc.ClientConn) {
|
func setUp(maxStream uint32, e env) (s *grpc.Server, cc *grpc.ClientConn) {
|
||||||
@ -315,9 +320,9 @@ func setUp(maxStream uint32, e env) (s *grpc.Server, cc *grpc.ClientConn) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Failed to create credentials %v", err)
|
log.Fatalf("Failed to create credentials %v", err)
|
||||||
}
|
}
|
||||||
cc, err = grpc.Dial(addr, grpc.WithTransportCredentials(creds), grpc.WithNetwork(e.network))
|
cc, err = grpc.Dial(addr, grpc.WithTransportCredentials(creds), grpc.WithDialer(e.dialer))
|
||||||
} else {
|
} else {
|
||||||
cc, err = grpc.Dial(addr, grpc.WithNetwork(e.network))
|
cc, err = grpc.Dial(addr, grpc.WithDialer(e.dialer))
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Dial(%q) = %v", addr, err)
|
log.Fatalf("Dial(%q) = %v", addr, err)
|
||||||
|
@ -98,6 +98,12 @@ type http2Client struct {
|
|||||||
// and starts to receive messages on it. Non-nil error returns if construction
|
// and starts to receive messages on it. Non-nil error returns if construction
|
||||||
// fails.
|
// fails.
|
||||||
func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err error) {
|
func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err error) {
|
||||||
|
if opts.Dialer == nil {
|
||||||
|
// Set the default Dialer.
|
||||||
|
opts.Dialer = func(addr string, timeout time.Duration) (net.Conn, error) {
|
||||||
|
return net.DialTimeout("tcp", addr, timeout)
|
||||||
|
}
|
||||||
|
}
|
||||||
var (
|
var (
|
||||||
connErr error
|
connErr error
|
||||||
conn net.Conn
|
conn net.Conn
|
||||||
@ -110,12 +116,12 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
|
|||||||
// multiple ones provided. Revisit this if it is not appropriate. Probably
|
// multiple ones provided. Revisit this if it is not appropriate. Probably
|
||||||
// place the ClientTransport construction into a separate function to make
|
// place the ClientTransport construction into a separate function to make
|
||||||
// things clear.
|
// things clear.
|
||||||
conn, connErr = ccreds.DialWithDialer(&net.Dialer{Timeout: opts.Timeout}, opts.Network, addr)
|
conn, connErr = ccreds.Dial(opts.Dialer, addr, opts.Timeout)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if scheme == "http" {
|
if scheme == "http" {
|
||||||
conn, connErr = net.DialTimeout(opts.Network, addr, opts.Timeout)
|
conn, connErr = opts.Dialer(addr, opts.Timeout)
|
||||||
}
|
}
|
||||||
if connErr != nil {
|
if connErr != nil {
|
||||||
return nil, ConnectionErrorf("transport: %v", connErr)
|
return nil, ConnectionErrorf("transport: %v", connErr)
|
||||||
|
@ -315,9 +315,7 @@ func NewServerTransport(protocol string, conn net.Conn, maxStreams uint32) (Serv
|
|||||||
|
|
||||||
// ConnectOptions covers all relevant options for dialing a server.
|
// ConnectOptions covers all relevant options for dialing a server.
|
||||||
type ConnectOptions struct {
|
type ConnectOptions struct {
|
||||||
// Network indicates the type of network where the connection is established.
|
Dialer func(string, time.Duration) (net.Conn, error)
|
||||||
// Known networks are "tcp", "tcp4", "tcp6", "unix"
|
|
||||||
Network string
|
|
||||||
AuthOptions []credentials.Credentials
|
AuthOptions []credentials.Credentials
|
||||||
Timeout time.Duration
|
Timeout time.Duration
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user