diff --git a/balancer/xds/xds.go b/balancer/xds/xds.go index 7795724f..ba8485a9 100644 --- a/balancer/xds/xds.go +++ b/balancer/xds/xds.go @@ -25,7 +25,6 @@ package xds import ( "context" "encoding/json" - "errors" "fmt" "reflect" "sync" @@ -130,7 +129,7 @@ type xdsBalancer struct { config *xdsConfig // may change when passed a different service config xdsLB edsBalancerInterface fallbackLB balancer.Balancer - fallbackInitData *addressUpdate // may change when HandleResolved address is called + fallbackInitData *resolver.State // may change when HandleResolved address is called } func (x *xdsBalancer) startNewXDSClient(u *xdsConfig) { @@ -218,43 +217,93 @@ func (x *xdsBalancer) run() { } } +func getBalancerConfig(serviceConfig string) *xdsConfig { + sc := parseFullServiceConfig(serviceConfig) + if sc == nil { + return nil + } + var xdsConfigRaw json.RawMessage + for _, lbcfg := range sc.LoadBalancingConfig { + if lbcfg.Name == xdsName { + xdsConfigRaw = lbcfg.Config + break + } + } + var cfg xdsConfig + if err := json.Unmarshal(xdsConfigRaw, &cfg); err != nil { + grpclog.Warningf("unable to unmarshal balancer config %s into xds config", string(xdsConfigRaw)) + return nil + } + return &cfg +} + func (x *xdsBalancer) handleGRPCUpdate(update interface{}) { switch u := update.(type) { - case *addressUpdate: - if x.fallbackLB != nil { - x.fallbackLB.HandleResolvedAddrs(u.addrs, u.err) - } - x.fallbackInitData = u case *subConnStateUpdate: if x.xdsLB != nil { - x.xdsLB.HandleSubConnStateChange(u.sc, u.state) + x.xdsLB.HandleSubConnStateChange(u.sc, u.state.ConnectivityState) } if x.fallbackLB != nil { - x.fallbackLB.HandleSubConnStateChange(u.sc, u.state) + if lb, ok := x.fallbackLB.(balancer.V2Balancer); ok { + lb.UpdateSubConnState(u.sc, u.state) + } else { + x.fallbackLB.HandleSubConnStateChange(u.sc, u.state.ConnectivityState) + } } - case *xdsConfig: - if x.config == nil { - // The first time we get config, we just need to start the xdsClient. - x.startNewXDSClient(u) - x.config = u + case *resolver.State: + cfg := getBalancerConfig(u.ServiceConfig) + if cfg == nil { + // service config parsing failed. should never happen. And this parsing will be removed, once + // we support service config validation. return } - // With a different BalancerName, we need to create a new xdsClient. - // If current or previous ChildPolicy is nil, then we also need to recreate a new xdsClient. - // This is because with nil ChildPolicy xdsClient will do CDS request, while non-nil won't. - if u.BalancerName != x.config.BalancerName || (u.ChildPolicy == nil) != (x.config.ChildPolicy == nil) { - x.startNewXDSClient(u) + + var fallbackChanged bool + // service config has been updated. + if !reflect.DeepEqual(cfg, x.config) { + if x.config == nil { + // The first time we get config, we just need to start the xdsClient. + x.startNewXDSClient(cfg) + x.config = cfg + x.fallbackInitData = &resolver.State{ + Addresses: u.Addresses, + // TODO(yuxuanli): get the fallback balancer config once the validation change completes, where + // we can pass along the config struct. + } + return + } + + // With a different BalancerName, we need to create a new xdsClient. + // If current or previous ChildPolicy is nil, then we also need to recreate a new xdsClient. + // This is because with nil ChildPolicy xdsClient will do CDS request, while non-nil won't. + if cfg.BalancerName != x.config.BalancerName || (cfg.ChildPolicy == nil) != (x.config.ChildPolicy == nil) { + x.startNewXDSClient(cfg) + } + // We will update the xdsLB with the new child policy, if we got a different one and it's not nil. + // The nil case will be handled when the CDS response gets processed, we will update xdsLB at that time. + if x.xdsLB != nil && !reflect.DeepEqual(cfg.ChildPolicy, x.config.ChildPolicy) && cfg.ChildPolicy != nil { + x.xdsLB.HandleChildPolicy(cfg.ChildPolicy.Name, cfg.ChildPolicy.Config) + } + + if x.fallbackLB != nil && !reflect.DeepEqual(cfg.FallBackPolicy, x.config.FallBackPolicy) { + x.fallbackLB.Close() + x.buildFallBackBalancer(cfg) + fallbackChanged = true + } } - // We will update the xdsLB with the new child policy, if we got a different one and it's not nil. - // The nil case will be handled when the CDS response gets processed, we will update xdsLB at that time. - if !reflect.DeepEqual(u.ChildPolicy, x.config.ChildPolicy) && u.ChildPolicy != nil && x.xdsLB != nil { - x.xdsLB.HandleChildPolicy(u.ChildPolicy.Name, u.ChildPolicy.Config) + + if x.fallbackLB != nil && (!reflect.DeepEqual(x.fallbackInitData.Addresses, u.Addresses) || fallbackChanged) { + x.updateFallbackWithResolverState(&resolver.State{ + Addresses: u.Addresses, + }) } - if !reflect.DeepEqual(u.FallBackPolicy, x.config.FallBackPolicy) && x.fallbackLB != nil { - x.fallbackLB.Close() - x.startFallBackBalancer(u) + + x.config = cfg + x.fallbackInitData = &resolver.State{ + Addresses: u.Addresses, + // TODO(yuxuanli): get the fallback balancer config once the validation change completes, where + // we can pass along the config struct. } - x.config = u default: // unreachable path panic("wrong update type") @@ -341,17 +390,20 @@ func (w *xdsClientConn) UpdateBalancerState(s connectivity.State, p balancer.Pic w.ClientConn.UpdateBalancerState(s, p) } -type addressUpdate struct { - addrs []resolver.Address - err error -} - type subConnStateUpdate struct { sc balancer.SubConn - state connectivity.State + state balancer.SubConnState } func (x *xdsBalancer) HandleSubConnStateChange(sc balancer.SubConn, state connectivity.State) { + grpclog.Error("UpdateSubConnState should be called instead of HandleSubConnStateChange") +} + +func (x *xdsBalancer) HandleResolvedAddrs(addrs []resolver.Address, err error) { + grpclog.Error("UpdateResolverState should be called instead of HandleResolvedAddrs") +} + +func (x *xdsBalancer) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnState) { update := &subConnStateUpdate{ sc: sc, state: state, @@ -362,29 +414,24 @@ func (x *xdsBalancer) HandleSubConnStateChange(sc balancer.SubConn, state connec } } -func (x *xdsBalancer) HandleResolvedAddrs(addrs []resolver.Address, err error) { - update := &addressUpdate{ - addrs: addrs, - err: err, - } - select { - case x.grpcUpdate <- update: - case <-x.ctx.Done(): - } +type serviceConfig struct { + LoadBalancingConfig []*loadBalancingConfig } -// TODO: once the API is merged, check whether we need to change the function name/signature here. -func (x *xdsBalancer) HandleBalancerConfig(config json.RawMessage) error { - var cfg xdsConfig - if err := json.Unmarshal(config, &cfg); err != nil { - return errors.New("unable to unmarshal balancer config into xds config") +func parseFullServiceConfig(s string) *serviceConfig { + var ret serviceConfig + err := json.Unmarshal([]byte(s), &ret) + if err != nil { + return nil } + return &ret +} +func (x *xdsBalancer) UpdateResolverState(s resolver.State) { select { - case x.grpcUpdate <- &cfg: + case x.grpcUpdate <- &s: case <-x.ctx.Done(): } - return nil } type cdsResp struct { @@ -441,10 +488,23 @@ func (x *xdsBalancer) switchFallback() { x.xdsLB.Close() x.xdsLB = nil } - x.startFallBackBalancer(x.config) + x.buildFallBackBalancer(x.config) + x.updateFallbackWithResolverState(x.fallbackInitData) x.cancelFallbackMonitoring() } +func (x *xdsBalancer) updateFallbackWithResolverState(s *resolver.State) { + if lb, ok := x.fallbackLB.(balancer.V2Balancer); ok { + lb.UpdateResolverState(resolver.State{ + Addresses: s.Addresses, + // TODO(yuxuanli): get the fallback balancer config once the validation change completes, where + // we can pass along the config struct. + }) + } else { + x.fallbackLB.HandleResolvedAddrs(s.Addresses, nil) + } +} + // x.cancelFallbackAndSwitchEDSBalancerIfNecessary() will be no-op if we have a working xds client. // It will cancel fallback monitoring if we are in fallback monitoring stage. // If there's no running edsBalancer currently, it will create one and initialize it. Also, it will @@ -466,9 +526,9 @@ func (x *xdsBalancer) cancelFallbackAndSwitchEDSBalancerIfNecessary() { } } -func (x *xdsBalancer) startFallBackBalancer(c *xdsConfig) { +func (x *xdsBalancer) buildFallBackBalancer(c *xdsConfig) { if c.FallBackPolicy == nil { - x.startFallBackBalancer(&xdsConfig{ + x.buildFallBackBalancer(&xdsConfig{ FallBackPolicy: &loadBalancingConfig{ Name: "round_robin", }, @@ -479,11 +539,6 @@ func (x *xdsBalancer) startFallBackBalancer(c *xdsConfig) { // balancer is registered or not. builder := balancer.Get(c.FallBackPolicy.Name) x.fallbackLB = builder.Build(x.cc, x.buildOpts) - if x.fallbackInitData != nil { - // TODO: uncomment when HandleBalancerConfig API is merged. - //x.fallbackLB.HandleBalancerConfig(c.FallBackPolicy.Config) - x.fallbackLB.HandleResolvedAddrs(x.fallbackInitData.addrs, x.fallbackInitData.err) - } } // There are three ways that could lead to fallback: @@ -596,7 +651,9 @@ type loadBalancingConfig struct { } func (l *loadBalancingConfig) MarshalJSON() ([]byte, error) { - return nil, nil + m := make(map[string]json.RawMessage) + m[l.Name] = l.Config + return json.Marshal(m) } func (l *loadBalancingConfig) UnmarshalJSON(data []byte) error { diff --git a/balancer/xds/xds_test.go b/balancer/xds/xds_test.go index abe2e1b3..be9c4e2f 100644 --- a/balancer/xds/xds_test.go +++ b/balancer/xds/xds_test.go @@ -64,12 +64,13 @@ const ( ) var ( - testBalancerNameFooBar = "foo.bar" - testBalancerConfigFooBar, _ = json.Marshal(&testBalancerConfig{ + testBalancerNameFooBar = "foo.bar" + testServiceConfigFooBar = constructServiceConfigFromXdsConfig(&testBalancerConfig{ BalancerName: testBalancerNameFooBar, ChildPolicy: []lbPolicy{fakeBalancerA}, FallbackPolicy: []lbPolicy{fakeBalancerA}, }) + specialAddrForBalancerA = resolver.Address{Addr: "this.is.balancer.A"} specialAddrForBalancerB = resolver.Address{Addr: "this.is.balancer.B"} @@ -95,6 +96,19 @@ func (l lbPolicy) MarshalJSON() ([]byte, error) { return json.Marshal(m) } +func constructServiceConfigFromXdsConfig(xdsCfg *testBalancerConfig) string { + cfgRaw, _ := json.Marshal(xdsCfg) + sc, _ := json.Marshal(&serviceConfig{ + LoadBalancingConfig: []*loadBalancingConfig{ + { + Name: xdsName, + Config: cfgRaw, + }, + }, + }) + return string(sc) +} + type balancerABuilder struct { mu sync.Mutex lastBalancer *balancerA @@ -264,19 +278,19 @@ func (s) TestXdsBalanceHandleResolvedAddrs(t *testing.T) { t.Fatalf("unable to type assert to *xdsBalancer") } defer lb.Close() - if err := lb.HandleBalancerConfig(json.RawMessage(testBalancerConfigFooBar)); err != nil { - t.Fatalf("failed to HandleBalancerConfig(%v), due to err: %v", string(testBalancerConfigFooBar), err) - } addrs := []resolver.Address{{Addr: "1.1.1.1:10001"}, {Addr: "2.2.2.2:10002"}, {Addr: "3.3.3.3:10003"}} for i := 0; i < 3; i++ { - lb.HandleResolvedAddrs(addrs, nil) + lb.UpdateResolverState(resolver.State{ + Addresses: addrs, + ServiceConfig: string(testServiceConfigFooBar), + }) select { case nsc := <-cc.newSubConns: if !reflect.DeepEqual(append(addrs, specialAddrForBalancerA), nsc) { t.Fatalf("got new subconn address %v, want %v", nsc, append(addrs, specialAddrForBalancerA)) } case <-time.After(2 * time.Second): - t.Fatalf("timeout when geting new subconn result") + t.Fatal("timeout when geting new subconn result") } addrs = addrs[:2-i] } @@ -298,11 +312,11 @@ func (s) TestXdsBalanceHandleBalancerConfigBalancerNameUpdate(t *testing.T) { t.Fatalf("unable to type assert to *xdsBalancer") } defer lb.Close() - if err := lb.HandleBalancerConfig(json.RawMessage(testBalancerConfigFooBar)); err != nil { - t.Fatalf("failed to HandleBalancerConfig(%v), due to err: %v", string(testBalancerConfigFooBar), err) - } addrs := []resolver.Address{{Addr: "1.1.1.1:10001"}, {Addr: "2.2.2.2:10002"}, {Addr: "3.3.3.3:10003"}} - lb.HandleResolvedAddrs(addrs, nil) + lb.UpdateResolverState(resolver.State{ + Addresses: addrs, + ServiceConfig: string(testServiceConfigFooBar), + }) // verify fallback takes over select { @@ -325,15 +339,15 @@ func (s) TestXdsBalanceHandleBalancerConfigBalancerNameUpdate(t *testing.T) { for i := 0; i < 2; i++ { addr, td, cleanup := setupServer(t) cleanups = append(cleanups, cleanup) - workingBalancerConfig, _ := json.Marshal(&testBalancerConfig{ + workingServiceConfig := constructServiceConfigFromXdsConfig(&testBalancerConfig{ BalancerName: addr, ChildPolicy: []lbPolicy{fakeBalancerA}, FallbackPolicy: []lbPolicy{fakeBalancerA}, }) - - if err := lb.HandleBalancerConfig(json.RawMessage(workingBalancerConfig)); err != nil { - t.Fatalf("failed to HandleBalancerConfig(%v), due to err: %v", string(workingBalancerConfig), err) - } + lb.UpdateResolverState(resolver.State{ + Addresses: addrs, + ServiceConfig: string(workingServiceConfig), + }) td.sendResp(&response{resp: testEDSRespWithoutEndpoints}) var j int @@ -415,11 +429,10 @@ func (s) TestXdsBalanceHandleBalancerConfigChildPolicyUpdate(t *testing.T) { addr, td, cleanup := setupServer(t) cleanups = append(cleanups, cleanup) test.cfg.BalancerName = addr - workingBalancerConfig, _ := json.Marshal(test.cfg) - if err := lb.HandleBalancerConfig(json.RawMessage(workingBalancerConfig)); err != nil { - t.Fatalf("failed to HandleBalancerConfig(%v), due to err: %v", string(workingBalancerConfig), err) - } + lb.UpdateResolverState(resolver.State{ + ServiceConfig: constructServiceConfigFromXdsConfig(test.cfg), + }) if test.responseToSend != nil { td.sendResp(&response{resp: test.responseToSend}) } @@ -468,18 +481,14 @@ func (s) TestXdsBalanceHandleBalancerConfigFallbackUpdate(t *testing.T) { ChildPolicy: []lbPolicy{fakeBalancerA}, FallbackPolicy: []lbPolicy{fakeBalancerA}, } - workingBalancerConfig, _ := json.Marshal(cfg) - - if err := lb.HandleBalancerConfig(json.RawMessage(workingBalancerConfig)); err != nil { - t.Fatalf("failed to HandleBalancerConfig(%v), due to err: %v", string(workingBalancerConfig), err) - } + lb.UpdateResolverState(resolver.State{ + ServiceConfig: constructServiceConfigFromXdsConfig(cfg), + }) cfg.FallbackPolicy = []lbPolicy{fakeBalancerB} - workingBalancerConfig, _ = json.Marshal(cfg) - - if err := lb.HandleBalancerConfig(json.RawMessage(workingBalancerConfig)); err != nil { - t.Fatalf("failed to HandleBalancerConfig(%v), due to err: %v", string(workingBalancerConfig), err) - } + lb.UpdateResolverState(resolver.State{ + ServiceConfig: constructServiceConfigFromXdsConfig(cfg), + }) td.sendResp(&response{resp: testEDSRespWithoutEndpoints}) @@ -497,7 +506,10 @@ func (s) TestXdsBalanceHandleBalancerConfigFallbackUpdate(t *testing.T) { cleanup() addrs := []resolver.Address{{Addr: "1.1.1.1:10001"}, {Addr: "2.2.2.2:10002"}, {Addr: "3.3.3.3:10003"}} - lb.HandleResolvedAddrs(addrs, nil) + lb.UpdateResolverState(resolver.State{ + Addresses: addrs, + ServiceConfig: constructServiceConfigFromXdsConfig(cfg), + }) // verify fallback balancer B takes over select { @@ -510,10 +522,10 @@ func (s) TestXdsBalanceHandleBalancerConfigFallbackUpdate(t *testing.T) { } cfg.FallbackPolicy = []lbPolicy{fakeBalancerA} - workingBalancerConfig, _ = json.Marshal(cfg) - if err := lb.HandleBalancerConfig(json.RawMessage(workingBalancerConfig)); err != nil { - t.Fatalf("failed to HandleBalancerConfig(%v), due to err: %v", string(workingBalancerConfig), err) - } + lb.UpdateResolverState(resolver.State{ + Addresses: addrs, + ServiceConfig: constructServiceConfigFromXdsConfig(cfg), + }) // verify fallback balancer A takes over select { @@ -548,11 +560,9 @@ func (s) TestXdsBalancerHandlerSubConnStateChange(t *testing.T) { ChildPolicy: []lbPolicy{fakeBalancerA}, FallbackPolicy: []lbPolicy{fakeBalancerA}, } - workingBalancerConfig, _ := json.Marshal(cfg) - - if err := lb.HandleBalancerConfig(json.RawMessage(workingBalancerConfig)); err != nil { - t.Fatalf("failed to HandleBalancerConfig(%v), due to err: %v", string(workingBalancerConfig), err) - } + lb.UpdateResolverState(resolver.State{ + ServiceConfig: constructServiceConfigFromXdsConfig(cfg), + }) td.sendResp(&response{resp: testEDSRespWithoutEndpoints}) @@ -564,7 +574,7 @@ func (s) TestXdsBalancerHandlerSubConnStateChange(t *testing.T) { var i int for i = 0; i < 10; i++ { if edsLB := getLatestEdsBalancer(); edsLB != nil { - lb.HandleSubConnStateChange(expectedScStateChange.sc, expectedScStateChange.state) + lb.UpdateSubConnState(expectedScStateChange.sc, balancer.SubConnState{ConnectivityState: expectedScStateChange.state}) select { case scsc := <-edsLB.subconnStateChange: if !reflect.DeepEqual(scsc, expectedScStateChange) { @@ -590,7 +600,7 @@ func (s) TestXdsBalancerHandlerSubConnStateChange(t *testing.T) { // fallback balancer A takes over for i = 0; i < 10; i++ { if fblb := lbABuilder.getLastBalancer(); fblb != nil { - lb.HandleSubConnStateChange(expectedScStateChange.sc, expectedScStateChange.state) + lb.UpdateSubConnState(expectedScStateChange.sc, balancer.SubConnState{ConnectivityState: expectedScStateChange.state}) select { case scsc := <-fblb.subconnStateChange: if !reflect.DeepEqual(scsc, expectedScStateChange) { @@ -630,11 +640,9 @@ func (s) TestXdsBalancerFallbackSignalFromEdsBalancer(t *testing.T) { ChildPolicy: []lbPolicy{fakeBalancerA}, FallbackPolicy: []lbPolicy{fakeBalancerA}, } - workingBalancerConfig, _ := json.Marshal(cfg) - - if err := lb.HandleBalancerConfig(json.RawMessage(workingBalancerConfig)); err != nil { - t.Fatalf("failed to HandleBalancerConfig(%v), due to err: %v", string(workingBalancerConfig), err) - } + lb.UpdateResolverState(resolver.State{ + ServiceConfig: constructServiceConfigFromXdsConfig(cfg), + }) td.sendResp(&response{resp: testEDSRespWithoutEndpoints}) @@ -646,7 +654,7 @@ func (s) TestXdsBalancerFallbackSignalFromEdsBalancer(t *testing.T) { var i int for i = 0; i < 10; i++ { if edsLB := getLatestEdsBalancer(); edsLB != nil { - lb.HandleSubConnStateChange(expectedScStateChange.sc, expectedScStateChange.state) + lb.UpdateSubConnState(expectedScStateChange.sc, balancer.SubConnState{ConnectivityState: expectedScStateChange.state}) select { case scsc := <-edsLB.subconnStateChange: if !reflect.DeepEqual(scsc, expectedScStateChange) { @@ -672,7 +680,7 @@ func (s) TestXdsBalancerFallbackSignalFromEdsBalancer(t *testing.T) { // fallback balancer A takes over for i = 0; i < 10; i++ { if fblb := lbABuilder.getLastBalancer(); fblb != nil { - lb.HandleSubConnStateChange(expectedScStateChange.sc, expectedScStateChange.state) + lb.UpdateSubConnState(expectedScStateChange.sc, balancer.SubConnState{ConnectivityState: expectedScStateChange.state}) select { case scsc := <-fblb.subconnStateChange: if !reflect.DeepEqual(scsc, expectedScStateChange) { @@ -710,3 +718,84 @@ func (s) TestXdsBalancerConfigParsingSelectingLBPolicy(t *testing.T) { t.Fatalf("got fallback policy %v, want %v", xdsCfg.FallBackPolicy, wantFallbackPolicy) } } + +func (s) TestXdsFullServiceConfigParsing(t *testing.T) { + tests := []struct { + name string + s string + want *serviceConfig + }{ + { + name: "empty", + s: "", + want: nil, + }, + { + name: "success1", + s: `{"loadBalancingConfig":[{"xds":{"childPolicy":[{"pick_first":{}}]}}]}`, + want: &serviceConfig{ + LoadBalancingConfig: []*loadBalancingConfig{ + {"xds", json.RawMessage(`{"childPolicy":[{"pick_first":{}}]}`)}, + }, + }, + }, + { + name: "success2", + s: `{"loadBalancingConfig":[{"xds":{"childPolicy":[{"round_robin":{}},{"pick_first":{}}]}}]}`, + want: &serviceConfig{ + LoadBalancingConfig: []*loadBalancingConfig{ + {"xds", json.RawMessage(`{"childPolicy":[{"round_robin":{}},{"pick_first":{}}]}`)}, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := parseFullServiceConfig(tt.s); !reflect.DeepEqual(got, tt.want) { + t.Errorf("test name: %s, parseFullServiceConfig() = %+v, want %+v", tt.name, got, tt.want) + } + }) + } +} + +func (s) TestXdsLoadbalancingConfigParsing(t *testing.T) { + tests := []struct { + name string + s string + want *xdsConfig + }{ + { + name: "empty", + s: "{}", + want: &xdsConfig{}, + }, + { + name: "success1", + s: `{"childPolicy":[{"pick_first":{}}]}`, + want: &xdsConfig{ + ChildPolicy: &loadBalancingConfig{ + Name: "pick_first", + Config: json.RawMessage(`{}`), + }, + }, + }, + { + name: "success2", + s: `{"childPolicy":[{"round_robin":{}},{"pick_first":{}}]}`, + want: &xdsConfig{ + ChildPolicy: &loadBalancingConfig{ + Name: "round_robin", + Config: json.RawMessage(`{}`), + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var cfg xdsConfig + if err := json.Unmarshal([]byte(tt.s), &cfg); err != nil || !reflect.DeepEqual(&cfg, tt.want) { + t.Errorf("test name: %s, parseFullServiceConfig() = %+v, err: %v, want %+v, ", tt.name, cfg, err, tt.want) + } + }) + } +}