add Notify API; move the name resolving into Balancer

This commit is contained in:
iamqizhao
2016-05-23 19:25:01 -07:00
parent fda7cb3cdf
commit 5b484e4099
4 changed files with 220 additions and 131 deletions

View File

@ -37,6 +37,8 @@ import (
"sync"
"golang.org/x/net/context"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/naming"
"google.golang.org/grpc/transport"
)
@ -53,6 +55,10 @@ type Address struct {
// Balancer chooses network addresses for RPCs.
// This is the EXPERIMENTAL API and may be changed or extended in the future.
type Balancer interface {
// Start does the initialization work to bootstrap a Balancer. For example,
// this function may start the service discovery and watch the name resolution
// updates.
Start(target string) error
// Up informs the balancer that gRPC has a connection to the server at
// addr. It returns down which is called once the connection to addr gets
// lost or closed. Once down is called, addr may no longer be returned
@ -64,21 +70,101 @@ type Balancer interface {
// is called once the rpc has completed or failed. put can collect and
// report rpc stats to remote load balancer.
Get(ctx context.Context) (addr Address, put func(), err error)
// Notify gRPC internals the list of Address which should be connected. gRPC
// internals will compare it with the exisiting connected addresses. If the
// address Balancer notified is not in the list of the connected addresses,
// gRPC starts to connect the address. If an address in the connected
// addresses is not in the notification list, the corresponding connect will be
// shutdown gracefully. Otherwise, there are no operations. Note that this
// function must return the full list of the Addrresses which should be connected.
// It is NOT delta.
Notify() <-chan []Address
// Close shuts down the balancer.
Close() error
}
// RoundRobin returns a Balancer that selects addresses round-robin.
func RoundRobin() Balancer {
return &roundRobin{}
// RoundRobin returns a Balancer that selects addresses round-robin. It starts to watch
// the name resolution updates.
func RoundRobin(r naming.Resolver) Balancer {
return &roundRobin{r: r}
}
type roundRobin struct {
mu sync.Mutex
addrs []Address
next int // index of the next address to return for Get()
waitCh chan struct{} // channel to block when there is no address available
done bool // The Balancer is closed.
r naming.Resolver
open []Address // all the known addresses the client can potentially connect
mu sync.Mutex
addrCh chan []Address // the channel to notify gRPC internals the list of addresses the client should connect to.
connected []Address // all the connected addresses
next int // index of the next address to return for Get()
waitCh chan struct{} // the channel to block when there is no connected address available
done bool // The Balancer is closed.
}
func (rr *roundRobin) watchAddrUpdates(w naming.Watcher) error {
updates, err := w.Next()
if err != nil {
return err
}
for _, update := range updates {
addr := Address{
Addr: update.Addr,
Metadata: update.Metadata,
}
switch update.Op {
case naming.Add:
var exisit bool
for _, v := range rr.open {
if addr == v {
exisit = true
grpclog.Println("grpc: The name resolver wanted to add an existing address: ", addr)
break
}
}
if exisit {
continue
}
rr.open = append(rr.open, addr)
case naming.Delete:
for i, v := range rr.open {
if v == addr {
copy(rr.open[i:], rr.open[i+1:])
rr.open = rr.open[:len(rr.open)-1]
break
}
}
default:
grpclog.Println("Unknown update.Op ", update.Op)
}
}
// Make a copy of rr.open and write it onto rr.addrCh so that gRPC internals gets notified.
open := make([]Address, len(rr.open))
rr.mu.Lock()
defer rr.mu.Unlock()
copy(open, rr.open)
if rr.done {
return ErrClientConnClosing
}
rr.addrCh <- open
return nil
}
func (rr *roundRobin) Start(target string) error {
if rr.r == nil {
return nil
}
w, err := rr.r.Resolve(target)
if err != nil {
return err
}
rr.addrCh = make(chan []Address)
go func() {
for {
if err := rr.watchAddrUpdates(w); err != nil {
return
}
}
}()
return nil
}
// Up appends addr to the end of rr.addrs and sends notification if there
@ -86,13 +172,13 @@ type roundRobin struct {
func (rr *roundRobin) Up(addr Address) func(error) {
rr.mu.Lock()
defer rr.mu.Unlock()
for _, a := range rr.addrs {
for _, a := range rr.connected {
if a == addr {
return nil
}
}
rr.addrs = append(rr.addrs, addr)
if len(rr.addrs) == 1 {
rr.connected = append(rr.connected, addr)
if len(rr.connected) == 1 {
// addr is only one available. Notify the Get() callers who are blocking.
if rr.waitCh != nil {
close(rr.waitCh)
@ -108,10 +194,10 @@ func (rr *roundRobin) Up(addr Address) func(error) {
func (rr *roundRobin) down(addr Address, err error) {
rr.mu.Lock()
defer rr.mu.Unlock()
for i, a := range rr.addrs {
for i, a := range rr.connected {
if a == addr {
copy(rr.addrs[i:], rr.addrs[i+1:])
rr.addrs = rr.addrs[:len(rr.addrs)-1]
copy(rr.connected[i:], rr.connected[i+1:])
rr.connected = rr.connected[:len(rr.connected)-1]
return
}
}
@ -126,16 +212,13 @@ func (rr *roundRobin) Get(ctx context.Context) (addr Address, put func(), err er
err = ErrClientConnClosing
return
}
if rr.next >= len(rr.addrs) {
if rr.next >= len(rr.connected) {
rr.next = 0
}
if len(rr.addrs) > 0 {
addr = rr.addrs[rr.next]
if len(rr.connected) > 0 {
addr = rr.connected[rr.next]
rr.next++
rr.mu.Unlock()
put = func() {
rr.put(ctx, addr)
}
return
}
// There is no address available. Wait on rr.waitCh.
@ -158,26 +241,24 @@ func (rr *roundRobin) Get(ctx context.Context) (addr Address, put func(), err er
err = ErrClientConnClosing
return
}
if len(rr.addrs) == 0 {
if len(rr.connected) == 0 {
// The newly added addr got removed by Down() again.
rr.mu.Unlock()
continue
}
if rr.next >= len(rr.addrs) {
if rr.next >= len(rr.connected) {
rr.next = 0
}
addr = rr.addrs[rr.next]
addr = rr.connected[rr.next]
rr.next++
rr.mu.Unlock()
put = func() {
rr.put(ctx, addr)
}
return
}
}
}
func (rr *roundRobin) put(ctx context.Context, addr Address) {
func (rr *roundRobin) Notify() <-chan []Address {
return rr.addrCh
}
func (rr *roundRobin) Close() error {
@ -188,5 +269,8 @@ func (rr *roundRobin) Close() error {
close(rr.waitCh)
rr.waitCh = nil
}
if rr.addrCh != nil {
close(rr.addrCh)
}
return nil
}

View File

@ -122,7 +122,7 @@ func TestNameDiscovery(t *testing.T) {
// Start 2 servers on 2 ports.
numServers := 2
servers, r := startServers(t, numServers, math.MaxUint32)
cc, err := Dial("foo.bar.com", WithNameResolver(r), WithBlock(), WithInsecure(), WithCodec(testCodec{}))
cc, err := Dial("foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{}))
if err != nil {
t.Fatalf("Failed to create ClientConn: %v", err)
}
@ -157,7 +157,7 @@ func TestNameDiscovery(t *testing.T) {
func TestEmptyAddrs(t *testing.T) {
servers, r := startServers(t, 1, math.MaxUint32)
cc, err := Dial("foo.bar.com", WithNameResolver(r), WithBlock(), WithInsecure(), WithCodec(testCodec{}))
cc, err := Dial("foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{}))
if err != nil {
t.Fatalf("Failed to create ClientConn: %v", err)
}
@ -167,12 +167,11 @@ func TestEmptyAddrs(t *testing.T) {
}
// Inject name resolution change to remove the server so that there is no address
// available after that.
var updates []*naming.Update
updates = append(updates, &naming.Update{
u := &naming.Update{
Op: naming.Delete,
Addr: "127.0.0.1:" + servers[0].port,
})
r.w.inject(updates)
}
r.w.inject([]*naming.Update{u})
// Loop until the above updates apply.
for {
time.Sleep(10 * time.Millisecond)
@ -189,24 +188,32 @@ func TestRoundRobin(t *testing.T) {
// Start 3 servers on 3 ports.
numServers := 3
servers, r := startServers(t, numServers, math.MaxUint32)
cc, err := Dial("foo.bar.com", WithNameResolver(r), WithBlock(), WithInsecure(), WithCodec(testCodec{}))
cc, err := Dial("foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{}))
if err != nil {
t.Fatalf("Failed to create ClientConn: %v", err)
}
// Add servers[1] and servers[2] to the service discovery.
var updates []*naming.Update
updates = append(updates, &naming.Update{
// Add servers[1] to the service discovery.
u := &naming.Update{
Op: naming.Add,
Addr: "127.0.0.1:" + servers[1].port,
})
updates = append(updates, &naming.Update{
Op: naming.Add,
Addr: "127.0.0.1:" + servers[2].port,
})
r.w.inject(updates)
}
r.w.inject([]*naming.Update{u})
req := "port"
var reply string
// Loop until an RPC is completed by servers[2].
// Loop until servers[1] is up
for {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[1].port {
break
}
time.Sleep(10 * time.Millisecond)
}
// Add server2[2] to the service discovery.
u = &naming.Update{
Op: naming.Add,
Addr: "127.0.0.1:" + servers[2].port,
}
r.w.inject([]*naming.Update{u})
// Loop until both servers[2] are up.
for {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[2].port {
break
@ -216,7 +223,7 @@ func TestRoundRobin(t *testing.T) {
// Check the incoming RPCs served in a round-robin manner.
for i := 0; i < 10; i++ {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || ErrorDesc(err) != servers[i%numServers].port {
t.Fatalf("Invoke(_, _, _, _, _) = %v, want %s", err, servers[i%numServers].port)
t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", i, err, servers[i%numServers].port)
}
}
cc.Close()
@ -227,7 +234,7 @@ func TestRoundRobin(t *testing.T) {
func TestCloseWithPendingRPC(t *testing.T) {
servers, r := startServers(t, 1, math.MaxUint32)
cc, err := Dial("foo.bar.com", WithNameResolver(r), WithBlock(), WithInsecure(), WithCodec(testCodec{}))
cc, err := Dial("foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{}))
if err != nil {
t.Fatalf("Failed to create ClientConn: %v", err)
}
@ -275,7 +282,7 @@ func TestCloseWithPendingRPC(t *testing.T) {
func TestGetOnWaitChannel(t *testing.T) {
servers, r := startServers(t, 1, math.MaxUint32)
cc, err := Dial("foo.bar.com", WithNameResolver(r), WithBlock(), WithInsecure(), WithCodec(testCodec{}))
cc, err := Dial("foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{}))
if err != nil {
t.Fatalf("Failed to create ClientConn: %v", err)
}

14
call.go
View File

@ -169,7 +169,10 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
}
stream, err = sendRequest(ctx, cc.dopts.codec, cc.dopts.cp, callHdr, t, args, topts)
if err != nil {
put()
if put != nil {
put()
put = nil
}
if _, ok := err.(transport.ConnectionError); ok {
if c.failFast {
return toRPCErr(err)
@ -181,7 +184,9 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
// Receive the response
err = recvResponse(cc.dopts, t, &c, stream, reply)
if err != nil {
put()
if put != nil {
put()
}
if _, ok := err.(transport.ConnectionError); ok {
if c.failFast {
return toRPCErr(err)
@ -195,7 +200,10 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
c.traceInfo.tr.LazyLog(&payload{sent: false, msg: reply}, true)
}
t.CloseStream(stream, nil)
put()
if put != nil {
put()
put = nil
}
return Errorf(stream.StatusCode(), "%s", stream.StatusDesc())
}
}

View File

@ -65,12 +65,12 @@ var (
// ErrClientConnTimeout indicates that the connection could not be
// established or re-established within the specified timeout.
ErrClientConnTimeout = errors.New("grpc: timed out trying to connect")
// errNetworkIP indicates that the connection is down due to some network I/O error.
errNetworkIO = errors.New("grpc: failed with network I/O error")
// errConnDrain indicates that the connection starts to be drained and does not accept any new RPCs.
errConnDrain = errors.New("grpc: the connection is drained")
errConnClosing = errors.New("grpc: the addrConn is closing")
// ErrNetworkIP indicates that the connection is down due to some network I/O error.
ErrNetworkIO = errors.New("grpc: failed with network I/O error")
// ErrConnDrain indicates that the connection starts to be drained and does not accept any new RPCs.
ErrConnDrain = errors.New("grpc: the connection is drained")
// ErrConnClosing
ErrConnClosing = errors.New("grpc: the addrConn is closing")
// minimum time to give a connection to complete
minConnectTimeout = 20 * time.Second
)
@ -82,7 +82,6 @@ type dialOptions struct {
cp Compressor
dc Decompressor
bs backoffStrategy
resolver naming.Resolver
balancer Balancer
block bool
insecure bool
@ -115,13 +114,6 @@ func WithDecompressor(dc Decompressor) DialOption {
}
}
// WithNameResolver returns a DialOption which sets a name resolver for service discovery.
func WithNameResolver(r naming.Resolver) DialOption {
return func(o *dialOptions) {
o.resolver = r
}
}
// WithBalancer returns a DialOption which sets a load balancer.
func WithBalancer(b Balancer) DialOption {
return func(o *dialOptions) {
@ -231,34 +223,29 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) {
cc.balancer = cc.dopts.balancer
if cc.balancer == nil {
cc.balancer = RoundRobin()
cc.balancer = RoundRobin(nil)
}
if cc.dopts.resolver == nil {
addr := Address{
Addr: cc.target,
}
if err := cc.newAddrConn(addr); err != nil {
if err := cc.balancer.Start(target); err != nil {
return nil, err
}
ch := cc.balancer.Notify()
if ch == nil {
// There is no name resolver installed.
addr := Address{Addr: target}
if err := cc.newAddrConn(addr, false); err != nil {
return nil, err
}
} else {
w, err := cc.dopts.resolver.Resolve(cc.target)
if err != nil {
return nil, err
addrs, ok := <-ch
if !ok || len(addrs) == 0 {
return nil, fmt.Errorf("grpc: there is no address available to dial")
}
cc.watcher = w
// Get the initial name resolution and dial the first connection.
if err := cc.watchAddrUpdates(); err != nil {
return nil, err
}
// Start a goroutine to watch for the future name resolution changes.
go func() {
for {
if err := cc.watchAddrUpdates(); err != nil {
return
}
for _, a := range addrs {
if err := cc.newAddrConn(a, false); err != nil {
return nil, err
}
}()
}
go cc.controller()
}
colonPos := strings.LastIndex(target, ":")
@ -314,50 +301,48 @@ type ClientConn struct {
conns map[Address]*addrConn
}
func (cc *ClientConn) watchAddrUpdates() error {
updates, err := cc.watcher.Next()
if err != nil {
return err
}
for _, update := range updates {
switch update.Op {
case naming.Add:
cc.mu.RLock()
addr := Address{
Addr: update.Addr,
Metadata: update.Metadata,
func (cc *ClientConn) controller() {
for {
addrs, ok := <-cc.balancer.Notify()
if !ok {
// cc has been closed.
return
}
var (
add []Address // Addresses need to setup connections.
del []*addrConn // Connections need to tear down.
)
cc.mu.Lock()
for _, a := range addrs {
if _, ok := cc.conns[a]; !ok {
add = append(add, a)
}
if _, ok := cc.conns[addr]; ok {
cc.mu.RUnlock()
grpclog.Println("grpc: The name resolver wanted to add an existing address: ", addr)
continue
}
for k, c := range cc.conns {
var keep bool
for _, a := range addrs {
if k == a {
keep = true
break
}
}
cc.mu.RUnlock()
if err := cc.newAddrConn(addr); err != nil {
return err
if !keep {
del = append(del, c)
}
case naming.Delete:
cc.mu.RLock()
addr := Address{
Addr: update.Addr,
Metadata: update.Metadata,
}
cc.mu.Unlock()
for _, a := range addrs {
if err := cc.newAddrConn(a, true); err != nil {
}
ac, ok := cc.conns[addr]
if !ok {
cc.mu.RUnlock()
grpclog.Println("grpc: The name resolver wanted to delete a non-exist address: ", addr)
continue
}
cc.mu.RUnlock()
ac.tearDown(errConnDrain)
default:
grpclog.Println("Unknown update.Op ", update.Op)
}
for _, c := range del {
c.tearDown(ErrConnDrain)
}
}
return nil
}
func (cc *ClientConn) newAddrConn(addr Address) error {
func (cc *ClientConn) newAddrConn(addr Address, skipWait bool) error {
ac := &addrConn{
cc: cc,
addr: addr,
@ -394,7 +379,8 @@ func (cc *ClientConn) newAddrConn(addr Address) error {
ac.cc.mu.Unlock()
ac.stateCV = sync.NewCond(&ac.mu)
if ac.dopts.block {
// skipWait may overwrite the decision in ac.dopts.block.
if ac.dopts.block && !skipWait {
if err := ac.resetTransport(false); err != nil {
ac.tearDown(err)
return err
@ -428,12 +414,16 @@ func (cc *ClientConn) getTransport(ctx context.Context) (transport.ClientTranspo
ac, ok := cc.conns[addr]
cc.mu.RUnlock()
if !ok {
put()
if put != nil {
put()
}
return nil, nil, transport.StreamErrorf(codes.Internal, "grpc: failed to find the transport to send the rpc")
}
t, err := ac.wait(ctx)
if err != nil {
put()
if put != nil {
put()
}
return nil, nil, err
}
return t, put, nil
@ -538,10 +528,10 @@ func (ac *addrConn) resetTransport(closeTransport bool) error {
if ac.state == Shutdown {
// ac.tearDown(...) has been invoked.
ac.mu.Unlock()
return errConnClosing
return ErrConnClosing
}
if ac.down != nil {
ac.down(errNetworkIO)
ac.down(ErrNetworkIO)
ac.down = nil
}
ac.state = Connecting
@ -579,7 +569,7 @@ func (ac *addrConn) resetTransport(closeTransport bool) error {
if ac.state == Shutdown {
// ac.tearDown(...) has been invoked.
ac.mu.Unlock()
return errConnClosing
return ErrConnClosing
}
ac.errorf("transient failure: %v", err)
ac.state = TransientFailure
@ -616,7 +606,7 @@ func (ac *addrConn) resetTransport(closeTransport bool) error {
// ac.tearDown(...) has been invoked.
ac.mu.Unlock()
newTransport.Close()
return errConnClosing
return ErrConnClosing
}
ac.state = Ready
ac.stateCV.Broadcast()
@ -671,7 +661,7 @@ func (ac *addrConn) wait(ctx context.Context) (transport.ClientTransport, error)
switch {
case ac.state == Shutdown:
ac.mu.Unlock()
return nil, errConnClosing
return nil, ErrConnClosing
case ac.state == Ready:
ct := ac.transport
ac.mu.Unlock()
@ -725,7 +715,7 @@ func (ac *addrConn) tearDown(err error) {
ac.ready = nil
}
if ac.transport != nil {
if err == errConnDrain {
if err == ErrConnDrain {
ac.transport.GracefulClose()
} else {
ac.transport.Close()