diff --git a/balancer/grpclb/grpclb.go b/balancer/grpclb/grpclb.go index b05db557..9c5b5f25 100644 --- a/balancer/grpclb/grpclb.go +++ b/balancer/grpclb/grpclb.go @@ -129,8 +129,19 @@ func newLBBuilderWithFallbackTimeout(fallbackTimeout time.Duration) balancer.Bui } } +// newLBBuilderWithPickFirst creates a grpclb builder with pick-first. +func newLBBuilderWithPickFirst() balancer.Builder { + return &lbBuilder{ + usePickFirst: true, + } +} + type lbBuilder struct { fallbackTimeout time.Duration + + // TODO: delete this when balancer can handle service config. This should be + // updated by service config. + usePickFirst bool // Use roundrobin or pickfirst for backends. } func (b *lbBuilder) Name() string { @@ -156,6 +167,7 @@ func (b *lbBuilder) Build(cc balancer.ClientConn, opt balancer.BuildOptions) bal cc: newLBCacheClientConn(cc), target: target, opt: opt, + usePickFirst: b.usePickFirst, fallbackTimeout: b.fallbackTimeout, doneCh: make(chan struct{}), @@ -188,6 +200,8 @@ type lbBalancer struct { target string opt balancer.BuildOptions + usePickFirst bool + // grpclbClientConnCreds is the creds bundle to be used to connect to grpclb // servers. If it's nil, use the TransportCredentials from BuildOptions // instead. @@ -249,11 +263,21 @@ func (lb *lbBalancer) regeneratePicker(resetDrop bool) { lb.picker = &errPicker{err: balancer.ErrTransientFailure} return } + var readySCs []balancer.SubConn - for _, a := range lb.backendAddrs { - if sc, ok := lb.subConns[a]; ok { - if st, ok := lb.scStates[sc]; ok && st == connectivity.Ready { + if lb.usePickFirst { + if lb.state == connectivity.Ready || lb.state == connectivity.Idle { + for _, sc := range lb.subConns { readySCs = append(readySCs, sc) + break + } + } + } else { + for _, a := range lb.backendAddrs { + if sc, ok := lb.subConns[a]; ok { + if st, ok := lb.scStates[sc]; ok && st == connectivity.Ready { + readySCs = append(readySCs, sc) + } } } } diff --git a/balancer/grpclb/grpclb_remote_balancer.go b/balancer/grpclb/grpclb_remote_balancer.go index 8d6cd0f1..19b2a436 100644 --- a/balancer/grpclb/grpclb_remote_balancer.go +++ b/balancer/grpclb/grpclb_remote_balancer.go @@ -96,14 +96,36 @@ func (lb *lbBalancer) processServerList(l *lbpb.ServerList) { // indicating whether the backendAddrs are different from the cached // backendAddrs (whether any SubConn was newed/removed). // Caller must hold lb.mu. -func (lb *lbBalancer) refreshSubConns(backendAddrs []resolver.Address, fromGRPCLBServer bool) bool { +func (lb *lbBalancer) refreshSubConns(backendAddrs []resolver.Address, fromGRPCLBServer bool) { opts := balancer.NewSubConnOptions{} if fromGRPCLBServer { opts.CredsBundle = lb.grpclbBackendCreds } lb.backendAddrs = nil - var backendsUpdated bool + + if lb.usePickFirst { + var sc balancer.SubConn + for _, sc = range lb.subConns { + break + } + if sc != nil { + sc.UpdateAddresses(backendAddrs) + sc.Connect() + return + } + // This bypasses the cc wrapper with SubConn cache. + sc, err := lb.cc.cc.NewSubConn(backendAddrs, opts) + if err != nil { + grpclog.Warningf("grpclb: failed to create new SubConn: %v", err) + return + } + sc.Connect() + lb.subConns[backendAddrs[0]] = sc + lb.scStates[sc] = connectivity.Idle + return + } + // addrsSet is the set converted from backendAddrs, it's used to quick // lookup for an address. addrsSet := make(map[resolver.Address]struct{}) @@ -115,12 +137,10 @@ func (lb *lbBalancer) refreshSubConns(backendAddrs []resolver.Address, fromGRPCL lb.backendAddrs = append(lb.backendAddrs, addrWithoutMD) if _, ok := lb.subConns[addrWithoutMD]; !ok { - backendsUpdated = true - // Use addrWithMD to create the SubConn. sc, err := lb.cc.NewSubConn([]resolver.Address{addr}, opts) if err != nil { - grpclog.Warningf("roundrobinBalancer: failed to create new SubConn: %v", err) + grpclog.Warningf("grpclb: failed to create new SubConn: %v", err) continue } lb.subConns[addrWithoutMD] = sc // Use the addr without MD as key for the map. @@ -136,16 +156,12 @@ func (lb *lbBalancer) refreshSubConns(backendAddrs []resolver.Address, fromGRPCL for a, sc := range lb.subConns { // a was removed by resolver. if _, ok := addrsSet[a]; !ok { - backendsUpdated = true - lb.cc.RemoveSubConn(sc) delete(lb.subConns, a) // Keep the state of this sc in b.scStates until sc's state becomes Shutdown. // The entry will be deleted in HandleSubConnStateChange. } } - - return backendsUpdated } func (lb *lbBalancer) readServerList(s *balanceLoadClientStream) error { diff --git a/balancer/grpclb/grpclb_test.go b/balancer/grpclb/grpclb_test.go index 5e49d1fe..89590d65 100644 --- a/balancer/grpclb/grpclb_test.go +++ b/balancer/grpclb/grpclb_test.go @@ -751,6 +751,107 @@ func TestFallback(t *testing.T) { t.Fatalf("No RPC sent to backend behind remote balancer after 1 second") } +// The remote balancer sends response with duplicates to grpclb client. +func TestGRPCLBPickFirst(t *testing.T) { + balancer.Register(newLBBuilderWithPickFirst()) + defer balancer.Register(newLBBuilder()) + + defer leakcheck.Check(t) + + r, cleanup := manual.GenerateAndRegisterManualResolver() + defer cleanup() + + tss, cleanup, err := newLoadBalancer(3) + if err != nil { + t.Fatalf("failed to create new load balancer: %v", err) + } + defer cleanup() + + beServers := []*lbpb.Server{{ + IpAddress: tss.beIPs[0], + Port: int32(tss.bePorts[0]), + LoadBalanceToken: lbToken, + }, { + IpAddress: tss.beIPs[1], + Port: int32(tss.bePorts[1]), + LoadBalanceToken: lbToken, + }, { + IpAddress: tss.beIPs[2], + Port: int32(tss.bePorts[2]), + LoadBalanceToken: lbToken, + }} + portsToIndex := make(map[int]int) + for i := range beServers { + portsToIndex[tss.bePorts[i]] = i + } + + creds := serverNameCheckCreds{ + expected: beServerName, + } + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, + grpc.WithTransportCredentials(&creds), grpc.WithContextDialer(fakeNameDialer)) + if err != nil { + t.Fatalf("Failed to dial to the backend %v", err) + } + defer cc.Close() + testC := testpb.NewTestServiceClient(cc) + + r.NewAddress([]resolver.Address{{ + Addr: tss.lbAddr, + Type: resolver.GRPCLB, + ServerName: lbServerName, + }}) + + var p peer.Peer + + portPicked1 := 0 + tss.ls.sls <- &lbpb.ServerList{Servers: beServers[1:2]} + for i := 0; i < 1000; i++ { + if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil { + t.Fatalf("_.EmptyCall(_, _) = _, %v, want _, ", err) + } + if portPicked1 == 0 { + portPicked1 = p.Addr.(*net.TCPAddr).Port + continue + } + if portPicked1 != p.Addr.(*net.TCPAddr).Port { + t.Fatalf("Different backends are picked for RPCs: %v vs %v", portPicked1, p.Addr.(*net.TCPAddr).Port) + } + } + + portPicked2 := portPicked1 + tss.ls.sls <- &lbpb.ServerList{Servers: beServers[:1]} + for i := 0; i < 1000; i++ { + if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil { + t.Fatalf("_.EmptyCall(_, _) = _, %v, want _, ", err) + } + if portPicked2 == portPicked1 { + portPicked2 = p.Addr.(*net.TCPAddr).Port + continue + } + if portPicked2 != p.Addr.(*net.TCPAddr).Port { + t.Fatalf("Different backends are picked for RPCs: %v vs %v", portPicked2, p.Addr.(*net.TCPAddr).Port) + } + } + + portPicked := portPicked2 + tss.ls.sls <- &lbpb.ServerList{Servers: beServers[1:]} + for i := 0; i < 1000; i++ { + if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil { + t.Fatalf("_.EmptyCall(_, _) = _, %v, want _, ", err) + } + if portPicked == portPicked2 { + portPicked = p.Addr.(*net.TCPAddr).Port + continue + } + if portPicked != p.Addr.(*net.TCPAddr).Port { + t.Fatalf("Different backends are picked for RPCs: %v vs %v", portPicked, p.Addr.(*net.TCPAddr).Port) + } + } +} + type failPreRPCCred struct{} func (failPreRPCCred) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {