New grpclb implementation (#1558)

The new grpclb supports fallback to backends if remote balancer is unavailable
This commit is contained in:
Menghan Li
2017-11-27 11:16:26 -08:00
committed by GitHub
parent 10873b30bf
commit 2ef021f78d
13 changed files with 1014 additions and 881 deletions

View File

@ -128,6 +128,10 @@ type PickOptions struct{}
type DoneInfo struct { type DoneInfo struct {
// Err is the rpc error the RPC finished with. It could be nil. // Err is the rpc error the RPC finished with. It could be nil.
Err error Err error
// BytesSent indicates if any bytes have been sent to the server.
BytesSent bool
// BytesReceived indicates if any byte has been received from the server.
BytesReceived bool
} }
var ( var (

27
call.go
View File

@ -277,11 +277,11 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
err = sendRequest(ctx, cc.dopts, cc.dopts.cp, c, callHdr, stream, t, args, topts) err = sendRequest(ctx, cc.dopts, cc.dopts.cp, c, callHdr, stream, t, args, topts)
if err != nil { if err != nil {
if done != nil { if done != nil {
updateRPCInfoInContext(ctx, rpcInfo{ done(balancer.DoneInfo{
bytesSent: true, Err: err,
bytesReceived: stream.BytesReceived(), BytesSent: true,
BytesReceived: stream.BytesReceived(),
}) })
done(balancer.DoneInfo{Err: err})
} }
// Retry a non-failfast RPC when // Retry a non-failfast RPC when
// i) the server started to drain before this RPC was initiated. // i) the server started to drain before this RPC was initiated.
@ -301,11 +301,11 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
err = recvResponse(ctx, cc.dopts, t, c, stream, reply) err = recvResponse(ctx, cc.dopts, t, c, stream, reply)
if err != nil { if err != nil {
if done != nil { if done != nil {
updateRPCInfoInContext(ctx, rpcInfo{ done(balancer.DoneInfo{
bytesSent: true, Err: err,
bytesReceived: stream.BytesReceived(), BytesSent: true,
BytesReceived: stream.BytesReceived(),
}) })
done(balancer.DoneInfo{Err: err})
} }
if !c.failFast && stream.Unprocessed() { if !c.failFast && stream.Unprocessed() {
// In these cases, the server did not receive the data, but we still // In these cases, the server did not receive the data, but we still
@ -323,12 +323,13 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
c.traceInfo.tr.LazyLog(&payload{sent: false, msg: reply}, true) c.traceInfo.tr.LazyLog(&payload{sent: false, msg: reply}, true)
} }
t.CloseStream(stream, nil) t.CloseStream(stream, nil)
err = stream.Status().Err()
if done != nil { if done != nil {
updateRPCInfoInContext(ctx, rpcInfo{ done(balancer.DoneInfo{
bytesSent: true, Err: err,
bytesReceived: stream.BytesReceived(), BytesSent: true,
BytesReceived: stream.BytesReceived(),
}) })
done(balancer.DoneInfo{Err: err})
} }
if !c.failFast && stream.Unprocessed() { if !c.failFast && stream.Unprocessed() {
// In these cases, the server did not receive the data, but we still // In these cases, the server did not receive the data, but we still
@ -339,6 +340,6 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
continue continue
} }
} }
return stream.Status().Err() return err
} }
} }

View File

@ -97,6 +97,8 @@ type dialOptions struct {
callOptions []CallOption callOptions []CallOption
// This is to support v1 balancer. // This is to support v1 balancer.
balancerBuilder balancer.Builder balancerBuilder balancer.Builder
// This is to support grpclb.
resolverBuilder resolver.Builder
} }
const ( const (
@ -204,6 +206,13 @@ func WithBalancerBuilder(b balancer.Builder) DialOption {
} }
} }
// withResolverBuilder is only for grpclb.
func withResolverBuilder(b resolver.Builder) DialOption {
return func(o *dialOptions) {
o.resolverBuilder = b
}
}
// WithServiceConfig returns a DialOption which has a channel to read the service configuration. // WithServiceConfig returns a DialOption which has a channel to read the service configuration.
// DEPRECATED: service config should be received through name resolver, as specified here. // DEPRECATED: service config should be received through name resolver, as specified here.
// https://github.com/grpc/grpc/blob/master/doc/service_config.md // https://github.com/grpc/grpc/blob/master/doc/service_config.md
@ -283,18 +292,23 @@ func WithTimeout(d time.Duration) DialOption {
} }
} }
func withContextDialer(f func(context.Context, string) (net.Conn, error)) DialOption {
return func(o *dialOptions) {
o.copts.Dialer = f
}
}
// WithDialer returns a DialOption that specifies a function to use for dialing network addresses. // WithDialer returns a DialOption that specifies a function to use for dialing network addresses.
// If FailOnNonTempDialError() is set to true, and an error is returned by f, gRPC checks the error's // If FailOnNonTempDialError() is set to true, and an error is returned by f, gRPC checks the error's
// Temporary() method to decide if it should try to reconnect to the network address. // Temporary() method to decide if it should try to reconnect to the network address.
func WithDialer(f func(string, time.Duration) (net.Conn, error)) DialOption { func WithDialer(f func(string, time.Duration) (net.Conn, error)) DialOption {
return func(o *dialOptions) { return withContextDialer(
o.copts.Dialer = func(ctx context.Context, addr string) (net.Conn, error) { func(ctx context.Context, addr string) (net.Conn, error) {
if deadline, ok := ctx.Deadline(); ok { if deadline, ok := ctx.Deadline(); ok {
return f(addr, deadline.Sub(time.Now())) return f(addr, deadline.Sub(time.Now()))
} }
return f(addr, 0) return f(addr, 0)
} })
}
} }
// WithStatsHandler returns a DialOption that specifies the stats handler // WithStatsHandler returns a DialOption that specifies the stats handler

857
grpclb.go
View File

