diff --git a/balancer/xds/edsbalancer/balancergroup.go b/balancer/xds/edsbalancer/balancergroup.go index a921fab6..1ae931b8 100644 --- a/balancer/xds/edsbalancer/balancergroup.go +++ b/balancer/xds/edsbalancer/balancergroup.go @@ -23,6 +23,8 @@ import ( "google.golang.org/grpc/balancer" "google.golang.org/grpc/balancer/base" "google.golang.org/grpc/balancer/internal/wrr" + "google.golang.org/grpc/balancer/xds/internal" + "google.golang.org/grpc/balancer/xds/lrs" "google.golang.org/grpc/connectivity" "google.golang.org/grpc/grpclog" "google.golang.org/grpc/resolver" @@ -57,27 +59,29 @@ type balancerGroup struct { cc balancer.ClientConn mu sync.Mutex - idToBalancer map[string]balancer.Balancer - scToID map[balancer.SubConn]string + idToBalancer map[internal.Locality]balancer.Balancer + scToID map[balancer.SubConn]internal.Locality + loadStore lrs.Store pickerMu sync.Mutex // All balancer IDs exist as keys in this map. If an ID is not in map, it's // either removed or never added. - idToPickerState map[string]*pickerState + idToPickerState map[internal.Locality]*pickerState } -func newBalancerGroup(cc balancer.ClientConn) *balancerGroup { +func newBalancerGroup(cc balancer.ClientConn, loadStore lrs.Store) *balancerGroup { return &balancerGroup{ cc: cc, - scToID: make(map[balancer.SubConn]string), - idToBalancer: make(map[string]balancer.Balancer), - idToPickerState: make(map[string]*pickerState), + scToID: make(map[balancer.SubConn]internal.Locality), + idToBalancer: make(map[internal.Locality]balancer.Balancer), + idToPickerState: make(map[internal.Locality]*pickerState), + loadStore: loadStore, } } // add adds a balancer built by builder to the group, with given id and weight. -func (bg *balancerGroup) add(id string, weight uint32, builder balancer.Builder) { +func (bg *balancerGroup) add(id internal.Locality, weight uint32, builder balancer.Builder) { bg.mu.Lock() if _, ok := bg.idToBalancer[id]; ok { bg.mu.Unlock() @@ -109,7 +113,7 @@ func (bg *balancerGroup) add(id string, weight uint32, builder balancer.Builder) // // It also removes the picker generated from this balancer from the picker // group. It always results in a picker update. -func (bg *balancerGroup) remove(id string) { +func (bg *balancerGroup) remove(id internal.Locality) { bg.mu.Lock() // Close balancer. if b, ok := bg.idToBalancer[id]; ok { @@ -139,7 +143,7 @@ func (bg *balancerGroup) remove(id string) { // NOTE: It always results in a picker update now. This probably isn't // necessary. But it seems better to do the update because it's a change in the // picker (which is balancer's snapshot). -func (bg *balancerGroup) changeWeight(id string, newWeight uint32) { +func (bg *balancerGroup) changeWeight(id internal.Locality, newWeight uint32) { bg.pickerMu.Lock() defer bg.pickerMu.Unlock() pState, ok := bg.idToPickerState[id] @@ -181,7 +185,7 @@ func (bg *balancerGroup) handleSubConnStateChange(sc balancer.SubConn, state con } // Address change: forward to balancer. -func (bg *balancerGroup) handleResolvedAddrs(id string, addrs []resolver.Address) { +func (bg *balancerGroup) handleResolvedAddrs(id internal.Locality, addrs []resolver.Address) { bg.mu.Lock() b, ok := bg.idToBalancer[id] bg.mu.Unlock() @@ -210,7 +214,7 @@ func (bg *balancerGroup) handleResolvedAddrs(id string, addrs []resolver.Address // from map. Delete sc from the map only when state changes to Shutdown. Since // it's just forwarding the action, there's no need for a removeSubConn() // wrapper function. -func (bg *balancerGroup) newSubConn(id string, addrs []resolver.Address, opts balancer.NewSubConnOptions) (balancer.SubConn, error) { +func (bg *balancerGroup) newSubConn(id internal.Locality, addrs []resolver.Address, opts balancer.NewSubConnOptions) (balancer.SubConn, error) { sc, err := bg.cc.NewSubConn(addrs, opts) if err != nil { return nil, err @@ -223,7 +227,7 @@ func (bg *balancerGroup) newSubConn(id string, addrs []resolver.Address, opts ba // updateBalancerState: create an aggregated picker and an aggregated // connectivity state, then forward to ClientConn. -func (bg *balancerGroup) updateBalancerState(id string, state connectivity.State, picker balancer.Picker) { +func (bg *balancerGroup) updateBalancerState(id internal.Locality, state connectivity.State, picker balancer.Picker) { grpclog.Infof("balancer group: update balancer state: %v, %v, %p", id, state, picker) bg.pickerMu.Lock() defer bg.pickerMu.Unlock() @@ -234,7 +238,7 @@ func (bg *balancerGroup) updateBalancerState(id string, state connectivity.State grpclog.Infof("balancer group: pickerState not found when update picker/state") return } - pickerSt.picker = picker + pickerSt.picker = newLoadReportPicker(picker, id, bg.loadStore) pickerSt.state = state bg.cc.UpdateBalancerState(buildPickerAndState(bg.idToPickerState)) } @@ -251,7 +255,7 @@ func (bg *balancerGroup) close() { bg.mu.Unlock() } -func buildPickerAndState(m map[string]*pickerState) (connectivity.State, balancer.Picker) { +func buildPickerAndState(m map[internal.Locality]*pickerState) (connectivity.State, balancer.Picker) { var readyN, connectingN int readyPickerWithWeights := make([]pickerState, 0, len(m)) for _, ps := range m { @@ -313,6 +317,36 @@ func (pg *pickerGroup) Pick(ctx context.Context, opts balancer.PickOptions) (con return p.Pick(ctx, opts) } +type loadReportPicker struct { + balancer.Picker + + id internal.Locality + loadStore lrs.Store +} + +func newLoadReportPicker(p balancer.Picker, id internal.Locality, loadStore lrs.Store) *loadReportPicker { + return &loadReportPicker{ + Picker: p, + id: id, + loadStore: loadStore, + } +} + +func (lrp *loadReportPicker) Pick(ctx context.Context, opts balancer.PickOptions) (conn balancer.SubConn, done func(balancer.DoneInfo), err error) { + conn, done, err = lrp.Picker.Pick(ctx, opts) + if lrp.loadStore != nil && err == nil { + lrp.loadStore.CallStarted(lrp.id) + td := done + done = func(info balancer.DoneInfo) { + lrp.loadStore.CallFinished(lrp.id, info.Err) + if td != nil { + td(info) + } + } + } + return +} + // balancerGroupCC implements the balancer.ClientConn API and get passed to each // sub-balancer. It contains the sub-balancer ID, so the parent balancer can // keep track of SubConn/pickers and the sub-balancers they belong to. @@ -320,7 +354,7 @@ func (pg *pickerGroup) Pick(ctx context.Context, opts balancer.PickOptions) (con // Some of the actions are forwarded to the parent ClientConn with no change. // Some are forward to balancer group with the sub-balancer ID. type balancerGroupCC struct { - id string + id internal.Locality group *balancerGroup } diff --git a/balancer/xds/edsbalancer/balancergroup_test.go b/balancer/xds/edsbalancer/balancergroup_test.go index 580d9069..27c86719 100644 --- a/balancer/xds/edsbalancer/balancergroup_test.go +++ b/balancer/xds/edsbalancer/balancergroup_test.go @@ -23,20 +23,21 @@ import ( "google.golang.org/grpc/balancer" "google.golang.org/grpc/balancer/roundrobin" + "google.golang.org/grpc/balancer/xds/internal" "google.golang.org/grpc/connectivity" "google.golang.org/grpc/resolver" ) var ( rrBuilder = balancer.Get(roundrobin.Name) - testBalancerIDs = []string{"b1", "b2", "b3"} + testBalancerIDs = []internal.Locality{{Region: "b1"}, {Region: "b2"}, {Region: "b3"}} testBackendAddrs = []resolver.Address{{Addr: "1.1.1.1:1"}, {Addr: "2.2.2.2:2"}, {Addr: "3.3.3.3:3"}, {Addr: "4.4.4.4:4"}} ) // 1 balancer, 1 backend -> 2 backends -> 1 backend. func TestBalancerGroup_OneRR_AddRemoveBackend(t *testing.T) { cc := newTestClientConn(t) - bg := newBalancerGroup(cc) + bg := newBalancerGroup(cc, nil) // Add one balancer to group. bg.add(testBalancerIDs[0], 1, rrBuilder) @@ -95,7 +96,7 @@ func TestBalancerGroup_OneRR_AddRemoveBackend(t *testing.T) { // 2 balancers, each with 1 backend. func TestBalancerGroup_TwoRR_OneBackend(t *testing.T) { cc := newTestClientConn(t) - bg := newBalancerGroup(cc) + bg := newBalancerGroup(cc, nil) // Add two balancers to group and send one resolved address to both // balancers. @@ -127,7 +128,7 @@ func TestBalancerGroup_TwoRR_OneBackend(t *testing.T) { // 2 balancers, each with more than 1 backends. func TestBalancerGroup_TwoRR_MoreBackends(t *testing.T) { cc := newTestClientConn(t) - bg := newBalancerGroup(cc) + bg := newBalancerGroup(cc, nil) // Add two balancers to group and send one resolved address to both // balancers. @@ -223,7 +224,7 @@ func TestBalancerGroup_TwoRR_MoreBackends(t *testing.T) { // 2 balancers with different weights. func TestBalancerGroup_TwoRR_DifferentWeight_MoreBackends(t *testing.T) { cc := newTestClientConn(t) - bg := newBalancerGroup(cc) + bg := newBalancerGroup(cc, nil) // Add two balancers to group and send two resolved addresses to both // balancers. @@ -261,7 +262,7 @@ func TestBalancerGroup_TwoRR_DifferentWeight_MoreBackends(t *testing.T) { // totally 3 balancers, add/remove balancer. func TestBalancerGroup_ThreeRR_RemoveBalancer(t *testing.T) { cc := newTestClientConn(t) - bg := newBalancerGroup(cc) + bg := newBalancerGroup(cc, nil) // Add three balancers to group and send one resolved address to both // balancers. @@ -328,7 +329,7 @@ func TestBalancerGroup_ThreeRR_RemoveBalancer(t *testing.T) { // 2 balancers, change balancer weight. func TestBalancerGroup_TwoRR_ChangeWeight_MoreBackends(t *testing.T) { cc := newTestClientConn(t) - bg := newBalancerGroup(cc) + bg := newBalancerGroup(cc, nil) // Add two balancers to group and send two resolved addresses to both // balancers. @@ -374,3 +375,61 @@ func TestBalancerGroup_TwoRR_ChangeWeight_MoreBackends(t *testing.T) { t.Fatalf("want %v, got %v", want, err) } } + +func TestBalancerGroup_LoadReport(t *testing.T) { + testLoadStore := newTestLoadStore() + + cc := newTestClientConn(t) + bg := newBalancerGroup(cc, testLoadStore) + + backendToBalancerID := make(map[balancer.SubConn]internal.Locality) + + // Add two balancers to group and send two resolved addresses to both + // balancers. + bg.add(testBalancerIDs[0], 2, rrBuilder) + bg.handleResolvedAddrs(testBalancerIDs[0], testBackendAddrs[0:2]) + sc1 := <-cc.newSubConnCh + sc2 := <-cc.newSubConnCh + backendToBalancerID[sc1] = testBalancerIDs[0] + backendToBalancerID[sc2] = testBalancerIDs[0] + + bg.add(testBalancerIDs[1], 1, rrBuilder) + bg.handleResolvedAddrs(testBalancerIDs[1], testBackendAddrs[2:4]) + sc3 := <-cc.newSubConnCh + sc4 := <-cc.newSubConnCh + backendToBalancerID[sc3] = testBalancerIDs[1] + backendToBalancerID[sc4] = testBalancerIDs[1] + + // Send state changes for both subconns. + bg.handleSubConnStateChange(sc1, connectivity.Connecting) + bg.handleSubConnStateChange(sc1, connectivity.Ready) + bg.handleSubConnStateChange(sc2, connectivity.Connecting) + bg.handleSubConnStateChange(sc2, connectivity.Ready) + bg.handleSubConnStateChange(sc3, connectivity.Connecting) + bg.handleSubConnStateChange(sc3, connectivity.Ready) + bg.handleSubConnStateChange(sc4, connectivity.Connecting) + bg.handleSubConnStateChange(sc4, connectivity.Ready) + + // Test roundrobin on the last picker. + p1 := <-cc.newPickerCh + var ( + wantStart []internal.Locality + wantEnd []internal.Locality + ) + for i := 0; i < 10; i++ { + sc, done, _ := p1.Pick(context.Background(), balancer.PickOptions{}) + locality := backendToBalancerID[sc] + wantStart = append(wantStart, locality) + if done != nil && sc != sc1 { + done(balancer.DoneInfo{}) + wantEnd = append(wantEnd, backendToBalancerID[sc]) + } + } + + if !reflect.DeepEqual(testLoadStore.callsStarted, wantStart) { + t.Fatalf("want started: %v, got: %v", testLoadStore.callsStarted, wantStart) + } + if !reflect.DeepEqual(testLoadStore.callsEnded, wantEnd) { + t.Fatalf("want ended: %v, got: %v", testLoadStore.callsEnded, wantEnd) + } +} diff --git a/balancer/xds/edsbalancer/edsbalancer.go b/balancer/xds/edsbalancer/edsbalancer.go index 3867b469..7a68129b 100644 --- a/balancer/xds/edsbalancer/edsbalancer.go +++ b/balancer/xds/edsbalancer/edsbalancer.go @@ -20,7 +20,6 @@ package edsbalancer import ( "context" "encoding/json" - "fmt" "net" "reflect" "strconv" @@ -28,6 +27,7 @@ import ( "google.golang.org/grpc/balancer" "google.golang.org/grpc/balancer/roundrobin" + "google.golang.org/grpc/balancer/xds/internal" edspb "google.golang.org/grpc/balancer/xds/internal/proto/envoy/api/v2/eds" percentpb "google.golang.org/grpc/balancer/xds/internal/proto/envoy/type/percent" "google.golang.org/grpc/balancer/xds/lrs" @@ -54,7 +54,7 @@ type EDSBalancer struct { bg *balancerGroup subBalancerBuilder balancer.Builder - lidToConfig map[string]*localityConfig + lidToConfig map[internal.Locality]*localityConfig loadStore lrs.Store pickerMu sync.Mutex @@ -69,7 +69,7 @@ func NewXDSBalancer(cc balancer.ClientConn, loadStore lrs.Store) *EDSBalancer { ClientConn: cc, subBalancerBuilder: balancer.Get(roundrobin.Name), - lidToConfig: make(map[string]*localityConfig), + lidToConfig: make(map[internal.Locality]*localityConfig), loadStore: loadStore, } // Don't start balancer group here. Start it when handling the first EDS @@ -173,7 +173,7 @@ func (xdsB *EDSBalancer) HandleEDSResponse(edsResp *edspb.ClusterLoadAssignment) // Create balancer group if it's never created (this is the first EDS // response). if xdsB.bg == nil { - xdsB.bg = newBalancerGroup(xdsB) + xdsB.bg = newBalancerGroup(xdsB, xdsB.loadStore) } // TODO: Unhandled fields from EDS response: @@ -189,7 +189,7 @@ func (xdsB *EDSBalancer) HandleEDSResponse(edsResp *edspb.ClusterLoadAssignment) // newLocalitiesSet contains all names of localitis in the new EDS response. // It's used to delete localities that are removed in the new EDS response. - newLocalitiesSet := make(map[string]struct{}) + newLocalitiesSet := make(map[internal.Locality]struct{}) for _, locality := range edsResp.Endpoints { // One balancer for each locality. @@ -198,7 +198,11 @@ func (xdsB *EDSBalancer) HandleEDSResponse(edsResp *edspb.ClusterLoadAssignment) grpclog.Warningf("xds: received LocalityLbEndpoints with Locality") continue } - lid := fmt.Sprintf("%s-%s-%s", l.Region, l.Zone, l.SubZone) + lid := internal.Locality{ + Region: l.Region, + Zone: l.Zone, + SubZone: l.SubZone, + } newLocalitiesSet[lid] = struct{}{} newWeight := locality.GetLoadBalancingWeight().GetValue() diff --git a/balancer/xds/edsbalancer/edsbalancer_test.go b/balancer/xds/edsbalancer/edsbalancer_test.go index 7d907483..2fbcbd32 100644 --- a/balancer/xds/edsbalancer/edsbalancer_test.go +++ b/balancer/xds/edsbalancer/edsbalancer_test.go @@ -27,6 +27,7 @@ import ( typespb "github.com/golang/protobuf/ptypes/wrappers" "google.golang.org/grpc/balancer" "google.golang.org/grpc/balancer/roundrobin" + "google.golang.org/grpc/balancer/xds/internal" addresspb "google.golang.org/grpc/balancer/xds/internal/proto/envoy/api/v2/core/address" basepb "google.golang.org/grpc/balancer/xds/internal/proto/envoy/api/v2/core/base" edspb "google.golang.org/grpc/balancer/xds/internal/proto/envoy/api/v2/eds" @@ -532,3 +533,55 @@ func TestDropPicker(t *testing.T) { }) } } + +func TestEDS_LoadReport(t *testing.T) { + testLoadStore := newTestLoadStore() + + cc := newTestClientConn(t) + edsb := NewXDSBalancer(cc, testLoadStore) + + backendToBalancerID := make(map[balancer.SubConn]internal.Locality) + + // Two localities, each with one backend. + clab1 := newClusterLoadAssignmentBuilder(testClusterNames[0], nil) + clab1.addLocality(testSubZones[0], 1, testEndpointAddrs[:1]) + clab1.addLocality(testSubZones[1], 1, testEndpointAddrs[1:2]) + edsb.HandleEDSResponse(clab1.build()) + + sc1 := <-cc.newSubConnCh + edsb.HandleSubConnStateChange(sc1, connectivity.Connecting) + edsb.HandleSubConnStateChange(sc1, connectivity.Ready) + backendToBalancerID[sc1] = internal.Locality{ + SubZone: testSubZones[0], + } + sc2 := <-cc.newSubConnCh + edsb.HandleSubConnStateChange(sc2, connectivity.Connecting) + edsb.HandleSubConnStateChange(sc2, connectivity.Ready) + backendToBalancerID[sc2] = internal.Locality{ + SubZone: testSubZones[1], + } + + // Test roundrobin with two subconns. + p1 := <-cc.newPickerCh + var ( + wantStart []internal.Locality + wantEnd []internal.Locality + ) + + for i := 0; i < 10; i++ { + sc, done, _ := p1.Pick(context.Background(), balancer.PickOptions{}) + locality := backendToBalancerID[sc] + wantStart = append(wantStart, locality) + if done != nil && sc != sc1 { + done(balancer.DoneInfo{}) + wantEnd = append(wantEnd, backendToBalancerID[sc]) + } + } + + if !reflect.DeepEqual(testLoadStore.callsStarted, wantStart) { + t.Fatalf("want started: %v, got: %v", testLoadStore.callsStarted, wantStart) + } + if !reflect.DeepEqual(testLoadStore.callsEnded, wantEnd) { + t.Fatalf("want ended: %v, got: %v", testLoadStore.callsEnded, wantEnd) + } +} diff --git a/balancer/xds/edsbalancer/test_util_test.go b/balancer/xds/edsbalancer/test_util_test.go index ed0415d4..559a431e 100644 --- a/balancer/xds/edsbalancer/test_util_test.go +++ b/balancer/xds/edsbalancer/test_util_test.go @@ -17,10 +17,13 @@ package edsbalancer import ( + "context" "fmt" "testing" + "google.golang.org/grpc" "google.golang.org/grpc/balancer" + "google.golang.org/grpc/balancer/xds/internal" "google.golang.org/grpc/connectivity" "google.golang.org/grpc/resolver" ) @@ -75,7 +78,7 @@ func (tcc *testClientConn) NewSubConn(a []resolver.Address, o balancer.NewSubCon sc := testSubConns[tcc.subConnIdx] tcc.subConnIdx++ - tcc.t.Logf("testClientConn: NewSubConn(%v, %+v) => %p", a, o, sc) + tcc.t.Logf("testClientConn: NewSubConn(%v, %+v) => %s", a, o, sc) select { case tcc.newSubConnAddrsCh <- a: default: @@ -120,6 +123,31 @@ func (tcc *testClientConn) Target() string { panic("not implemented") } +type testLoadStore struct { + callsStarted []internal.Locality + callsEnded []internal.Locality +} + +func newTestLoadStore() *testLoadStore { + return &testLoadStore{} +} + +func (*testLoadStore) CallDropped(category string) { + panic("not implemented") +} + +func (tls *testLoadStore) CallStarted(l internal.Locality) { + tls.callsStarted = append(tls.callsStarted, l) +} + +func (tls *testLoadStore) CallFinished(l internal.Locality, err error) { + tls.callsEnded = append(tls.callsEnded, l) +} + +func (*testLoadStore) ReportTo(ctx context.Context, cc *grpc.ClientConn) { + panic("not implemented") +} + // isRoundRobin checks whether f's return value is roundrobin of elements from // want. But it doesn't check for the order. Note that want can contain // duplicate items, which makes it weight-round-robin. diff --git a/balancer/xds/internal/internal.go b/balancer/xds/internal/internal.go new file mode 100644 index 00000000..a0d209cc --- /dev/null +++ b/balancer/xds/internal/internal.go @@ -0,0 +1,50 @@ +/* + * + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package internal + +import ( + "fmt" + + basepb "google.golang.org/grpc/balancer/xds/internal/proto/envoy/api/v2/core/base" +) + +// Locality is xds.Locality without XXX fields, so it can be used as map +// keys. +// +// xds.Locality cannot be map keys because one of the XXX fields is a slice. +// +// This struct should only be used as map keys. Use the proto message directly +// in all other places. +type Locality struct { + Region string + Zone string + SubZone string +} + +func (lamk Locality) String() string { + return fmt.Sprintf("%s-%s-%s", lamk.Region, lamk.Zone, lamk.SubZone) +} + +// ToProto convert Locality to the proto representation. +func (lamk Locality) ToProto() *basepb.Locality { + return &basepb.Locality{ + Region: lamk.Region, + Zone: lamk.Zone, + SubZone: lamk.SubZone, + } +} diff --git a/balancer/xds/internal/internal_test.go b/balancer/xds/internal/internal_test.go new file mode 100644 index 00000000..89c99f76 --- /dev/null +++ b/balancer/xds/internal/internal_test.go @@ -0,0 +1,51 @@ +/* + * + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package internal + +import ( + "reflect" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + basepb "google.golang.org/grpc/balancer/xds/internal/proto/envoy/api/v2/core/base" +) + +// A reflection based test to make sure internal.Locality contains all the +// fields (expect for XXX_) from the proto message. +func TestLocalityMatchProtoMessage(t *testing.T) { + want1 := make(map[string]string) + for ty, i := reflect.TypeOf(Locality{}), 0; i < ty.NumField(); i++ { + f := ty.Field(i) + want1[f.Name] = f.Type.Name() + } + + const ignorePrefix = "XXX_" + want2 := make(map[string]string) + for ty, i := reflect.TypeOf(basepb.Locality{}), 0; i < ty.NumField(); i++ { + f := ty.Field(i) + if strings.HasPrefix(f.Name, ignorePrefix) { + continue + } + want2[f.Name] = f.Type.Name() + } + + if !reflect.DeepEqual(want1, want2) { + t.Fatalf("internal type and proto message have different fields:\n%+v", cmp.Diff(want1, want2)) + } +} diff --git a/balancer/xds/lrs/lrs.go b/balancer/xds/lrs/lrs.go index 4f336d8d..74f80d5d 100644 --- a/balancer/xds/lrs/lrs.go +++ b/balancer/xds/lrs/lrs.go @@ -35,13 +35,32 @@ import ( "google.golang.org/grpc/internal/backoff" ) +const negativeOneUInt64 = ^uint64(0) + // Store defines the interface for a load store. It keeps loads and can report // them to a server when requested. type Store interface { CallDropped(category string) + CallStarted(l internal.Locality) + CallFinished(l internal.Locality, err error) ReportTo(ctx context.Context, cc *grpc.ClientConn) } +type rpcCountData struct { + // Only atomic accesses are allowed for the fields. + succeeded *uint64 + errored *uint64 + inProgress *uint64 +} + +func newRPCCountData() *rpcCountData { + return &rpcCountData{ + succeeded: new(uint64), + errored: new(uint64), + inProgress: new(uint64), + } +} + // lrsStore collects loads from xds balancer, and periodically sends load to the // server. type lrsStore struct { @@ -50,7 +69,8 @@ type lrsStore struct { backoff backoff.Strategy lastReported time.Time - drops sync.Map // map[string]*uint64 + drops sync.Map // map[string]*uint64 + localityRPCCount sync.Map // map[internal.Locality]*rpcCountData } // NewStore creates a store for load reports. @@ -86,14 +106,35 @@ func (ls *lrsStore) CallDropped(category string) { atomic.AddUint64(p.(*uint64), 1) } -// TODO: add query counts -// callStarted(l locality) -// callFinished(l locality, err error) +func (ls *lrsStore) CallStarted(l internal.Locality) { + p, ok := ls.localityRPCCount.Load(l) + if !ok { + tp := newRPCCountData() + p, _ = ls.localityRPCCount.LoadOrStore(l, tp) + } + atomic.AddUint64(p.(*rpcCountData).inProgress, 1) +} + +func (ls *lrsStore) CallFinished(l internal.Locality, err error) { + p, ok := ls.localityRPCCount.Load(l) + if !ok { + // The map is never cleared, only values in the map are reset. So the + // case where entry for call-finish is not found should never happen. + return + } + atomic.AddUint64(p.(*rpcCountData).inProgress, negativeOneUInt64) // atomic.Add(x, -1) + if err == nil { + atomic.AddUint64(p.(*rpcCountData).succeeded, 1) + } else { + atomic.AddUint64(p.(*rpcCountData).errored, 1) + } +} func (ls *lrsStore) buildStats() []*loadreportpb.ClusterStats { var ( - totalDropped uint64 - droppedReqs []*loadreportpb.ClusterStats_DroppedRequests + totalDropped uint64 + droppedReqs []*loadreportpb.ClusterStats_DroppedRequests + localityStats []*loadreportpb.UpstreamLocalityStats ) ls.drops.Range(func(category, countP interface{}) bool { tempCount := atomic.SwapUint64(countP.(*uint64), 0) @@ -107,6 +148,31 @@ func (ls *lrsStore) buildStats() []*loadreportpb.ClusterStats { }) return true }) + ls.localityRPCCount.Range(func(locality, countP interface{}) bool { + tempLocality := locality.(internal.Locality) + tempCount := countP.(*rpcCountData) + + tempSucceeded := atomic.SwapUint64(tempCount.succeeded, 0) + tempInProgress := atomic.LoadUint64(tempCount.inProgress) // InProgress count is not clear when reading. + tempErrored := atomic.SwapUint64(tempCount.errored, 0) + if tempSucceeded == 0 && tempInProgress == 0 && tempErrored == 0 { + return true + } + + localityStats = append(localityStats, &loadreportpb.UpstreamLocalityStats{ + Locality: &basepb.Locality{ + Region: tempLocality.Region, + Zone: tempLocality.Zone, + SubZone: tempLocality.SubZone, + }, + TotalSuccessfulRequests: tempSucceeded, + TotalRequestsInProgress: tempInProgress, + TotalErrorRequests: tempErrored, + LoadMetricStats: nil, // TODO: populate for user loads. + UpstreamEndpointStats: nil, // TODO: populate for per endpoint loads. + }) + return true + }) dur := time.Since(ls.lastReported) ls.lastReported = time.Now() @@ -114,7 +180,7 @@ func (ls *lrsStore) buildStats() []*loadreportpb.ClusterStats { var ret []*loadreportpb.ClusterStats ret = append(ret, &loadreportpb.ClusterStats{ ClusterName: ls.serviceName, - UpstreamLocalityStats: nil, // TODO: populate this to support per locality loads. + UpstreamLocalityStats: localityStats, TotalDroppedRequests: totalDropped, DroppedRequests: droppedReqs, diff --git a/balancer/xds/lrs/lrs_test.go b/balancer/xds/lrs/lrs_test.go index 5bb8c00c..925f52b4 100644 --- a/balancer/xds/lrs/lrs_test.go +++ b/balancer/xds/lrs/lrs_test.go @@ -19,11 +19,13 @@ package lrs import ( "context" + "fmt" "io" "net" "reflect" "sort" "sync" + "sync/atomic" "testing" "time" @@ -43,7 +45,23 @@ import ( const testService = "grpc.service.test" -var dropCategories = []string{"drop_for_real", "drop_for_fun"} +var ( + dropCategories = []string{"drop_for_real", "drop_for_fun"} + localities = []internal.Locality{{Region: "a"}, {Region: "b"}} + errTest = fmt.Errorf("test error") +) + +func newRPCCountDataWithInitData(succeeded, errored, inprogress uint64) *rpcCountData { + return &rpcCountData{ + succeeded: &succeeded, + errored: &errored, + inProgress: &inprogress, + } +} + +func (rcd *rpcCountData) Equal(b *rpcCountData) bool { + return *rcd.inProgress == *b.inProgress && *rcd.errored == *b.errored && *rcd.succeeded == *b.succeeded +} // equalClusterStats sorts requests and clear report internal before comparing. func equalClusterStats(a, b []*loadreportpb.ClusterStats) bool { @@ -51,31 +69,37 @@ func equalClusterStats(a, b []*loadreportpb.ClusterStats) bool { sort.Slice(s.DroppedRequests, func(i, j int) bool { return s.DroppedRequests[i].Category < s.DroppedRequests[j].Category }) + sort.Slice(s.UpstreamLocalityStats, func(i, j int) bool { + return s.UpstreamLocalityStats[i].Locality.String() < s.UpstreamLocalityStats[j].Locality.String() + }) s.LoadReportInterval = nil } for _, s := range b { sort.Slice(s.DroppedRequests, func(i, j int) bool { return s.DroppedRequests[i].Category < s.DroppedRequests[j].Category }) + sort.Slice(s.UpstreamLocalityStats, func(i, j int) bool { + return s.UpstreamLocalityStats[i].Locality.String() < s.UpstreamLocalityStats[j].Locality.String() + }) s.LoadReportInterval = nil } return reflect.DeepEqual(a, b) } -func Test_lrsStore_buildStats(t *testing.T) { +func Test_lrsStore_buildStats_drops(t *testing.T) { tests := []struct { name string drops []map[string]uint64 }{ { - name: "one report", + name: "one drop report", drops: []map[string]uint64{{ dropCategories[0]: 31, dropCategories[1]: 41, }}, }, { - name: "two reports", + name: "two drop reports", drops: []map[string]uint64{{ dropCategories[0]: 31, dropCategories[1]: 41, @@ -84,6 +108,16 @@ func Test_lrsStore_buildStats(t *testing.T) { dropCategories[1]: 26, }}, }, + { + name: "no empty report", + drops: []map[string]uint64{{ + dropCategories[0]: 31, + dropCategories[1]: 41, + }, { + dropCategories[0]: 0, // This is shouldn't cause an empty report for category[0]. + dropCategories[1]: 26, + }}, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -95,6 +129,9 @@ func Test_lrsStore_buildStats(t *testing.T) { droppedReqs []*loadreportpb.ClusterStats_DroppedRequests ) for cat, count := range ds { + if count == 0 { + continue + } totalDropped += count droppedReqs = append(droppedReqs, &loadreportpb.ClusterStats_DroppedRequests{ Category: cat, @@ -130,11 +167,141 @@ func Test_lrsStore_buildStats(t *testing.T) { } } +func Test_lrsStore_buildStats_rpcCounts(t *testing.T) { + tests := []struct { + name string + rpcs []map[internal.Locality]struct { + start, success, failure uint64 + } + }{ + { + name: "one rpcCount report", + rpcs: []map[internal.Locality]struct { + start, success, failure uint64 + }{{ + localities[0]: {8, 3, 1}, + }}, + }, + { + name: "two localities rpcCount reports", + rpcs: []map[internal.Locality]struct { + start, success, failure uint64 + }{{ + localities[0]: {8, 3, 1}, + localities[1]: {15, 1, 5}, + }}, + }, + { + name: "two rpcCount reports", + rpcs: []map[internal.Locality]struct { + start, success, failure uint64 + }{{ + localities[0]: {8, 3, 1}, + localities[1]: {15, 1, 5}, + }, { + localities[0]: {8, 3, 1}, + }, { + localities[1]: {15, 1, 5}, + }}, + }, + { + name: "no empty report", + rpcs: []map[internal.Locality]struct { + start, success, failure uint64 + }{{ + localities[0]: {4, 3, 1}, + localities[1]: {7, 1, 5}, + }, { + localities[0]: {0, 0, 0}, // This is shouldn't cause an empty report for locality[0]. + localities[1]: {1, 1, 0}, + }}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ls := NewStore(testService).(*lrsStore) + + // InProgress count doesn't get cleared at each buildStats, keep + // them to carry over. + inProgressCounts := make(map[internal.Locality]uint64) + + for _, counts := range tt.rpcs { + var upstreamLocalityStats []*loadreportpb.UpstreamLocalityStats + + for l, count := range counts { + tempInProgress := count.start - count.success - count.failure + inProgressCounts[l] + inProgressCounts[l] = tempInProgress + if count.success == 0 && tempInProgress == 0 && count.failure == 0 { + continue + } + upstreamLocalityStats = append(upstreamLocalityStats, &loadreportpb.UpstreamLocalityStats{ + Locality: l.ToProto(), + TotalSuccessfulRequests: count.success, + TotalRequestsInProgress: tempInProgress, + TotalErrorRequests: count.failure, + }) + } + // InProgress count doesn't get cleared at each buildStats, and + // needs to be carried over to the next result. + for l, c := range inProgressCounts { + if _, ok := counts[l]; !ok { + upstreamLocalityStats = append(upstreamLocalityStats, &loadreportpb.UpstreamLocalityStats{ + Locality: l.ToProto(), + TotalRequestsInProgress: c, + }) + } + } + want := []*loadreportpb.ClusterStats{ + { + ClusterName: testService, + UpstreamLocalityStats: upstreamLocalityStats, + }, + } + + var wg sync.WaitGroup + for l, count := range counts { + for i := 0; i < int(count.success); i++ { + wg.Add(1) + go func(i int, l internal.Locality) { + ls.CallStarted(l) + ls.CallFinished(l, nil) + wg.Done() + }(i, l) + } + for i := 0; i < int(count.failure); i++ { + wg.Add(1) + go func(i int, l internal.Locality) { + ls.CallStarted(l) + ls.CallFinished(l, errTest) + wg.Done() + }(i, l) + } + for i := 0; i < int(count.start-count.success-count.failure); i++ { + wg.Add(1) + go func(i int, l internal.Locality) { + ls.CallStarted(l) + wg.Done() + }(i, l) + } + } + wg.Wait() + + if got := ls.buildStats(); !equalClusterStats(got, want) { + t.Errorf("lrsStore.buildStats() = %v, want %v", got, want) + t.Errorf("%s", cmp.Diff(got, want)) + } + } + }) + } +} + type lrsServer struct { - mu sync.Mutex - dropTotal uint64 - drops map[string]uint64 reportingInterval *durationpb.Duration + + mu sync.Mutex + dropTotal uint64 + drops map[string]uint64 + rpcs map[internal.Locality]*rpcCountData } func (lrss *lrsServer) StreamLoadStats(stream lrsgrpc.LoadReportingService_StreamLoadStatsServer) error { @@ -176,6 +343,21 @@ func (lrss *lrsServer) StreamLoadStats(stream lrsgrpc.LoadReportingService_Strea for _, d := range stats.DroppedRequests { lrss.drops[d.Category] += d.DroppedCount } + for _, ss := range stats.UpstreamLocalityStats { + l := internal.Locality{ + Region: ss.Locality.Region, + Zone: ss.Locality.Zone, + SubZone: ss.Locality.SubZone, + } + counts, ok := lrss.rpcs[l] + if !ok { + counts = newRPCCountDataWithInitData(0, 0, 0) + lrss.rpcs[l] = counts + } + atomic.AddUint64(counts.succeeded, ss.TotalSuccessfulRequests) + atomic.StoreUint64(counts.inProgress, ss.TotalRequestsInProgress) + atomic.AddUint64(counts.errored, ss.TotalErrorRequests) + } lrss.mu.Unlock() } } @@ -187,8 +369,9 @@ func setupServer(t *testing.T, reportingInterval *durationpb.Duration) (addr str } svr := grpc.NewServer() lrss = &lrsServer{ - drops: make(map[string]uint64), reportingInterval: reportingInterval, + drops: make(map[string]uint64), + rpcs: make(map[internal.Locality]*rpcCountData), } lrsgrpc.RegisterLoadReportingServiceServer(svr, lrss) go svr.Serve(lis) @@ -220,16 +403,40 @@ func Test_lrsStore_ReportTo(t *testing.T) { }() drops := map[string]uint64{ - dropCategories[0]: 31, - dropCategories[1]: 41, + dropCategories[0]: 13, + dropCategories[1]: 14, } - for c, d := range drops { for i := 0; i < int(d); i++ { ls.CallDropped(c) time.Sleep(time.Nanosecond * intervalNano / 10) } } + + rpcs := map[internal.Locality]*rpcCountData{ + localities[0]: newRPCCountDataWithInitData(3, 1, 4), + localities[1]: newRPCCountDataWithInitData(1, 5, 9), + } + for l, count := range rpcs { + for i := 0; i < int(*count.succeeded); i++ { + go func(i int, l internal.Locality) { + ls.CallStarted(l) + ls.CallFinished(l, nil) + }(i, l) + } + for i := 0; i < int(*count.inProgress); i++ { + go func(i int, l internal.Locality) { + ls.CallStarted(l) + }(i, l) + } + for i := 0; i < int(*count.errored); i++ { + go func(i int, l internal.Locality) { + ls.CallStarted(l) + ls.CallFinished(l, errTest) + }(i, l) + } + } + time.Sleep(time.Nanosecond * intervalNano * 2) cancel() <-done @@ -239,4 +446,7 @@ func Test_lrsStore_ReportTo(t *testing.T) { if !cmp.Equal(lrss.drops, drops) { t.Errorf("different: %v", cmp.Diff(lrss.drops, drops)) } + if !cmp.Equal(lrss.rpcs, rpcs) { + t.Errorf("different: %v", cmp.Diff(lrss.rpcs, rpcs)) + } }