diff --git a/balancer.go b/balancer.go index 419e2146..17767768 100644 --- a/balancer.go +++ b/balancer.go @@ -206,6 +206,11 @@ func (rr *roundRobin) watchAddrUpdates() error { } func (rr *roundRobin) Start(target string) error { + rr.mu.Lock() + defer rr.mu.Unlock() + if rr.done { + return ErrClientConnClosing + } if rr.r == nil { // If there is no name resolver installed, it is not needed to // do name resolution. In this case, target is added into rr.addrs diff --git a/clientconn.go b/clientconn.go index 6428f8a2..bf561ac1 100644 --- a/clientconn.go +++ b/clientconn.go @@ -271,30 +271,30 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * cc.dopts.bs = DefaultBackoffConfig } - var ( - ok bool - addrs []Address - ) - if cc.dopts.balancer == nil { - // Connect to target directly if balancer is nil. - addrs = append(addrs, Address{Addr: target}) - } else { - if err := cc.dopts.balancer.Start(target); err != nil { - return nil, err - } - ch := cc.dopts.balancer.Notify() - if ch == nil { - // There is no name resolver installed. - addrs = append(addrs, Address{Addr: target}) - } else { - addrs, ok = <-ch - if !ok || len(addrs) == 0 { - return nil, errNoAddr - } - } - } + var ok bool waitC := make(chan error, 1) go func() { + var addrs []Address + if cc.dopts.balancer == nil { + // Connect to target directly if balancer is nil. + addrs = append(addrs, Address{Addr: target}) + } else { + if err := cc.dopts.balancer.Start(target); err != nil { + waitC <- err + return + } + ch := cc.dopts.balancer.Notify() + if ch == nil { + // There is no name resolver installed. + addrs = append(addrs, Address{Addr: target}) + } else { + addrs, ok = <-ch + if !ok || len(addrs) == 0 { + waitC <- errNoAddr + return + } + } + } for _, a := range addrs { if err := cc.resetAddrConn(a, false, nil); err != nil { waitC <- err diff --git a/clientconn_test.go b/clientconn_test.go index 3d635c73..ec40c405 100644 --- a/clientconn_test.go +++ b/clientconn_test.go @@ -75,7 +75,7 @@ func TestTLSServerNameOverwrite(t *testing.T) { if err != nil { t.Fatalf("Failed to create credentials %v", err) } - conn, err := Dial("Non-Existent.Server:80", WithTransportCredentials(creds), WithTimeout(time.Millisecond)) + conn, err := Dial("Non-Existent.Server:80", WithTransportCredentials(creds)) if err != nil { t.Fatalf("Dial(_, _) = _, %v, want _, ", err) } @@ -93,6 +93,43 @@ func TestDialContextCancel(t *testing.T) { } } +// blockingBalancer mimics the behavior of balancers whose initialization takes a long time. +// In this test, reading from blockingBalancer.Notify() blocks forever. +type blockingBalancer struct { + ch chan []Address +} + +func newBlockingBalancer() Balancer { + return &blockingBalancer{ch: make(chan []Address)} +} +func (b *blockingBalancer) Start(target string) error { + return nil +} +func (b *blockingBalancer) Up(addr Address) func(error) { + return nil +} +func (b *blockingBalancer) Get(ctx context.Context, opts BalancerGetOptions) (addr Address, put func(), err error) { + return Address{}, nil, nil +} +func (b *blockingBalancer) Notify() <-chan []Address { + return b.ch +} +func (b *blockingBalancer) Close() error { + close(b.ch) + return nil +} + +func TestDialWithBlockingBalancer(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + dialDone := make(chan struct{}) + go func() { + DialContext(ctx, "Non-Existent.Server:80", WithBlock(), WithInsecure(), WithBalancer(newBlockingBalancer())) + close(dialDone) + }() + cancel() + <-dialDone +} + func TestCredentialsMisuse(t *testing.T) { tlsCreds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "x.test.youtube.com") if err != nil {