Merge pull request #93 from iamqizhao/master

Support (re)connect time-out
This commit is contained in:
Qi Zhao
2015-03-05 15:53:32 -08:00
7 changed files with 169 additions and 44 deletions

View File

@ -137,15 +137,15 @@ func Invoke(ctx context.Context, method string, args, reply proto.Message, cc *C
) )
// TODO(zhaoq): Need a formal spec of retry strategy for non-failfast rpcs. // TODO(zhaoq): Need a formal spec of retry strategy for non-failfast rpcs.
if lastErr != nil && c.failFast { if lastErr != nil && c.failFast {
return lastErr return toRPCErr(lastErr)
} }
t, ts, err = cc.wait(ctx, ts) t, ts, err = cc.wait(ctx, ts)
if err != nil { if err != nil {
if lastErr != nil { if lastErr != nil {
// This was a retry; return the error from the last attempt. // This was a retry; return the error from the last attempt.
return lastErr return toRPCErr(lastErr)
} }
return err return Errorf(codes.Internal, "%v", err)
} }
stream, err = sendRPC(ctx, callHdr, t, args, topts) stream, err = sendRPC(ctx, callHdr, t, args, topts)
if err != nil { if err != nil {

View File

@ -50,30 +50,34 @@ var (
// ErrClientConnClosing indicates that the operation is illegal because // ErrClientConnClosing indicates that the operation is illegal because
// the session is closing. // the session is closing.
ErrClientConnClosing = errors.New("grpc: the client connection is closing") ErrClientConnClosing = errors.New("grpc: the client connection is closing")
// ErrClientConnTimeout indicates that the connection could not be
// established or re-established within the specified timeout.
ErrClientConnTimeout = errors.New("grpc: timed out trying to connect")
) )
type dialOptions struct { // DialOption configures how we set up the connection.
protocol string type DialOption func(*transport.DialOptions)
authOptions []credentials.Credentials
}
// DialOption configures how we set up the connection including auth
// credentials.
type DialOption func(*dialOptions)
// WithTransportCredentials returns a DialOption which configures a // WithTransportCredentials returns a DialOption which configures a
// connection level security credentials (e.g., TLS/SSL). // connection level security credentials (e.g., TLS/SSL).
func WithTransportCredentials(creds credentials.TransportAuthenticator) DialOption { func WithTransportCredentials(creds credentials.TransportAuthenticator) DialOption {
return func(o *dialOptions) { return func(o *transport.DialOptions) {
o.authOptions = append(o.authOptions, creds) o.AuthOptions = append(o.AuthOptions, creds)
} }
} }
// WithPerRPCCredentials returns a DialOption which sets // WithPerRPCCredentials returns a DialOption which sets
// credentials which will place auth state on each outbound RPC. // credentials which will place auth state on each outbound RPC.
func WithPerRPCCredentials(creds credentials.Credentials) DialOption { func WithPerRPCCredentials(creds credentials.Credentials) DialOption {
return func(o *dialOptions) { return func(o *transport.DialOptions) {
o.authOptions = append(o.authOptions, creds) o.AuthOptions = append(o.AuthOptions, creds)
}
}
// WithTimeout returns a DialOption that configures a timeout for dialing a client connection.
func WithTimeout(d time.Duration) DialOption {
return func(o *transport.DialOptions) {
o.Timeout = d
} }
} }
@ -102,11 +106,12 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) {
// ClientConn represents a client connection to an RPC service. // ClientConn represents a client connection to an RPC service.
type ClientConn struct { type ClientConn struct {
target string target string
dopts dialOptions dopts transport.DialOptions
shutdownChan chan struct{} shutdownChan chan struct{}
mu sync.Mutex mu sync.Mutex
// Is closed and becomes nil when a new transport is up. // ready is closed and becomes nil when a new transport is up or failed
// due to timeout.
ready chan struct{} ready chan struct{}
// Indicates the ClientConn is under destruction. // Indicates the ClientConn is under destruction.
closing bool closing bool
@ -119,6 +124,7 @@ type ClientConn struct {
func (cc *ClientConn) resetTransport(closeTransport bool) error { func (cc *ClientConn) resetTransport(closeTransport bool) error {
var retries int var retries int
start := time.Now()
for { for {
cc.mu.Lock() cc.mu.Lock()
t := cc.transport t := cc.transport
@ -133,16 +139,41 @@ func (cc *ClientConn) resetTransport(closeTransport bool) error {
if closeTransport { if closeTransport {
t.Close() t.Close()
} }
newTransport, err := transport.NewClientTransport(cc.dopts.protocol, cc.target, cc.dopts.authOptions) // Adjust timeout for the current try.
dopts := cc.dopts
if dopts.Timeout < 0 {
cc.Close()
return ErrClientConnTimeout
}
if dopts.Timeout > 0 {
dopts.Timeout -= time.Since(start)
if dopts.Timeout <= 0 {
cc.Close()
return ErrClientConnTimeout
}
}
newTransport, err := transport.NewClientTransport(cc.target, &dopts)
if err != nil { if err != nil {
// TODO(zhaoq): Record the error with glog.V. sleepTime := backoff(retries)
// Fail early before falling into sleep.
if cc.dopts.Timeout > 0 && cc.dopts.Timeout < sleepTime + time.Since(start) {
cc.Close()
return ErrClientConnTimeout
}
closeTransport = false closeTransport = false
time.Sleep(backoff(retries)) time.Sleep(sleepTime)
retries++ retries++
// TODO(zhaoq): Record the error with glog.V.
log.Printf("grpc: ClientConn.resetTransport failed to create client transport: %v; Reconnecting to %q", err, cc.target) log.Printf("grpc: ClientConn.resetTransport failed to create client transport: %v; Reconnecting to %q", err, cc.target)
continue continue
} }
cc.mu.Lock() cc.mu.Lock()
if cc.closing {
// cc.Close() has been invoked.
cc.mu.Unlock()
newTransport.Close()
return ErrClientConnClosing
}
cc.transport = newTransport cc.transport = newTransport
cc.transportSeq = ts + 1 cc.transportSeq = ts + 1
if cc.ready != nil { if cc.ready != nil {
@ -166,6 +197,8 @@ func (cc *ClientConn) transportMonitor() {
case <-cc.transport.Error(): case <-cc.transport.Error():
if err := cc.resetTransport(true); err != nil { if err := cc.resetTransport(true); err != nil {
// The channel is closing. // The channel is closing.
// TODO(zhaoq): Record the error with glog.V.
log.Printf("grpc: ClientConn.transportMonitor exits due to: %v", err)
return return
} }
continue continue
@ -197,24 +230,34 @@ func (cc *ClientConn) wait(ctx context.Context, ts int) (transport.ClientTranspo
select { select {
case <-ctx.Done(): case <-ctx.Done():
return nil, 0, transport.ContextErr(ctx.Err()) return nil, 0, transport.ContextErr(ctx.Err())
// Wait until the new transport is ready. // Wait until the new transport is ready or failed.
case <-ready: case <-ready:
} }
} }
} }
} }
// Close starts to tear down the ClientConn. // Close starts to tear down the ClientConn. Returns ErrClientConnClosing if
// it has been closed (mostly due to dial time-out).
// TODO(zhaoq): Make this synchronous to avoid unbounded memory consumption in // TODO(zhaoq): Make this synchronous to avoid unbounded memory consumption in
// some edge cases (e.g., the caller opens and closes many ClientConn's in a // some edge cases (e.g., the caller opens and closes many ClientConn's in a
// tight loop. // tight loop.
func (cc *ClientConn) Close() { func (cc *ClientConn) Close() error {
cc.mu.Lock() cc.mu.Lock()
defer cc.mu.Unlock() defer cc.mu.Unlock()
if cc.closing { if cc.closing {
return return ErrClientConnClosing
} }
cc.closing = true cc.closing = true
if cc.ready != nil {
close(cc.ready)
cc.ready = nil
}
if cc.transport != nil {
cc.transport.Close() cc.transport.Close()
}
if cc.shutdownChan != nil {
close(cc.shutdownChan) close(cc.shutdownChan)
} }
return nil
}

View File

@ -71,9 +71,15 @@ type Credentials interface {
// TransportAuthenticator defines the common interface all supported transport // TransportAuthenticator defines the common interface all supported transport
// authentication protocols (e.g., TLS, SSL) must implement. // authentication protocols (e.g., TLS, SSL) must implement.
type TransportAuthenticator interface { type TransportAuthenticator interface {
// Dial connects to the given network address and does the authentication // Dial connects to the given network address using net.Dial and then
// handshake specified by the corresponding authentication protocol. // does the authentication handshake specified by the corresponding
Dial(addr string) (net.Conn, error) // authentication protocol.
Dial(network, addr string) (net.Conn, error)
// DialWithDialer connects to the given network address using
// dialer.Dial does the authentication handshake specified by the
// corresponding authentication protocol. Any timeout or deadline
// given in the dialer apply to connection and handshake as a whole.
DialWithDialer(dialer *net.Dialer, network, addr string) (net.Conn, error)
// NewListener creates a listener which accepts connections with requested // NewListener creates a listener which accepts connections with requested
// authentication handshake. // authentication handshake.
NewListener(lis net.Listener) net.Listener NewListener(lis net.Listener) net.Listener
@ -103,8 +109,7 @@ func (c *tlsCreds) GetRequestMetadata(ctx context.Context) (map[string]string, e
return nil, nil return nil, nil
} }
// Dial connects to addr and performs TLS handshake. func (c *tlsCreds) DialWithDialer(dialer *net.Dialer, network, addr string) (_ net.Conn, err error) {
func (c *tlsCreds) Dial(addr string) (_ net.Conn, err error) {
name := c.serverName name := c.serverName
if name == "" { if name == "" {
name, _, err = net.SplitHostPort(addr) name, _, err = net.SplitHostPort(addr)
@ -112,13 +117,18 @@ func (c *tlsCreds) Dial(addr string) (_ net.Conn, err error) {
return nil, fmt.Errorf("credentials: failed to parse server address %v", err) return nil, fmt.Errorf("credentials: failed to parse server address %v", err)
} }
} }
return tls.Dial("tcp", addr, &tls.Config{ return tls.DialWithDialer(dialer, "tcp", addr, &tls.Config{
RootCAs: c.rootCAs, RootCAs: c.rootCAs,
NextProtos: alpnProtoStr, NextProtos: alpnProtoStr,
ServerName: name, ServerName: name,
}) })
} }
// Dial connects to addr and performs TLS handshake.
func (c *tlsCreds) Dial(network, addr string) (_ net.Conn, err error) {
return c.DialWithDialer(new(net.Dialer), network, addr)
}
// NewListener creates a net.Listener with a TLS configuration constructed // NewListener creates a net.Listener with a TLS configuration constructed
// from the information in tlsCreds. // from the information in tlsCreds.
func (c *tlsCreds) NewListener(lis net.Listener) net.Listener { func (c *tlsCreds) NewListener(lis net.Listener) net.Listener {

View File

@ -193,6 +193,68 @@ func (s *testServer) HalfDuplexCall(stream testpb.TestService_HalfDuplexCallServ
const tlsDir = "testdata/" const tlsDir = "testdata/"
func TestDialTimeout(t *testing.T) {
conn, err := grpc.Dial("Non-Existent.Server:80", grpc.WithTimeout(time.Millisecond))
if err == nil {
conn.Close()
}
if err != grpc.ErrClientConnTimeout {
t.Fatalf("grpc.Dial(_, _) = %v, %v, want %v", conn, err, grpc.ErrClientConnTimeout)
}
}
func TestTLSDialTimeout(t *testing.T) {
creds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "x.test.youtube.com")
if err != nil {
t.Fatalf("Failed to create credentials %v", err)
}
conn, err := grpc.Dial("Non-Existent.Server:80", grpc.WithTransportCredentials(creds), grpc.WithTimeout(time.Millisecond))
if err == nil {
conn.Close()
}
if err != grpc.ErrClientConnTimeout {
t.Fatalf("grpc.Dial(_, _) = %v, %v, want %v", conn, err, grpc.ErrClientConnTimeout)
}
}
func TestReconnectTimeout(t *testing.T) {
lis, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatalf("Failed to listen: %v", err)
}
_, port, err := net.SplitHostPort(lis.Addr().String())
if err != nil {
t.Fatalf("Failed to parse listener address: %v", err)
}
addr := "localhost:" + port
conn, err := grpc.Dial(addr, grpc.WithTimeout(time.Second))
if err != nil {
t.Fatalf("Failed to dial to the server %q: %v", addr, err)
}
lis.Close()
tc := testpb.NewTestServiceClient(conn)
waitC := make(chan struct{})
go func() {
defer close(waitC)
argSize := 271828
respSize := 314159
req := &testpb.SimpleRequest{
ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(),
ResponseSize: proto.Int32(int32(respSize)),
Payload: newPayload(testpb.PayloadType_COMPRESSABLE, int32(argSize)),
}
_, err := tc.UnaryCall(context.Background(), req)
if err != grpc.Errorf(codes.Internal, "%v", grpc.ErrClientConnClosing) {
t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, %v", err, grpc.Errorf(codes.Internal, "%v", grpc.ErrClientConnClosing))
}
}()
// Block untill reconnect times out.
<-waitC
if err := conn.Close(); err != grpc.ErrClientConnClosing {
t.Fatalf("%v.Close() = %v, want %v", conn, err, grpc.ErrClientConnClosing)
}
}
func setUp(useTLS bool, maxStream uint32) (s *grpc.Server, tc testpb.TestServiceClient) { func setUp(useTLS bool, maxStream uint32) (s *grpc.Server, tc testpb.TestServiceClient) {
lis, err := net.Listen("tcp", ":0") lis, err := net.Listen("tcp", ":0")
if err != nil { if err != nil {
@ -331,7 +393,7 @@ func TestRetry(t *testing.T) {
} }
// TODO(zhaoq): Have a better test coverage of timeout and cancellation mechanism. // TODO(zhaoq): Have a better test coverage of timeout and cancellation mechanism.
func TestTimeout(t *testing.T) { func TestRPCTimeout(t *testing.T) {
s, tc := setUp(true, math.MaxUint32) s, tc := setUp(true, math.MaxUint32)
defer s.Stop() defer s.Stop()
argSize := 2718 argSize := 2718

View File

@ -96,26 +96,25 @@ type http2Client struct {
// newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2 // newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2
// and starts to receive messages on it. Non-nil error returns if construction // and starts to receive messages on it. Non-nil error returns if construction
// fails. // fails.
func newHTTP2Client(addr string, authOpts []credentials.Credentials) (_ ClientTransport, err error) { func newHTTP2Client(addr string, opts *DialOptions) (_ ClientTransport, err error) {
var ( var (
connErr error connErr error
conn net.Conn conn net.Conn
) )
scheme := "http" scheme := "http"
// TODO(zhaoq): Use DialTimeout instead. for _, c := range opts.AuthOptions {
for _, c := range authOpts {
if ccreds, ok := c.(credentials.TransportAuthenticator); ok { if ccreds, ok := c.(credentials.TransportAuthenticator); ok {
scheme = "https" scheme = "https"
// TODO(zhaoq): Now the first TransportAuthenticator is used if there are // TODO(zhaoq): Now the first TransportAuthenticator is used if there are
// multiple ones provided. Revisit this if it is not appropriate. Probably // multiple ones provided. Revisit this if it is not appropriate. Probably
// place the ClientTransport construction into a separate function to make // place the ClientTransport construction into a separate function to make
// things clear. // things clear.
conn, connErr = ccreds.Dial(addr) conn, connErr = ccreds.DialWithDialer(&net.Dialer{Timeout: opts.Timeout}, "tcp", addr)
break break
} }
} }
if scheme == "http" { if scheme == "http" {
conn, connErr = net.Dial("tcp", addr) conn, connErr = net.DialTimeout("tcp", addr, opts.Timeout)
} }
if connErr != nil { if connErr != nil {
return nil, ConnectionErrorf("transport: %v", connErr) return nil, ConnectionErrorf("transport: %v", connErr)
@ -155,7 +154,7 @@ func newHTTP2Client(addr string, authOpts []credentials.Credentials) (_ ClientTr
state: reachable, state: reachable,
activeStreams: make(map[uint32]*Stream), activeStreams: make(map[uint32]*Stream),
maxStreams: math.MaxUint32, maxStreams: math.MaxUint32,
authCreds: authOpts, authCreds: opts.AuthOptions,
} }
go t.controller() go t.controller()
t.writableChan <- 0 t.writableChan <- 0

View File

@ -44,6 +44,7 @@ import (
"io" "io"
"net" "net"
"sync" "sync"
"time"
"golang.org/x/net/context" "golang.org/x/net/context"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
@ -310,10 +311,17 @@ func NewServerTransport(protocol string, conn net.Conn, maxStreams uint32) (Serv
return newHTTP2Server(conn, maxStreams) return newHTTP2Server(conn, maxStreams)
} }
// NewClientTransport establishes the transport with the required protocol // DialOptions covers all relevant options for dialing a server.
type DialOptions struct {
Protocol string
AuthOptions []credentials.Credentials
Timeout time.Duration
}
// NewClientTransport establishes the transport with the required DialOptions
// and returns it to the caller. // and returns it to the caller.
func NewClientTransport(protocol, target string, authOpts []credentials.Credentials) (ClientTransport, error) { func NewClientTransport(target string, opts *DialOptions) (ClientTransport, error) {
return newHTTP2Client(target, authOpts) return newHTTP2Client(target, opts)
} }
// Options provides additional hints and information for message // Options provides additional hints and information for message

View File

@ -181,9 +181,12 @@ func setUp(t *testing.T, useTLS bool, port int, maxStreams uint32, suspend bool)
if err != nil { if err != nil {
t.Fatalf("Failed to create credentials %v", err) t.Fatalf("Failed to create credentials %v", err)
} }
ct, connErr = NewClientTransport("http2", addr, []credentials.Credentials{creds}) dopts := DialOptions{
AuthOptions: []credentials.Credentials{creds},
}
ct, connErr = NewClientTransport(addr, &dopts)
} else { } else {
ct, connErr = NewClientTransport("http2", addr, nil) ct, connErr = NewClientTransport(addr, &DialOptions{})
} }
if connErr != nil { if connErr != nil {
t.Fatalf("failed to create transport: %v", connErr) t.Fatalf("failed to create transport: %v", connErr)