Merge pull request #209 from iamqizhao/master

Make Dial nonblocking by default. Add a DialOption to enable blocking operation if needed.
This commit is contained in:
Qi Zhao
2015-06-04 16:19:26 -07:00
2 changed files with 30 additions and 9 deletions

View File

@ -61,6 +61,7 @@ var (
// values passed to Dial. // values passed to Dial.
type dialOptions struct { type dialOptions struct {
codec Codec codec Codec
block bool
copts transport.ConnectOptions copts transport.ConnectOptions
} }
@ -74,6 +75,15 @@ func WithCodec(c Codec) DialOption {
} }
} }
// WithBlock returns a DialOption which makes caller of Dial blocks until the underlying
// connection is up. Without this, Dial returns immediately and connecting the server
// happens in background.
func WithBlock() DialOption {
return func(o *dialOptions) {
o.block = true
}
}
// WithTransportCredentials returns a DialOption which configures a // WithTransportCredentials returns a DialOption which configures a
// connection level security credentials (e.g., TLS/SSL). // connection level security credentials (e.g., TLS/SSL).
func WithTransportCredentials(creds credentials.TransportAuthenticator) DialOption { func WithTransportCredentials(creds credentials.TransportAuthenticator) DialOption {
@ -112,7 +122,8 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) {
return nil, ErrUnspecTarget return nil, ErrUnspecTarget
} }
cc := &ClientConn{ cc := &ClientConn{
target: target, target: target,
shutdownChan: make(chan struct{}),
} }
for _, opt := range opts { for _, opt := range opts {
opt(&cc.dopts) opt(&cc.dopts)
@ -126,12 +137,22 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) {
// Set the default codec. // Set the default codec.
cc.dopts.codec = protoCodec{} cc.dopts.codec = protoCodec{}
} }
if err := cc.resetTransport(false); err != nil { if cc.dopts.block {
return nil, err if err := cc.resetTransport(false); err != nil {
return nil, err
}
// Start to monitor the error status of transport.
go cc.transportMonitor()
} else {
// Start a goroutine connecting to the server asynchronously.
go func() {
if err := cc.resetTransport(false); err != nil {
grpclog.Printf("Failed to dial %s: %v; please retry.", target, err)
return
}
go cc.transportMonitor()
}()
} }
cc.shutdownChan = make(chan struct{})
// Start to monitor the error status of transport.
go cc.transportMonitor()
return cc, nil return cc, nil
} }

View File

@ -205,7 +205,7 @@ func (s *testServer) HalfDuplexCall(stream testpb.TestService_HalfDuplexCallServ
const tlsDir = "testdata/" const tlsDir = "testdata/"
func TestDialTimeout(t *testing.T) { func TestDialTimeout(t *testing.T) {
conn, err := grpc.Dial("Non-Existent.Server:80", grpc.WithTimeout(time.Millisecond)) conn, err := grpc.Dial("Non-Existent.Server:80", grpc.WithTimeout(time.Millisecond), grpc.WithBlock())
if err == nil { if err == nil {
conn.Close() conn.Close()
} }
@ -219,7 +219,7 @@ func TestTLSDialTimeout(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Failed to create credentials %v", err) t.Fatalf("Failed to create credentials %v", err)
} }
conn, err := grpc.Dial("Non-Existent.Server:80", grpc.WithTransportCredentials(creds), grpc.WithTimeout(time.Millisecond)) conn, err := grpc.Dial("Non-Existent.Server:80", grpc.WithTransportCredentials(creds), grpc.WithTimeout(time.Millisecond), grpc.WithBlock())
if err == nil { if err == nil {
conn.Close() conn.Close()
} }
@ -238,7 +238,7 @@ func TestReconnectTimeout(t *testing.T) {
t.Fatalf("Failed to parse listener address: %v", err) t.Fatalf("Failed to parse listener address: %v", err)
} }
addr := "localhost:" + port addr := "localhost:" + port
conn, err := grpc.Dial(addr, grpc.WithTimeout(5*time.Second)) conn, err := grpc.Dial(addr, grpc.WithTimeout(5*time.Second), grpc.WithBlock())
if err != nil { if err != nil {
t.Fatalf("Failed to dial to the server %q: %v", addr, err) t.Fatalf("Failed to dial to the server %q: %v", addr, err)
} }