@ -19,21 +19,32 @@
package grpc package grpc
import ( import (
"errors" "strconv"
"fmt" "strings"
"math/rand"
"net"
"sync" "sync"
"time" "time"
"golang.org/x/net/context" "golang.org/x/net/context"
"google.golang.org/grpc/codes" "google.golang.org/grpc/balancer"
lbmpb "google.golang.org/grpc/grpclb/grpc_lb_v1/messages" "google.golang.org/grpc/connectivity"
lbpb "google.golang.org/grpc/grpclb/grpc_lb_v1/messages"
"google.golang.org/grpc/grpclog" "google.golang.org/grpc/grpclog"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/resolver"
"google.golang.org/grpc/naming" "google.golang.org/grpc/resolver/manual"
) )
const (
lbTokeyKey = "lb-token"
defaultFallbackTimeout = 10 * time.Second
)
func convertDuration(d *lbpb.Duration) time.Duration {
if d == nil {
return 0
}
return time.Duration(d.Seconds)*time.Second + time.Duration(d.Nanos)*time.Nanosecond
}
// Client API for LoadBalancer service. // Client API for LoadBalancer service.
// Mostly copied from generated pb.go file. // Mostly copied from generated pb.go file.
// To avoid circular dependency. // To avoid circular dependency.
@ -59,646 +70,270 @@ type balanceLoadClientStream struct {
ClientStream ClientStream
} }
func (x *balanceLoadClientStream) Send(m *lbmpb.LoadBalanceRequest) error { func (x *balanceLoadClientStream) Send(m *lbpb.LoadBalanceRequest) error {
return x.ClientStream.SendMsg(m) return x.ClientStream.SendMsg(m)
} }
func (x *balanceLoadClientStream) Recv() (*lbmpb.LoadBalanceResponse, error) { func (x *balanceLoadClientStream) Recv() (*lbpb.LoadBalanceResponse, error) {
m := new(lbmpb.LoadBalanceResponse) m := new(lbpb.LoadBalanceResponse)
if err := x.ClientStream.RecvMsg(m); err != nil { if err := x.ClientStream.RecvMsg(m); err != nil {
return nil, err return nil, err
} }
return m, nil return m, nil
} }
// NewGRPCLBBalancer creates a grpclb load balancer. // NewLBBuilder creates a builder for grpclb. For testing only.
func NewGRPCLBBalancer(r naming.Resolver) Balancer { func NewLBBuilder() balancer.Builder {
return &grpclbBalancer{ // TODO(bar grpclb) this function is exported for testing only, remove it when resolver supports selecting grpclb.
r: r, return NewLBBuilderWithFallbackTimeout(defaultFallbackTimeout)
}
// NewLBBuilderWithFallbackTimeout creates a grpclb builder with the given
// fallbackTimeout. If no response is received from the remote balancer within
// fallbackTimeout, the backend addresses from the resolved address list will be
// used.
//
// Only call this function when a non-default fallback timeout is needed.
func NewLBBuilderWithFallbackTimeout(fallbackTimeout time.Duration) balancer.Builder {
return &lbBuilder{
fallbackTimeout: fallbackTimeout,
} }
} }
type remoteBalancerInfo struct { type lbBuilder struct {
addr string fallbackTimeout time.Duration
// the server name used for authentication with the remote LB server.
name string
} }
// grpclbAddrInfo consists of the information of a backend server. func (b *lbBuilder) Name() string {
type grpclbAddrInfo struct { return "grpclb"
addr Address
connected bool
// dropForRateLimiting indicates whether this particular request should be
// dropped by the client for rate limiting.
dropForRateLimiting bool
// dropForLoadBalancing indicates whether this particular request should be
// dropped by the client for load balancing.
dropForLoadBalancing bool
} }
type grpclbBalancer struct { func (b *lbBuilder) Build(cc balancer.ClientConn, opt balancer.BuildOptions) balancer.Balancer {
r naming.Resolver // This generates a manual resolver builder with a random scheme. This
target string // scheme will be used to dial to remote LB, so we can send filtered address
mu sync.Mutex // updates to remote LB ClientConn using this manual resolver.
seq int // a sequence number to make sure addrCh does not get stale addresses. scheme := "grpclb_internal_" + strconv.FormatInt(time.Now().UnixNano(), 36)
w naming.Watcher r := manual.NewBuilderWithScheme(scheme)
addrCh chan []Address
rbs []remoteBalancerInfo
addrs []*grpclbAddrInfo
next int
waitCh chan struct{}
done bool
rand *rand.Rand
clientStats lbmpb.ClientStats var target string
targetSplitted := strings.Split(cc.Target(), ":///")
if len(targetSplitted) < 2 {
target = cc.Target()
} else {
target = targetSplitted[1]
}
lb := &lbBalancer{
cc: cc,
target: target,
opt: opt,
fallbackTimeout: b.fallbackTimeout,
doneCh: make(chan struct{}),
manualResolver: r,
csEvltr: &connectivityStateEvaluator{},
subConns: make(map[resolver.Address]balancer.SubConn),
scStates: make(map[balancer.SubConn]connectivity.State),
picker: &errPicker{err: balancer.ErrNoSubConnAvailable},
clientStats: &rpcStats{},
}
return lb
} }
func (b *grpclbBalancer) watchAddrUpdates(w naming.Watcher, ch chan []remoteBalancerInfo) error { type lbBalancer struct {
updates, err := w.Next() cc balancer.ClientConn
if err != nil { target string
grpclog.Warningf("grpclb: failed to get next addr update from watcher: %v", err) opt balancer.BuildOptions
return err fallbackTimeout time.Duration
} doneCh chan struct{}
b.mu.Lock()
defer b.mu.Unlock() // manualResolver is used in the remote LB ClientConn inside grpclb. When
if b.done { // resolved address updates are received by grpclb, filtered updates will be
return ErrClientConnClosing // send to remote LB ClientConn through this resolver.
} manualResolver *manual.Resolver
for _, update := range updates { // The ClientConn to talk to the remote balancer.
switch update.Op { ccRemoteLB *ClientConn
case naming.Add:
var exist bool // Support client side load reporting. Each picker gets a reference to this,
for _, v := range b.rbs { // and will update its content.
// TODO: Is the same addr with different server name a different balancer? clientStats *rpcStats
if update.Addr == v.addr {
exist = true mu sync.Mutex // guards everything following.
break // The full server list including drops, used to check if the newly received
} // serverList contains anything new. Each generate picker will also have
} // reference to this list to do the first layer pick.
if exist { fullServerList []*lbpb.Server
continue // All backends addresses, with metadata set to nil. This list contains all
} // backend addresses in the same order and with the same duplicates as in
md, ok := update.Metadata.(*naming.AddrMetadataGRPCLB) // serverlist. When generating picker, a SubConn slice with the same order
if !ok { // but with only READY SCs will be gerenated.
// TODO: Revisit the handling here and may introduce some fallback mechanism. backendAddrs []resolver.Address
grpclog.Errorf("The name resolution contains unexpected metadata %v", update.Metadata) // Roundrobin functionalities.
continue csEvltr *connectivityStateEvaluator
} state connectivity.State
switch md.AddrType { subConns map[resolver.Address]balancer.SubConn // Used to new/remove SubConn.
case naming.Backend: scStates map[balancer.SubConn]connectivity.State // Used to filter READY SubConns.
// TODO: Revisit the handling here and may introduce some fallback mechanism. picker balancer.Picker
grpclog.Errorf("The name resolution does not give grpclb addresses") // Support fallback to resolved backend addresses if there's no response
continue // from remote balancer within fallbackTimeout.
case naming.GRPCLB: fallbackTimerExpired bool
b.rbs = append(b.rbs, remoteBalancerInfo{ serverListReceived bool
addr: update.Addr, // resolvedBackendAddrs is resolvedAddrs minus remote balancers. It's set
name: md.ServerName, // when resolved address updates are received, and read in the goroutine
}) // handling fallback.
default: resolvedBackendAddrs []resolver.Address
grpclog.Errorf("Received unknow address type %d", md.AddrType)
continue
}
case naming.Delete:
for i, v := range b.rbs {
if update.Addr == v.addr {
copy(b.rbs[i:], b.rbs[i+1:])
b.rbs = b.rbs[:len(b.rbs)-1]
break
}
}
default:
grpclog.Errorf("Unknown update.Op %v", update.Op)
}
}
// TODO: Fall back to the basic round-robin load balancing if the resulting address is
// not a load balancer.
select {
case <-ch:
default:
}
ch <- b.rbs
return nil
} }
func convertDuration(d *lbmpb.Duration) time.Duration { // regeneratePicker takes a snapshot of the balancer, and generates a picker from
if d == nil { // it. The picker
return 0 // - always returns ErrTransientFailure if the balancer is in TransientFailure,
} // - does two layer roundrobin pick otherwise.
return time.Duration(d.Seconds)*time.Second + time.Duration(d.Nanos)*time.Nanosecond // Caller must hold lb.mu.
} func (lb *lbBalancer) regeneratePicker() {
if lb.state == connectivity.TransientFailure {
func (b *grpclbBalancer) processServerList(l *lbmpb.ServerList, seq int) { lb.picker = &errPicker{err: balancer.ErrTransientFailure}
if l == nil {
return return
} }
servers := l.GetServers() var readySCs []balancer.SubConn
var ( for _, a := range lb.backendAddrs {
sl []*grpclbAddrInfo if sc, ok := lb.subConns[a]; ok {
addrs []Address if st, ok := lb.scStates[sc]; ok && st == connectivity.Ready {
) readySCs = append(readySCs, sc)
for _, s := range servers { }
md := metadata.Pairs("lb-token", s.LoadBalanceToken)
ip := net.IP(s.IpAddress)
ipStr := ip.String()
if ip.To4() == nil {
// Add square brackets to ipv6 addresses, otherwise net.Dial() and
// net.SplitHostPort() will return too many colons error.
ipStr = fmt.Sprintf("[%s]", ipStr)
} }
addr := Address{
Addr: fmt.Sprintf("%s:%d", ipStr, s.Port),
Metadata: &md,
}
sl = append(sl, &grpclbAddrInfo{
addr: addr,
dropForRateLimiting: s.DropForRateLimiting,
dropForLoadBalancing: s.DropForLoadBalancing,
})
addrs = append(addrs, addr)
} }
b.mu.Lock()
defer b.mu.Unlock() if len(lb.fullServerList) <= 0 {
if b.done || seq < b.seq { if len(readySCs) <= 0 {
lb.picker = &errPicker{err: balancer.ErrNoSubConnAvailable}
return
}
lb.picker = &rrPicker{subConns: readySCs}
return return
} }
if len(sl) > 0 { lb.picker = &lbPicker{
// reset b.next to 0 when replacing the server list. serverList: lb.fullServerList,
b.next = 0 subConns: readySCs,
b.addrs = sl stats: lb.clientStats,
b.addrCh <- addrs
} }
return return
} }
func (b *grpclbBalancer) sendLoadReport(s *balanceLoadClientStream, interval time.Duration, done <-chan struct{}) { func (lb *lbBalancer) HandleSubConnStateChange(sc balancer.SubConn, s connectivity.State) {
ticker := time.NewTicker(interval) grpclog.Infof("lbBalancer: handle SubConn state change: %p, %v", sc, s)
defer ticker.Stop() lb.mu.Lock()
for { defer lb.mu.Unlock()
select {
case <-ticker.C: oldS, ok := lb.scStates[sc]
case <-done: if !ok {
return grpclog.Infof("lbBalancer: got state changes for an unknown SubConn: %p, %v", sc, s)
} return
b.mu.Lock() }
stats := b.clientStats lb.scStates[sc] = s
b.clientStats = lbmpb.ClientStats{} // Clear the stats. switch s {
b.mu.Unlock() case connectivity.Idle:
t := time.Now() sc.Connect()
stats.Timestamp = &lbmpb.Timestamp{ case connectivity.Shutdown:
Seconds: t.Unix(), // When an address was removed by resolver, b called RemoveSubConn but
Nanos: int32(t.Nanosecond()), // kept the sc's state in scStates. Remove state for this sc here.
} delete(lb.scStates, sc)
if err := s.Send(&lbmpb.LoadBalanceRequest{ }
LoadBalanceRequestType: &lbmpb.LoadBalanceRequest_ClientStats{
ClientStats: &stats, oldAggrState := lb.state
}, lb.state = lb.csEvltr.recordTransition(oldS, s)
}); err != nil {
grpclog.Errorf("grpclb: failed to send load report: %v", err) // Regenerate picker when one of the following happens:
// - this sc became ready from not-ready
// - this sc became not-ready from ready
// - the aggregated state of balancer became TransientFailure from non-TransientFailure
// - the aggregated state of balancer became non-TransientFailure from TransientFailure
if (oldS == connectivity.Ready) != (s == connectivity.Ready) ||
(lb.state == connectivity.TransientFailure) != (oldAggrState == connectivity.TransientFailure) {
lb.regeneratePicker()
}
lb.cc.UpdateBalancerState(lb.state, lb.picker)
return
}
// fallbackToBackendsAfter blocks for fallbackTimeout and falls back to use
// resolved backends (backends received from resolver, not from remote balancer)
// if no connection to remote balancers was successful.
func (lb *lbBalancer) fallbackToBackendsAfter(fallbackTimeout time.Duration) {
timer := time.NewTimer(fallbackTimeout)
defer timer.Stop()
select {
case <-timer.C:
case <-lb.doneCh:
return
}
lb.mu.Lock()
if lb.serverListReceived {
lb.mu.Unlock()
return
}
lb.fallbackTimerExpired = true
lb.refreshSubConns(lb.resolvedBackendAddrs)
lb.mu.Unlock()
}
// HandleResolvedAddrs sends the updated remoteLB addresses to remoteLB
// clientConn. The remoteLB clientConn will handle creating/removing remoteLB
// connections.
func (lb *lbBalancer) HandleResolvedAddrs(addrs []resolver.Address, err error) {
grpclog.Infof("lbBalancer: handleResolvedResult: %+v", addrs)
if len(addrs) <= 0 {
return
}
var remoteBalancerAddrs, backendAddrs []resolver.Address
for _, a := range addrs {
if a.Type == resolver.GRPCLB {
remoteBalancerAddrs = append(remoteBalancerAddrs, a)
} else {
backendAddrs = append(backendAddrs, a)
}
}
if lb.ccRemoteLB == nil {
if len(remoteBalancerAddrs) <= 0 {
grpclog.Errorf("grpclb: no remote balancer address is available, should never happen")
return return
} }
// First time receiving resolved addresses, create a cc to remote
// balancers.
lb.dialRemoteLB(remoteBalancerAddrs[0].ServerName)
// Start the fallback goroutine.
go lb.fallbackToBackendsAfter(lb.fallbackTimeout)
} }
// cc to remote balancers uses lb.manualResolver. Send the updated remote
// balancer addresses to it through manualResolver.
lb.manualResolver.NewAddress(remoteBalancerAddrs)
lb.mu.Lock()
lb.resolvedBackendAddrs = backendAddrs
// If serverListReceived is true, connection to remote balancer was
// successful and there's no need to do fallback anymore.
// If fallbackTimerExpired is false, fallback hasn't happened yet.
if !lb.serverListReceived && lb.fallbackTimerExpired {
// This means we received a new list of resolved backends, and we are
// still in fallback mode. Need to update the list of backends we are
// using to the new list of backends.
lb.refreshSubConns(lb.resolvedBackendAddrs)
}
lb.mu.Unlock()
} }
func (b *grpclbBalancer) callRemoteBalancer(lbc *loadBalancerClient, seq int) (retry bool) { func (lb *lbBalancer) Close() {
ctx, cancel := context.WithCancel(context.Background()) select {
defer cancel() case <-lb.doneCh:
stream, err := lbc.BalanceLoad(ctx)
if err != nil {
grpclog.Errorf("grpclb: failed to perform RPC to the remote balancer %v", err)
return return
default:
} }
b.mu.Lock() close(lb.doneCh)
if b.done { if lb.ccRemoteLB != nil {
b.mu.Unlock() lb.ccRemoteLB.Close()
return
}
b.mu.Unlock()
initReq := &lbmpb.LoadBalanceRequest{
LoadBalanceRequestType: &lbmpb.LoadBalanceRequest_InitialRequest{
InitialRequest: &lbmpb.InitialLoadBalanceRequest{
Name: b.target,
},
},
}
if err := stream.Send(initReq); err != nil {
grpclog.Errorf("grpclb: failed to send init request: %v", err)
// TODO: backoff on retry?
return true
}
reply, err := stream.Recv()
if err != nil {
grpclog.Errorf("grpclb: failed to recv init response: %v", err)
// TODO: backoff on retry?
return true
}
initResp := reply.GetInitialResponse()
if initResp == nil {
grpclog.Errorf("grpclb: reply from remote balancer did not include initial response.")
return
}
// TODO: Support delegation.
if initResp.LoadBalancerDelegate != "" {
// delegation
grpclog.Errorf("TODO: Delegation is not supported yet.")
return
}
streamDone := make(chan struct{})
defer close(streamDone)
b.mu.Lock()
b.clientStats = lbmpb.ClientStats{} // Clear client stats.
b.mu.Unlock()
if d := convertDuration(initResp.ClientStatsReportInterval); d > 0 {
go b.sendLoadReport(stream, d, streamDone)
}
// Retrieve the server list.
for {
reply, err := stream.Recv()
if err != nil {
grpclog.Errorf("grpclb: failed to recv server list: %v", err)
break
}
b.mu.Lock()
if b.done || seq < b.seq {
b.mu.Unlock()
return
}
b.seq++ // tick when receiving a new list of servers.
seq = b.seq
b.mu.Unlock()
if serverList := reply.GetServerList(); serverList != nil {
b.processServerList(serverList, seq)
}
}
return true
}
func (b *grpclbBalancer) Start(target string, config BalancerConfig) error {
b.rand = rand.New(rand.NewSource(time.Now().Unix()))
// TODO: Fall back to the basic direct connection if there is no name resolver.
if b.r == nil {
return errors.New("there is no name resolver installed")
}
b.target = target
b.mu.Lock()
if b.done {
b.mu.Unlock()
return ErrClientConnClosing
}
b.addrCh = make(chan []Address)
w, err := b.r.Resolve(target)
if err != nil {
b.mu.Unlock()
grpclog.Errorf("grpclb: failed to resolve address: %v, err: %v", target, err)
return err
}
b.w = w
b.mu.Unlock()
balancerAddrsCh := make(chan []remoteBalancerInfo, 1)
// Spawn a goroutine to monitor the name resolution of remote load balancer.
go func() {
for {
if err := b.watchAddrUpdates(w, balancerAddrsCh); err != nil {
grpclog.Warningf("grpclb: the naming watcher stops working due to %v.\n", err)
close(balancerAddrsCh)
return
}
}
}()
// Spawn a goroutine to talk to the remote load balancer.
go func() {
var (
cc *ClientConn
// ccError is closed when there is an error in the current cc.
// A new rb should be picked from rbs and connected.
ccError chan struct{}
rb *remoteBalancerInfo
rbs []remoteBalancerInfo
rbIdx int
)
defer func() {
if ccError != nil {
select {
case <-ccError:
default:
close(ccError)
}
}
if cc != nil {
cc.Close()
}
}()
for {
var ok bool
select {
case rbs, ok = <-balancerAddrsCh:
if !ok {
return
}
foundIdx := -1
if rb != nil {
for i, trb := range rbs {
if trb == *rb {
foundIdx = i
break
}
}
}
if foundIdx >= 0 {
if foundIdx >= 1 {
// Move the address in use to the beginning of the list.
b.rbs[0], b.rbs[foundIdx] = b.rbs[foundIdx], b.rbs[0]
rbIdx = 0
}
continue // If found, don't dial new cc.
} else if len(rbs) > 0 {
// Pick a random one from the list, instead of always using the first one.
if l := len(rbs); l > 1 && rb != nil {
tmpIdx := b.rand.Intn(l - 1)
b.rbs[0], b.rbs[tmpIdx] = b.rbs[tmpIdx], b.rbs[0]
}
rbIdx = 0
rb = &rbs[0]
} else {
// foundIdx < 0 && len(rbs) <= 0.
rb = nil
}
case <-ccError:
ccError = nil
if rbIdx < len(rbs)-1 {
rbIdx++
rb = &rbs[rbIdx]
} else {
rb = nil
}
}
if rb == nil {
continue
}
if cc != nil {
cc.Close()
}
// Talk to the remote load balancer to get the server list.
var (
err error
dopts []DialOption
)
if creds := config.DialCreds; creds != nil {
if rb.name != "" {
if err := creds.OverrideServerName(rb.name); err != nil {
grpclog.Warningf("grpclb: failed to override the server name in the credentials: %v", err)
continue
}
}
dopts = append(dopts, WithTransportCredentials(creds))
} else {
dopts = append(dopts, WithInsecure())
}
if dialer := config.Dialer; dialer != nil {
// WithDialer takes a different type of function, so we instead use a special DialOption here.
dopts = append(dopts, func(o *dialOptions) { o.copts.Dialer = dialer })
}
dopts = append(dopts, WithBlock())
ccError = make(chan struct{})
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
cc, err = DialContext(ctx, rb.addr, dopts...)
cancel()
if err != nil {
grpclog.Warningf("grpclb: failed to setup a connection to the remote balancer %v: %v", rb.addr, err)
close(ccError)
continue
}
b.mu.Lock()
b.seq++ // tick when getting a new balancer address
seq := b.seq
b.next = 0
b.mu.Unlock()
go func(cc *ClientConn, ccError chan struct{}) {
lbc := &loadBalancerClient{cc}
b.callRemoteBalancer(lbc, seq)
cc.Close()
select {
case <-ccError:
default:
close(ccError)
}
}(cc, ccError)
}
}()
return nil
}
func (b *grpclbBalancer) down(addr Address, err error) {
b.mu.Lock()
defer b.mu.Unlock()
for _, a := range b.addrs {
if addr == a.addr {
a.connected = false
break
}
} }
} }
func (b *grpclbBalancer) Up(addr Address) func(error) {
b.mu.Lock()
defer b.mu.Unlock()
if b.done {
return nil
}
var cnt int
for _, a := range b.addrs {
if a.addr == addr {
if a.connected {
return nil
}
a.connected = true
}
if a.connected && !a.dropForRateLimiting && !a.dropForLoadBalancing {
cnt++
}
}
// addr is the only one which is connected. Notify the Get() callers who are blocking.
if cnt == 1 && b.waitCh != nil {
close(b.waitCh)
b.waitCh = nil
}
return func(err error) {
b.down(addr, err)
}
}
func (b *grpclbBalancer) Get(ctx context.Context, opts BalancerGetOptions) (addr Address, put func(), err error) {
var ch chan struct{}
b.mu.Lock()
if b.done {
b.mu.Unlock()
err = ErrClientConnClosing
return
}
seq := b.seq
defer func() {
if err != nil {
return
}
put = func() {
s, ok := rpcInfoFromContext(ctx)
if !ok {
return
}
b.mu.Lock()
defer b.mu.Unlock()
if b.done || seq < b.seq {
return
}
b.clientStats.NumCallsFinished++
if !s.bytesSent {
b.clientStats.NumCallsFinishedWithClientFailedToSend++
} else if s.bytesReceived {
b.clientStats.NumCallsFinishedKnownReceived++
}
}
}()
b.clientStats.NumCallsStarted++
if len(b.addrs) > 0 {
if b.next >= len(b.addrs) {
b.next = 0
}
next := b.next
for {
a := b.addrs[next]
next = (next + 1) % len(b.addrs)
if a.connected {
if !a.dropForRateLimiting && !a.dropForLoadBalancing {
addr = a.addr
b.next = next
b.mu.Unlock()
return
}
if !opts.BlockingWait {
b.next = next
if a.dropForLoadBalancing {
b.clientStats.NumCallsFinished++
b.clientStats.NumCallsFinishedWithDropForLoadBalancing++
} else if a.dropForRateLimiting {
b.clientStats.NumCallsFinished++
b.clientStats.NumCallsFinishedWithDropForRateLimiting++
}
b.mu.Unlock()
err = Errorf(codes.Unavailable, "%s drops requests", a.addr.Addr)
return
}
}
if next == b.next {
// Has iterated all the possible address but none is connected.
break
}
}
}
if !opts.BlockingWait {
b.clientStats.NumCallsFinished++
b.clientStats.NumCallsFinishedWithClientFailedToSend++
b.mu.Unlock()
err = Errorf(codes.Unavailable, "there is no address available")
return
}
// Wait on b.waitCh for non-failfast RPCs.
if b.waitCh == nil {
ch = make(chan struct{})
b.waitCh = ch
} else {
ch = b.waitCh
}
b.mu.Unlock()
for {
select {
case <-ctx.Done():
b.mu.Lock()
b.clientStats.NumCallsFinished++
b.clientStats.NumCallsFinishedWithClientFailedToSend++
b.mu.Unlock()
err = ctx.Err()
return
case <-ch:
b.mu.Lock()
if b.done {
b.clientStats.NumCallsFinished++
b.clientStats.NumCallsFinishedWithClientFailedToSend++
b.mu.Unlock()
err = ErrClientConnClosing
return
}
if len(b.addrs) > 0 {
if b.next >= len(b.addrs) {
b.next = 0
}
next := b.next
for {
a := b.addrs[next]
next = (next + 1) % len(b.addrs)
if a.connected {
if !a.dropForRateLimiting && !a.dropForLoadBalancing {
addr = a.addr
b.next = next
b.mu.Unlock()
return
}
if !opts.BlockingWait {
b.next = next
if a.dropForLoadBalancing {
b.clientStats.NumCallsFinished++
b.clientStats.NumCallsFinishedWithDropForLoadBalancing++
} else if a.dropForRateLimiting {
b.clientStats.NumCallsFinished++
b.clientStats.NumCallsFinishedWithDropForRateLimiting++
}
b.mu.Unlock()
err = Errorf(codes.Unavailable, "drop requests for the addreess %s", a.addr.Addr)
return
}
}
if next == b.next {
// Has iterated all the possible address but none is connected.
break
}
}
}
// The newly added addr got removed by Down() again.
if b.waitCh == nil {
ch = make(chan struct{})
b.waitCh = ch
} else {
ch = b.waitCh
}
b.mu.Unlock()
}
}
}
func (b *grpclbBalancer) Notify() <-chan []Address {
return b.addrCh
}
func (b *grpclbBalancer) Close() error {
b.mu.Lock()
defer b.mu.Unlock()
if b.done {
return errBalancerClosed
}
b.done = true
if b.waitCh != nil {
close(b.waitCh)
}
if b.addrCh != nil {
close(b.addrCh)
}
if b.w != nil {
b.w.Close()
}
return nil
}

