Merge pull request #209 from iamqizhao/master

Make Dial nonblocking by default. Add a DialOption to enable blocking operation if needed.
This commit is contained in:
Qi Zhao
2015-06-04 16:19:26 -07:00
2 changed files with 30 additions and 9 deletions

View File

@ -61,6 +61,7 @@ var (
// values passed to Dial.
type dialOptions struct {
codec Codec
block bool
copts transport.ConnectOptions
}
@ -74,6 +75,15 @@ func WithCodec(c Codec) DialOption {
}
}
// WithBlock returns a DialOption which makes caller of Dial blocks until the underlying
// connection is up. Without this, Dial returns immediately and connecting the server
// happens in background.
func WithBlock() DialOption {
return func(o *dialOptions) {
o.block = true
}
}
// WithTransportCredentials returns a DialOption which configures a
// connection level security credentials (e.g., TLS/SSL).
func WithTransportCredentials(creds credentials.TransportAuthenticator) DialOption {
@ -112,7 +122,8 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) {
return nil, ErrUnspecTarget
}
cc := &ClientConn{
target: target,
target: target,
shutdownChan: make(chan struct{}),
}
for _, opt := range opts {
opt(&cc.dopts)
@ -126,12 +137,22 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) {
// Set the default codec.
cc.dopts.codec = protoCodec{}
}
if err := cc.resetTransport(false); err != nil {
return nil, err
if cc.dopts.block {
if err := cc.resetTransport(false); err != nil {
return nil, err
}
// Start to monitor the error status of transport.
go cc.transportMonitor()
} else {
// Start a goroutine connecting to the server asynchronously.
go func() {
if err := cc.resetTransport(false); err != nil {
grpclog.Printf("Failed to dial %s: %v; please retry.", target, err)
return
}
go cc.transportMonitor()
}()
}
cc.shutdownChan = make(chan struct{})
// Start to monitor the error status of transport.
go cc.transportMonitor()
return cc, nil
}

View File

@ -205,7 +205,7 @@ func (s *testServer) HalfDuplexCall(stream testpb.TestService_HalfDuplexCallServ
const tlsDir = "testdata/"
func TestDialTimeout(t *testing.T) {
conn, err := grpc.Dial("Non-Existent.Server:80", grpc.WithTimeout(time.Millisecond))
conn, err := grpc.Dial("Non-Existent.Server:80", grpc.WithTimeout(time.Millisecond), grpc.WithBlock())
if err == nil {
conn.Close()
}
@ -219,7 +219,7 @@ func TestTLSDialTimeout(t *testing.T) {
if err != nil {
t.Fatalf("Failed to create credentials %v", err)
}
conn, err := grpc.Dial("Non-Existent.Server:80", grpc.WithTransportCredentials(creds), grpc.WithTimeout(time.Millisecond))
conn, err := grpc.Dial("Non-Existent.Server:80", grpc.WithTransportCredentials(creds), grpc.WithTimeout(time.Millisecond), grpc.WithBlock())
if err == nil {
conn.Close()
}
@ -238,7 +238,7 @@ func TestReconnectTimeout(t *testing.T) {
t.Fatalf("Failed to parse listener address: %v", err)
}
addr := "localhost:" + port
conn, err := grpc.Dial(addr, grpc.WithTimeout(5*time.Second))
conn, err := grpc.Dial(addr, grpc.WithTimeout(5*time.Second), grpc.WithBlock())
if err != nil {
t.Fatalf("Failed to dial to the server %q: %v", addr, err)
}