add Notify API; move the name resolving into Balancer

This commit is contained in:
iamqizhao
2016-05-23 19:25:01 -07:00
parent fda7cb3cdf
commit 5b484e4099
4 changed files with 220 additions and 131 deletions

View File

@ -37,6 +37,8 @@ import (
"sync" "sync"
"golang.org/x/net/context" "golang.org/x/net/context"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/naming"
"google.golang.org/grpc/transport" "google.golang.org/grpc/transport"
) )
@ -53,6 +55,10 @@ type Address struct {
// Balancer chooses network addresses for RPCs. // Balancer chooses network addresses for RPCs.
// This is the EXPERIMENTAL API and may be changed or extended in the future. // This is the EXPERIMENTAL API and may be changed or extended in the future.
type Balancer interface { type Balancer interface {
// Start does the initialization work to bootstrap a Balancer. For example,
// this function may start the service discovery and watch the name resolution
// updates.
Start(target string) error
// Up informs the balancer that gRPC has a connection to the server at // Up informs the balancer that gRPC has a connection to the server at
// addr. It returns down which is called once the connection to addr gets // addr. It returns down which is called once the connection to addr gets
// lost or closed. Once down is called, addr may no longer be returned // lost or closed. Once down is called, addr may no longer be returned
@ -64,21 +70,101 @@ type Balancer interface {
// is called once the rpc has completed or failed. put can collect and // is called once the rpc has completed or failed. put can collect and
// report rpc stats to remote load balancer. // report rpc stats to remote load balancer.
Get(ctx context.Context) (addr Address, put func(), err error) Get(ctx context.Context) (addr Address, put func(), err error)
// Notify gRPC internals the list of Address which should be connected. gRPC
// internals will compare it with the exisiting connected addresses. If the
// address Balancer notified is not in the list of the connected addresses,
// gRPC starts to connect the address. If an address in the connected
// addresses is not in the notification list, the corresponding connect will be
// shutdown gracefully. Otherwise, there are no operations. Note that this
// function must return the full list of the Addrresses which should be connected.
// It is NOT delta.
Notify() <-chan []Address
// Close shuts down the balancer. // Close shuts down the balancer.
Close() error Close() error
} }
// RoundRobin returns a Balancer that selects addresses round-robin. // RoundRobin returns a Balancer that selects addresses round-robin. It starts to watch
func RoundRobin() Balancer { // the name resolution updates.
return &roundRobin{} func RoundRobin(r naming.Resolver) Balancer {
return &roundRobin{r: r}
} }
type roundRobin struct { type roundRobin struct {
mu sync.Mutex r naming.Resolver
addrs []Address open []Address // all the known addresses the client can potentially connect
next int // index of the next address to return for Get() mu sync.Mutex
waitCh chan struct{} // channel to block when there is no address available addrCh chan []Address // the channel to notify gRPC internals the list of addresses the client should connect to.
done bool // The Balancer is closed. connected []Address // all the connected addresses
next int // index of the next address to return for Get()
waitCh chan struct{} // the channel to block when there is no connected address available
done bool // The Balancer is closed.
}
func (rr *roundRobin) watchAddrUpdates(w naming.Watcher) error {
updates, err := w.Next()
if err != nil {
return err
}
for _, update := range updates {
addr := Address{
Addr: update.Addr,
Metadata: update.Metadata,
}
switch update.Op {
case naming.Add:
var exisit bool
for _, v := range rr.open {
if addr == v {
exisit = true
grpclog.Println("grpc: The name resolver wanted to add an existing address: ", addr)
break
}
}
if exisit {
continue
}
rr.open = append(rr.open, addr)
case naming.Delete:
for i, v := range rr.open {
if v == addr {
copy(rr.open[i:], rr.open[i+1:])
rr.open = rr.open[:len(rr.open)-1]
break
}
}
default:
grpclog.Println("Unknown update.Op ", update.Op)
}
}
// Make a copy of rr.open and write it onto rr.addrCh so that gRPC internals gets notified.
open := make([]Address, len(rr.open))
rr.mu.Lock()
defer rr.mu.Unlock()
copy(open, rr.open)
if rr.done {
return ErrClientConnClosing
}
rr.addrCh <- open
return nil
}
func (rr *roundRobin) Start(target string) error {
if rr.r == nil {
return nil
}
w, err := rr.r.Resolve(target)
if err != nil {
return err
}
rr.addrCh = make(chan []Address)
go func() {
for {
if err := rr.watchAddrUpdates(w); err != nil {
return
}
}
}()
return nil
} }
// Up appends addr to the end of rr.addrs and sends notification if there // Up appends addr to the end of rr.addrs and sends notification if there
@ -86,13 +172,13 @@ type roundRobin struct {
func (rr *roundRobin) Up(addr Address) func(error) { func (rr *roundRobin) Up(addr Address) func(error) {
rr.mu.Lock() rr.mu.Lock()
defer rr.mu.Unlock() defer rr.mu.Unlock()
for _, a := range rr.addrs { for _, a := range rr.connected {
if a == addr { if a == addr {
return nil return nil
} }
} }
rr.addrs = append(rr.addrs, addr) rr.connected = append(rr.connected, addr)
if len(rr.addrs) == 1 { if len(rr.connected) == 1 {
// addr is only one available. Notify the Get() callers who are blocking. // addr is only one available. Notify the Get() callers who are blocking.
if rr.waitCh != nil { if rr.waitCh != nil {
close(rr.waitCh) close(rr.waitCh)
@ -108,10 +194,10 @@ func (rr *roundRobin) Up(addr Address) func(error) {
func (rr *roundRobin) down(addr Address, err error) { func (rr *roundRobin) down(addr Address, err error) {
rr.mu.Lock() rr.mu.Lock()
defer rr.mu.Unlock() defer rr.mu.Unlock()
for i, a := range rr.addrs { for i, a := range rr.connected {
if a == addr { if a == addr {
copy(rr.addrs[i:], rr.addrs[i+1:]) copy(rr.connected[i:], rr.connected[i+1:])
rr.addrs = rr.addrs[:len(rr.addrs)-1] rr.connected = rr.connected[:len(rr.connected)-1]
return return
} }
} }
@ -126,16 +212,13 @@ func (rr *roundRobin) Get(ctx context.Context) (addr Address, put func(), err er
err = ErrClientConnClosing err = ErrClientConnClosing
return return
} }
if rr.next >= len(rr.addrs) { if rr.next >= len(rr.connected) {
rr.next = 0 rr.next = 0
} }
if len(rr.addrs) > 0 { if len(rr.connected) > 0 {
addr = rr.addrs[rr.next] addr = rr.connected[rr.next]
rr.next++ rr.next++
rr.mu.Unlock() rr.mu.Unlock()
put = func() {
rr.put(ctx, addr)
}
return return
} }
// There is no address available. Wait on rr.waitCh. // There is no address available. Wait on rr.waitCh.
@ -158,26 +241,24 @@ func (rr *roundRobin) Get(ctx context.Context) (addr Address, put func(), err er
err = ErrClientConnClosing err = ErrClientConnClosing
return return
} }
if len(rr.addrs) == 0 { if len(rr.connected) == 0 {
// The newly added addr got removed by Down() again. // The newly added addr got removed by Down() again.
rr.mu.Unlock() rr.mu.Unlock()
continue continue
} }
if rr.next >= len(rr.addrs) { if rr.next >= len(rr.connected) {
rr.next = 0 rr.next = 0
} }
addr = rr.addrs[rr.next] addr = rr.connected[rr.next]
rr.next++ rr.next++
rr.mu.Unlock() rr.mu.Unlock()
put = func() {
rr.put(ctx, addr)
}
return return
} }
} }
} }
func (rr *roundRobin) put(ctx context.Context, addr Address) { func (rr *roundRobin) Notify() <-chan []Address {
return rr.addrCh
} }
func (rr *roundRobin) Close() error { func (rr *roundRobin) Close() error {
@ -188,5 +269,8 @@ func (rr *roundRobin) Close() error {
close(rr.waitCh) close(rr.waitCh)
rr.waitCh = nil rr.waitCh = nil
} }
if rr.addrCh != nil {
close(rr.addrCh)
}
return nil return nil
} }

View File

@ -122,7 +122,7 @@ func TestNameDiscovery(t *testing.T) {
// Start 2 servers on 2 ports. // Start 2 servers on 2 ports.
numServers := 2 numServers := 2
servers, r := startServers(t, numServers, math.MaxUint32) servers, r := startServers(t, numServers, math.MaxUint32)
cc, err := Dial("foo.bar.com", WithNameResolver(r), WithBlock(), WithInsecure(), WithCodec(testCodec{})) cc, err := Dial("foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{}))
if err != nil { if err != nil {
t.Fatalf("Failed to create ClientConn: %v", err) t.Fatalf("Failed to create ClientConn: %v", err)
} }
@ -157,7 +157,7 @@ func TestNameDiscovery(t *testing.T) {
func TestEmptyAddrs(t *testing.T) { func TestEmptyAddrs(t *testing.T) {
servers, r := startServers(t, 1, math.MaxUint32) servers, r := startServers(t, 1, math.MaxUint32)
cc, err := Dial("foo.bar.com", WithNameResolver(r), WithBlock(), WithInsecure(), WithCodec(testCodec{})) cc, err := Dial("foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{}))
if err != nil { if err != nil {
t.Fatalf("Failed to create ClientConn: %v", err) t.Fatalf("Failed to create ClientConn: %v", err)
} }
@ -167,12 +167,11 @@ func TestEmptyAddrs(t *testing.T) {
} }
// Inject name resolution change to remove the server so that there is no address // Inject name resolution change to remove the server so that there is no address
// available after that. // available after that.
var updates []*naming.Update u := &naming.Update{
updates = append(updates, &naming.Update{
Op: naming.Delete, Op: naming.Delete,
Addr: "127.0.0.1:" + servers[0].port, Addr: "127.0.0.1:" + servers[0].port,
}) }
r.w.inject(updates) r.w.inject([]*naming.Update{u})
// Loop until the above updates apply. // Loop until the above updates apply.
for { for {
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
@ -189,24 +188,32 @@ func TestRoundRobin(t *testing.T) {
// Start 3 servers on 3 ports. // Start 3 servers on 3 ports.
numServers := 3 numServers := 3
servers, r := startServers(t, numServers, math.MaxUint32) servers, r := startServers(t, numServers, math.MaxUint32)
cc, err := Dial("foo.bar.com", WithNameResolver(r), WithBlock(), WithInsecure(), WithCodec(testCodec{})) cc, err := Dial("foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{}))
if err != nil { if err != nil {
t.Fatalf("Failed to create ClientConn: %v", err) t.Fatalf("Failed to create ClientConn: %v", err)
} }
// Add servers[1] and servers[2] to the service discovery. // Add servers[1] to the service discovery.
var updates []*naming.Update u := &naming.Update{
updates = append(updates, &naming.Update{
Op: naming.Add, Op: naming.Add,
Addr: "127.0.0.1:" + servers[1].port, Addr: "127.0.0.1:" + servers[1].port,
}) }
updates = append(updates, &naming.Update{ r.w.inject([]*naming.Update{u})
Op: naming.Add,
Addr: "127.0.0.1:" + servers[2].port,
})
r.w.inject(updates)
req := "port" req := "port"
var reply string var reply string
// Loop until an RPC is completed by servers[2]. // Loop until servers[1] is up
for {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[1].port {
break
}
time.Sleep(10 * time.Millisecond)
}
// Add server2[2] to the service discovery.
u = &naming.Update{
Op: naming.Add,
Addr: "127.0.0.1:" + servers[2].port,
}
r.w.inject([]*naming.Update{u})
// Loop until both servers[2] are up.
for { for {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[2].port { if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[2].port {
break break
@ -216,7 +223,7 @@ func TestRoundRobin(t *testing.T) {
// Check the incoming RPCs served in a round-robin manner. // Check the incoming RPCs served in a round-robin manner.
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || ErrorDesc(err) != servers[i%numServers].port { if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || ErrorDesc(err) != servers[i%numServers].port {
t.Fatalf("Invoke(_, _, _, _, _) = %v, want %s", err, servers[i%numServers].port) t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", i, err, servers[i%numServers].port)
} }
} }
cc.Close() cc.Close()
@ -227,7 +234,7 @@ func TestRoundRobin(t *testing.T) {
func TestCloseWithPendingRPC(t *testing.T) { func TestCloseWithPendingRPC(t *testing.T) {
servers, r := startServers(t, 1, math.MaxUint32) servers, r := startServers(t, 1, math.MaxUint32)
cc, err := Dial("foo.bar.com", WithNameResolver(r), WithBlock(), WithInsecure(), WithCodec(testCodec{})) cc, err := Dial("foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{}))
if err != nil { if err != nil {
t.Fatalf("Failed to create ClientConn: %v", err) t.Fatalf("Failed to create ClientConn: %v", err)
} }
@ -275,7 +282,7 @@ func TestCloseWithPendingRPC(t *testing.T) {
func TestGetOnWaitChannel(t *testing.T) { func TestGetOnWaitChannel(t *testing.T) {
servers, r := startServers(t, 1, math.MaxUint32) servers, r := startServers(t, 1, math.MaxUint32)
cc, err := Dial("foo.bar.com", WithNameResolver(r), WithBlock(), WithInsecure(), WithCodec(testCodec{})) cc, err := Dial("foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{}))
if err != nil { if err != nil {
t.Fatalf("Failed to create ClientConn: %v", err) t.Fatalf("Failed to create ClientConn: %v", err)
} }

14
call.go
View File

@ -169,7 +169,10 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
} }
stream, err = sendRequest(ctx, cc.dopts.codec, cc.dopts.cp, callHdr, t, args, topts) stream, err = sendRequest(ctx, cc.dopts.codec, cc.dopts.cp, callHdr, t, args, topts)
if err != nil { if err != nil {
put() if put != nil {
put()
put = nil
}
if _, ok := err.(transport.ConnectionError); ok { if _, ok := err.(transport.ConnectionError); ok {
if c.failFast { if c.failFast {
return toRPCErr(err) return toRPCErr(err)
@ -181,7 +184,9 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
// Receive the response // Receive the response
err = recvResponse(cc.dopts, t, &c, stream, reply) err = recvResponse(cc.dopts, t, &c, stream, reply)
if err != nil { if err != nil {
put() if put != nil {
put()
}
if _, ok := err.(transport.ConnectionError); ok { if _, ok := err.(transport.ConnectionError); ok {
if c.failFast { if c.failFast {
return toRPCErr(err) return toRPCErr(err)
@ -195,7 +200,10 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
c.traceInfo.tr.LazyLog(&payload{sent: false, msg: reply}, true) c.traceInfo.tr.LazyLog(&payload{sent: false, msg: reply}, true)
} }
t.CloseStream(stream, nil) t.CloseStream(stream, nil)
put() if put != nil {
put()
put = nil
}
return Errorf(stream.StatusCode(), "%s", stream.StatusDesc()) return Errorf(stream.StatusCode(), "%s", stream.StatusDesc())
} }
} }

View File

@ -65,12 +65,12 @@ var (
// ErrClientConnTimeout indicates that the connection could not be // ErrClientConnTimeout indicates that the connection could not be
// established or re-established within the specified timeout. // established or re-established within the specified timeout.
ErrClientConnTimeout = errors.New("grpc: timed out trying to connect") ErrClientConnTimeout = errors.New("grpc: timed out trying to connect")
// ErrNetworkIP indicates that the connection is down due to some network I/O error.
// errNetworkIP indicates that the connection is down due to some network I/O error. ErrNetworkIO = errors.New("grpc: failed with network I/O error")
errNetworkIO = errors.New("grpc: failed with network I/O error") // ErrConnDrain indicates that the connection starts to be drained and does not accept any new RPCs.
// errConnDrain indicates that the connection starts to be drained and does not accept any new RPCs. ErrConnDrain = errors.New("grpc: the connection is drained")
errConnDrain = errors.New("grpc: the connection is drained") // ErrConnClosing
errConnClosing = errors.New("grpc: the addrConn is closing") ErrConnClosing = errors.New("grpc: the addrConn is closing")
// minimum time to give a connection to complete // minimum time to give a connection to complete
minConnectTimeout = 20 * time.Second minConnectTimeout = 20 * time.Second
) )
@ -82,7 +82,6 @@ type dialOptions struct {
cp Compressor cp Compressor
dc Decompressor dc Decompressor
bs backoffStrategy bs backoffStrategy
resolver naming.Resolver
balancer Balancer balancer Balancer
block bool block bool
insecure bool insecure bool
@ -115,13 +114,6 @@ func WithDecompressor(dc Decompressor) DialOption {
} }
} }
// WithNameResolver returns a DialOption which sets a name resolver for service discovery.
func WithNameResolver(r naming.Resolver) DialOption {
return func(o *dialOptions) {
o.resolver = r
}
}
// WithBalancer returns a DialOption which sets a load balancer. // WithBalancer returns a DialOption which sets a load balancer.
func WithBalancer(b Balancer) DialOption { func WithBalancer(b Balancer) DialOption {
return func(o *dialOptions) { return func(o *dialOptions) {
@ -231,34 +223,29 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) {
cc.balancer = cc.dopts.balancer cc.balancer = cc.dopts.balancer
if cc.balancer == nil { if cc.balancer == nil {
cc.balancer = RoundRobin() cc.balancer = RoundRobin(nil)
} }
if err := cc.balancer.Start(target); err != nil {
if cc.dopts.resolver == nil { return nil, err
addr := Address{ }
Addr: cc.target, ch := cc.balancer.Notify()
} if ch == nil {
if err := cc.newAddrConn(addr); err != nil { // There is no name resolver installed.
addr := Address{Addr: target}
if err := cc.newAddrConn(addr, false); err != nil {
return nil, err return nil, err
} }
} else { } else {
w, err := cc.dopts.resolver.Resolve(cc.target) addrs, ok := <-ch
if err != nil { if !ok || len(addrs) == 0 {
return nil, err return nil, fmt.Errorf("grpc: there is no address available to dial")
} }
cc.watcher = w for _, a := range addrs {
// Get the initial name resolution and dial the first connection. if err := cc.newAddrConn(a, false); err != nil {
if err := cc.watchAddrUpdates(); err != nil { return nil, err
return nil, err
}
// Start a goroutine to watch for the future name resolution changes.
go func() {
for {
if err := cc.watchAddrUpdates(); err != nil {
return
}
} }
}() }
go cc.controller()
} }
colonPos := strings.LastIndex(target, ":") colonPos := strings.LastIndex(target, ":")
@ -314,50 +301,48 @@ type ClientConn struct {
conns map[Address]*addrConn conns map[Address]*addrConn
} }
func (cc *ClientConn) watchAddrUpdates() error { func (cc *ClientConn) controller() {
updates, err := cc.watcher.Next() for {
if err != nil { addrs, ok := <-cc.balancer.Notify()
return err if !ok {
} // cc has been closed.
for _, update := range updates { return
switch update.Op { }
case naming.Add: var (
cc.mu.RLock() add []Address // Addresses need to setup connections.
addr := Address{ del []*addrConn // Connections need to tear down.
Addr: update.Addr, )
Metadata: update.Metadata, cc.mu.Lock()
for _, a := range addrs {
if _, ok := cc.conns[a]; !ok {
add = append(add, a)
} }
if _, ok := cc.conns[addr]; ok { }
cc.mu.RUnlock() for k, c := range cc.conns {
grpclog.Println("grpc: The name resolver wanted to add an existing address: ", addr) var keep bool
continue for _, a := range addrs {
if k == a {
keep = true
break
}
} }
cc.mu.RUnlock() if !keep {
if err := cc.newAddrConn(addr); err != nil { del = append(del, c)
return err
} }
case naming.Delete: }
cc.mu.RLock() cc.mu.Unlock()
addr := Address{ for _, a := range addrs {
Addr: update.Addr, if err := cc.newAddrConn(a, true); err != nil {
Metadata: update.Metadata,
} }
ac, ok := cc.conns[addr] }
if !ok { for _, c := range del {
cc.mu.RUnlock() c.tearDown(ErrConnDrain)
grpclog.Println("grpc: The name resolver wanted to delete a non-exist address: ", addr)
continue
}
cc.mu.RUnlock()
ac.tearDown(errConnDrain)
default:
grpclog.Println("Unknown update.Op ", update.Op)
} }
} }
return nil
} }
func (cc *ClientConn) newAddrConn(addr Address) error { func (cc *ClientConn) newAddrConn(addr Address, skipWait bool) error {
ac := &addrConn{ ac := &addrConn{
cc: cc, cc: cc,
addr: addr, addr: addr,
@ -394,7 +379,8 @@ func (cc *ClientConn) newAddrConn(addr Address) error {
ac.cc.mu.Unlock() ac.cc.mu.Unlock()
ac.stateCV = sync.NewCond(&ac.mu) ac.stateCV = sync.NewCond(&ac.mu)
if ac.dopts.block { // skipWait may overwrite the decision in ac.dopts.block.
if ac.dopts.block && !skipWait {
if err := ac.resetTransport(false); err != nil { if err := ac.resetTransport(false); err != nil {
ac.tearDown(err) ac.tearDown(err)
return err return err
@ -428,12 +414,16 @@ func (cc *ClientConn) getTransport(ctx context.Context) (transport.ClientTranspo
ac, ok := cc.conns[addr] ac, ok := cc.conns[addr]
cc.mu.RUnlock() cc.mu.RUnlock()
if !ok { if !ok {
put() if put != nil {
put()
}
return nil, nil, transport.StreamErrorf(codes.Internal, "grpc: failed to find the transport to send the rpc") return nil, nil, transport.StreamErrorf(codes.Internal, "grpc: failed to find the transport to send the rpc")
} }
t, err := ac.wait(ctx) t, err := ac.wait(ctx)
if err != nil { if err != nil {
put() if put != nil {
put()
}
return nil, nil, err return nil, nil, err
} }
return t, put, nil return t, put, nil
@ -538,10 +528,10 @@ func (ac *addrConn) resetTransport(closeTransport bool) error {
if ac.state == Shutdown { if ac.state == Shutdown {
// ac.tearDown(...) has been invoked. // ac.tearDown(...) has been invoked.
ac.mu.Unlock() ac.mu.Unlock()
return errConnClosing return ErrConnClosing
} }
if ac.down != nil { if ac.down != nil {
ac.down(errNetworkIO) ac.down(ErrNetworkIO)
ac.down = nil ac.down = nil
} }
ac.state = Connecting ac.state = Connecting
@ -579,7 +569,7 @@ func (ac *addrConn) resetTransport(closeTransport bool) error {
if ac.state == Shutdown { if ac.state == Shutdown {
// ac.tearDown(...) has been invoked. // ac.tearDown(...) has been invoked.
ac.mu.Unlock() ac.mu.Unlock()
return errConnClosing return ErrConnClosing
} }
ac.errorf("transient failure: %v", err) ac.errorf("transient failure: %v", err)
ac.state = TransientFailure ac.state = TransientFailure
@ -616,7 +606,7 @@ func (ac *addrConn) resetTransport(closeTransport bool) error {
// ac.tearDown(...) has been invoked. // ac.tearDown(...) has been invoked.
ac.mu.Unlock() ac.mu.Unlock()
newTransport.Close() newTransport.Close()
return errConnClosing return ErrConnClosing
} }
ac.state = Ready ac.state = Ready
ac.stateCV.Broadcast() ac.stateCV.Broadcast()
@ -671,7 +661,7 @@ func (ac *addrConn) wait(ctx context.Context) (transport.ClientTransport, error)
switch { switch {
case ac.state == Shutdown: case ac.state == Shutdown:
ac.mu.Unlock() ac.mu.Unlock()
return nil, errConnClosing return nil, ErrConnClosing
case ac.state == Ready: case ac.state == Ready:
ct := ac.transport ct := ac.transport
ac.mu.Unlock() ac.mu.Unlock()
@ -725,7 +715,7 @@ func (ac *addrConn) tearDown(err error) {
ac.ready = nil ac.ready = nil
} }
if ac.transport != nil { if ac.transport != nil {
if err == errConnDrain { if err == ErrConnDrain {
ac.transport.GracefulClose() ac.transport.GracefulClose()
} else { } else {
ac.transport.Close() ac.transport.Close()