View File

@ -27,12 +27,14 @@ import (
"fmt" "fmt"
"io" "io"
"net" "net"
"strconv"
"strings" "strings"
"sync" "sync"
"testing" "testing"
"time" "time"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"golang.org/x/net/context" "golang.org/x/net/context"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
@ -41,16 +43,20 @@ import (
lbspb "google.golang.org/grpc/grpclb/grpc_lb_v1/service" lbspb "google.golang.org/grpc/grpclb/grpc_lb_v1/service"
_ "google.golang.org/grpc/grpclog/glogger" _ "google.golang.org/grpc/grpclog/glogger"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
"google.golang.org/grpc/naming" "google.golang.org/grpc/peer"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/resolver/manual"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
testpb "google.golang.org/grpc/test/grpc_testing" testpb "google.golang.org/grpc/test/grpc_testing"
"google.golang.org/grpc/test/leakcheck" "google.golang.org/grpc/test/leakcheck"
_ "google.golang.org/grpc/grpclog/glogger"
) )
var ( var (
lbsn = "bar.com" lbServerName = "bar.com"
besn = "foo.com" beServerName = "foo.com"
lbToken = "iamatoken" lbToken = "iamatoken"
// Resolver replaces localhost with fakeName in Next(). // Resolver replaces localhost with fakeName in Next().
// Dialer replaces fakeName with localhost when dialing. // Dialer replaces fakeName with localhost when dialing.
@ -58,83 +64,6 @@ var (
fakeName = "fake.Name" fakeName = "fake.Name"
) )
type testWatcher struct {
// the channel to receives name resolution updates
update chan *naming.Update
// the side channel to get to know how many updates in a batch
side chan int
// the channel to notifiy update injector that the update reading is done
readDone chan int
}
func (w *testWatcher) Next() (updates []*naming.Update, err error) {
n, ok := <-w.side
if !ok {
return nil, fmt.Errorf("w.side is closed")
}
for i := 0; i < n; i++ {
u, ok := <-w.update
if !ok {
break
}
if u != nil {
// Resolver replaces localhost with fakeName in Next().
// Custom dialer will replace fakeName with localhost when dialing.
u.Addr = strings.Replace(u.Addr, "localhost", fakeName, 1)
updates = append(updates, u)
}
}
w.readDone <- 0
return
}
func (w *testWatcher) Close() {
close(w.side)
}
// Inject naming resolution updates to the testWatcher.
func (w *testWatcher) inject(updates []*naming.Update) {
w.side <- len(updates)
for _, u := range updates {
w.update <- u
}
<-w.readDone
}
type testNameResolver struct {
w *testWatcher
addrs []string
}
func (r *testNameResolver) Resolve(target string) (naming.Watcher, error) {
r.w = &testWatcher{
update: make(chan *naming.Update, len(r.addrs)),
side: make(chan int, 1),
readDone: make(chan int),
}
r.w.side <- len(r.addrs)
for _, addr := range r.addrs {
r.w.update <- &naming.Update{
Op: naming.Add,
Addr: addr,
Metadata: &naming.AddrMetadataGRPCLB{
AddrType: naming.GRPCLB,
ServerName: lbsn,
},
}
}
go func() {
<-r.w.readDone
}()
return r.w, nil
}
func (r *testNameResolver) inject(updates []*naming.Update) {
if r.w != nil {
r.w.inject(updates)
}
}
type serverNameCheckCreds struct { type serverNameCheckCreds struct {
mu sync.Mutex mu sync.Mutex
sn string sn string
@ -199,23 +128,22 @@ func fakeNameDialer(addr string, timeout time.Duration) (net.Conn, error) {
} }
type remoteBalancer struct { type remoteBalancer struct {
sls []*lbmpb.ServerList sls chan *lbmpb.ServerList
intervals []time.Duration
statsDura time.Duration statsDura time.Duration
done chan struct{} done chan struct{}
mu sync.Mutex mu sync.Mutex
stats lbmpb.ClientStats stats lbmpb.ClientStats
} }
func newRemoteBalancer(sls []*lbmpb.ServerList, intervals []time.Duration) *remoteBalancer { func newRemoteBalancer(intervals []time.Duration) *remoteBalancer {
return &remoteBalancer{ return &remoteBalancer{
sls: sls, sls: make(chan *lbmpb.ServerList, 1),
intervals: intervals, done: make(chan struct{}),
done: make(chan struct{}),
} }
} }
func (b *remoteBalancer) stop() { func (b *remoteBalancer) stop() {
close(b.sls)
close(b.done) close(b.done)
} }
@ -225,7 +153,7 @@ func (b *remoteBalancer) BalanceLoad(stream lbspb.LoadBalancer_BalanceLoadServer
return err return err
} }
initReq := req.GetInitialRequest() initReq := req.GetInitialRequest()
if initReq.Name != besn { if initReq.Name != beServerName {
return status.Errorf(codes.InvalidArgument, "invalid service name: %v", initReq.Name) return status.Errorf(codes.InvalidArgument, "invalid service name: %v", initReq.Name)
} }
resp := &lbmpb.LoadBalanceResponse{ resp := &lbmpb.LoadBalanceResponse{
@ -260,8 +188,7 @@ func (b *remoteBalancer) BalanceLoad(stream lbspb.LoadBalancer_BalanceLoadServer
b.mu.Unlock() b.mu.Unlock()
} }
}() }()
for k, v := range b.sls { for v := range b.sls {
time.Sleep(b.intervals[k])
resp = &lbmpb.LoadBalanceResponse{ resp = &lbmpb.LoadBalanceResponse{
LoadBalanceResponseType: &lbmpb.LoadBalanceResponse_ServerList{ LoadBalanceResponseType: &lbmpb.LoadBalanceResponse_ServerList{
ServerList: v, ServerList: v,
@ -278,7 +205,8 @@ func (b *remoteBalancer) BalanceLoad(stream lbspb.LoadBalancer_BalanceLoadServer
type testServer struct { type testServer struct {
testpb.TestServiceServer testpb.TestServiceServer
addr string addr string
fallback bool
} }
const testmdkey = "testmd" const testmdkey = "testmd"
@ -288,7 +216,7 @@ func (s *testServer) EmptyCall(ctx context.Context, in *testpb.Empty) (*testpb.E
if !ok { if !ok {
return nil, status.Error(codes.Internal, "failed to receive metadata") return nil, status.Error(codes.Internal, "failed to receive metadata")
} }
if md == nil || md["lb-token"][0] != lbToken { if !s.fallback && (md == nil || md["lb-token"][0] != lbToken) {
return nil, status.Errorf(codes.Internal, "received unexpected metadata: %v", md) return nil, status.Errorf(codes.Internal, "received unexpected metadata: %v", md)
} }
grpc.SetTrailer(ctx, metadata.Pairs(testmdkey, s.addr)) grpc.SetTrailer(ctx, metadata.Pairs(testmdkey, s.addr))
@ -299,13 +227,13 @@ func (s *testServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServ
return nil return nil
} }
func startBackends(sn string, lis ...net.Listener) (servers []*grpc.Server) { func startBackends(sn string, fallback bool, lis ...net.Listener) (servers []*grpc.Server) {
for _, l := range lis { for _, l := range lis {
creds := &serverNameCheckCreds{ creds := &serverNameCheckCreds{
sn: sn, sn: sn,
} }
s := grpc.NewServer(grpc.Creds(creds)) s := grpc.NewServer(grpc.Creds(creds))
testpb.RegisterTestServiceServer(s, &testServer{addr: l.Addr().String()}) testpb.RegisterTestServiceServer(s, &testServer{addr: l.Addr().String(), fallback: fallback})
servers = append(servers, s) servers = append(servers, s)
go func(s *grpc.Server, l net.Listener) { go func(s *grpc.Server, l net.Listener) {
s.Serve(l) s.Serve(l)
@ -348,7 +276,7 @@ func newLoadBalancer(numberOfBackends int) (tss *testServers, cleanup func(), er
beListeners = append(beListeners, beLis) beListeners = append(beListeners, beLis)
} }
backends := startBackends(besn, beListeners...) backends := startBackends(beServerName, false, beListeners...)
// Start a load balancer. // Start a load balancer.
lbLis, err := net.Listen("tcp", "localhost:0") lbLis, err := net.Listen("tcp", "localhost:0")
@ -357,21 +285,21 @@ func newLoadBalancer(numberOfBackends int) (tss *testServers, cleanup func(), er
return return
} }
lbCreds := &serverNameCheckCreds{ lbCreds := &serverNameCheckCreds{
sn: lbsn, sn: lbServerName,
} }
lb = grpc.NewServer(grpc.Creds(lbCreds)) lb = grpc.NewServer(grpc.Creds(lbCreds))
if err != nil { if err != nil {
err = fmt.Errorf("Failed to generate the port number %v", err) err = fmt.Errorf("Failed to generate the port number %v", err)
return return
} }
ls = newRemoteBalancer(nil, nil) ls = newRemoteBalancer(nil)
lbspb.RegisterLoadBalancerServer(lb, ls) lbspb.RegisterLoadBalancerServer(lb, ls)
go func() { go func() {
lb.Serve(lbLis) lb.Serve(lbLis)
}() }()
tss = &testServers{ tss = &testServers{
lbAddr: lbLis.Addr().String(), lbAddr: fakeName + ":" + strconv.Itoa(lbLis.Addr().(*net.TCPAddr).Port),
ls: ls, ls: ls,
lb: lb, lb: lb,
beIPs: beIPs, beIPs: beIPs,
@ -389,6 +317,10 @@ func newLoadBalancer(numberOfBackends int) (tss *testServers, cleanup func(), er
func TestGRPCLB(t *testing.T) { func TestGRPCLB(t *testing.T) {
defer leakcheck.Check(t) defer leakcheck.Check(t)
r, cleanup := manual.GenerateAndRegisterManualResolver()
defer cleanup()
tss, cleanup, err := newLoadBalancer(1) tss, cleanup, err := newLoadBalancer(1)
if err != nil { if err != nil {
t.Fatalf("failed to create new load balancer: %v", err) t.Fatalf("failed to create new load balancer: %v", err)
@ -405,136 +337,178 @@ func TestGRPCLB(t *testing.T) {
sl := &lbmpb.ServerList{ sl := &lbmpb.ServerList{
Servers: bes, Servers: bes,
} }
tss.ls.sls = []*lbmpb.ServerList{sl} tss.ls.sls <- sl
tss.ls.intervals = []time.Duration{0}
creds := serverNameCheckCreds{ creds := serverNameCheckCreds{
expected: besn, expected: beServerName,
} }
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel() defer cancel()
cc, err := grpc.DialContext(ctx, besn, cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName,
grpc.WithBalancer(grpc.NewGRPCLBBalancer(&testNameResolver{addrs: []string{tss.lbAddr}})), grpc.WithBalancerBuilder(grpc.NewLBBuilder()),
grpc.WithBlock(), grpc.WithTransportCredentials(&creds), grpc.WithDialer(fakeNameDialer)) grpc.WithTransportCredentials(&creds), grpc.WithDialer(fakeNameDialer))
if err != nil { if err != nil {
t.Fatalf("Failed to dial to the backend %v", err) t.Fatalf("Failed to dial to the backend %v", err)
} }
defer cc.Close() defer cc.Close()
testC := testpb.NewTestServiceClient(cc) testC := testpb.NewTestServiceClient(cc)
r.NewAddress([]resolver.Address{{
Addr: tss.lbAddr,
Type: resolver.GRPCLB,
ServerName: lbServerName,
}})
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil { if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil {
t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err) t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
} }
} }
func TestDropRequest(t *testing.T) { // The remote balancer sends response with duplicates to grpclb client.
func TestGRPCLBWeighted(t *testing.T) {
defer leakcheck.Check(t) defer leakcheck.Check(t)
r, cleanup := manual.GenerateAndRegisterManualResolver()
defer cleanup()
tss, cleanup, err := newLoadBalancer(2) tss, cleanup, err := newLoadBalancer(2)
if err != nil { if err != nil {
t.Fatalf("failed to create new load balancer: %v", err) t.Fatalf("failed to create new load balancer: %v", err)
} }
defer cleanup() defer cleanup()
tss.ls.sls = []*lbmpb.ServerList{{
Servers: []*lbmpb.Server{{ beServers := []*lbmpb.Server{{
IpAddress: tss.beIPs[0], IpAddress: tss.beIPs[0],
Port: int32(tss.bePorts[0]), Port: int32(tss.bePorts[0]),
LoadBalanceToken: lbToken, LoadBalanceToken: lbToken,
DropForLoadBalancing: true, }, {
}, { IpAddress: tss.beIPs[1],
IpAddress: tss.beIPs[1], Port: int32(tss.bePorts[1]),
Port: int32(tss.bePorts[1]), LoadBalanceToken: lbToken,
LoadBalanceToken: lbToken,
DropForLoadBalancing: false,
}},
}} }}
tss.ls.intervals = []time.Duration{0} portsToIndex := make(map[int]int)
for i := range beServers {
portsToIndex[tss.bePorts[i]] = i
}
creds := serverNameCheckCreds{ creds := serverNameCheckCreds{
expected: besn, expected: beServerName,
} }
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel() defer cancel()
cc, err := grpc.DialContext(ctx, besn, cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName,
grpc.WithBalancer(grpc.NewGRPCLBBalancer(&testNameResolver{addrs: []string{tss.lbAddr}})), grpc.WithBalancerBuilder(grpc.NewLBBuilder()),
grpc.WithBlock(), grpc.WithTransportCredentials(&creds), grpc.WithDialer(fakeNameDialer)) grpc.WithTransportCredentials(&creds), grpc.WithDialer(fakeNameDialer))
if err != nil { if err != nil {
t.Fatalf("Failed to dial to the backend %v", err) t.Fatalf("Failed to dial to the backend %v", err)
} }
defer cc.Close() defer cc.Close()
testC := testpb.NewTestServiceClient(cc) testC := testpb.NewTestServiceClient(cc)
// Wait until the first connection is up.
// The first one has Drop set to true, error should contain "drop requests". r.NewAddress([]resolver.Address{{
for { Addr: tss.lbAddr,
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil { Type: resolver.GRPCLB,
if strings.Contains(err.Error(), "drops requests") { ServerName: lbServerName,
break }})
sequences := []string{"00101", "00011"}
for _, seq := range sequences {
var (
bes []*lbmpb.Server
p peer.Peer
result string
)
for _, s := range seq {
bes = append(bes, beServers[s-'0'])
}
tss.ls.sls <- &lbmpb.ServerList{Servers: bes}
for i := 0; i < 1000; i++ {
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false), grpc.Peer(&p)); err != nil {
t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
} }
result += strconv.Itoa(portsToIndex[p.Addr.(*net.TCPAddr).Port])
} }
} // The generated result will be in format of "0010100101".
// The 1st, non-fail-fast RPC should succeed. This ensures both server if !strings.Contains(result, strings.Repeat(seq, 2)) {
// connections are made, because the first one has DropForLoadBalancing set to true. t.Errorf("got result sequence %q, want patten %q", result, seq)
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil {
t.Fatalf("%v.SayHello(_, _) = _, %v, want _, <nil>", testC, err)
}
for i := 0; i < 3; i++ {
// Odd fail-fast RPCs should fail, because the 1st backend has DropForLoadBalancing
// set to true.
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); grpc.Code(err) != codes.Unavailable {
t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, %s", testC, err, codes.Unavailable)
}
// Even fail-fast RPCs should succeed since they choose the
// non-drop-request backend according to the round robin policy.
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
} }
} }
} }
func TestDropRequestFailedNonFailFast(t *testing.T) { func TestDropRequest(t *testing.T) {
defer leakcheck.Check(t) defer leakcheck.Check(t)
r, cleanup := manual.GenerateAndRegisterManualResolver()
defer cleanup()
tss, cleanup, err := newLoadBalancer(1) tss, cleanup, err := newLoadBalancer(1)
if err != nil { if err != nil {
t.Fatalf("failed to create new load balancer: %v", err) t.Fatalf("failed to create new load balancer: %v", err)
} }
defer cleanup() defer cleanup()
be := &lbmpb.Server{ tss.ls.sls <- &lbmpb.ServerList{
IpAddress: tss.beIPs[0], Servers: []*lbmpb.Server{{
Port: int32(tss.bePorts[0]), IpAddress: tss.beIPs[0],
LoadBalanceToken: lbToken, Port: int32(tss.bePorts[0]),
DropForLoadBalancing: true, LoadBalanceToken: lbToken,
DropForLoadBalancing: false,
}, {
DropForLoadBalancing: true,
}},
} }
var bes []*lbmpb.Server
bes = append(bes, be)
sl := &lbmpb.ServerList{
Servers: bes,
}
tss.ls.sls = []*lbmpb.ServerList{sl}
tss.ls.intervals = []time.Duration{0}
creds := serverNameCheckCreds{ creds := serverNameCheckCreds{
expected: besn, expected: beServerName,
} }
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel() defer cancel()
cc, err := grpc.DialContext(ctx, besn, cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName,
grpc.WithBalancer(grpc.NewGRPCLBBalancer(&testNameResolver{addrs: []string{tss.lbAddr}})), grpc.WithBalancerBuilder(grpc.NewLBBuilder()),
grpc.WithBlock(), grpc.WithTransportCredentials(&creds), grpc.WithDialer(fakeNameDialer)) grpc.WithTransportCredentials(&creds), grpc.WithDialer(fakeNameDialer))
if err != nil { if err != nil {
t.Fatalf("Failed to dial to the backend %v", err) t.Fatalf("Failed to dial to the backend %v", err)
} }
defer cc.Close() defer cc.Close()
testC := testpb.NewTestServiceClient(cc) testC := testpb.NewTestServiceClient(cc)
ctx, cancel = context.WithTimeout(context.Background(), 10*time.Millisecond)
defer cancel() r.NewAddress([]resolver.Address{{
if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); grpc.Code(err) != codes.DeadlineExceeded { Addr: tss.lbAddr,
t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, %s", testC, err, codes.DeadlineExceeded) Type: resolver.GRPCLB,
ServerName: lbServerName,
}})
// The 1st, non-fail-fast RPC should succeed. This ensures both server
// connections are made, because the first one has DropForLoadBalancing set to true.
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil {
t.Fatalf("%v.SayHello(_, _) = _, %v, want _, <nil>", testC, err)
}
for _, failfast := range []bool{true, false} {
for i := 0; i < 3; i++ {
// Even RPCs should fail, because the 2st backend has
// DropForLoadBalancing set to true.
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(failfast)); grpc.Code(err) != codes.Unavailable {
t.Errorf("%v.EmptyCall(_, _) = _, %v, want _, %s", testC, err, codes.Unavailable)
}
// Odd RPCs should succeed since they choose the non-drop-request
// backend according to the round robin policy.
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(failfast)); err != nil {
t.Errorf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
}
}
} }
} }
// When the balancer in use disconnects, grpclb should connect to the next address from resolved balancer address list. // When the balancer in use disconnects, grpclb should connect to the next address from resolved balancer address list.
func TestBalancerDisconnects(t *testing.T) { func TestBalancerDisconnects(t *testing.T) {
defer leakcheck.Check(t) defer leakcheck.Check(t)
r, cleanup := manual.GenerateAndRegisterManualResolver()
defer cleanup()
var ( var (
lbAddrs []string tests []*testServers
lbs []*grpc.Server lbs []*grpc.Server
) )
for i := 0; i < 3; i++ { for i := 0; i < 2; i++ {
tss, cleanup, err := newLoadBalancer(1) tss, cleanup, err := newLoadBalancer(1)
if err != nil { if err != nil {
t.Fatalf("failed to create new load balancer: %v", err) t.Fatalf("failed to create new load balancer: %v", err)
@ -551,78 +525,149 @@ func TestBalancerDisconnects(t *testing.T) {
sl := &lbmpb.ServerList{ sl := &lbmpb.ServerList{
Servers: bes, Servers: bes,
} }
tss.ls.sls = []*lbmpb.ServerList{sl} tss.ls.sls <- sl
tss.ls.intervals = []time.Duration{0}
lbAddrs = append(lbAddrs, tss.lbAddr) tests = append(tests, tss)
lbs = append(lbs, tss.lb) lbs = append(lbs, tss.lb)
} }
creds := serverNameCheckCreds{ creds := serverNameCheckCreds{
expected: besn, expected: beServerName,
} }
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel() defer cancel()
resolver := &testNameResolver{ cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName,
addrs: lbAddrs[:2], grpc.WithBalancerBuilder(grpc.NewLBBuilder()),
} grpc.WithTransportCredentials(&creds), grpc.WithDialer(fakeNameDialer))
cc, err := grpc.DialContext(ctx, besn,
grpc.WithBalancer(grpc.NewGRPCLBBalancer(resolver)),
grpc.WithBlock(), grpc.WithTransportCredentials(&creds), grpc.WithDialer(fakeNameDialer))
if err != nil { if err != nil {
t.Fatalf("Failed to dial to the backend %v", err) t.Fatalf("Failed to dial to the backend %v", err)
} }
defer cc.Close() defer cc.Close()
testC := testpb.NewTestServiceClient(cc) testC := testpb.NewTestServiceClient(cc)
var previousTrailer string
trailer := metadata.MD{} r.NewAddress([]resolver.Address{{
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Trailer(&trailer), grpc.FailFast(false)); err != nil { Addr: tests[0].lbAddr,
Type: resolver.GRPCLB,
ServerName: lbServerName,
}, {
Addr: tests[1].lbAddr,
Type: resolver.GRPCLB,
ServerName: lbServerName,
}})
var p peer.Peer
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false), grpc.Peer(&p)); err != nil {
t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err) t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
} else {
previousTrailer = trailer[testmdkey][0]
} }
// The initial resolver update contains lbs[0] and lbs[1]. if p.Addr.(*net.TCPAddr).Port != tests[0].bePorts[0] {
// When lbs[0] is stopped, lbs[1] should be used. t.Fatalf("got peer: %v, want peer port: %v", p.Addr, tests[0].bePorts[0])
}
lbs[0].Stop() lbs[0].Stop()
for { // Stop balancer[0], balancer[1] should be used by grpclb.
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Trailer(&trailer), grpc.FailFast(false)); err != nil { // Check peer address to see if that happened.
for i := 0; i < 1000; i++ {
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false), grpc.Peer(&p)); err != nil {
t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err) t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
} else if trailer[testmdkey][0] != previousTrailer {
// A new backend server should receive the request.
// The trailer contains the backend address, so the trailer should be different from the previous one.
previousTrailer = trailer[testmdkey][0]
break
} }
time.Sleep(100 * time.Millisecond) if p.Addr.(*net.TCPAddr).Port == tests[1].bePorts[0] {
return
}
time.Sleep(time.Millisecond)
} }
// Inject a update to add lbs[2] to resolved addresses. t.Fatalf("No RPC sent to second backend after 1 second")
resolver.inject([]*naming.Update{ }
{Op: naming.Add,
Addr: lbAddrs[2], func TestFallback(t *testing.T) {
Metadata: &naming.AddrMetadataGRPCLB{ defer leakcheck.Check(t)
AddrType: naming.GRPCLB,
ServerName: lbsn, r, cleanup := manual.GenerateAndRegisterManualResolver()
}, defer cleanup()
},
}) tss, cleanup, err := newLoadBalancer(1)
// Stop lbs[1]. Now lbs[0] and lbs[1] are all stopped. lbs[2] should be used. if err != nil {
lbs[1].Stop() t.Fatalf("failed to create new load balancer: %v", err)
for { }
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Trailer(&trailer), grpc.FailFast(false)); err != nil { defer cleanup()
// Start a standalone backend.
beLis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Failed to listen %v", err)
}
defer beLis.Close()
standaloneBEs := startBackends(beServerName, true, beLis)
defer stopBackends(standaloneBEs)
be := &lbmpb.Server{
IpAddress: tss.beIPs[0],
Port: int32(tss.bePorts[0]),
LoadBalanceToken: lbToken,
}
var bes []*lbmpb.Server
bes = append(bes, be)
sl := &lbmpb.ServerList{
Servers: bes,
}
tss.ls.sls <- sl
creds := serverNameCheckCreds{
expected: beServerName,
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName,
grpc.WithBalancerBuilder(grpc.NewLBBuilderWithFallbackTimeout(100*time.Millisecond)),
grpc.WithTransportCredentials(&creds), grpc.WithDialer(fakeNameDialer))
if err != nil {
t.Fatalf("Failed to dial to the backend %v", err)
}
defer cc.Close()
testC := testpb.NewTestServiceClient(cc)
r.NewAddress([]resolver.Address{{
Addr: "",
Type: resolver.GRPCLB,
ServerName: lbServerName,
}, {
Addr: beLis.Addr().String(),
Type: resolver.Backend,
ServerName: beServerName,
}})
var p peer.Peer
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false), grpc.Peer(&p)); err != nil {
t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
}
if p.Addr.String() != beLis.Addr().String() {
t.Fatalf("got peer: %v, want peer: %v", p.Addr, beLis.Addr())
}
r.NewAddress([]resolver.Address{{
Addr: tss.lbAddr,
Type: resolver.GRPCLB,
ServerName: lbServerName,
}, {
Addr: beLis.Addr().String(),
Type: resolver.Backend,
ServerName: beServerName,
}})
for i := 0; i < 1000; i++ {
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false), grpc.Peer(&p)); err != nil {
t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err) t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
} else if trailer[testmdkey][0] != previousTrailer {
// A new backend server should receive the request.
// The trailer contains the backend address, so the trailer should be different from the previous one.
break
} }
time.Sleep(100 * time.Millisecond) if p.Addr.(*net.TCPAddr).Port == tss.bePorts[0] {
return
}
time.Sleep(time.Millisecond)
} }
t.Fatalf("No RPC sent to backend behind remote balancer after 1 second")
} }
type failPreRPCCred struct{} type failPreRPCCred struct{}
func (failPreRPCCred) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { func (failPreRPCCred) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
if strings.Contains(uri[0], "failtosend") { if strings.Contains(uri[0], failtosendURI) {
return nil, fmt.Errorf("rpc should fail to send") return nil, fmt.Errorf("rpc should fail to send")
} }
return nil, nil return nil, nil
@ -640,35 +685,46 @@ func checkStats(stats *lbmpb.ClientStats, expected *lbmpb.ClientStats) error {
} }
func runAndGetStats(t *testing.T, dropForLoadBalancing, dropForRateLimiting bool, runRPCs func(*grpc.ClientConn)) lbmpb.ClientStats { func runAndGetStats(t *testing.T, dropForLoadBalancing, dropForRateLimiting bool, runRPCs func(*grpc.ClientConn)) lbmpb.ClientStats {
tss, cleanup, err := newLoadBalancer(3) defer leakcheck.Check(t)
r, cleanup := manual.GenerateAndRegisterManualResolver()
defer cleanup()
tss, cleanup, err := newLoadBalancer(1)
if err != nil { if err != nil {
t.Fatalf("failed to create new load balancer: %v", err) t.Fatalf("failed to create new load balancer: %v", err)
} }
defer cleanup() defer cleanup()
tss.ls.sls = []*lbmpb.ServerList{{ tss.ls.sls <- &lbmpb.ServerList{
Servers: []*lbmpb.Server{{ Servers: []*lbmpb.Server{{
IpAddress: tss.beIPs[2], IpAddress: tss.beIPs[0],
Port: int32(tss.bePorts[2]), Port: int32(tss.bePorts[0]),
LoadBalanceToken: lbToken, LoadBalanceToken: lbToken,
DropForLoadBalancing: dropForLoadBalancing, DropForLoadBalancing: dropForLoadBalancing,
DropForRateLimiting: dropForRateLimiting, DropForRateLimiting: dropForRateLimiting,
}}, }},
}} }
tss.ls.intervals = []time.Duration{0}
tss.ls.statsDura = 100 * time.Millisecond tss.ls.statsDura = 100 * time.Millisecond
creds := serverNameCheckCreds{expected: besn} creds := serverNameCheckCreds{expected: beServerName}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel() defer cancel()
cc, err := grpc.DialContext(ctx, besn, cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName,
grpc.WithBalancer(grpc.NewGRPCLBBalancer(&testNameResolver{addrs: []string{tss.lbAddr}})), grpc.WithBalancerBuilder(grpc.NewLBBuilder()),
grpc.WithTransportCredentials(&creds), grpc.WithPerRPCCredentials(failPreRPCCred{}), grpc.WithTransportCredentials(&creds),
grpc.WithBlock(), grpc.WithDialer(fakeNameDialer)) grpc.WithPerRPCCredentials(failPreRPCCred{}),
grpc.WithDialer(fakeNameDialer))
if err != nil { if err != nil {
t.Fatalf("Failed to dial to the backend %v", err) t.Fatalf("Failed to dial to the backend %v", err)
} }
defer cc.Close() defer cc.Close()
r.NewAddress([]resolver.Address{{
Addr: tss.lbAddr,
Type: resolver.GRPCLB,
ServerName: lbServerName,
}})
runRPCs(cc) runRPCs(cc)
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
tss.ls.mu.Lock() tss.ls.mu.Lock()
@ -677,7 +733,11 @@ func runAndGetStats(t *testing.T, dropForLoadBalancing, dropForRateLimiting bool
return stats return stats
} }
const countRPC = 40 const (
countRPC = 40
failtosendURI = "failtosend"
dropErrDesc = "request dropped by grpclb"
)
func TestGRPCLBStatsUnarySuccess(t *testing.T) { func TestGRPCLBStatsUnarySuccess(t *testing.T) {
defer leakcheck.Check(t) defer leakcheck.Check(t)
@ -709,7 +769,7 @@ func TestGRPCLBStatsUnaryDropLoadBalancing(t *testing.T) {
for { for {
c++ c++
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil { if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
if strings.Contains(err.Error(), "drops requests") { if strings.Contains(err.Error(), dropErrDesc) {
break break
} }
} }
@ -737,7 +797,7 @@ func TestGRPCLBStatsUnaryDropRateLimiting(t *testing.T) {
for { for {
c++ c++
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil { if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
if strings.Contains(err.Error(), "drops requests") { if strings.Contains(err.Error(), dropErrDesc) {
break break
} }
} }
@ -766,7 +826,7 @@ func TestGRPCLBStatsUnaryFailedToSend(t *testing.T) {
t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err) t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
} }
for i := 0; i < countRPC-1; i++ { for i := 0; i < countRPC-1; i++ {
grpc.Invoke(context.Background(), "failtosend", &testpb.Empty{}, nil, cc) grpc.Invoke(context.Background(), failtosendURI, &testpb.Empty{}, nil, cc)
} }
}) })
@ -824,7 +884,7 @@ func TestGRPCLBStatsStreamingDropLoadBalancing(t *testing.T) {
for { for {
c++ c++
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil { if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
if strings.Contains(err.Error(), "drops requests") { if strings.Contains(err.Error(), dropErrDesc) {
break break
} }
} }
@ -852,7 +912,7 @@ func TestGRPCLBStatsStreamingDropRateLimiting(t *testing.T) {
for { for {
c++ c++
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil { if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
if strings.Contains(err.Error(), "drops requests") { if strings.Contains(err.Error(), dropErrDesc) {
break break
} }
} }
@ -887,7 +947,7 @@ func TestGRPCLBStatsStreamingFailedToSend(t *testing.T) {
} }
} }
for i := 0; i < countRPC-1; i++ { for i := 0; i < countRPC-1; i++ {
grpc.NewClientStream(context.Background(), &grpc.StreamDesc{}, cc, "failtosend") grpc.NewClientStream(context.Background(), &grpc.StreamDesc{}, cc, failtosendURI)
} }
}) })

