diff --git a/clientconn.go b/clientconn.go index 3206d674..214fb900 100644 --- a/clientconn.go +++ b/clientconn.go @@ -196,7 +196,7 @@ func WithTimeout(d time.Duration) DialOption { } // WithDialer returns a DialOption that specifies a function to use for dialing network addresses. -func WithDialer(f func(addr string, timeout time.Duration) (net.Conn, error)) DialOption { +func WithDialer(f func(string, time.Duration, <-chan struct{}) (net.Conn, error)) DialOption { return func(o *dialOptions) { o.copts.Dialer = f } @@ -361,11 +361,11 @@ func (cc *ClientConn) lbWatcher() { func (cc *ClientConn) newAddrConn(addr Address, skipWait bool) error { ac := &addrConn{ - cc: cc, - addr: addr, - dopts: cc.dopts, - shutdownChan: make(chan struct{}), + cc: cc, + addr: addr, + dopts: cc.dopts, } + ac.dopts.copts.Cancel = make(chan struct{}) if EnableTracing { ac.events = trace.NewEventLog("grpc.ClientConn", ac.addr.Addr) } @@ -468,11 +468,10 @@ func (cc *ClientConn) Close() error { // addrConn is a network connection to a given address. type addrConn struct { - cc *ClientConn - addr Address - dopts dialOptions - shutdownChan chan struct{} - events trace.EventLog + cc *ClientConn + addr Address + dopts dialOptions + events trace.EventLog mu sync.Mutex state ConnectivityState @@ -558,12 +557,13 @@ func (ac *addrConn) resetTransport(closeTransport bool) error { t.Close() } sleepTime := ac.dopts.bs.backoff(retries) - ac.dopts.copts.Timeout = sleepTime + copts := ac.dopts.copts + copts.Timeout = sleepTime if sleepTime < minConnectTimeout { - ac.dopts.copts.Timeout = minConnectTimeout + copts.Timeout = minConnectTimeout } connectTime := time.Now() - newTransport, err := transport.NewClientTransport(ac.addr.Addr, &ac.dopts.copts) + newTransport, err := transport.NewClientTransport(ac.addr.Addr, copts) if err != nil { ac.mu.Lock() if ac.state == Shutdown { @@ -586,7 +586,7 @@ func (ac *addrConn) resetTransport(closeTransport bool) error { closeTransport = false select { case <-time.After(sleepTime): - case <-ac.shutdownChan: + case <-ac.dopts.copts.Cancel: } retries++ grpclog.Printf("grpc: addrConn.resetTransport failed to create client transport: %v; Reconnecting to %q", err, ac.addr) @@ -621,9 +621,9 @@ func (ac *addrConn) transportMonitor() { t := ac.transport ac.mu.Unlock() select { - // shutdownChan is needed to detect the teardown when + // Cancel is needed to detect the teardown when // the addrConn is idle (i.e., no RPC in flight). - case <-ac.shutdownChan: + case <-ac.dopts.copts.Cancel: return case <-t.GoAway(): ac.tearDown(errConnDrain) @@ -724,8 +724,8 @@ func (ac *addrConn) tearDown(err error) { if ac.transport != nil && err != errConnDrain { ac.transport.Close() } - if ac.shutdownChan != nil { - close(ac.shutdownChan) + if ac.dopts.copts.Cancel != nil { + close(ac.dopts.copts.Cancel) } return } diff --git a/test/end2end_test.go b/test/end2end_test.go index cdbc4c55..81bc1559 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -300,39 +300,35 @@ func (s *testServer) HalfDuplexCall(stream testpb.TestService_HalfDuplexCallServ const tlsDir = "testdata/" -func unixDialer(addr string, timeout time.Duration) (net.Conn, error) { - return net.DialTimeout("unix", addr, timeout) -} - type env struct { name string network string // The type of network such as tcp, unix, etc. - dialer func(addr string, timeout time.Duration) (net.Conn, error) security string // The security protocol such as TLS, SSH, etc. httpHandler bool // whether to use the http.Handler ServerTransport; requires TLS } func (e env) runnable() bool { - if runtime.GOOS == "windows" && strings.HasPrefix(e.name, "unix-") { + if runtime.GOOS == "windows" && e.network == "unix" { return false } return true } -func (e env) getDialer() func(addr string, timeout time.Duration) (net.Conn, error) { - if e.dialer != nil { - return e.dialer - } - return func(addr string, timeout time.Duration) (net.Conn, error) { - return net.DialTimeout("tcp", addr, timeout) - } +func (e env) dialer(addr string, timeout time.Duration, cancel <-chan struct{}) (net.Conn, error) { + // NB: Go 1.6 added a Cancel field on net.Dialer, which would allow this + // to be written as + // + // `(&net.Dialer{Cancel: cancel, Timeout: timeout}).Dial(e.network, addr)` + // + // but that would break compatibility with earlier Go versions. + return net.DialTimeout(e.network, addr, timeout) } var ( tcpClearEnv = env{name: "tcp-clear", network: "tcp"} tcpTLSEnv = env{name: "tcp-tls", network: "tcp", security: "tls"} - unixClearEnv = env{name: "unix-clear", network: "unix", dialer: unixDialer} - unixTLSEnv = env{name: "unix-tls", network: "unix", dialer: unixDialer, security: "tls"} + unixClearEnv = env{name: "unix-clear", network: "unix"} + unixTLSEnv = env{name: "unix-tls", network: "unix", security: "tls"} handlerEnv = env{name: "handler-tls", network: "tcp", security: "tls", httpHandler: true} allEnv = []env{tcpClearEnv, tcpTLSEnv, unixClearEnv, unixTLSEnv, handlerEnv} ) @@ -515,9 +511,7 @@ func (te *test) declareLogNoise(phrases ...string) { } func (te *test) withServerTester(fn func(st *serverTester)) { - var c net.Conn - var err error - c, err = te.e.getDialer()(te.srvAddr, 10*time.Second) + c, err := te.e.dialer(te.srvAddr, 10*time.Second, nil) if err != nil { te.t.Fatal(err) } diff --git a/transport/go16.go b/transport/go16.go new file mode 100644 index 00000000..c0d051ef --- /dev/null +++ b/transport/go16.go @@ -0,0 +1,45 @@ +// +build go1.6 + +/* + * Copyright 2014, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +package transport + +import ( + "net" + "time" +) + +// newDialer constructs a net.Dialer. +func newDialer(timeout time.Duration, cancel <-chan struct{}) *net.Dialer { + return &net.Dialer{Cancel: cancel, Timeout: timeout} +} diff --git a/transport/http2_client.go b/transport/http2_client.go index 41c0acf0..bcfcdf0a 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -107,20 +107,21 @@ type http2Client struct { prevGoAwayID uint32 } +func dial(fn func(string, time.Duration, <-chan struct{}) (net.Conn, error), addr string, timeout time.Duration, cancel <-chan struct{}) (net.Conn, error) { + if fn != nil { + return fn(addr, timeout, cancel) + } + return newDialer(timeout, cancel).Dial("tcp", addr) +} + // newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2 // and starts to receive messages on it. Non-nil error returns if construction // fails. -func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err error) { - if opts.Dialer == nil { - // Set the default Dialer. - opts.Dialer = func(addr string, timeout time.Duration) (net.Conn, error) { - return net.DialTimeout("tcp", addr, timeout) - } - } +func newHTTP2Client(addr string, opts ConnectOptions) (_ ClientTransport, err error) { scheme := "http" startT := time.Now() timeout := opts.Timeout - conn, connErr := opts.Dialer(addr, timeout) + conn, connErr := dial(opts.Dialer, addr, timeout, opts.Cancel) if connErr != nil { return nil, ConnectionErrorf("transport: %v", connErr) } diff --git a/transport/pre_go16.go b/transport/pre_go16.go new file mode 100644 index 00000000..126bfbd8 --- /dev/null +++ b/transport/pre_go16.go @@ -0,0 +1,45 @@ +// +build !go1.6 + +/* + * Copyright 2016, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +package transport + +import ( + "net" + "time" +) + +// newDialer constructs a net.Dialer. +func newDialer(timeout time.Duration, _ <-chan struct{}) *net.Dialer { + return &net.Dialer{Timeout: timeout} +} diff --git a/transport/transport.go b/transport/transport.go index 9dade654..86c8fcd6 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -353,8 +353,10 @@ func NewServerTransport(protocol string, conn net.Conn, maxStreams uint32, authI type ConnectOptions struct { // UserAgent is the application user agent. UserAgent string + // Cancel is closed to indicate that dialing should be cancelled. + Cancel chan struct{} // Dialer specifies how to dial a network address. - Dialer func(string, time.Duration) (net.Conn, error) + Dialer func(string, time.Duration, <-chan struct{}) (net.Conn, error) // PerRPCCredentials stores the PerRPCCredentials required to issue RPCs. PerRPCCredentials []credentials.PerRPCCredentials // TransportCredentials stores the Authenticator required to setup a client connection. @@ -365,7 +367,7 @@ type ConnectOptions struct { // NewClientTransport establishes the transport with the required ConnectOptions // and returns it to the caller. -func NewClientTransport(target string, opts *ConnectOptions) (ClientTransport, error) { +func NewClientTransport(target string, opts ConnectOptions) (ClientTransport, error) { return newHTTP2Client(target, opts) } diff --git a/transport/transport_test.go b/transport/transport_test.go index 047e6543..6f9cc50c 100644 --- a/transport/transport_test.go +++ b/transport/transport_test.go @@ -220,7 +220,7 @@ func setUp(t *testing.T, port int, maxStreams uint32, ht hType) (*server, Client ct ClientTransport connErr error ) - ct, connErr = NewClientTransport(addr, &ConnectOptions{}) + ct, connErr = NewClientTransport(addr, ConnectOptions{}) if connErr != nil { t.Fatalf("failed to create transport: %v", connErr) }