From a8b5bd3c39ac82177c7bad36e1dd695096cd0ef5 Mon Sep 17 00:00:00 2001
From: Menghan Li <menghanl@google.com>
Date: Tue, 23 Apr 2019 13:48:02 -0700
Subject: [PATCH] xds: wrr with random (#2745)

---
 balancer/internal/wrr/random.go           | 59 ++++++++++++++
 balancer/internal/wrr/wrr.go              | 28 +++++++
 balancer/internal/wrr/wrr_test.go         | 99 +++++++++++++++++++++++
 balancer/xds/edsbalancer/balancergroup.go | 33 ++++----
 balancer/xds/edsbalancer/edsbalancer.go   | 11 +--
 balancer/xds/edsbalancer/util.go          | 30 ++-----
 balancer/xds/edsbalancer/util_test.go     | 49 +++++++++++
 go.mod                                    |  1 +
 go.sum                                    |  2 +
 9 files changed, 264 insertions(+), 48 deletions(-)
 create mode 100644 balancer/internal/wrr/random.go
 create mode 100644 balancer/internal/wrr/wrr.go
 create mode 100644 balancer/internal/wrr/wrr_test.go

diff --git a/balancer/internal/wrr/random.go b/balancer/internal/wrr/random.go
new file mode 100644
index 00000000..fe345f78
--- /dev/null
+++ b/balancer/internal/wrr/random.go
@@ -0,0 +1,59 @@
+/*
+ *
+ * Copyright 2019 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package wrr
+
+import "google.golang.org/grpc/internal/grpcrand"
+
+// weightedItem is a wrapped weighted item that is used to implement weighted random algorithm.
+type weightedItem struct {
+	Item   interface{}
+	Weight int64
+}
+
+// randomWRR is a struct that contains weighted items implement weighted random algorithm.
+type randomWRR struct {
+	items        []*weightedItem
+	sumOfWeights int64
+}
+
+// NewRandom creates a new WRR with random.
+func NewRandom() WRR {
+	return &randomWRR{}
+}
+
+func (rw *randomWRR) Next() (item interface{}) {
+	if rw.sumOfWeights == 0 {
+		return nil
+	}
+	// Random number in [0, sum).
+	randomWeight := grpcrand.Int63n(rw.sumOfWeights)
+	for _, item := range rw.items {
+		randomWeight = randomWeight - item.Weight
+		if randomWeight < 0 {
+			return item.Item
+		}
+	}
+
+	return rw.items[len(rw.items)-1].Item
+}
+
+func (rw *randomWRR) Add(item interface{}, weight int64) {
+	rItem := &weightedItem{Item: item, Weight: weight}
+	rw.items = append(rw.items, rItem)
+	rw.sumOfWeights += weight
+}
diff --git a/balancer/internal/wrr/wrr.go b/balancer/internal/wrr/wrr.go
new file mode 100644
index 00000000..7b9cdc73
--- /dev/null
+++ b/balancer/internal/wrr/wrr.go
@@ -0,0 +1,28 @@
+/*
+ *
+ * Copyright 2019 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package wrr
+
+// WRR defines an interface that implements weighted round robin.
+type WRR interface {
+	// Add adds an item with weight to the WRR set.
+	Add(item interface{}, weight int64)
+	// Next returns the next picked item.
+	//
+	// Next needs to be thread safe.
+	Next() interface{}
+}
diff --git a/balancer/internal/wrr/wrr_test.go b/balancer/internal/wrr/wrr_test.go
new file mode 100644
index 00000000..03c2b900
--- /dev/null
+++ b/balancer/internal/wrr/wrr_test.go
@@ -0,0 +1,99 @@
+/*
+ *
+ * Copyright 2019 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package wrr
+
+import (
+	"errors"
+	"math"
+	"testing"
+
+	"github.com/google/go-cmp/cmp"
+)
+
+const iterCount = 10000
+
+func equalApproximate(a, b float64) error {
+	opt := cmp.Comparer(func(x, y float64) bool {
+		delta := math.Abs(x - y)
+		mean := math.Abs(x+y) / 2.0
+		return delta/mean < 0.05
+	})
+	if !cmp.Equal(a, b, opt) {
+		return errors.New(cmp.Diff(a, b))
+	}
+	return nil
+}
+
+func testWRRNext(t *testing.T, newWRR func() WRR) {
+	tests := []struct {
+		name    string
+		weights []int64
+	}{
+		{
+			name:    "1-1-1",
+			weights: []int64{1, 1, 1},
+		},
+		{
+			name:    "1-2-3",
+			weights: []int64{1, 2, 3},
+		},
+		{
+			name:    "5-3-2",
+			weights: []int64{5, 3, 2},
+		},
+		{
+			name:    "17-23-37",
+			weights: []int64{17, 23, 37},
+		},
+	}
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			var sumOfWeights int64
+
+			w := newWRR()
+			for i, weight := range tt.weights {
+				w.Add(i, weight)
+				sumOfWeights += weight
+			}
+
+			results := make(map[int]int)
+			for i := 0; i < iterCount; i++ {
+				results[w.Next().(int)]++
+			}
+
+			wantRatio := make([]float64, len(tt.weights))
+			for i, weight := range tt.weights {
+				wantRatio[i] = float64(weight) / float64(sumOfWeights)
+			}
+			gotRatio := make([]float64, len(tt.weights))
+			for i, count := range results {
+				gotRatio[i] = float64(count) / iterCount
+			}
+
+			for i := range wantRatio {
+				if err := equalApproximate(gotRatio[i], wantRatio[i]); err != nil {
+					t.Errorf("%v not equal %v", i, err)
+				}
+			}
+		})
+	}
+}
+
+func TestRandomWRRNext(t *testing.T) {
+	testWRRNext(t, NewRandom)
+}
diff --git a/balancer/xds/edsbalancer/balancergroup.go b/balancer/xds/edsbalancer/balancergroup.go
index 77b185c8..5367dee8 100644
--- a/balancer/xds/edsbalancer/balancergroup.go
+++ b/balancer/xds/edsbalancer/balancergroup.go
@@ -24,6 +24,7 @@ import (
 
 	"google.golang.org/grpc/balancer"
 	"google.golang.org/grpc/balancer/base"
+	"google.golang.org/grpc/balancer/internal/wrr"
 	"google.golang.org/grpc/connectivity"
 	"google.golang.org/grpc/grpclog"
 	"google.golang.org/grpc/resolver"
@@ -279,13 +280,12 @@ func buildPickerAndState(m map[string]*pickerState) (connectivity.State, balance
 	return aggregatedState, newPickerGroup(readyPickerWithWeights)
 }
 
-type pickerGroup struct {
-	readyPickerWithWeights []pickerState
-	length                 int
+// RandomWRR constructor, to be modified in tests.
+var newRandomWRR = wrr.NewRandom
 
-	mu    sync.Mutex
-	idx   int    // The index of the picker that will be picked
-	count uint32 // The number of times the current picker has been picked.
+type pickerGroup struct {
+	length int
+	w      wrr.WRR
 }
 
 // newPickerGroup takes pickers with weights, and group them into one picker.
@@ -296,9 +296,14 @@ type pickerGroup struct {
 // TODO: (bg) confirm this is the expected behavior: non-ready balancers should
 // be ignored when picking. Only ready balancers are picked.
 func newPickerGroup(readyPickerWithWeights []pickerState) *pickerGroup {
+	w := newRandomWRR()
+	for _, ps := range readyPickerWithWeights {
+		w.Add(ps.picker, int64(ps.weight))
+	}
+
 	return &pickerGroup{
-		readyPickerWithWeights: readyPickerWithWeights,
-		length:                 len(readyPickerWithWeights),
+		length: len(readyPickerWithWeights),
+		w:      w,
 	}
 }
 
@@ -306,17 +311,7 @@ func (pg *pickerGroup) Pick(ctx context.Context, opts balancer.PickOptions) (con
 	if pg.length <= 0 {
 		return nil, nil, balancer.ErrNoSubConnAvailable
 	}
-	// TODO: the WRR algorithm needs a design.
-	// MAYBE: move WRR implmentation to util.go as a separate struct.
-	pg.mu.Lock()
-	pickerSt := pg.readyPickerWithWeights[pg.idx]
-	p := pickerSt.picker
-	pg.count++
-	if pg.count >= pickerSt.weight {
-		pg.idx = (pg.idx + 1) % pg.length
-		pg.count = 0
-	}
-	pg.mu.Unlock()
+	p := pg.w.Next().(balancer.Picker)
 	return p.Pick(ctx, opts)
 }
 
diff --git a/balancer/xds/edsbalancer/edsbalancer.go b/balancer/xds/edsbalancer/edsbalancer.go
index 67e39260..789d6d86 100644
--- a/balancer/xds/edsbalancer/edsbalancer.go
+++ b/balancer/xds/edsbalancer/edsbalancer.go
@@ -292,13 +292,10 @@ func newDropPicker(p balancer.Picker, drops []*dropper) *dropPicker {
 func (d *dropPicker) Pick(ctx context.Context, opts balancer.PickOptions) (conn balancer.SubConn, done func(balancer.DoneInfo), err error) {
 	var drop bool
 	for _, dp := range d.drops {
-		// It's necessary to call drop on all droppers if the droppers are
-		// stateful. For example, if the second drop only drops 1/2, and only
-		// drops even number picks, we need to call it's drop() even if the
-		// first dropper already returned true.
-		//
-		// It won't be necessary if droppers are stateless, like toss a coin.
-		drop = drop || dp.drop()
+		if dp.drop() {
+			drop = true
+			break
+		}
 	}
 	if drop {
 		return nil, nil, status.Errorf(codes.Unavailable, "RPC is dropped")
diff --git a/balancer/xds/edsbalancer/util.go b/balancer/xds/edsbalancer/util.go
index 0b1a397f..dfae1b51 100644
--- a/balancer/xds/edsbalancer/util.go
+++ b/balancer/xds/edsbalancer/util.go
@@ -18,41 +18,27 @@
 
 package edsbalancer
 
-import (
-	"sync"
-)
+import "google.golang.org/grpc/balancer/internal/wrr"
 
 type dropper struct {
 	// Drop rate will be numerator/denominator.
 	numerator   uint32
 	denominator uint32
-
-	mu sync.Mutex
-	i  uint32
+	w           wrr.WRR
 }
 
 func newDropper(numerator, denominator uint32) *dropper {
+	w := newRandomWRR()
+	w.Add(true, int64(numerator))
+	w.Add(false, int64(denominator-numerator))
+
 	return &dropper{
 		numerator:   numerator,
 		denominator: denominator,
+		w:           w,
 	}
 }
 
 func (d *dropper) drop() (ret bool) {
-	d.mu.Lock()
-	defer d.mu.Unlock()
-
-	// TODO: the drop algorithm needs a design.
-	// Currently, for drop rate 3/5:
-	// 0 1 2 3 4
-	// d d d n n
-	if d.i < d.numerator {
-		ret = true
-	}
-	d.i++
-	if d.i >= d.denominator {
-		d.i = 0
-	}
-
-	return
+	return d.w.Next().(bool)
 }
diff --git a/balancer/xds/edsbalancer/util_test.go b/balancer/xds/edsbalancer/util_test.go
index 68471b6c..318b9784 100644
--- a/balancer/xds/edsbalancer/util_test.go
+++ b/balancer/xds/edsbalancer/util_test.go
@@ -19,9 +19,58 @@
 package edsbalancer
 
 import (
+	"sync"
 	"testing"
+
+	"google.golang.org/grpc/balancer/internal/wrr"
 )
 
+// testWRR is a deterministic WRR implementation.
+//
+// The real implementation does random WRR. testWRR makes the balancer behavior
+// deterministic and easier to test.
+//
+// With {a: 2, b: 3}, the Next() results will be {a, a, b, b, b}.
+type testWRR struct {
+	itemsWithWeight []struct {
+		item   interface{}
+		weight int64
+	}
+	length int
+
+	mu    sync.Mutex
+	idx   int   // The index of the item that will be picked
+	count int64 // The number of times the current item has been picked.
+}
+
+func newTestWRR() wrr.WRR {
+	return &testWRR{}
+}
+
+func (twrr *testWRR) Add(item interface{}, weight int64) {
+	twrr.itemsWithWeight = append(twrr.itemsWithWeight, struct {
+		item   interface{}
+		weight int64
+	}{item: item, weight: weight})
+	twrr.length++
+}
+
+func (twrr *testWRR) Next() interface{} {
+	twrr.mu.Lock()
+	iww := twrr.itemsWithWeight[twrr.idx]
+	twrr.count++
+	if twrr.count >= iww.weight {
+		twrr.idx = (twrr.idx + 1) % twrr.length
+		twrr.count = 0
+	}
+	twrr.mu.Unlock()
+	return iww.item
+}
+
+func init() {
+	newRandomWRR = newTestWRR
+}
+
 func TestDropper(t *testing.T) {
 	const repeat = 2
 
diff --git a/go.mod b/go.mod
index 9f3ef3a5..b75c069a 100644
--- a/go.mod
+++ b/go.mod
@@ -7,6 +7,7 @@ require (
 	github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b
 	github.com/golang/mock v1.1.1
 	github.com/golang/protobuf v1.2.0
+	github.com/google/go-cmp v0.2.0
 	golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3
 	golang.org/x/net v0.0.0-20190311183353-d8887717615a
 	golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be
diff --git a/go.sum b/go.sum
index b8638ce7..2a172347 100644
--- a/go.sum
+++ b/go.sum
@@ -10,6 +10,8 @@ github.com/golang/mock v1.1.1 h1:G5FRp8JnTd7RQH5kemVNlMeyXQAztQ3mOWV95KxsXH8=
 github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
 github.com/golang/protobuf v1.2.0 h1:P3YflyNX/ehuJFLhxviNdFxQPkGK5cDcApsge1SqnvM=
 github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
+github.com/google/go-cmp v0.2.0 h1:+dTQ8DZQJz0Mb/HjFlkptS1FeQ4cWSnN941F8aEG4SQ=
+github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
 golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
 golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3 h1:XQyxROzUlZH+WIQwySDgnISgOivlhjIEwaQaJEJrrN0=
 golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=