159
grpclb_picker.go Normal file
View File

@ -0,0 +1,159 @@
/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package grpc
import (
"sync"
"sync/atomic"
"golang.org/x/net/context"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/codes"
lbpb "google.golang.org/grpc/grpclb/grpc_lb_v1/messages"
"google.golang.org/grpc/status"
)
type rpcStats struct {
NumCallsStarted int64
NumCallsFinished int64
NumCallsFinishedWithDropForRateLimiting int64
NumCallsFinishedWithDropForLoadBalancing int64
NumCallsFinishedWithClientFailedToSend int64
NumCallsFinishedKnownReceived int64
}
// toClientStats converts rpcStats to lbpb.ClientStats, and clears rpcStats.
func (s *rpcStats) toClientStats() *lbpb.ClientStats {
stats := &lbpb.ClientStats{
NumCallsStarted: atomic.SwapInt64(&s.NumCallsStarted, 0),
NumCallsFinished: atomic.SwapInt64(&s.NumCallsFinished, 0),
NumCallsFinishedWithDropForRateLimiting: atomic.SwapInt64(&s.NumCallsFinishedWithDropForRateLimiting, 0),
NumCallsFinishedWithDropForLoadBalancing: atomic.SwapInt64(&s.NumCallsFinishedWithDropForLoadBalancing, 0),
NumCallsFinishedWithClientFailedToSend: atomic.SwapInt64(&s.NumCallsFinishedWithClientFailedToSend, 0),
NumCallsFinishedKnownReceived: atomic.SwapInt64(&s.NumCallsFinishedKnownReceived, 0),
}
return stats
}
func (s *rpcStats) dropForRateLimiting() {
atomic.AddInt64(&s.NumCallsStarted, 1)
atomic.AddInt64(&s.NumCallsFinishedWithDropForRateLimiting, 1)
atomic.AddInt64(&s.NumCallsFinished, 1)
}
func (s *rpcStats) dropForLoadBalancing() {
atomic.AddInt64(&s.NumCallsStarted, 1)
atomic.AddInt64(&s.NumCallsFinishedWithDropForLoadBalancing, 1)
atomic.AddInt64(&s.NumCallsFinished, 1)
}
func (s *rpcStats) failedToSend() {
atomic.AddInt64(&s.NumCallsStarted, 1)
atomic.AddInt64(&s.NumCallsFinishedWithClientFailedToSend, 1)
atomic.AddInt64(&s.NumCallsFinished, 1)
}
func (s *rpcStats) knownReceived() {
atomic.AddInt64(&s.NumCallsStarted, 1)
atomic.AddInt64(&s.NumCallsFinishedKnownReceived, 1)
atomic.AddInt64(&s.NumCallsFinished, 1)
}
type errPicker struct {
// Pick always returns this err.
err error
}
func (p *errPicker) Pick(ctx context.Context, opts balancer.PickOptions) (balancer.SubConn, func(balancer.DoneInfo), error) {
return nil, nil, p.err
}
// rrPicker does roundrobin on subConns. It's typically used when there's no
// response from remote balancer, and grpclb falls back to the resolved
// backends.
//
// It guaranteed that len(subConns) > 0.
type rrPicker struct {
mu sync.Mutex
subConns []balancer.SubConn // The subConns that were READY when taking the snapshot.
subConnsNext int
}
func (p *rrPicker) Pick(ctx context.Context, opts balancer.PickOptions) (balancer.SubConn, func(balancer.DoneInfo), error) {
p.mu.Lock()
defer p.mu.Unlock()
sc := p.subConns[p.subConnsNext]
p.subConnsNext = (p.subConnsNext + 1) % len(p.subConns)
return sc, nil, nil
}
// lbPicker does two layers of picks:
//
// First layer: roundrobin on all servers in serverList, including drops and backends.
// - If it picks a drop, the RPC will fail as being dropped.
// - If it picks a backend, do a second layer pick to pick the real backend.
//
// Second layer: roundrobin on all READY backends.
//
// It's guaranteed that len(serverList) > 0.
type lbPicker struct {
mu sync.Mutex
serverList []*lbpb.Server
serverListNext int
subConns []balancer.SubConn // The subConns that were READY when taking the snapshot.
subConnsNext int
stats *rpcStats
}
func (p *lbPicker) Pick(ctx context.Context, opts balancer.PickOptions) (balancer.SubConn, func(balancer.DoneInfo), error) {
p.mu.Lock()
defer p.mu.Unlock()
// Layer one roundrobin on serverList.
s := p.serverList[p.serverListNext]
p.serverListNext = (p.serverListNext + 1) % len(p.serverList)
// If it's a drop, return an error and fail the RPC.
if s.DropForRateLimiting {
p.stats.dropForRateLimiting()
return nil, nil, status.Errorf(codes.Unavailable, "request dropped by grpclb")
}
if s.DropForLoadBalancing {
p.stats.dropForLoadBalancing()
return nil, nil, status.Errorf(codes.Unavailable, "request dropped by grpclb")
}
// If not a drop but there's no ready subConns.
if len(p.subConns) <= 0 {
return nil, nil, balancer.ErrNoSubConnAvailable
}
// Return the next ready subConn in the list, also collect rpc stats.
sc := p.subConns[p.subConnsNext]
p.subConnsNext = (p.subConnsNext + 1) % len(p.subConns)
done := func(info balancer.DoneInfo) {
if !info.BytesSent {
p.stats.failedToSend()
} else if info.BytesReceived {
p.stats.knownReceived()
}
}
return sc, done, nil
}

