diff --git a/balancer.go b/balancer.go index 1a9c7501..279474b9 100644 --- a/balancer.go +++ b/balancer.go @@ -37,6 +37,8 @@ import ( "sync" "golang.org/x/net/context" + "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/naming" "google.golang.org/grpc/transport" ) @@ -53,6 +55,10 @@ type Address struct { // Balancer chooses network addresses for RPCs. // This is the EXPERIMENTAL API and may be changed or extended in the future. type Balancer interface { + // Start does the initialization work to bootstrap a Balancer. For example, + // this function may start the service discovery and watch the name resolution + // updates. + Start(target string) error // Up informs the balancer that gRPC has a connection to the server at // addr. It returns down which is called once the connection to addr gets // lost or closed. Once down is called, addr may no longer be returned @@ -64,21 +70,101 @@ type Balancer interface { // is called once the rpc has completed or failed. put can collect and // report rpc stats to remote load balancer. Get(ctx context.Context) (addr Address, put func(), err error) + // Notify gRPC internals the list of Address which should be connected. gRPC + // internals will compare it with the exisiting connected addresses. If the + // address Balancer notified is not in the list of the connected addresses, + // gRPC starts to connect the address. If an address in the connected + // addresses is not in the notification list, the corresponding connect will be + // shutdown gracefully. Otherwise, there are no operations. Note that this + // function must return the full list of the Addrresses which should be connected. + // It is NOT delta. + Notify() <-chan []Address // Close shuts down the balancer. Close() error } -// RoundRobin returns a Balancer that selects addresses round-robin. -func RoundRobin() Balancer { - return &roundRobin{} +// RoundRobin returns a Balancer that selects addresses round-robin. It starts to watch +// the name resolution updates. +func RoundRobin(r naming.Resolver) Balancer { + return &roundRobin{r: r} } type roundRobin struct { - mu sync.Mutex - addrs []Address - next int // index of the next address to return for Get() - waitCh chan struct{} // channel to block when there is no address available - done bool // The Balancer is closed. + r naming.Resolver + open []Address // all the known addresses the client can potentially connect + mu sync.Mutex + addrCh chan []Address // the channel to notify gRPC internals the list of addresses the client should connect to. + connected []Address // all the connected addresses + next int // index of the next address to return for Get() + waitCh chan struct{} // the channel to block when there is no connected address available + done bool // The Balancer is closed. +} + +func (rr *roundRobin) watchAddrUpdates(w naming.Watcher) error { + updates, err := w.Next() + if err != nil { + return err + } + for _, update := range updates { + addr := Address{ + Addr: update.Addr, + Metadata: update.Metadata, + } + switch update.Op { + case naming.Add: + var exisit bool + for _, v := range rr.open { + if addr == v { + exisit = true + grpclog.Println("grpc: The name resolver wanted to add an existing address: ", addr) + break + } + } + if exisit { + continue + } + rr.open = append(rr.open, addr) + case naming.Delete: + for i, v := range rr.open { + if v == addr { + copy(rr.open[i:], rr.open[i+1:]) + rr.open = rr.open[:len(rr.open)-1] + break + } + } + default: + grpclog.Println("Unknown update.Op ", update.Op) + } + } + // Make a copy of rr.open and write it onto rr.addrCh so that gRPC internals gets notified. + open := make([]Address, len(rr.open)) + rr.mu.Lock() + defer rr.mu.Unlock() + copy(open, rr.open) + if rr.done { + return ErrClientConnClosing + } + rr.addrCh <- open + return nil +} + +func (rr *roundRobin) Start(target string) error { + if rr.r == nil { + return nil + } + w, err := rr.r.Resolve(target) + if err != nil { + return err + } + rr.addrCh = make(chan []Address) + go func() { + for { + if err := rr.watchAddrUpdates(w); err != nil { + return + } + } + }() + return nil } // Up appends addr to the end of rr.addrs and sends notification if there @@ -86,13 +172,13 @@ type roundRobin struct { func (rr *roundRobin) Up(addr Address) func(error) { rr.mu.Lock() defer rr.mu.Unlock() - for _, a := range rr.addrs { + for _, a := range rr.connected { if a == addr { return nil } } - rr.addrs = append(rr.addrs, addr) - if len(rr.addrs) == 1 { + rr.connected = append(rr.connected, addr) + if len(rr.connected) == 1 { // addr is only one available. Notify the Get() callers who are blocking. if rr.waitCh != nil { close(rr.waitCh) @@ -108,10 +194,10 @@ func (rr *roundRobin) Up(addr Address) func(error) { func (rr *roundRobin) down(addr Address, err error) { rr.mu.Lock() defer rr.mu.Unlock() - for i, a := range rr.addrs { + for i, a := range rr.connected { if a == addr { - copy(rr.addrs[i:], rr.addrs[i+1:]) - rr.addrs = rr.addrs[:len(rr.addrs)-1] + copy(rr.connected[i:], rr.connected[i+1:]) + rr.connected = rr.connected[:len(rr.connected)-1] return } } @@ -126,16 +212,13 @@ func (rr *roundRobin) Get(ctx context.Context) (addr Address, put func(), err er err = ErrClientConnClosing return } - if rr.next >= len(rr.addrs) { + if rr.next >= len(rr.connected) { rr.next = 0 } - if len(rr.addrs) > 0 { - addr = rr.addrs[rr.next] + if len(rr.connected) > 0 { + addr = rr.connected[rr.next] rr.next++ rr.mu.Unlock() - put = func() { - rr.put(ctx, addr) - } return } // There is no address available. Wait on rr.waitCh. @@ -158,26 +241,24 @@ func (rr *roundRobin) Get(ctx context.Context) (addr Address, put func(), err er err = ErrClientConnClosing return } - if len(rr.addrs) == 0 { + if len(rr.connected) == 0 { // The newly added addr got removed by Down() again. rr.mu.Unlock() continue } - if rr.next >= len(rr.addrs) { + if rr.next >= len(rr.connected) { rr.next = 0 } - addr = rr.addrs[rr.next] + addr = rr.connected[rr.next] rr.next++ rr.mu.Unlock() - put = func() { - rr.put(ctx, addr) - } return } } } -func (rr *roundRobin) put(ctx context.Context, addr Address) { +func (rr *roundRobin) Notify() <-chan []Address { + return rr.addrCh } func (rr *roundRobin) Close() error { @@ -188,5 +269,8 @@ func (rr *roundRobin) Close() error { close(rr.waitCh) rr.waitCh = nil } + if rr.addrCh != nil { + close(rr.addrCh) + } return nil } diff --git a/balancer_test.go b/balancer_test.go index d0379497..976432f4 100644 --- a/balancer_test.go +++ b/balancer_test.go @@ -122,7 +122,7 @@ func TestNameDiscovery(t *testing.T) { // Start 2 servers on 2 ports. numServers := 2 servers, r := startServers(t, numServers, math.MaxUint32) - cc, err := Dial("foo.bar.com", WithNameResolver(r), WithBlock(), WithInsecure(), WithCodec(testCodec{})) + cc, err := Dial("foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) if err != nil { t.Fatalf("Failed to create ClientConn: %v", err) } @@ -157,7 +157,7 @@ func TestNameDiscovery(t *testing.T) { func TestEmptyAddrs(t *testing.T) { servers, r := startServers(t, 1, math.MaxUint32) - cc, err := Dial("foo.bar.com", WithNameResolver(r), WithBlock(), WithInsecure(), WithCodec(testCodec{})) + cc, err := Dial("foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) if err != nil { t.Fatalf("Failed to create ClientConn: %v", err) } @@ -167,12 +167,11 @@ func TestEmptyAddrs(t *testing.T) { } // Inject name resolution change to remove the server so that there is no address // available after that. - var updates []*naming.Update - updates = append(updates, &naming.Update{ + u := &naming.Update{ Op: naming.Delete, Addr: "127.0.0.1:" + servers[0].port, - }) - r.w.inject(updates) + } + r.w.inject([]*naming.Update{u}) // Loop until the above updates apply. for { time.Sleep(10 * time.Millisecond) @@ -189,24 +188,32 @@ func TestRoundRobin(t *testing.T) { // Start 3 servers on 3 ports. numServers := 3 servers, r := startServers(t, numServers, math.MaxUint32) - cc, err := Dial("foo.bar.com", WithNameResolver(r), WithBlock(), WithInsecure(), WithCodec(testCodec{})) + cc, err := Dial("foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) if err != nil { t.Fatalf("Failed to create ClientConn: %v", err) } - // Add servers[1] and servers[2] to the service discovery. - var updates []*naming.Update - updates = append(updates, &naming.Update{ + // Add servers[1] to the service discovery. + u := &naming.Update{ Op: naming.Add, Addr: "127.0.0.1:" + servers[1].port, - }) - updates = append(updates, &naming.Update{ - Op: naming.Add, - Addr: "127.0.0.1:" + servers[2].port, - }) - r.w.inject(updates) + } + r.w.inject([]*naming.Update{u}) req := "port" var reply string - // Loop until an RPC is completed by servers[2]. + // Loop until servers[1] is up + for { + if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[1].port { + break + } + time.Sleep(10 * time.Millisecond) + } + // Add server2[2] to the service discovery. + u = &naming.Update{ + Op: naming.Add, + Addr: "127.0.0.1:" + servers[2].port, + } + r.w.inject([]*naming.Update{u}) + // Loop until both servers[2] are up. for { if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[2].port { break @@ -216,7 +223,7 @@ func TestRoundRobin(t *testing.T) { // Check the incoming RPCs served in a round-robin manner. for i := 0; i < 10; i++ { if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || ErrorDesc(err) != servers[i%numServers].port { - t.Fatalf("Invoke(_, _, _, _, _) = %v, want %s", err, servers[i%numServers].port) + t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", i, err, servers[i%numServers].port) } } cc.Close() @@ -227,7 +234,7 @@ func TestRoundRobin(t *testing.T) { func TestCloseWithPendingRPC(t *testing.T) { servers, r := startServers(t, 1, math.MaxUint32) - cc, err := Dial("foo.bar.com", WithNameResolver(r), WithBlock(), WithInsecure(), WithCodec(testCodec{})) + cc, err := Dial("foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) if err != nil { t.Fatalf("Failed to create ClientConn: %v", err) } @@ -275,7 +282,7 @@ func TestCloseWithPendingRPC(t *testing.T) { func TestGetOnWaitChannel(t *testing.T) { servers, r := startServers(t, 1, math.MaxUint32) - cc, err := Dial("foo.bar.com", WithNameResolver(r), WithBlock(), WithInsecure(), WithCodec(testCodec{})) + cc, err := Dial("foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) if err != nil { t.Fatalf("Failed to create ClientConn: %v", err) } diff --git a/call.go b/call.go index 5fe5af26..98b8e2b1 100644 --- a/call.go +++ b/call.go @@ -169,7 +169,10 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli } stream, err = sendRequest(ctx, cc.dopts.codec, cc.dopts.cp, callHdr, t, args, topts) if err != nil { - put() + if put != nil { + put() + put = nil + } if _, ok := err.(transport.ConnectionError); ok { if c.failFast { return toRPCErr(err) @@ -181,7 +184,9 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli // Receive the response err = recvResponse(cc.dopts, t, &c, stream, reply) if err != nil { - put() + if put != nil { + put() + } if _, ok := err.(transport.ConnectionError); ok { if c.failFast { return toRPCErr(err) @@ -195,7 +200,10 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli c.traceInfo.tr.LazyLog(&payload{sent: false, msg: reply}, true) } t.CloseStream(stream, nil) - put() + if put != nil { + put() + put = nil + } return Errorf(stream.StatusCode(), "%s", stream.StatusDesc()) } } diff --git a/clientconn.go b/clientconn.go index 81c2c731..141cbbf0 100644 --- a/clientconn.go +++ b/clientconn.go @@ -65,12 +65,12 @@ var ( // ErrClientConnTimeout indicates that the connection could not be // established or re-established within the specified timeout. ErrClientConnTimeout = errors.New("grpc: timed out trying to connect") - - // errNetworkIP indicates that the connection is down due to some network I/O error. - errNetworkIO = errors.New("grpc: failed with network I/O error") - // errConnDrain indicates that the connection starts to be drained and does not accept any new RPCs. - errConnDrain = errors.New("grpc: the connection is drained") - errConnClosing = errors.New("grpc: the addrConn is closing") + // ErrNetworkIP indicates that the connection is down due to some network I/O error. + ErrNetworkIO = errors.New("grpc: failed with network I/O error") + // ErrConnDrain indicates that the connection starts to be drained and does not accept any new RPCs. + ErrConnDrain = errors.New("grpc: the connection is drained") + // ErrConnClosing + ErrConnClosing = errors.New("grpc: the addrConn is closing") // minimum time to give a connection to complete minConnectTimeout = 20 * time.Second ) @@ -82,7 +82,6 @@ type dialOptions struct { cp Compressor dc Decompressor bs backoffStrategy - resolver naming.Resolver balancer Balancer block bool insecure bool @@ -115,13 +114,6 @@ func WithDecompressor(dc Decompressor) DialOption { } } -// WithNameResolver returns a DialOption which sets a name resolver for service discovery. -func WithNameResolver(r naming.Resolver) DialOption { - return func(o *dialOptions) { - o.resolver = r - } -} - // WithBalancer returns a DialOption which sets a load balancer. func WithBalancer(b Balancer) DialOption { return func(o *dialOptions) { @@ -231,34 +223,29 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) { cc.balancer = cc.dopts.balancer if cc.balancer == nil { - cc.balancer = RoundRobin() + cc.balancer = RoundRobin(nil) } - - if cc.dopts.resolver == nil { - addr := Address{ - Addr: cc.target, - } - if err := cc.newAddrConn(addr); err != nil { + if err := cc.balancer.Start(target); err != nil { + return nil, err + } + ch := cc.balancer.Notify() + if ch == nil { + // There is no name resolver installed. + addr := Address{Addr: target} + if err := cc.newAddrConn(addr, false); err != nil { return nil, err } } else { - w, err := cc.dopts.resolver.Resolve(cc.target) - if err != nil { - return nil, err + addrs, ok := <-ch + if !ok || len(addrs) == 0 { + return nil, fmt.Errorf("grpc: there is no address available to dial") } - cc.watcher = w - // Get the initial name resolution and dial the first connection. - if err := cc.watchAddrUpdates(); err != nil { - return nil, err - } - // Start a goroutine to watch for the future name resolution changes. - go func() { - for { - if err := cc.watchAddrUpdates(); err != nil { - return - } + for _, a := range addrs { + if err := cc.newAddrConn(a, false); err != nil { + return nil, err } - }() + } + go cc.controller() } colonPos := strings.LastIndex(target, ":") @@ -314,50 +301,48 @@ type ClientConn struct { conns map[Address]*addrConn } -func (cc *ClientConn) watchAddrUpdates() error { - updates, err := cc.watcher.Next() - if err != nil { - return err - } - for _, update := range updates { - switch update.Op { - case naming.Add: - cc.mu.RLock() - addr := Address{ - Addr: update.Addr, - Metadata: update.Metadata, +func (cc *ClientConn) controller() { + for { + addrs, ok := <-cc.balancer.Notify() + if !ok { + // cc has been closed. + return + } + var ( + add []Address // Addresses need to setup connections. + del []*addrConn // Connections need to tear down. + ) + cc.mu.Lock() + for _, a := range addrs { + if _, ok := cc.conns[a]; !ok { + add = append(add, a) } - if _, ok := cc.conns[addr]; ok { - cc.mu.RUnlock() - grpclog.Println("grpc: The name resolver wanted to add an existing address: ", addr) - continue + } + for k, c := range cc.conns { + var keep bool + for _, a := range addrs { + if k == a { + keep = true + break + } } - cc.mu.RUnlock() - if err := cc.newAddrConn(addr); err != nil { - return err + if !keep { + del = append(del, c) } - case naming.Delete: - cc.mu.RLock() - addr := Address{ - Addr: update.Addr, - Metadata: update.Metadata, + } + cc.mu.Unlock() + for _, a := range addrs { + if err := cc.newAddrConn(a, true); err != nil { + } - ac, ok := cc.conns[addr] - if !ok { - cc.mu.RUnlock() - grpclog.Println("grpc: The name resolver wanted to delete a non-exist address: ", addr) - continue - } - cc.mu.RUnlock() - ac.tearDown(errConnDrain) - default: - grpclog.Println("Unknown update.Op ", update.Op) + } + for _, c := range del { + c.tearDown(ErrConnDrain) } } - return nil } -func (cc *ClientConn) newAddrConn(addr Address) error { +func (cc *ClientConn) newAddrConn(addr Address, skipWait bool) error { ac := &addrConn{ cc: cc, addr: addr, @@ -394,7 +379,8 @@ func (cc *ClientConn) newAddrConn(addr Address) error { ac.cc.mu.Unlock() ac.stateCV = sync.NewCond(&ac.mu) - if ac.dopts.block { + // skipWait may overwrite the decision in ac.dopts.block. + if ac.dopts.block && !skipWait { if err := ac.resetTransport(false); err != nil { ac.tearDown(err) return err @@ -428,12 +414,16 @@ func (cc *ClientConn) getTransport(ctx context.Context) (transport.ClientTranspo ac, ok := cc.conns[addr] cc.mu.RUnlock() if !ok { - put() + if put != nil { + put() + } return nil, nil, transport.StreamErrorf(codes.Internal, "grpc: failed to find the transport to send the rpc") } t, err := ac.wait(ctx) if err != nil { - put() + if put != nil { + put() + } return nil, nil, err } return t, put, nil @@ -538,10 +528,10 @@ func (ac *addrConn) resetTransport(closeTransport bool) error { if ac.state == Shutdown { // ac.tearDown(...) has been invoked. ac.mu.Unlock() - return errConnClosing + return ErrConnClosing } if ac.down != nil { - ac.down(errNetworkIO) + ac.down(ErrNetworkIO) ac.down = nil } ac.state = Connecting @@ -579,7 +569,7 @@ func (ac *addrConn) resetTransport(closeTransport bool) error { if ac.state == Shutdown { // ac.tearDown(...) has been invoked. ac.mu.Unlock() - return errConnClosing + return ErrConnClosing } ac.errorf("transient failure: %v", err) ac.state = TransientFailure @@ -616,7 +606,7 @@ func (ac *addrConn) resetTransport(closeTransport bool) error { // ac.tearDown(...) has been invoked. ac.mu.Unlock() newTransport.Close() - return errConnClosing + return ErrConnClosing } ac.state = Ready ac.stateCV.Broadcast() @@ -671,7 +661,7 @@ func (ac *addrConn) wait(ctx context.Context) (transport.ClientTransport, error) switch { case ac.state == Shutdown: ac.mu.Unlock() - return nil, errConnClosing + return nil, ErrConnClosing case ac.state == Ready: ct := ac.transport ac.mu.Unlock() @@ -725,7 +715,7 @@ func (ac *addrConn) tearDown(err error) { ac.ready = nil } if ac.transport != nil { - if err == errConnDrain { + if err == ErrConnDrain { ac.transport.GracefulClose() } else { ac.transport.Close()