From 157dc9f3e45215af512dafc638af3133fb2e396d Mon Sep 17 00:00:00 2001 From: Menghan Li Date: Tue, 26 Nov 2019 15:53:20 -0800 Subject: [PATCH] xds: cleanup eds tests to use fakexds server (#3209) And move LRS test server to fakexds package. --- xds/internal/balancer/xds_client_test.go | 277 +++++++---------------- xds/internal/balancer/xds_lrs_test.go | 70 +----- xds/internal/balancer/xds_test.go | 13 +- xds/internal/client/fakexds/fakexds.go | 10 +- xds/internal/client/fakexds/lrsserver.go | 104 +++++++++ 5 files changed, 213 insertions(+), 261 deletions(-) create mode 100644 xds/internal/client/fakexds/lrsserver.go diff --git a/xds/internal/balancer/xds_client_test.go b/xds/internal/balancer/xds_client_test.go index 75b18129..2f1b077e 100644 --- a/xds/internal/balancer/xds_client_test.go +++ b/xds/internal/balancer/xds_client_test.go @@ -19,29 +19,24 @@ package balancer import ( - "net" "testing" xdspb "github.com/envoyproxy/go-control-plane/envoy/api/v2" corepb "github.com/envoyproxy/go-control-plane/envoy/api/v2/core" endpointpb "github.com/envoyproxy/go-control-plane/envoy/api/v2/endpoint" - xdsgrpc "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v2" - lrsgrpc "github.com/envoyproxy/go-control-plane/envoy/service/load_stats/v2" "github.com/golang/protobuf/proto" anypb "github.com/golang/protobuf/ptypes/any" - durationpb "github.com/golang/protobuf/ptypes/duration" structpb "github.com/golang/protobuf/ptypes/struct" wrpb "github.com/golang/protobuf/ptypes/wrappers" "github.com/google/go-cmp/cmp" "google.golang.org/grpc" "google.golang.org/grpc/attributes" "google.golang.org/grpc/balancer" - "google.golang.org/grpc/codes" "google.golang.org/grpc/resolver" - "google.golang.org/grpc/status" xdsinternal "google.golang.org/grpc/xds/internal" xdsclient "google.golang.org/grpc/xds/internal/client" "google.golang.org/grpc/xds/internal/client/bootstrap" + "google.golang.org/grpc/xds/internal/client/fakexds" ) const ( @@ -108,166 +103,64 @@ var ( } ) -// TODO: remove all usage of testTrafficDirector, and use fakexds server -// instead. -type testTrafficDirector struct { - reqChan chan *request - respChan chan *response -} - -type request struct { - req *xdspb.DiscoveryRequest - err error -} - -type response struct { - resp *xdspb.DiscoveryResponse - err error -} - -func (ttd *testTrafficDirector) StreamAggregatedResources(s xdsgrpc.AggregatedDiscoveryService_StreamAggregatedResourcesServer) error { - go func() { - for { - req, err := s.Recv() - if err != nil { - ttd.reqChan <- &request{ - req: nil, - err: err, - } - return - } - ttd.reqChan <- &request{ - req: req, - err: nil, - } - } - }() - - for { - select { - case resp := <-ttd.respChan: - if resp.err != nil { - return resp.err - } - if err := s.Send(resp.resp); err != nil { - return err - } - case <-s.Context().Done(): - return s.Context().Err() - } - } -} - -func (ttd *testTrafficDirector) DeltaAggregatedResources(xdsgrpc.AggregatedDiscoveryService_DeltaAggregatedResourcesServer) error { - return status.Error(codes.Unimplemented, "") -} - -func (ttd *testTrafficDirector) sendResp(resp *response) { - ttd.respChan <- resp -} - -func (ttd *testTrafficDirector) getReq() *request { - return <-ttd.reqChan -} - -func newTestTrafficDirector() *testTrafficDirector { - return &testTrafficDirector{ - reqChan: make(chan *request, 10), - respChan: make(chan *response, 10), - } -} - -type testConfig struct { - edsServiceName string - expectedRequests []*xdspb.DiscoveryRequest - responsesToSend []*xdspb.DiscoveryResponse - expectedADSResponses []proto.Message -} - -func setupServer(t *testing.T) (addr string, td *testTrafficDirector, lrss *lrsServer, cleanup func()) { - lis, err := net.Listen("tcp", "localhost:0") - if err != nil { - t.Fatalf("listen failed due to: %v", err) - } - svr := grpc.NewServer() - td = newTestTrafficDirector() - lrss = &lrsServer{ - drops: make(map[string]uint64), - reportingInterval: &durationpb.Duration{ - Seconds: 60 * 60, // 1 hour, each test can override this to a shorter duration. - Nanos: 0, - }, - } - xdsgrpc.RegisterAggregatedDiscoveryServiceServer(svr, td) - lrsgrpc.RegisterLoadReportingServiceServer(svr, lrss) - go svr.Serve(lis) - return lis.Addr().String(), td, lrss, func() { - svr.Stop() - lis.Close() - } -} - -func (s) TestXdsClientResponseHandling(t *testing.T) { - for _, test := range []*testConfig{ - { - // Test that if clusterName is not set, dialing target is used. - expectedRequests: []*xdspb.DiscoveryRequest{{ - TypeUrl: edsType, - ResourceNames: []string{testServiceName}, // ResourceName is dialing target. - Node: &corepb.Node{}, - }}, - }, - { - edsServiceName: testEDSClusterName, - expectedRequests: []*xdspb.DiscoveryRequest{{ - TypeUrl: edsType, - ResourceNames: []string{testEDSClusterName}, - Node: &corepb.Node{}, - }}, - responsesToSend: []*xdspb.DiscoveryResponse{testEDSResp}, - expectedADSResponses: []proto.Message{testClusterLoadAssignment}, - }, - } { - testXdsClientResponseHandling(t, test) - } -} - -func testXdsClientResponseHandling(t *testing.T, test *testConfig) { - addr, td, _, cleanup := setupServer(t) +func (s) TestEDSClientResponseHandling(t *testing.T) { + td, cleanup := fakexds.StartServer(t) defer cleanup() - adsChan := make(chan *xdsclient.EDSUpdate, 10) - newADS := func(i *xdsclient.EDSUpdate) error { - adsChan <- i + edsRespChan := make(chan *xdsclient.EDSUpdate, 10) + newEDS := func(i *xdsclient.EDSUpdate) error { + edsRespChan <- i return nil } - client := newXDSClientWrapper(newADS, func() {}, balancer.BuildOptions{Target: resolver.Target{Endpoint: testServiceName}}, nil) + client := newXDSClientWrapper(newEDS, func() {}, balancer.BuildOptions{Target: resolver.Target{Endpoint: testServiceName}}, nil) defer client.close() - client.handleUpdate(&XDSConfig{ - BalancerName: addr, - EDSServiceName: test.edsServiceName, - LrsLoadReportingServerName: "", - }, nil) - for _, expectedReq := range test.expectedRequests { - req := td.getReq() - if req.err != nil { - t.Fatalf("ads RPC failed with err: %v", req.err) + // Test that EDS requests sent match XDSConfig. + for _, test := range []struct { + name string + edsServiceName string + wantResourceName string + }{ + { + name: "empty-edsServiceName", + edsServiceName: "", + // EDSServiceName is an empty string, user's dialing target will be + // set in resource names. + wantResourceName: testServiceName, + }, + { + name: "non-empty-edsServiceName", + edsServiceName: testEDSClusterName, + wantResourceName: testEDSClusterName, + }, + } { + client.handleUpdate(&XDSConfig{ + BalancerName: td.Address, + EDSServiceName: test.edsServiceName, + LrsLoadReportingServerName: "", + }, nil) + req := <-td.RequestChan + if req.Err != nil { + t.Fatalf("EDS RPC failed with err: %v", req.Err) } - if !proto.Equal(req.req, expectedReq) { - t.Fatalf("got ADS request %T, expected: %T, diff: %s", req.req, expectedReq, cmp.Diff(req.req, expectedReq, cmp.Comparer(proto.Equal))) + wantReq1 := &xdspb.DiscoveryRequest{ + TypeUrl: edsType, + ResourceNames: []string{test.wantResourceName}, + Node: &corepb.Node{}, + } + if !proto.Equal(req.Req, wantReq1) { + t.Fatalf("%v: got EDS request %v, expected: %v, diff: %s", test.name, req.Req, wantReq1, cmp.Diff(req.Req, wantReq1, cmp.Comparer(proto.Equal))) } } - for i, resp := range test.responsesToSend { - td.sendResp(&response{resp: resp}) - ads := <-adsChan - want, err := xdsclient.ParseEDSRespProto(test.expectedADSResponses[i].(*xdspb.ClusterLoadAssignment)) - if err != nil { - t.Fatalf("parsing wanted EDS response failed: %v", err) - } - if !cmp.Equal(ads, want) { - t.Fatalf("received unexpected ads response, got %v, want %v", ads, test.expectedADSResponses[i]) - } + // Make sure that the responses from the stream are also handled. + td.ResponseChan <- &fakexds.Response{Resp: testEDSResp} + gotResp := <-edsRespChan + want, err := xdsclient.ParseEDSRespProto(testClusterLoadAssignment) + if err != nil { + t.Fatalf("parsing wanted EDS response failed: %v", err) + } + if !cmp.Equal(gotResp, want) { + t.Fatalf("received unexpected EDS response, got %v, want %v", gotResp, want) } } @@ -276,10 +169,10 @@ func testXdsClientResponseHandling(t *testing.T, test *testConfig) { // // And also that when xds_client in attributes is updated, the new one will be // used, and watch will be restarted. -func (s) TestXdsClientInAttributes(t *testing.T) { - adsChan := make(chan *xdsclient.EDSUpdate, 10) - newADS := func(i *xdsclient.EDSUpdate) error { - adsChan <- i +func (s) TestEDSClientInAttributes(t *testing.T) { + edsRespChan := make(chan *xdsclient.EDSUpdate, 10) + newEDS := func(i *xdsclient.EDSUpdate) error { + edsRespChan <- i return nil } @@ -290,12 +183,12 @@ func (s) TestXdsClientInAttributes(t *testing.T) { } defer func() { xdsclientNew = oldxdsclientNew }() - addr, td, _, cleanup := setupServer(t) + td, cleanup := fakexds.StartServer(t) defer cleanup() // Create a client to be passed in attributes. c, _ := oldxdsclientNew(xdsclient.Options{ Config: bootstrap.Config{ - BalancerName: addr, + BalancerName: td.Address, Creds: grpc.WithInsecure(), NodeProto: &corepb.Node{}, }, @@ -304,7 +197,7 @@ func (s) TestXdsClientInAttributes(t *testing.T) { // from attributes). defer c.Close() - client := newXDSClientWrapper(newADS, func() {}, balancer.BuildOptions{Target: resolver.Target{Endpoint: testServiceName}}, nil) + client := newXDSClientWrapper(newEDS, func() {}, balancer.BuildOptions{Target: resolver.Target{Endpoint: testServiceName}}, nil) defer client.close() client.handleUpdate( @@ -319,20 +212,20 @@ func (s) TestXdsClientInAttributes(t *testing.T) { } // Make sure the requests are sent to the correct td. - req := td.getReq() - if req.err != nil { - t.Fatalf("ads RPC failed with err: %v", req.err) + req := <-td.RequestChan + if req.Err != nil { + t.Fatalf("EDS RPC failed with err: %v", req.Err) } - if !proto.Equal(req.req, expectedReq) { - t.Fatalf("got ADS request %T, expected: %T, diff: %s", req.req, expectedReq, cmp.Diff(req.req, expectedReq, cmp.Comparer(proto.Equal))) + if !proto.Equal(req.Req, expectedReq) { + t.Fatalf("got EDS request %T, expected: %T, diff: %s", req.Req, expectedReq, cmp.Diff(req.Req, expectedReq, cmp.Comparer(proto.Equal))) } - addr2, td2, _, cleanup2 := setupServer(t) + td2, cleanup2 := fakexds.StartServer(t) defer cleanup2() // Create a client to be passed in attributes. c2, _ := oldxdsclientNew(xdsclient.Options{ Config: bootstrap.Config{ - BalancerName: addr2, + BalancerName: td2.Address, Creds: grpc.WithInsecure(), NodeProto: &corepb.Node{}, }, @@ -356,21 +249,21 @@ func (s) TestXdsClientInAttributes(t *testing.T) { } // Make sure the requests are sent to the correct td. - req2 := td2.getReq() - if req.err != nil { - t.Fatalf("ads RPC failed with err: %v", req.err) + req2 := <-td2.RequestChan + if req.Err != nil { + t.Fatalf("EDS RPC failed with err: %v", req.Err) } - if !proto.Equal(req2.req, expectedReq2) { - t.Fatalf("got ADS request %T, expected: %T, diff: %s", req2.req, expectedReq, cmp.Diff(req2.req, expectedReq2, cmp.Comparer(proto.Equal))) + if !proto.Equal(req2.Req, expectedReq2) { + t.Fatalf("got EDS request %T, expected: %T, diff: %s", req2.Req, expectedReq, cmp.Diff(req2.Req, expectedReq2, cmp.Comparer(proto.Equal))) } } // Test that when edsServiceName from service config is updated, the new one // will be watched. func (s) TestEDSServiceNameUpdate(t *testing.T) { - adsChan := make(chan *xdsclient.EDSUpdate, 10) - newADS := func(i *xdsclient.EDSUpdate) error { - adsChan <- i + edsRespChan := make(chan *xdsclient.EDSUpdate, 10) + newEDS := func(i *xdsclient.EDSUpdate) error { + edsRespChan <- i return nil } @@ -381,12 +274,12 @@ func (s) TestEDSServiceNameUpdate(t *testing.T) { } defer func() { xdsclientNew = oldxdsclientNew }() - addr, td, _, cleanup := setupServer(t) + td, cleanup := fakexds.StartServer(t) defer cleanup() // Create a client to be passed in attributes. c, _ := oldxdsclientNew(xdsclient.Options{ Config: bootstrap.Config{ - BalancerName: addr, + BalancerName: td.Address, Creds: grpc.WithInsecure(), NodeProto: &corepb.Node{}, }, @@ -395,7 +288,7 @@ func (s) TestEDSServiceNameUpdate(t *testing.T) { // from attributes). defer c.Close() - client := newXDSClientWrapper(newADS, func() {}, balancer.BuildOptions{Target: resolver.Target{Endpoint: testServiceName}}, nil) + client := newXDSClientWrapper(newEDS, func() {}, balancer.BuildOptions{Target: resolver.Target{Endpoint: testServiceName}}, nil) defer client.close() client.handleUpdate( @@ -410,12 +303,12 @@ func (s) TestEDSServiceNameUpdate(t *testing.T) { } // Make sure the requests are sent to the correct td. - req := td.getReq() - if req.err != nil { - t.Fatalf("ads RPC failed with err: %v", req.err) + req := <-td.RequestChan + if req.Err != nil { + t.Fatalf("EDS RPC failed with err: %v", req.Err) } - if !proto.Equal(req.req, expectedReq) { - t.Fatalf("got ADS request %T, expected: %T, diff: %s", req.req, expectedReq, cmp.Diff(req.req, expectedReq, cmp.Comparer(proto.Equal))) + if !proto.Equal(req.Req, expectedReq) { + t.Fatalf("got EDS request %T, expected: %T, diff: %s", req.Req, expectedReq, cmp.Diff(req.Req, expectedReq, cmp.Comparer(proto.Equal))) } // Update with a new edsServiceName. @@ -433,11 +326,11 @@ func (s) TestEDSServiceNameUpdate(t *testing.T) { } // Make sure the requests are sent to the correct td. - req2 := td.getReq() - if req.err != nil { - t.Fatalf("ads RPC failed with err: %v", req.err) + req2 := <-td.RequestChan + if req.Err != nil { + t.Fatalf("EDS RPC failed with err: %v", req.Err) } - if !proto.Equal(req2.req, expectedReq2) { - t.Fatalf("got ADS request %T, expected: %T, diff: %s", req2.req, expectedReq, cmp.Diff(req2.req, expectedReq2, cmp.Comparer(proto.Equal))) + if !proto.Equal(req2.Req, expectedReq2) { + t.Fatalf("got EDS request %T, expected: %T, diff: %s", req2.Req, expectedReq, cmp.Diff(req2.Req, expectedReq2, cmp.Comparer(proto.Equal))) } } diff --git a/xds/internal/balancer/xds_lrs_test.go b/xds/internal/balancer/xds_lrs_test.go index 89a99cd5..980ce11b 100644 --- a/xds/internal/balancer/xds_lrs_test.go +++ b/xds/internal/balancer/xds_lrs_test.go @@ -19,69 +19,16 @@ package balancer import ( - "io" - "sync" "testing" "time" - corepb "github.com/envoyproxy/go-control-plane/envoy/api/v2/core" - endpointpb "github.com/envoyproxy/go-control-plane/envoy/api/v2/endpoint" - lrsgrpc "github.com/envoyproxy/go-control-plane/envoy/service/load_stats/v2" - lrspb "github.com/envoyproxy/go-control-plane/envoy/service/load_stats/v2" - "github.com/golang/protobuf/proto" durationpb "github.com/golang/protobuf/ptypes/duration" "github.com/google/go-cmp/cmp" "google.golang.org/grpc/balancer" - "google.golang.org/grpc/codes" "google.golang.org/grpc/resolver" - "google.golang.org/grpc/status" + "google.golang.org/grpc/xds/internal/client/fakexds" ) -type lrsServer struct { - mu sync.Mutex - dropTotal uint64 - drops map[string]uint64 - reportingInterval *durationpb.Duration -} - -func (lrss *lrsServer) StreamLoadStats(stream lrsgrpc.LoadReportingService_StreamLoadStatsServer) error { - req, err := stream.Recv() - if err != nil { - return err - } - if !proto.Equal(req, &lrspb.LoadStatsRequest{ - ClusterStats: []*endpointpb.ClusterStats{{ - ClusterName: testEDSClusterName, - }}, - Node: &corepb.Node{}, - }) { - return status.Errorf(codes.FailedPrecondition, "unexpected req: %+v", req) - } - if err := stream.Send(&lrspb.LoadStatsResponse{ - Clusters: []string{testEDSClusterName}, - LoadReportingInterval: lrss.reportingInterval, - }); err != nil { - return err - } - - for { - req, err := stream.Recv() - if err != nil { - if err == io.EOF { - return nil - } - return err - } - stats := req.ClusterStats[0] - lrss.mu.Lock() - lrss.dropTotal += stats.TotalDroppedRequests - for _, d := range stats.DroppedRequests { - lrss.drops[d.Category] += d.DroppedCount - } - lrss.mu.Unlock() - } -} - func (s) TestXdsLoadReporting(t *testing.T) { originalNewEDSBalancer := newEDSBalancer newEDSBalancer = newFakeEDSBalancer @@ -97,20 +44,21 @@ func (s) TestXdsLoadReporting(t *testing.T) { } defer lb.Close() - addr, td, lrss, cleanup := setupServer(t) + td, cleanup := fakexds.StartServer(t) defer cleanup() const intervalNano = 1000 * 1000 * 50 - lrss.reportingInterval = &durationpb.Duration{ + td.LRS.ReportingInterval = &durationpb.Duration{ Seconds: 0, Nanos: intervalNano, } + td.LRS.ExpectedEDSClusterName = testEDSClusterName cfg := &XDSConfig{ - BalancerName: addr, + BalancerName: td.Address, } lb.UpdateClientConnState(balancer.ClientConnState{BalancerConfig: cfg}) - td.sendResp(&response{resp: testEDSResp}) + td.ResponseChan <- &fakexds.Response{Resp: testEDSResp} var ( i int edsLB *fakeEDSBalancer @@ -140,9 +88,7 @@ func (s) TestXdsLoadReporting(t *testing.T) { } time.Sleep(time.Nanosecond * intervalNano * 2) - lrss.mu.Lock() - defer lrss.mu.Unlock() - if !cmp.Equal(lrss.drops, drops) { - t.Errorf("different: %v %v %v", lrss.drops, drops, cmp.Diff(lrss.drops, drops)) + if got := td.LRS.GetDrops(); !cmp.Equal(got, drops) { + t.Errorf("different: %v %v %v", got, drops, cmp.Diff(got, drops)) } } diff --git a/xds/internal/balancer/xds_test.go b/xds/internal/balancer/xds_test.go index 6da975b0..cf92f479 100644 --- a/xds/internal/balancer/xds_test.go +++ b/xds/internal/balancer/xds_test.go @@ -42,6 +42,7 @@ import ( "google.golang.org/grpc/xds/internal/balancer/lrs" xdsclient "google.golang.org/grpc/xds/internal/client" "google.golang.org/grpc/xds/internal/client/bootstrap" + "google.golang.org/grpc/xds/internal/client/fakexds" ) var lbABuilder = &balancerABuilder{} @@ -378,10 +379,10 @@ func (s) TestXdsBalanceHandleBalancerConfigBalancerNameUpdate(t *testing.T) { // In the first iteration, an eds balancer takes over fallback balancer // In the second iteration, a new xds client takes over previous one. for i := 0; i < 2; i++ { - addr, td, _, cleanup := setupServer(t) + td, cleanup := fakexds.StartServer(t) cleanups = append(cleanups, cleanup) workingLBConfig := &XDSConfig{ - BalancerName: addr, + BalancerName: td.Address, ChildPolicy: &loadBalancingConfig{Name: fakeBalancerA}, FallBackPolicy: &loadBalancingConfig{Name: fakeBalancerA}, EDSServiceName: testEDSClusterName, @@ -390,7 +391,7 @@ func (s) TestXdsBalanceHandleBalancerConfigBalancerNameUpdate(t *testing.T) { ResolverState: resolver.State{Addresses: addrs}, BalancerConfig: workingLBConfig, }) - td.sendResp(&response{resp: testEDSResp}) + td.ResponseChan <- &fakexds.Response{Resp: testEDSResp} var j int for j = 0; j < 10; j++ { @@ -472,13 +473,13 @@ func (s) TestXdsBalanceHandleBalancerConfigChildPolicyUpdate(t *testing.T) { }, }, } { - addr, td, _, cleanup := setupServer(t) + td, cleanup := fakexds.StartServer(t) cleanups = append(cleanups, cleanup) - test.cfg.BalancerName = addr + test.cfg.BalancerName = td.Address lb.UpdateClientConnState(balancer.ClientConnState{BalancerConfig: test.cfg}) if test.responseToSend != nil { - td.sendResp(&response{resp: test.responseToSend}) + td.ResponseChan <- &fakexds.Response{Resp: test.responseToSend} } var i int for i = 0; i < 10; i++ { diff --git a/xds/internal/client/fakexds/fakexds.go b/xds/internal/client/fakexds/fakexds.go index 814100fa..9505e5ac 100644 --- a/xds/internal/client/fakexds/fakexds.go +++ b/xds/internal/client/fakexds/fakexds.go @@ -32,6 +32,7 @@ import ( discoverypb "github.com/envoyproxy/go-control-plane/envoy/api/v2" adsgrpc "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v2" + lrsgrpc "github.com/envoyproxy/go-control-plane/envoy/service/load_stats/v2" ) // TODO: Make this a var or a field in the server if there is a need to use a @@ -62,6 +63,8 @@ type Server struct { ResponseChan chan *Response // Address is the host:port on which the fake xdsServer is listening on. Address string + // LRS is the LRS server installed. + LRS *LRSServer } // StartServer starts a fakexds.Server. The returned function should be invoked @@ -73,14 +76,19 @@ func StartServer(t *testing.T) (*Server, func()) { if err != nil { t.Fatalf("net.Listen() failed: %v", err) } - server := grpc.NewServer() + + lrss := newLRSServer() + lrsgrpc.RegisterLoadReportingServiceServer(server, lrss) + fs := &Server{ RequestChan: make(chan *Request, defaultChannelBufferSize), ResponseChan: make(chan *Response, defaultChannelBufferSize), Address: lis.Addr().String(), + LRS: lrss, } adsgrpc.RegisterAggregatedDiscoveryServiceServer(server, fs) + go server.Serve(lis) t.Logf("Starting fake xDS server at %v...", fs.Address) diff --git a/xds/internal/client/fakexds/lrsserver.go b/xds/internal/client/fakexds/lrsserver.go new file mode 100644 index 00000000..55a6feef --- /dev/null +++ b/xds/internal/client/fakexds/lrsserver.go @@ -0,0 +1,104 @@ +/* + * + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fakexds + +import ( + "io" + "sync" + + corepb "github.com/envoyproxy/go-control-plane/envoy/api/v2/core" + endpointpb "github.com/envoyproxy/go-control-plane/envoy/api/v2/endpoint" + lrsgrpc "github.com/envoyproxy/go-control-plane/envoy/service/load_stats/v2" + lrspb "github.com/envoyproxy/go-control-plane/envoy/service/load_stats/v2" + "github.com/golang/protobuf/proto" + durationpb "github.com/golang/protobuf/ptypes/duration" + "github.com/google/go-cmp/cmp" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// LRSServer implements the LRS service, and is to be installed on the fakexds +// server. It collects load reports, and returned them later for comparison. +type LRSServer struct { + // ReportingInterval will be sent in the first response to control reporting + // interval. + ReportingInterval *durationpb.Duration + // ExpectedEDSClusterName is checked against the first LRS request. The RPC + // is failed if they don't match. + ExpectedEDSClusterName string + + mu sync.Mutex + dropTotal uint64 + drops map[string]uint64 +} + +func newLRSServer() *LRSServer { + return &LRSServer{ + drops: make(map[string]uint64), + ReportingInterval: &durationpb.Duration{ + Seconds: 60 * 60, // 1 hour, each test can override this to a shorter duration. + }, + } +} + +// StreamLoadStats implements LRS service. +func (lrss *LRSServer) StreamLoadStats(stream lrsgrpc.LoadReportingService_StreamLoadStatsServer) error { + req, err := stream.Recv() + if err != nil { + return err + } + wantReq := &lrspb.LoadStatsRequest{ + ClusterStats: []*endpointpb.ClusterStats{{ + ClusterName: lrss.ExpectedEDSClusterName, + }}, + Node: &corepb.Node{}, + } + if !proto.Equal(req, wantReq) { + return status.Errorf(codes.FailedPrecondition, "unexpected req: %+v, want %+v, diff: %s", req, wantReq, cmp.Diff(req, wantReq, cmp.Comparer(proto.Equal))) + } + if err := stream.Send(&lrspb.LoadStatsResponse{ + Clusters: []string{lrss.ExpectedEDSClusterName}, + LoadReportingInterval: lrss.ReportingInterval, + }); err != nil { + return err + } + + for { + req, err := stream.Recv() + if err != nil { + if err == io.EOF { + return nil + } + return err + } + stats := req.ClusterStats[0] + lrss.mu.Lock() + lrss.dropTotal += stats.TotalDroppedRequests + for _, d := range stats.DroppedRequests { + lrss.drops[d.Category] += d.DroppedCount + } + lrss.mu.Unlock() + } +} + +// GetDrops returns the drops reported to this server. +func (lrss *LRSServer) GetDrops() map[string]uint64 { + lrss.mu.Lock() + defer lrss.mu.Unlock() + return lrss.drops +}