Merge pull request #751 from tamird/dialer-cancel

cancel outgoing net.Dial when ClientConn is closed
This commit is contained in:
Qi Zhao
2016-07-28 11:19:15 -07:00
committed by GitHub
7 changed files with 134 additions and 47 deletions

View File

@ -196,7 +196,7 @@ func WithTimeout(d time.Duration) DialOption {
} }
// WithDialer returns a DialOption that specifies a function to use for dialing network addresses. // WithDialer returns a DialOption that specifies a function to use for dialing network addresses.
func WithDialer(f func(addr string, timeout time.Duration) (net.Conn, error)) DialOption { func WithDialer(f func(string, time.Duration, <-chan struct{}) (net.Conn, error)) DialOption {
return func(o *dialOptions) { return func(o *dialOptions) {
o.copts.Dialer = f o.copts.Dialer = f
} }
@ -364,8 +364,8 @@ func (cc *ClientConn) newAddrConn(addr Address, skipWait bool) error {
cc: cc, cc: cc,
addr: addr, addr: addr,
dopts: cc.dopts, dopts: cc.dopts,
shutdownChan: make(chan struct{}),
} }
ac.dopts.copts.Cancel = make(chan struct{})
if EnableTracing { if EnableTracing {
ac.events = trace.NewEventLog("grpc.ClientConn", ac.addr.Addr) ac.events = trace.NewEventLog("grpc.ClientConn", ac.addr.Addr)
} }
@ -471,7 +471,6 @@ type addrConn struct {
cc *ClientConn cc *ClientConn
addr Address addr Address
dopts dialOptions dopts dialOptions
shutdownChan chan struct{}
events trace.EventLog events trace.EventLog
mu sync.Mutex mu sync.Mutex
@ -558,12 +557,13 @@ func (ac *addrConn) resetTransport(closeTransport bool) error {
t.Close() t.Close()
} }
sleepTime := ac.dopts.bs.backoff(retries) sleepTime := ac.dopts.bs.backoff(retries)
ac.dopts.copts.Timeout = sleepTime copts := ac.dopts.copts
copts.Timeout = sleepTime
if sleepTime < minConnectTimeout { if sleepTime < minConnectTimeout {
ac.dopts.copts.Timeout = minConnectTimeout copts.Timeout = minConnectTimeout
} }
connectTime := time.Now() connectTime := time.Now()
newTransport, err := transport.NewClientTransport(ac.addr.Addr, &ac.dopts.copts) newTransport, err := transport.NewClientTransport(ac.addr.Addr, copts)
if err != nil { if err != nil {
ac.mu.Lock() ac.mu.Lock()
if ac.state == Shutdown { if ac.state == Shutdown {
@ -586,7 +586,7 @@ func (ac *addrConn) resetTransport(closeTransport bool) error {
closeTransport = false closeTransport = false
select { select {
case <-time.After(sleepTime): case <-time.After(sleepTime):
case <-ac.shutdownChan: case <-ac.dopts.copts.Cancel:
} }
retries++ retries++
grpclog.Printf("grpc: addrConn.resetTransport failed to create client transport: %v; Reconnecting to %q", err, ac.addr) grpclog.Printf("grpc: addrConn.resetTransport failed to create client transport: %v; Reconnecting to %q", err, ac.addr)
@ -621,9 +621,9 @@ func (ac *addrConn) transportMonitor() {
t := ac.transport t := ac.transport
ac.mu.Unlock() ac.mu.Unlock()
select { select {
// shutdownChan is needed to detect the teardown when // Cancel is needed to detect the teardown when
// the addrConn is idle (i.e., no RPC in flight). // the addrConn is idle (i.e., no RPC in flight).
case <-ac.shutdownChan: case <-ac.dopts.copts.Cancel:
return return
case <-t.GoAway(): case <-t.GoAway():
ac.tearDown(errConnDrain) ac.tearDown(errConnDrain)
@ -724,8 +724,8 @@ func (ac *addrConn) tearDown(err error) {
if ac.transport != nil && err != errConnDrain { if ac.transport != nil && err != errConnDrain {
ac.transport.Close() ac.transport.Close()
} }
if ac.shutdownChan != nil { if ac.dopts.copts.Cancel != nil {
close(ac.shutdownChan) close(ac.dopts.copts.Cancel)
} }
return return
} }

View File

@ -300,39 +300,35 @@ func (s *testServer) HalfDuplexCall(stream testpb.TestService_HalfDuplexCallServ
const tlsDir = "testdata/" const tlsDir = "testdata/"
func unixDialer(addr string, timeout time.Duration) (net.Conn, error) {
return net.DialTimeout("unix", addr, timeout)
}
type env struct { type env struct {
name string name string
network string // The type of network such as tcp, unix, etc. network string // The type of network such as tcp, unix, etc.
dialer func(addr string, timeout time.Duration) (net.Conn, error)
security string // The security protocol such as TLS, SSH, etc. security string // The security protocol such as TLS, SSH, etc.
httpHandler bool // whether to use the http.Handler ServerTransport; requires TLS httpHandler bool // whether to use the http.Handler ServerTransport; requires TLS
} }
func (e env) runnable() bool { func (e env) runnable() bool {
if runtime.GOOS == "windows" && strings.HasPrefix(e.name, "unix-") { if runtime.GOOS == "windows" && e.network == "unix" {
return false return false
} }
return true return true
} }
func (e env) getDialer() func(addr string, timeout time.Duration) (net.Conn, error) { func (e env) dialer(addr string, timeout time.Duration, cancel <-chan struct{}) (net.Conn, error) {
if e.dialer != nil { // NB: Go 1.6 added a Cancel field on net.Dialer, which would allow this
return e.dialer // to be written as
} //
return func(addr string, timeout time.Duration) (net.Conn, error) { // `(&net.Dialer{Cancel: cancel, Timeout: timeout}).Dial(e.network, addr)`
return net.DialTimeout("tcp", addr, timeout) //
} // but that would break compatibility with earlier Go versions.
return net.DialTimeout(e.network, addr, timeout)
} }
var ( var (
tcpClearEnv = env{name: "tcp-clear", network: "tcp"} tcpClearEnv = env{name: "tcp-clear", network: "tcp"}
tcpTLSEnv = env{name: "tcp-tls", network: "tcp", security: "tls"} tcpTLSEnv = env{name: "tcp-tls", network: "tcp", security: "tls"}
unixClearEnv = env{name: "unix-clear", network: "unix", dialer: unixDialer} unixClearEnv = env{name: "unix-clear", network: "unix"}
unixTLSEnv = env{name: "unix-tls", network: "unix", dialer: unixDialer, security: "tls"} unixTLSEnv = env{name: "unix-tls", network: "unix", security: "tls"}
handlerEnv = env{name: "handler-tls", network: "tcp", security: "tls", httpHandler: true} handlerEnv = env{name: "handler-tls", network: "tcp", security: "tls", httpHandler: true}
allEnv = []env{tcpClearEnv, tcpTLSEnv, unixClearEnv, unixTLSEnv, handlerEnv} allEnv = []env{tcpClearEnv, tcpTLSEnv, unixClearEnv, unixTLSEnv, handlerEnv}
) )
@ -515,9 +511,7 @@ func (te *test) declareLogNoise(phrases ...string) {
} }
func (te *test) withServerTester(fn func(st *serverTester)) { func (te *test) withServerTester(fn func(st *serverTester)) {
var c net.Conn c, err := te.e.dialer(te.srvAddr, 10*time.Second, nil)
var err error
c, err = te.e.getDialer()(te.srvAddr, 10*time.Second)
if err != nil { if err != nil {
te.t.Fatal(err) te.t.Fatal(err)
} }

45
transport/go16.go Normal file
View File

@ -0,0 +1,45 @@
// +build go1.6
/*
* Copyright 2014, Google Inc.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
*/
package transport
import (
"net"
"time"
)
// newDialer constructs a net.Dialer.
func newDialer(timeout time.Duration, cancel <-chan struct{}) *net.Dialer {
return &net.Dialer{Cancel: cancel, Timeout: timeout}
}

View File

@ -107,20 +107,21 @@ type http2Client struct {
prevGoAwayID uint32 prevGoAwayID uint32
} }
func dial(fn func(string, time.Duration, <-chan struct{}) (net.Conn, error), addr string, timeout time.Duration, cancel <-chan struct{}) (net.Conn, error) {
if fn != nil {
return fn(addr, timeout, cancel)
}
return newDialer(timeout, cancel).Dial("tcp", addr)
}
// newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2 // newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2
// and starts to receive messages on it. Non-nil error returns if construction // and starts to receive messages on it. Non-nil error returns if construction
// fails. // fails.
func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err error) { func newHTTP2Client(addr string, opts ConnectOptions) (_ ClientTransport, err error) {
if opts.Dialer == nil {
// Set the default Dialer.
opts.Dialer = func(addr string, timeout time.Duration) (net.Conn, error) {
return net.DialTimeout("tcp", addr, timeout)
}
}
scheme := "http" scheme := "http"
startT := time.Now() startT := time.Now()
timeout := opts.Timeout timeout := opts.Timeout
conn, connErr := opts.Dialer(addr, timeout) conn, connErr := dial(opts.Dialer, addr, timeout, opts.Cancel)
if connErr != nil { if connErr != nil {
return nil, ConnectionErrorf("transport: %v", connErr) return nil, ConnectionErrorf("transport: %v", connErr)
} }

45
transport/pre_go16.go Normal file
View File

@ -0,0 +1,45 @@
// +build !go1.6
/*
* Copyright 2016, Google Inc.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
*/
package transport
import (
"net"
"time"
)
// newDialer constructs a net.Dialer.
func newDialer(timeout time.Duration, _ <-chan struct{}) *net.Dialer {
return &net.Dialer{Timeout: timeout}
}

View File

@ -353,8 +353,10 @@ func NewServerTransport(protocol string, conn net.Conn, maxStreams uint32, authI
type ConnectOptions struct { type ConnectOptions struct {
// UserAgent is the application user agent. // UserAgent is the application user agent.
UserAgent string UserAgent string
// Cancel is closed to indicate that dialing should be cancelled.
Cancel chan struct{}
// Dialer specifies how to dial a network address. // Dialer specifies how to dial a network address.
Dialer func(string, time.Duration) (net.Conn, error) Dialer func(string, time.Duration, <-chan struct{}) (net.Conn, error)
// PerRPCCredentials stores the PerRPCCredentials required to issue RPCs. // PerRPCCredentials stores the PerRPCCredentials required to issue RPCs.
PerRPCCredentials []credentials.PerRPCCredentials PerRPCCredentials []credentials.PerRPCCredentials
// TransportCredentials stores the Authenticator required to setup a client connection. // TransportCredentials stores the Authenticator required to setup a client connection.
@ -365,7 +367,7 @@ type ConnectOptions struct {
// NewClientTransport establishes the transport with the required ConnectOptions // NewClientTransport establishes the transport with the required ConnectOptions
// and returns it to the caller. // and returns it to the caller.
func NewClientTransport(target string, opts *ConnectOptions) (ClientTransport, error) { func NewClientTransport(target string, opts ConnectOptions) (ClientTransport, error) {
return newHTTP2Client(target, opts) return newHTTP2Client(target, opts)
} }

View File

@ -220,7 +220,7 @@ func setUp(t *testing.T, port int, maxStreams uint32, ht hType) (*server, Client
ct ClientTransport ct ClientTransport
connErr error connErr error
) )
ct, connErr = NewClientTransport(addr, &ConnectOptions{}) ct, connErr = NewClientTransport(addr, ConnectOptions{})
if connErr != nil { if connErr != nil {
t.Fatalf("failed to create transport: %v", connErr) t.Fatalf("failed to create transport: %v", connErr)
} }