diff --git a/clientconn.go b/clientconn.go index 25c9282e..e7bb9453 100644 --- a/clientconn.go +++ b/clientconn.go @@ -61,6 +61,7 @@ var ( // values passed to Dial. type dialOptions struct { codec Codec + block bool copts transport.ConnectOptions } @@ -74,6 +75,15 @@ func WithCodec(c Codec) DialOption { } } +// WithBlock returns a DialOption which makes caller of Dial blocks until the underlying +// connection is up. Without this, Dial returns immediately and connecting the server +// happens in background. +func WithBlock() DialOption { + return func(o *dialOptions) { + o.block = true + } +} + // WithTransportCredentials returns a DialOption which configures a // connection level security credentials (e.g., TLS/SSL). func WithTransportCredentials(creds credentials.TransportAuthenticator) DialOption { @@ -112,7 +122,8 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) { return nil, ErrUnspecTarget } cc := &ClientConn{ - target: target, + target: target, + shutdownChan: make(chan struct{}), } for _, opt := range opts { opt(&cc.dopts) @@ -126,12 +137,22 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) { // Set the default codec. cc.dopts.codec = protoCodec{} } - if err := cc.resetTransport(false); err != nil { - return nil, err + if cc.dopts.block { + if err := cc.resetTransport(false); err != nil { + return nil, err + } + // Start to monitor the error status of transport. + go cc.transportMonitor() + } else { + // Start a goroutine connecting to the server asynchronously. + go func() { + if err := cc.resetTransport(false); err != nil { + grpclog.Printf("Failed to dial %s: %v; please retry.", target, err) + return + } + go cc.transportMonitor() + }() } - cc.shutdownChan = make(chan struct{}) - // Start to monitor the error status of transport. - go cc.transportMonitor() return cc, nil } diff --git a/test/end2end_test.go b/test/end2end_test.go index 73e6cb65..6f573afe 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -205,7 +205,7 @@ func (s *testServer) HalfDuplexCall(stream testpb.TestService_HalfDuplexCallServ const tlsDir = "testdata/" func TestDialTimeout(t *testing.T) { - conn, err := grpc.Dial("Non-Existent.Server:80", grpc.WithTimeout(time.Millisecond)) + conn, err := grpc.Dial("Non-Existent.Server:80", grpc.WithTimeout(time.Millisecond), grpc.WithBlock()) if err == nil { conn.Close() } @@ -219,7 +219,7 @@ func TestTLSDialTimeout(t *testing.T) { if err != nil { t.Fatalf("Failed to create credentials %v", err) } - conn, err := grpc.Dial("Non-Existent.Server:80", grpc.WithTransportCredentials(creds), grpc.WithTimeout(time.Millisecond)) + conn, err := grpc.Dial("Non-Existent.Server:80", grpc.WithTransportCredentials(creds), grpc.WithTimeout(time.Millisecond), grpc.WithBlock()) if err == nil { conn.Close() } @@ -238,7 +238,7 @@ func TestReconnectTimeout(t *testing.T) { t.Fatalf("Failed to parse listener address: %v", err) } addr := "localhost:" + port - conn, err := grpc.Dial(addr, grpc.WithTimeout(5*time.Second)) + conn, err := grpc.Dial(addr, grpc.WithTimeout(5*time.Second), grpc.WithBlock()) if err != nil { t.Fatalf("Failed to dial to the server %q: %v", addr, err) }