254
grpclb_remote_balancer.go Normal file
View File

@ -0,0 +1,254 @@
/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package grpc
import (
"fmt"
"net"
"reflect"
"time"
"golang.org/x/net/context"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/connectivity"
lbpb "google.golang.org/grpc/grpclb/grpc_lb_v1/messages"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/resolver"
)
// processServerList updates balaner's internal state, create/remove SubConns
// and regenerates picker using the received serverList.
func (lb *lbBalancer) processServerList(l *lbpb.ServerList) {
grpclog.Infof("lbBalancer: processing server list: %+v", l)
lb.mu.Lock()
defer lb.mu.Unlock()
// Set serverListReceived to true so fallback will not take effect if it has
// not hit timeout.
lb.serverListReceived = true
// If the new server list == old server list, do nothing.
if reflect.DeepEqual(lb.fullServerList, l.Servers) {
grpclog.Infof("lbBalancer: new serverlist same as the previous one, ignoring")
return
}
lb.fullServerList = l.Servers
var backendAddrs []resolver.Address
for _, s := range l.Servers {
if s.DropForLoadBalancing || s.DropForRateLimiting {
continue
}
md := metadata.Pairs(lbTokeyKey, s.LoadBalanceToken)
ip := net.IP(s.IpAddress)
ipStr := ip.String()
if ip.To4() == nil {
// Add square brackets to ipv6 addresses, otherwise net.Dial() and
// net.SplitHostPort() will return too many colons error.
ipStr = fmt.Sprintf("[%s]", ipStr)
}
addr := resolver.Address{
Addr: fmt.Sprintf("%s:%d", ipStr, s.Port),
Metadata: &md,
}
backendAddrs = append(backendAddrs, addr)
}
// Call refreshSubConns to create/remove SubConns.
backendsUpdated := lb.refreshSubConns(backendAddrs)
// If no backend was updated, no SubConn will be newed/removed. But since
// the full serverList was different, there might be updates in drops or
// pick weights(different number of duplicates). We need to update picker
// with the fulllist.
if !backendsUpdated {
lb.regeneratePicker()
lb.cc.UpdateBalancerState(lb.state, lb.picker)
}
}
// refreshSubConns creates/removes SubConns with backendAddrs. It returns a bool
// indicating whether the backendAddrs are different from the cached
// backendAddrs (whether any SubConn was newed/removed).
// Caller must hold lb.mu.
func (lb *lbBalancer) refreshSubConns(backendAddrs []resolver.Address) bool {
lb.backendAddrs = nil
var backendsUpdated bool
// addrsSet is the set converted from backendAddrs, it's used to quick
// lookup for an address.
addrsSet := make(map[resolver.Address]struct{})
// Create new SubConns.
for _, addr := range backendAddrs {
addrWithoutMD := addr
addrWithoutMD.Metadata = nil
addrsSet[addrWithoutMD] = struct{}{}
lb.backendAddrs = append(lb.backendAddrs, addrWithoutMD)
if _, ok := lb.subConns[addrWithoutMD]; !ok {
backendsUpdated = true
// Use addrWithMD to create the SubConn.
sc, err := lb.cc.NewSubConn([]resolver.Address{addr}, balancer.NewSubConnOptions{})
if err != nil {
grpclog.Warningf("roundrobinBalancer: failed to create new SubConn: %v", err)
continue
}
lb.subConns[addrWithoutMD] = sc // Use the addr without MD as key for the map.
lb.scStates[sc] = connectivity.Idle
sc.Connect()
}
}
for a, sc := range lb.subConns {
// a was removed by resolver.
if _, ok := addrsSet[a]; !ok {
backendsUpdated = true
lb.cc.RemoveSubConn(sc)
delete(lb.subConns, a)
// Keep the state of this sc in b.scStates until sc's state becomes Shutdown.
// The entry will be deleted in HandleSubConnStateChange.
}
}
return backendsUpdated
}
func (lb *lbBalancer) readServerList(s *balanceLoadClientStream) error {
for {
reply, err := s.Recv()
if err != nil {
return fmt.Errorf("grpclb: failed to recv server list: %v", err)
}
if serverList := reply.GetServerList(); serverList != nil {
lb.processServerList(serverList)
}
}
}
func (lb *lbBalancer) sendLoadReport(s *balanceLoadClientStream, interval time.Duration) {
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
case <-s.Context().Done():
return
}
stats := lb.clientStats.toClientStats()
t := time.Now()
stats.Timestamp = &lbpb.Timestamp{
Seconds: t.Unix(),
Nanos: int32(t.Nanosecond()),
}
if err := s.Send(&lbpb.LoadBalanceRequest{
LoadBalanceRequestType: &lbpb.LoadBalanceRequest_ClientStats{
ClientStats: stats,
},
}); err != nil {
return
}
}
}
func (lb *lbBalancer) callRemoteBalancer() error {
lbClient := &loadBalancerClient{cc: lb.ccRemoteLB}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
stream, err := lbClient.BalanceLoad(ctx, FailFast(false))
if err != nil {
return fmt.Errorf("grpclb: failed to perform RPC to the remote balancer %v", err)
}
// grpclb handshake on the stream.
initReq := &lbpb.LoadBalanceRequest{
LoadBalanceRequestType: &lbpb.LoadBalanceRequest_InitialRequest{
InitialRequest: &lbpb.InitialLoadBalanceRequest{
Name: lb.target,
},
},
}
if err := stream.Send(initReq); err != nil {
return fmt.Errorf("grpclb: failed to send init request: %v", err)
}
reply, err := stream.Recv()
if err != nil {
return fmt.Errorf("grpclb: failed to recv init response: %v", err)
}
initResp := reply.GetInitialResponse()
if initResp == nil {
return fmt.Errorf("grpclb: reply from remote balancer did not include initial response")
}
if initResp.LoadBalancerDelegate != "" {
return fmt.Errorf("grpclb: Delegation is not supported")
}
go func() {
if d := convertDuration(initResp.ClientStatsReportInterval); d > 0 {
lb.sendLoadReport(stream, d)
}
}()
return lb.readServerList(stream)
}
func (lb *lbBalancer) watchRemoteBalancer() {
for {
err := lb.callRemoteBalancer()
select {
case <-lb.doneCh:
return
default:
if err != nil {
grpclog.Error(err)
}
}
}
}
func (lb *lbBalancer) dialRemoteLB(remoteLBName string) {
var dopts []DialOption
if creds := lb.opt.DialCreds; creds != nil {
if err := creds.OverrideServerName(remoteLBName); err == nil {
dopts = append(dopts, WithTransportCredentials(creds))
} else {
grpclog.Warningf("grpclb: failed to override the server name in the credentials: %v, using Insecure", err)
dopts = append(dopts, WithInsecure())
}
} else {
dopts = append(dopts, WithInsecure())
}
if lb.opt.Dialer != nil {
// WithDialer takes a different type of function, so we instead use a
// special DialOption here.
dopts = append(dopts, withContextDialer(lb.opt.Dialer))
}
// Explicitly set pickfirst as the balancer.
dopts = append(dopts, WithBalancerBuilder(newPickfirstBuilder()))
dopts = append(dopts, withResolverBuilder(lb.manualResolver))
// Dial using manualResolver.Scheme, which is a random scheme generated
// when init grpclb. The target name is not important.
cc, err := Dial("grpclb:///grpclb.server", dopts...)
if err != nil {
grpclog.Fatalf("failed to dial: %v", err)
}
lb.ccRemoteLB = cc
go lb.watchRemoteBalancer()
}

View File

@ -97,7 +97,7 @@ func (bp *pickerWrapper) pick(ctx context.Context, failfast bool, opts balancer.
p = bp.picker p = bp.picker
bp.mu.Unlock() bp.mu.Unlock()
subConn, put, err := p.Pick(ctx, opts) subConn, done, err := p.Pick(ctx, opts)
if err != nil { if err != nil {
switch err { switch err {
@ -120,7 +120,7 @@ func (bp *pickerWrapper) pick(ctx context.Context, failfast bool, opts balancer.
continue continue
} }
if t, ok := acw.getAddrConn().getReadyTransport(); ok { if t, ok := acw.getAddrConn().getReadyTransport(); ok {
return t, put, nil return t, done, nil
} }
grpclog.Infof("blockingPicker: the picked transport is not ready, loop back to repick") grpclog.Infof("blockingPicker: the picked transport is not ready, loop back to repick")
// If ok == false, ac.state is not READY. // If ok == false, ac.state is not READY.

View File

@ -16,7 +16,8 @@
* *
*/ */
// Package manual contains a resolver for testing purpose only. // Package manual defines a resolver that can be used to manually send resolved
// addresses to ClientConn.
package manual package manual
import ( import (

View File

@ -78,7 +78,9 @@ type Address struct {
// Type is the type of this address. // Type is the type of this address.
Type AddressType Type AddressType
// ServerName is the name of this address. // ServerName is the name of this address.
// It's the name of the grpc load balancer, which will be used for authentication. //
// e.g. if Type is GRPCLB, ServerName should be the name of the remote load
// balancer, not the name of the backend.
ServerName string ServerName string
// Metadata is the information associated with Addr, which may be used // Metadata is the information associated with Addr, which may be used
// to make load balancing decision. // to make load balancing decision.

View File

@ -61,12 +61,18 @@ func parseTarget(target string) (ret resolver.Target) {
// newCCResolverWrapper parses cc.target for scheme and gets the resolver // newCCResolverWrapper parses cc.target for scheme and gets the resolver
// builder for this scheme. It then builds the resolver and starts the // builder for this scheme. It then builds the resolver and starts the
// monitoring goroutine for it. // monitoring goroutine for it.
//
// If withResolverBuilder dial option is set, the specified resolver will be
// used instead.
func newCCResolverWrapper(cc *ClientConn) (*ccResolverWrapper, error) { func newCCResolverWrapper(cc *ClientConn) (*ccResolverWrapper, error) {
grpclog.Infof("dialing to target with scheme: %q", cc.parsedTarget.Scheme) grpclog.Infof("dialing to target with scheme: %q", cc.parsedTarget.Scheme)
rb := resolver.Get(cc.parsedTarget.Scheme) rb := cc.dopts.resolverBuilder
if rb == nil { if rb == nil {
return nil, fmt.Errorf("could not get resolver for scheme: %q", cc.parsedTarget.Scheme) rb = resolver.Get(cc.parsedTarget.Scheme)
if rb == nil {
return nil, fmt.Errorf("could not get resolver for scheme: %q", cc.parsedTarget.Scheme)
}
} }
ccr := &ccResolverWrapper{ ccr := &ccResolverWrapper{

View File

@ -441,9 +441,7 @@ func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{
} }
type rpcInfo struct { type rpcInfo struct {
failfast bool failfast bool
bytesSent bool
bytesReceived bool
} }
type rpcInfoContextKey struct{} type rpcInfoContextKey struct{}
@ -457,14 +455,6 @@ func rpcInfoFromContext(ctx context.Context) (s *rpcInfo, ok bool) {
return return
} }
func updateRPCInfoInContext(ctx context.Context, s rpcInfo) {
if ss, ok := rpcInfoFromContext(ctx); ok {
ss.bytesReceived = s.bytesReceived
ss.bytesSent = s.bytesSent
}
return
}
// Code returns the error code for err if it was produced by the rpc system. // Code returns the error code for err if it was produced by the rpc system.
// Otherwise, it returns codes.Unknown. // Otherwise, it returns codes.Unknown.
// //

View File

@ -232,7 +232,14 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
s, err = t.NewStream(ctx, callHdr) s, err = t.NewStream(ctx, callHdr)
if err != nil { if err != nil {
if done != nil { if done != nil {
done(balancer.DoneInfo{Err: err}) doneInfo := balancer.DoneInfo{Err: err}
if _, ok := err.(transport.ConnectionError); ok {
// If error is connection error, transport was sending data on wire,
// and we are not sure if anything has been sent on wire.
// If error is not connection error, we are sure nothing has been sent.
doneInfo.BytesSent = true
}
done(doneInfo)
done = nil done = nil
} }
// In the event of any error from NewStream, we never attempted to write // In the event of any error from NewStream, we never attempted to write
@ -529,11 +536,11 @@ func (cs *clientStream) finish(err error) {
o.after(cs.c) o.after(cs.c)
} }
if cs.done != nil { if cs.done != nil {
updateRPCInfoInContext(cs.s.Context(), rpcInfo{ cs.done(balancer.DoneInfo{
bytesSent: true, Err: err,
bytesReceived: cs.s.BytesReceived(), BytesSent: true,
BytesReceived: cs.s.BytesReceived(),
}) })
cs.done(balancer.DoneInfo{Err: err})
cs.done = nil cs.done = nil
} }
if cs.statsHandler != nil { if cs.statsHandler != nil {