Register and use default balancers and resolvers (#1551)
This commit is contained in:
		| @ -33,8 +33,6 @@ import ( | |||||||
| var ( | var ( | ||||||
| 	// m is a map from name to balancer builder. | 	// m is a map from name to balancer builder. | ||||||
| 	m = make(map[string]Builder) | 	m = make(map[string]Builder) | ||||||
| 	// defaultBuilder is the default balancer to use. |  | ||||||
| 	defaultBuilder Builder // TODO(bar) install pickfirst as default. |  | ||||||
| ) | ) | ||||||
|  |  | ||||||
| // Register registers the balancer builder to the balancer map. | // Register registers the balancer builder to the balancer map. | ||||||
| @ -44,13 +42,12 @@ func Register(b Builder) { | |||||||
| } | } | ||||||
|  |  | ||||||
| // Get returns the resolver builder registered with the given name. | // Get returns the resolver builder registered with the given name. | ||||||
| // If no builder is register with the name, the default pickfirst will | // If no builder is register with the name, nil will be returned. | ||||||
| // be used. |  | ||||||
| func Get(name string) Builder { | func Get(name string) Builder { | ||||||
| 	if b, ok := m[name]; ok { | 	if b, ok := m[name]; ok { | ||||||
| 		return b | 		return b | ||||||
| 	} | 	} | ||||||
| 	return defaultBuilder | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
| // SubConn represents a gRPC sub connection. | // SubConn represents a gRPC sub connection. | ||||||
|  | |||||||
| @ -16,7 +16,7 @@ | |||||||
|  * |  * | ||||||
|  */ |  */ | ||||||
|  |  | ||||||
| package roundrobin | package roundrobin_test | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
| @ -27,6 +27,7 @@ import ( | |||||||
|  |  | ||||||
| 	"golang.org/x/net/context" | 	"golang.org/x/net/context" | ||||||
| 	"google.golang.org/grpc" | 	"google.golang.org/grpc" | ||||||
|  | 	"google.golang.org/grpc/balancer" | ||||||
| 	"google.golang.org/grpc/codes" | 	"google.golang.org/grpc/codes" | ||||||
| 	_ "google.golang.org/grpc/grpclog/glogger" | 	_ "google.golang.org/grpc/grpclog/glogger" | ||||||
| 	"google.golang.org/grpc/peer" | 	"google.golang.org/grpc/peer" | ||||||
| @ -36,6 +37,8 @@ import ( | |||||||
| 	"google.golang.org/grpc/test/leakcheck" | 	"google.golang.org/grpc/test/leakcheck" | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | var rr = balancer.Get("roundrobin") | ||||||
|  |  | ||||||
| type testServer struct { | type testServer struct { | ||||||
| 	testpb.TestServiceServer | 	testpb.TestServiceServer | ||||||
| } | } | ||||||
| @ -99,7 +102,7 @@ func TestOneBackend(t *testing.T) { | |||||||
| 	} | 	} | ||||||
| 	defer test.cleanup() | 	defer test.cleanup() | ||||||
|  |  | ||||||
| 	cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerBuilder(newBuilder())) | 	cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerBuilder(rr)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("failed to dial: %v", err) | 		t.Fatalf("failed to dial: %v", err) | ||||||
| 	} | 	} | ||||||
| @ -131,7 +134,7 @@ func TestBackendsRoundRobin(t *testing.T) { | |||||||
| 	} | 	} | ||||||
| 	defer test.cleanup() | 	defer test.cleanup() | ||||||
|  |  | ||||||
| 	cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerBuilder(newBuilder())) | 	cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerBuilder(rr)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("failed to dial: %v", err) | 		t.Fatalf("failed to dial: %v", err) | ||||||
| 	} | 	} | ||||||
| @ -190,7 +193,7 @@ func TestAddressesRemoved(t *testing.T) { | |||||||
| 	} | 	} | ||||||
| 	defer test.cleanup() | 	defer test.cleanup() | ||||||
|  |  | ||||||
| 	cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerBuilder(newBuilder())) | 	cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerBuilder(rr)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("failed to dial: %v", err) | 		t.Fatalf("failed to dial: %v", err) | ||||||
| 	} | 	} | ||||||
| @ -232,7 +235,7 @@ func TestCloseWithPendingRPC(t *testing.T) { | |||||||
| 	} | 	} | ||||||
| 	defer test.cleanup() | 	defer test.cleanup() | ||||||
|  |  | ||||||
| 	cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerBuilder(newBuilder())) | 	cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerBuilder(rr)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("failed to dial: %v", err) | 		t.Fatalf("failed to dial: %v", err) | ||||||
| 	} | 	} | ||||||
| @ -266,7 +269,7 @@ func TestNewAddressWhileBlocking(t *testing.T) { | |||||||
| 	} | 	} | ||||||
| 	defer test.cleanup() | 	defer test.cleanup() | ||||||
|  |  | ||||||
| 	cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerBuilder(newBuilder())) | 	cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerBuilder(rr)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("failed to dial: %v", err) | 		t.Fatalf("failed to dial: %v", err) | ||||||
| 	} | 	} | ||||||
| @ -315,7 +318,7 @@ func TestOneServerDown(t *testing.T) { | |||||||
| 	} | 	} | ||||||
| 	defer test.cleanup() | 	defer test.cleanup() | ||||||
|  |  | ||||||
| 	cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerBuilder(newBuilder())) | 	cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerBuilder(rr)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("failed to dial: %v", err) | 		t.Fatalf("failed to dial: %v", err) | ||||||
| 	} | 	} | ||||||
| @ -408,7 +411,7 @@ func TestAllServersDown(t *testing.T) { | |||||||
| 	} | 	} | ||||||
| 	defer test.cleanup() | 	defer test.cleanup() | ||||||
|  |  | ||||||
| 	cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerBuilder(newBuilder())) | 	cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerBuilder(rr)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("failed to dial: %v", err) | 		t.Fatalf("failed to dial: %v", err) | ||||||
| 	} | 	} | ||||||
|  | |||||||
| @ -73,7 +73,7 @@ func (b *scStateUpdateBuffer) load() { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| // get returns the channel that receives a recvMsg in the buffer. | // get returns the channel that the scStateUpdate will be sent to. | ||||||
| // | // | ||||||
| // Upon receiving, the caller should call load to send another | // Upon receiving, the caller should call load to send another | ||||||
| // scStateChangeTuple onto the channel if there is any. | // scStateChangeTuple onto the channel if there is any. | ||||||
| @ -96,6 +96,8 @@ type ccBalancerWrapper struct { | |||||||
| 	stateChangeQueue *scStateUpdateBuffer | 	stateChangeQueue *scStateUpdateBuffer | ||||||
| 	resolverUpdateCh chan *resolverUpdate | 	resolverUpdateCh chan *resolverUpdate | ||||||
| 	done             chan struct{} | 	done             chan struct{} | ||||||
|  |  | ||||||
|  | 	subConns map[*acBalancerWrapper]struct{} | ||||||
| } | } | ||||||
|  |  | ||||||
| func newCCBalancerWrapper(cc *ClientConn, b balancer.Builder, bopts balancer.BuildOptions) *ccBalancerWrapper { | func newCCBalancerWrapper(cc *ClientConn, b balancer.Builder, bopts balancer.BuildOptions) *ccBalancerWrapper { | ||||||
| @ -104,6 +106,7 @@ func newCCBalancerWrapper(cc *ClientConn, b balancer.Builder, bopts balancer.Bui | |||||||
| 		stateChangeQueue: newSCStateUpdateBuffer(), | 		stateChangeQueue: newSCStateUpdateBuffer(), | ||||||
| 		resolverUpdateCh: make(chan *resolverUpdate, 1), | 		resolverUpdateCh: make(chan *resolverUpdate, 1), | ||||||
| 		done:             make(chan struct{}), | 		done:             make(chan struct{}), | ||||||
|  | 		subConns:         make(map[*acBalancerWrapper]struct{}), | ||||||
| 	} | 	} | ||||||
| 	go ccb.watcher() | 	go ccb.watcher() | ||||||
| 	ccb.balancer = b.Build(ccb, bopts) | 	ccb.balancer = b.Build(ccb, bopts) | ||||||
| @ -117,8 +120,20 @@ func (ccb *ccBalancerWrapper) watcher() { | |||||||
| 		select { | 		select { | ||||||
| 		case t := <-ccb.stateChangeQueue.get(): | 		case t := <-ccb.stateChangeQueue.get(): | ||||||
| 			ccb.stateChangeQueue.load() | 			ccb.stateChangeQueue.load() | ||||||
|  | 			select { | ||||||
|  | 			case <-ccb.done: | ||||||
|  | 				ccb.balancer.Close() | ||||||
|  | 				return | ||||||
|  | 			default: | ||||||
|  | 			} | ||||||
| 			ccb.balancer.HandleSubConnStateChange(t.sc, t.state) | 			ccb.balancer.HandleSubConnStateChange(t.sc, t.state) | ||||||
| 		case t := <-ccb.resolverUpdateCh: | 		case t := <-ccb.resolverUpdateCh: | ||||||
|  | 			select { | ||||||
|  | 			case <-ccb.done: | ||||||
|  | 				ccb.balancer.Close() | ||||||
|  | 				return | ||||||
|  | 			default: | ||||||
|  | 			} | ||||||
| 			ccb.balancer.HandleResolvedAddrs(t.addrs, t.err) | 			ccb.balancer.HandleResolvedAddrs(t.addrs, t.err) | ||||||
| 		case <-ccb.done: | 		case <-ccb.done: | ||||||
| 		} | 		} | ||||||
| @ -126,6 +141,9 @@ func (ccb *ccBalancerWrapper) watcher() { | |||||||
| 		select { | 		select { | ||||||
| 		case <-ccb.done: | 		case <-ccb.done: | ||||||
| 			ccb.balancer.Close() | 			ccb.balancer.Close() | ||||||
|  | 			for acbw := range ccb.subConns { | ||||||
|  | 				ccb.cc.removeAddrConn(acbw.getAddrConn(), errConnDrain) | ||||||
|  | 			} | ||||||
| 			return | 			return | ||||||
| 		default: | 		default: | ||||||
| 		} | 		} | ||||||
| @ -171,7 +189,10 @@ func (ccb *ccBalancerWrapper) NewSubConn(addrs []resolver.Address, opts balancer | |||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 	acbw := &acBalancerWrapper{ac: ac} | 	acbw := &acBalancerWrapper{ac: ac} | ||||||
|  | 	acbw.ac.mu.Lock() | ||||||
| 	ac.acbw = acbw | 	ac.acbw = acbw | ||||||
|  | 	acbw.ac.mu.Unlock() | ||||||
|  | 	ccb.subConns[acbw] = struct{}{} | ||||||
| 	return acbw, nil | 	return acbw, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| @ -181,6 +202,7 @@ func (ccb *ccBalancerWrapper) RemoveSubConn(sc balancer.SubConn) { | |||||||
| 	if !ok { | 	if !ok { | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  | 	delete(ccb.subConns, acbw) | ||||||
| 	ccb.cc.removeAddrConn(acbw.getAddrConn(), errConnDrain) | 	ccb.cc.removeAddrConn(acbw.getAddrConn(), errConnDrain) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | |||||||
							
								
								
									
										133
									
								
								balancer_switching_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										133
									
								
								balancer_switching_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,133 @@ | |||||||
|  | /* | ||||||
|  |  * | ||||||
|  |  * Copyright 2017 gRPC authors. | ||||||
|  |  * | ||||||
|  |  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  |  * you may not use this file except in compliance with the License. | ||||||
|  |  * You may obtain a copy of the License at | ||||||
|  |  * | ||||||
|  |  *     http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  |  * | ||||||
|  |  * Unless required by applicable law or agreed to in writing, software | ||||||
|  |  * distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  |  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  |  * See the License for the specific language governing permissions and | ||||||
|  |  * limitations under the License. | ||||||
|  |  * | ||||||
|  |  */ | ||||||
|  |  | ||||||
|  | package grpc | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"math" | ||||||
|  | 	"testing" | ||||||
|  | 	"time" | ||||||
|  |  | ||||||
|  | 	"golang.org/x/net/context" | ||||||
|  | 	_ "google.golang.org/grpc/grpclog/glogger" | ||||||
|  | 	"google.golang.org/grpc/resolver" | ||||||
|  | 	"google.golang.org/grpc/resolver/manual" | ||||||
|  | 	"google.golang.org/grpc/test/leakcheck" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func checkPickFirst(cc *ClientConn, servers []*server) error { | ||||||
|  | 	var ( | ||||||
|  | 		req   = "port" | ||||||
|  | 		reply string | ||||||
|  | 		err   error | ||||||
|  | 	) | ||||||
|  | 	connected := false | ||||||
|  | 	for i := 0; i < 1000; i++ { | ||||||
|  | 		if err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc); ErrorDesc(err) == servers[0].port { | ||||||
|  | 			if connected { | ||||||
|  | 				// connected is set to false if peer is not server[0]. So if | ||||||
|  | 				// connected is true here, this is the second time we saw | ||||||
|  | 				// server[0] in a row. Break because pickfirst is in effect. | ||||||
|  | 				break | ||||||
|  | 			} | ||||||
|  | 			connected = true | ||||||
|  | 		} else { | ||||||
|  | 			connected = false | ||||||
|  | 		} | ||||||
|  | 		time.Sleep(time.Millisecond) | ||||||
|  | 	} | ||||||
|  | 	if !connected { | ||||||
|  | 		return fmt.Errorf("pickfirst is not in effect after 1 second, EmptyCall() = _, %v, want _, %v", err, servers[0].port) | ||||||
|  | 	} | ||||||
|  | 	// The following RPCs should all succeed with the first server. | ||||||
|  | 	for i := 0; i < 3; i++ { | ||||||
|  | 		err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc) | ||||||
|  | 		if ErrorDesc(err) != servers[0].port { | ||||||
|  | 			return fmt.Errorf("Index %d: want peer %v, got peer %v", i, servers[0].port, err) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func checkRoundRobin(cc *ClientConn, servers []*server) error { | ||||||
|  | 	var ( | ||||||
|  | 		req   = "port" | ||||||
|  | 		reply string | ||||||
|  | 		err   error | ||||||
|  | 	) | ||||||
|  |  | ||||||
|  | 	// Make sure connections to all servers are up. | ||||||
|  | 	for i := 0; i < 2; i++ { | ||||||
|  | 		// Do this check twice, otherwise the first RPC's transport may still be | ||||||
|  | 		// picked by the closing pickfirst balancer, and the test becomes flaky. | ||||||
|  | 		for _, s := range servers { | ||||||
|  | 			var up bool | ||||||
|  | 			for i := 0; i < 1000; i++ { | ||||||
|  | 				if err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc); ErrorDesc(err) == s.port { | ||||||
|  | 					up = true | ||||||
|  | 					break | ||||||
|  | 				} | ||||||
|  | 				time.Sleep(time.Millisecond) | ||||||
|  | 			} | ||||||
|  | 			if !up { | ||||||
|  | 				return fmt.Errorf("server %v is not up within 1 second", s.port) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	serverCount := len(servers) | ||||||
|  | 	for i := 0; i < 3*serverCount; i++ { | ||||||
|  | 		err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc) | ||||||
|  | 		if ErrorDesc(err) != servers[i%serverCount].port { | ||||||
|  | 			return fmt.Errorf("Index %d: want peer %v, got peer %v", i, servers[i%serverCount].port, err) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestSwitchBalancer(t *testing.T) { | ||||||
|  | 	defer leakcheck.Check(t) | ||||||
|  | 	r, rcleanup := manual.GenerateAndRegisterManualResolver() | ||||||
|  | 	defer rcleanup() | ||||||
|  |  | ||||||
|  | 	numServers := 2 | ||||||
|  | 	servers, _, scleanup := startServers(t, numServers, math.MaxInt32) | ||||||
|  | 	defer scleanup() | ||||||
|  |  | ||||||
|  | 	cc, err := Dial(r.Scheme()+":///test.server", WithInsecure(), WithCodec(testCodec{})) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("failed to dial: %v", err) | ||||||
|  | 	} | ||||||
|  | 	defer cc.Close() | ||||||
|  | 	r.NewAddress([]resolver.Address{{Addr: servers[0].addr}, {Addr: servers[1].addr}}) | ||||||
|  | 	// The default balancer is pickfirst. | ||||||
|  | 	if err := checkPickFirst(cc, servers); err != nil { | ||||||
|  | 		t.Fatalf("check pickfirst returned non-nil error: %v", err) | ||||||
|  | 	} | ||||||
|  | 	// Switch to roundrobin. | ||||||
|  | 	cc.switchBalancer("roundrobin") | ||||||
|  | 	if err := checkRoundRobin(cc, servers); err != nil { | ||||||
|  | 		t.Fatalf("check roundrobin returned non-nil error: %v", err) | ||||||
|  | 	} | ||||||
|  | 	// Switch to pickfirst. | ||||||
|  | 	cc.switchBalancer("pickfirst") | ||||||
|  | 	if err := checkPickFirst(cc, servers); err != nil { | ||||||
|  | 		t.Fatalf("check pickfirst returned non-nil error: %v", err) | ||||||
|  | 	} | ||||||
|  | } | ||||||
| @ -31,6 +31,10 @@ import ( | |||||||
| 	_ "google.golang.org/grpc/grpclog/glogger" | 	_ "google.golang.org/grpc/grpclog/glogger" | ||||||
| 	"google.golang.org/grpc/naming" | 	"google.golang.org/grpc/naming" | ||||||
| 	"google.golang.org/grpc/test/leakcheck" | 	"google.golang.org/grpc/test/leakcheck" | ||||||
|  |  | ||||||
|  | 	// V1 balancer tests use passthrough resolver instead of dns. | ||||||
|  | 	// TODO(bar) remove this when removing v1 balaner entirely. | ||||||
|  | 	_ "google.golang.org/grpc/resolver/passthrough" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| type testWatcher struct { | type testWatcher struct { | ||||||
| @ -117,7 +121,7 @@ func TestNameDiscovery(t *testing.T) { | |||||||
| 	numServers := 2 | 	numServers := 2 | ||||||
| 	servers, r, cleanup := startServers(t, numServers, math.MaxUint32) | 	servers, r, cleanup := startServers(t, numServers, math.MaxUint32) | ||||||
| 	defer cleanup() | 	defer cleanup() | ||||||
| 	cc, err := Dial("foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) | 	cc, err := Dial("passthrough:///foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("Failed to create ClientConn: %v", err) | 		t.Fatalf("Failed to create ClientConn: %v", err) | ||||||
| 	} | 	} | ||||||
| @ -151,7 +155,7 @@ func TestEmptyAddrs(t *testing.T) { | |||||||
| 	defer leakcheck.Check(t) | 	defer leakcheck.Check(t) | ||||||
| 	servers, r, cleanup := startServers(t, 1, math.MaxUint32) | 	servers, r, cleanup := startServers(t, 1, math.MaxUint32) | ||||||
| 	defer cleanup() | 	defer cleanup() | ||||||
| 	cc, err := Dial("foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) | 	cc, err := Dial("passthrough:///foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("Failed to create ClientConn: %v", err) | 		t.Fatalf("Failed to create ClientConn: %v", err) | ||||||
| 	} | 	} | ||||||
| @ -185,7 +189,7 @@ func TestRoundRobin(t *testing.T) { | |||||||
| 	numServers := 3 | 	numServers := 3 | ||||||
| 	servers, r, cleanup := startServers(t, numServers, math.MaxUint32) | 	servers, r, cleanup := startServers(t, numServers, math.MaxUint32) | ||||||
| 	defer cleanup() | 	defer cleanup() | ||||||
| 	cc, err := Dial("foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) | 	cc, err := Dial("passthrough:///foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("Failed to create ClientConn: %v", err) | 		t.Fatalf("Failed to create ClientConn: %v", err) | ||||||
| 	} | 	} | ||||||
| @ -230,7 +234,7 @@ func TestCloseWithPendingRPC(t *testing.T) { | |||||||
| 	defer leakcheck.Check(t) | 	defer leakcheck.Check(t) | ||||||
| 	servers, r, cleanup := startServers(t, 1, math.MaxUint32) | 	servers, r, cleanup := startServers(t, 1, math.MaxUint32) | ||||||
| 	defer cleanup() | 	defer cleanup() | ||||||
| 	cc, err := Dial("foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) | 	cc, err := Dial("passthrough:///foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("Failed to create ClientConn: %v", err) | 		t.Fatalf("Failed to create ClientConn: %v", err) | ||||||
| 	} | 	} | ||||||
| @ -282,7 +286,7 @@ func TestGetOnWaitChannel(t *testing.T) { | |||||||
| 	defer leakcheck.Check(t) | 	defer leakcheck.Check(t) | ||||||
| 	servers, r, cleanup := startServers(t, 1, math.MaxUint32) | 	servers, r, cleanup := startServers(t, 1, math.MaxUint32) | ||||||
| 	defer cleanup() | 	defer cleanup() | ||||||
| 	cc, err := Dial("foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) | 	cc, err := Dial("passthrough:///foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("Failed to create ClientConn: %v", err) | 		t.Fatalf("Failed to create ClientConn: %v", err) | ||||||
| 	} | 	} | ||||||
| @ -328,7 +332,7 @@ func TestOneServerDown(t *testing.T) { | |||||||
| 	numServers := 2 | 	numServers := 2 | ||||||
| 	servers, r, cleanup := startServers(t, numServers, math.MaxUint32) | 	servers, r, cleanup := startServers(t, numServers, math.MaxUint32) | ||||||
| 	defer cleanup() | 	defer cleanup() | ||||||
| 	cc, err := Dial("foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) | 	cc, err := Dial("passthrough:///foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("Failed to create ClientConn: %v", err) | 		t.Fatalf("Failed to create ClientConn: %v", err) | ||||||
| 	} | 	} | ||||||
| @ -381,7 +385,7 @@ func TestOneAddressRemoval(t *testing.T) { | |||||||
| 	numServers := 2 | 	numServers := 2 | ||||||
| 	servers, r, cleanup := startServers(t, numServers, math.MaxUint32) | 	servers, r, cleanup := startServers(t, numServers, math.MaxUint32) | ||||||
| 	defer cleanup() | 	defer cleanup() | ||||||
| 	cc, err := Dial("foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) | 	cc, err := Dial("passthrough:///foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("Failed to create ClientConn: %v", err) | 		t.Fatalf("Failed to create ClientConn: %v", err) | ||||||
| 	} | 	} | ||||||
| @ -439,7 +443,7 @@ func TestOneAddressRemoval(t *testing.T) { | |||||||
| func checkServerUp(t *testing.T, currentServer *server) { | func checkServerUp(t *testing.T, currentServer *server) { | ||||||
| 	req := "port" | 	req := "port" | ||||||
| 	port := currentServer.port | 	port := currentServer.port | ||||||
| 	cc, err := Dial("localhost:"+port, WithBlock(), WithInsecure(), WithCodec(testCodec{})) | 	cc, err := Dial("passthrough:///localhost:"+port, WithBlock(), WithInsecure(), WithCodec(testCodec{})) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("Failed to create ClientConn: %v", err) | 		t.Fatalf("Failed to create ClientConn: %v", err) | ||||||
| 	} | 	} | ||||||
| @ -457,7 +461,7 @@ func TestPickFirstEmptyAddrs(t *testing.T) { | |||||||
| 	defer leakcheck.Check(t) | 	defer leakcheck.Check(t) | ||||||
| 	servers, r, cleanup := startServers(t, 1, math.MaxUint32) | 	servers, r, cleanup := startServers(t, 1, math.MaxUint32) | ||||||
| 	defer cleanup() | 	defer cleanup() | ||||||
| 	cc, err := Dial("foo.bar.com", WithBalancer(pickFirstBalancerV1(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) | 	cc, err := Dial("passthrough:///foo.bar.com", WithBalancer(pickFirstBalancerV1(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("Failed to create ClientConn: %v", err) | 		t.Fatalf("Failed to create ClientConn: %v", err) | ||||||
| 	} | 	} | ||||||
| @ -489,7 +493,7 @@ func TestPickFirstCloseWithPendingRPC(t *testing.T) { | |||||||
| 	defer leakcheck.Check(t) | 	defer leakcheck.Check(t) | ||||||
| 	servers, r, cleanup := startServers(t, 1, math.MaxUint32) | 	servers, r, cleanup := startServers(t, 1, math.MaxUint32) | ||||||
| 	defer cleanup() | 	defer cleanup() | ||||||
| 	cc, err := Dial("foo.bar.com", WithBalancer(pickFirstBalancerV1(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) | 	cc, err := Dial("passthrough:///foo.bar.com", WithBalancer(pickFirstBalancerV1(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("Failed to create ClientConn: %v", err) | 		t.Fatalf("Failed to create ClientConn: %v", err) | ||||||
| 	} | 	} | ||||||
| @ -543,7 +547,7 @@ func TestPickFirstOrderAllServerUp(t *testing.T) { | |||||||
| 	numServers := 3 | 	numServers := 3 | ||||||
| 	servers, r, cleanup := startServers(t, numServers, math.MaxUint32) | 	servers, r, cleanup := startServers(t, numServers, math.MaxUint32) | ||||||
| 	defer cleanup() | 	defer cleanup() | ||||||
| 	cc, err := Dial("foo.bar.com", WithBalancer(pickFirstBalancerV1(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) | 	cc, err := Dial("passthrough:///foo.bar.com", WithBalancer(pickFirstBalancerV1(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("Failed to create ClientConn: %v", err) | 		t.Fatalf("Failed to create ClientConn: %v", err) | ||||||
| 	} | 	} | ||||||
| @ -656,7 +660,7 @@ func TestPickFirstOrderOneServerDown(t *testing.T) { | |||||||
| 	numServers := 3 | 	numServers := 3 | ||||||
| 	servers, r, cleanup := startServers(t, numServers, math.MaxUint32) | 	servers, r, cleanup := startServers(t, numServers, math.MaxUint32) | ||||||
| 	defer cleanup() | 	defer cleanup() | ||||||
| 	cc, err := Dial("foo.bar.com", WithBalancer(pickFirstBalancerV1(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) | 	cc, err := Dial("passthrough:///foo.bar.com", WithBalancer(pickFirstBalancerV1(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("Failed to create ClientConn: %v", err) | 		t.Fatalf("Failed to create ClientConn: %v", err) | ||||||
| 	} | 	} | ||||||
| @ -747,7 +751,7 @@ func TestPickFirstOneAddressRemoval(t *testing.T) { | |||||||
| 	numServers := 2 | 	numServers := 2 | ||||||
| 	servers, r, cleanup := startServers(t, numServers, math.MaxUint32) | 	servers, r, cleanup := startServers(t, numServers, math.MaxUint32) | ||||||
| 	defer cleanup() | 	defer cleanup() | ||||||
| 	cc, err := Dial("localhost:"+servers[0].port, WithBalancer(pickFirstBalancerV1(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) | 	cc, err := Dial("passthrough:///localhost:"+servers[0].port, WithBalancer(pickFirstBalancerV1(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("Failed to create ClientConn: %v", err) | 		t.Fatalf("Failed to create ClientConn: %v", err) | ||||||
| 	} | 	} | ||||||
|  | |||||||
| @ -19,6 +19,7 @@ | |||||||
| package grpc | package grpc | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"strings" | ||||||
| 	"sync" | 	"sync" | ||||||
|  |  | ||||||
| 	"golang.org/x/net/context" | 	"golang.org/x/net/context" | ||||||
| @ -34,7 +35,13 @@ type balancerWrapperBuilder struct { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (bwb *balancerWrapperBuilder) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer { | func (bwb *balancerWrapperBuilder) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer { | ||||||
| 	bwb.b.Start(cc.Target(), BalancerConfig{ | 	targetAddr := cc.Target() | ||||||
|  | 	targetSplitted := strings.Split(targetAddr, ":///") | ||||||
|  | 	if len(targetSplitted) >= 2 { | ||||||
|  | 		targetAddr = targetSplitted[1] | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	bwb.b.Start(targetAddr, BalancerConfig{ | ||||||
| 		DialCreds: opts.DialCreds, | 		DialCreds: opts.DialCreds, | ||||||
| 		Dialer:    opts.Dialer, | 		Dialer:    opts.Dialer, | ||||||
| 	}) | 	}) | ||||||
| @ -43,6 +50,7 @@ func (bwb *balancerWrapperBuilder) Build(cc balancer.ClientConn, opts balancer.B | |||||||
| 		balancer:   bwb.b, | 		balancer:   bwb.b, | ||||||
| 		pickfirst:  pickfirst, | 		pickfirst:  pickfirst, | ||||||
| 		cc:         cc, | 		cc:         cc, | ||||||
|  | 		targetAddr: targetAddr, | ||||||
| 		startCh:    make(chan struct{}), | 		startCh:    make(chan struct{}), | ||||||
| 		conns:      make(map[resolver.Address]balancer.SubConn), | 		conns:      make(map[resolver.Address]balancer.SubConn), | ||||||
| 		connSt:     make(map[balancer.SubConn]*scState), | 		connSt:     make(map[balancer.SubConn]*scState), | ||||||
| @ -69,6 +77,7 @@ type balancerWrapper struct { | |||||||
| 	pickfirst bool | 	pickfirst bool | ||||||
|  |  | ||||||
| 	cc         balancer.ClientConn | 	cc         balancer.ClientConn | ||||||
|  | 	targetAddr string // Target without the scheme. | ||||||
|  |  | ||||||
| 	// To aggregate the connectivity state. | 	// To aggregate the connectivity state. | ||||||
| 	csEvltr *connectivityStateEvaluator | 	csEvltr *connectivityStateEvaluator | ||||||
| @ -93,7 +102,7 @@ func (bw *balancerWrapper) lbWatcher() { | |||||||
| 	if notifyCh == nil { | 	if notifyCh == nil { | ||||||
| 		// There's no resolver in the balancer. Connect directly. | 		// There's no resolver in the balancer. Connect directly. | ||||||
| 		a := resolver.Address{ | 		a := resolver.Address{ | ||||||
| 			Addr: bw.cc.Target(), | 			Addr: bw.targetAddr, | ||||||
| 			Type: resolver.Backend, | 			Type: resolver.Backend, | ||||||
| 		} | 		} | ||||||
| 		sc, err := bw.cc.NewSubConn([]resolver.Address{a}, balancer.NewSubConnOptions{}) | 		sc, err := bw.cc.NewSubConn([]resolver.Address{a}, balancer.NewSubConnOptions{}) | ||||||
| @ -103,7 +112,7 @@ func (bw *balancerWrapper) lbWatcher() { | |||||||
| 			bw.mu.Lock() | 			bw.mu.Lock() | ||||||
| 			bw.conns[a] = sc | 			bw.conns[a] = sc | ||||||
| 			bw.connSt[sc] = &scState{ | 			bw.connSt[sc] = &scState{ | ||||||
| 				addr: Address{Addr: bw.cc.Target()}, | 				addr: Address{Addr: bw.targetAddr}, | ||||||
| 				s:    connectivity.Idle, | 				s:    connectivity.Idle, | ||||||
| 			} | 			} | ||||||
| 			bw.mu.Unlock() | 			bw.mu.Unlock() | ||||||
|  | |||||||
							
								
								
									
										208
									
								
								clientconn.go
									
									
									
									
									
								
							
							
						
						
									
										208
									
								
								clientconn.go
									
									
									
									
									
								
							| @ -31,11 +31,13 @@ import ( | |||||||
| 	"golang.org/x/net/context" | 	"golang.org/x/net/context" | ||||||
| 	"golang.org/x/net/trace" | 	"golang.org/x/net/trace" | ||||||
| 	"google.golang.org/grpc/balancer" | 	"google.golang.org/grpc/balancer" | ||||||
|  | 	_ "google.golang.org/grpc/balancer/roundrobin" // To register roundrobin. | ||||||
| 	"google.golang.org/grpc/connectivity" | 	"google.golang.org/grpc/connectivity" | ||||||
| 	"google.golang.org/grpc/credentials" | 	"google.golang.org/grpc/credentials" | ||||||
| 	"google.golang.org/grpc/grpclog" | 	"google.golang.org/grpc/grpclog" | ||||||
| 	"google.golang.org/grpc/keepalive" | 	"google.golang.org/grpc/keepalive" | ||||||
| 	"google.golang.org/grpc/resolver" | 	"google.golang.org/grpc/resolver" | ||||||
|  | 	_ "google.golang.org/grpc/resolver/dns" // To register dns resolver. | ||||||
| 	"google.golang.org/grpc/stats" | 	"google.golang.org/grpc/stats" | ||||||
| 	"google.golang.org/grpc/transport" | 	"google.golang.org/grpc/transport" | ||||||
| ) | ) | ||||||
| @ -435,42 +437,6 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * | |||||||
| 		cc.authority = target | 		cc.authority = target | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if cc.dopts.balancerBuilder != nil { |  | ||||||
| 		var credsClone credentials.TransportCredentials |  | ||||||
| 		if creds != nil { |  | ||||||
| 			credsClone = creds.Clone() |  | ||||||
| 		} |  | ||||||
| 		buildOpts := balancer.BuildOptions{ |  | ||||||
| 			DialCreds: credsClone, |  | ||||||
| 			Dialer:    cc.dopts.copts.Dialer, |  | ||||||
| 		} |  | ||||||
| 		// Build should not take long time. So it's ok to not have a goroutine for it. |  | ||||||
| 		// TODO(bar) init balancer after first resolver result to support service config balancer. |  | ||||||
| 		cc.balancerWrapper = newCCBalancerWrapper(cc, cc.dopts.balancerBuilder, buildOpts) |  | ||||||
| 	} else { |  | ||||||
| 		waitC := make(chan error, 1) |  | ||||||
| 		go func() { |  | ||||||
| 			defer close(waitC) |  | ||||||
| 			// No balancer, or no resolver within the balancer.  Connect directly. |  | ||||||
| 			ac, err := cc.newAddrConn([]resolver.Address{{Addr: target}}) |  | ||||||
| 			if err != nil { |  | ||||||
| 				waitC <- err |  | ||||||
| 				return |  | ||||||
| 			} |  | ||||||
| 			if err := ac.connect(cc.dopts.block); err != nil { |  | ||||||
| 				waitC <- err |  | ||||||
| 				return |  | ||||||
| 			} |  | ||||||
| 		}() |  | ||||||
| 		select { |  | ||||||
| 		case <-ctx.Done(): |  | ||||||
| 			return nil, ctx.Err() |  | ||||||
| 		case err := <-waitC: |  | ||||||
| 			if err != nil { |  | ||||||
| 				return nil, err |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	if cc.dopts.scChan != nil && !scSet { | 	if cc.dopts.scChan != nil && !scSet { | ||||||
| 		// Blocking wait for the initial service config. | 		// Blocking wait for the initial service config. | ||||||
| 		select { | 		select { | ||||||
| @ -486,20 +452,27 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * | |||||||
| 		go cc.scWatcher() | 		go cc.scWatcher() | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	var credsClone credentials.TransportCredentials | ||||||
|  | 	if creds := cc.dopts.copts.TransportCredentials; creds != nil { | ||||||
|  | 		credsClone = creds.Clone() | ||||||
|  | 	} | ||||||
|  | 	cc.balancerBuildOpts = balancer.BuildOptions{ | ||||||
|  | 		DialCreds: credsClone, | ||||||
|  | 		Dialer:    cc.dopts.copts.Dialer, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if cc.dopts.balancerBuilder != nil { | ||||||
|  | 		cc.customBalancer = true | ||||||
|  | 		// Build should not take long time. So it's ok to not have a goroutine for it. | ||||||
|  | 		cc.balancerWrapper = newCCBalancerWrapper(cc, cc.dopts.balancerBuilder, cc.balancerBuildOpts) | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	// Build the resolver. | 	// Build the resolver. | ||||||
| 	cc.resolverWrapper, err = newCCResolverWrapper(cc) | 	cc.resolverWrapper, err = newCCResolverWrapper(cc) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, fmt.Errorf("failed to build resolver: %v", err) | 		return nil, fmt.Errorf("failed to build resolver: %v", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if cc.balancerWrapper != nil && cc.resolverWrapper == nil { |  | ||||||
| 		// TODO(bar) there should always be a resolver (DNS as the default). |  | ||||||
| 		// Unblock balancer initialization with a fake resolver update if there's no resolver. |  | ||||||
| 		// The balancer wrapper will not read the addresses, so an empty list works. |  | ||||||
| 		// TODO(bar) remove this after the real resolver is started. |  | ||||||
| 		cc.balancerWrapper.handleResolvedAddrs([]resolver.Address{}, nil) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// A blocking dial blocks until the clientConn is ready. | 	// A blocking dial blocks until the clientConn is ready. | ||||||
| 	if cc.dopts.block { | 	if cc.dopts.block { | ||||||
| 		for { | 		for { | ||||||
| @ -570,9 +543,9 @@ type ClientConn struct { | |||||||
| 	dopts     dialOptions | 	dopts     dialOptions | ||||||
| 	csMgr     *connectivityStateManager | 	csMgr     *connectivityStateManager | ||||||
|  |  | ||||||
| 	balancerWrapper *ccBalancerWrapper | 	customBalancer    bool // If this is true, switching balancer will be disabled. | ||||||
|  | 	balancerBuildOpts balancer.BuildOptions | ||||||
| 	resolverWrapper   *ccResolverWrapper | 	resolverWrapper   *ccResolverWrapper | ||||||
|  |  | ||||||
| 	blockingpicker    *pickerWrapper | 	blockingpicker    *pickerWrapper | ||||||
|  |  | ||||||
| 	mu    sync.RWMutex | 	mu    sync.RWMutex | ||||||
| @ -580,6 +553,9 @@ type ClientConn struct { | |||||||
| 	conns map[*addrConn]struct{} | 	conns map[*addrConn]struct{} | ||||||
| 	// Keepalive parameter can be updated if a GoAway is received. | 	// Keepalive parameter can be updated if a GoAway is received. | ||||||
| 	mkp             keepalive.ClientParameters | 	mkp             keepalive.ClientParameters | ||||||
|  | 	curBalancerName string | ||||||
|  | 	curAddresses    []resolver.Address | ||||||
|  | 	balancerWrapper *ccBalancerWrapper | ||||||
| } | } | ||||||
|  |  | ||||||
| // WaitForStateChange waits until the connectivity.State of ClientConn changes from sourceState or | // WaitForStateChange waits until the connectivity.State of ClientConn changes from sourceState or | ||||||
| @ -622,6 +598,71 @@ func (cc *ClientConn) scWatcher() { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (cc *ClientConn) handleResolvedAddrs(addrs []resolver.Address, err error) { | ||||||
|  | 	cc.mu.Lock() | ||||||
|  | 	defer cc.mu.Unlock() | ||||||
|  | 	if cc.conns == nil { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// TODO(bar switching) when grpclb is submitted, check address type and start grpclb. | ||||||
|  | 	if !cc.customBalancer && cc.balancerWrapper == nil { | ||||||
|  | 		// No customBalancer was specified by DialOption, and this is the first | ||||||
|  | 		// time handling resolved addresses, create a pickfirst balancer. | ||||||
|  | 		builder := newPickfirstBuilder() | ||||||
|  | 		cc.curBalancerName = builder.Name() | ||||||
|  | 		cc.balancerWrapper = newCCBalancerWrapper(cc, builder, cc.balancerBuildOpts) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// TODO(bar switching) compare addresses, if there's no update, don't notify balancer. | ||||||
|  | 	cc.curAddresses = addrs | ||||||
|  | 	cc.balancerWrapper.handleResolvedAddrs(addrs, nil) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // switchBalancer starts the switching from current balancer to the balancer with name. | ||||||
|  | func (cc *ClientConn) switchBalancer(name string) { | ||||||
|  | 	cc.mu.Lock() | ||||||
|  | 	defer cc.mu.Unlock() | ||||||
|  | 	if cc.conns == nil { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	grpclog.Infof("ClientConn switching balancer to %q", name) | ||||||
|  |  | ||||||
|  | 	if cc.customBalancer { | ||||||
|  | 		grpclog.Infoln("ignoring service config balancer configuration: WithBalancer DialOption used instead") | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if cc.curBalancerName == name { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// TODO(bar switching) change this to two steps: drain and close. | ||||||
|  | 	// Keep track of sc in wrapper. | ||||||
|  | 	cc.balancerWrapper.close() | ||||||
|  |  | ||||||
|  | 	builder := balancer.Get(name) | ||||||
|  | 	if builder == nil { | ||||||
|  | 		grpclog.Infof("failed to get balancer builder for: %v (this should never happen...)", name) | ||||||
|  | 		builder = newPickfirstBuilder() | ||||||
|  | 	} | ||||||
|  | 	cc.curBalancerName = builder.Name() | ||||||
|  | 	cc.balancerWrapper = newCCBalancerWrapper(cc, builder, cc.balancerBuildOpts) | ||||||
|  | 	cc.balancerWrapper.handleResolvedAddrs(cc.curAddresses, nil) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (cc *ClientConn) handleSubConnStateChange(sc balancer.SubConn, s connectivity.State) { | ||||||
|  | 	cc.mu.Lock() | ||||||
|  | 	if cc.conns == nil { | ||||||
|  | 		cc.mu.Unlock() | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	// TODO(bar switching) send updates to all balancer wrappers when balancer | ||||||
|  | 	// gracefully switching is supported. | ||||||
|  | 	cc.balancerWrapper.handleSubConnStateChange(sc, s) | ||||||
|  | 	cc.mu.Unlock() | ||||||
|  | } | ||||||
|  |  | ||||||
| // newAddrConn creates an addrConn for addrs and adds it to cc.conns. | // newAddrConn creates an addrConn for addrs and adds it to cc.conns. | ||||||
| func (cc *ClientConn) newAddrConn(addrs []resolver.Address) (*addrConn, error) { | func (cc *ClientConn) newAddrConn(addrs []resolver.Address) (*addrConn, error) { | ||||||
| 	ac := &addrConn{ | 	ac := &addrConn{ | ||||||
| @ -670,11 +711,7 @@ func (ac *addrConn) connect(block bool) error { | |||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
| 	ac.state = connectivity.Connecting | 	ac.state = connectivity.Connecting | ||||||
| 	if ac.cc.balancerWrapper != nil { | 	ac.cc.handleSubConnStateChange(ac.acbw, ac.state) | ||||||
| 		ac.cc.balancerWrapper.handleSubConnStateChange(ac.acbw, ac.state) |  | ||||||
| 	} else { |  | ||||||
| 		ac.cc.csMgr.updateState(ac.state) |  | ||||||
| 	} |  | ||||||
| 	ac.mu.Unlock() | 	ac.mu.Unlock() | ||||||
|  |  | ||||||
| 	if block { | 	if block { | ||||||
| @ -756,31 +793,6 @@ func (cc *ClientConn) GetMethodConfig(method string) MethodConfig { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (cc *ClientConn) getTransport(ctx context.Context, failfast bool) (transport.ClientTransport, func(balancer.DoneInfo), error) { | func (cc *ClientConn) getTransport(ctx context.Context, failfast bool) (transport.ClientTransport, func(balancer.DoneInfo), error) { | ||||||
| 	if cc.balancerWrapper == nil { |  | ||||||
| 		// If balancer is nil, there should be only one addrConn available. |  | ||||||
| 		cc.mu.RLock() |  | ||||||
| 		if cc.conns == nil { |  | ||||||
| 			cc.mu.RUnlock() |  | ||||||
| 			// TODO this function returns toRPCErr and non-toRPCErr. Clean up |  | ||||||
| 			// the errors in ClientConn. |  | ||||||
| 			return nil, nil, toRPCErr(ErrClientConnClosing) |  | ||||||
| 		} |  | ||||||
| 		var ac *addrConn |  | ||||||
| 		for ac = range cc.conns { |  | ||||||
| 			// Break after the first iteration to get the first addrConn. |  | ||||||
| 			break |  | ||||||
| 		} |  | ||||||
| 		cc.mu.RUnlock() |  | ||||||
| 		if ac == nil { |  | ||||||
| 			return nil, nil, errConnClosing |  | ||||||
| 		} |  | ||||||
| 		t, err := ac.wait(ctx, false /*hasBalancer*/, failfast) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return nil, nil, err |  | ||||||
| 		} |  | ||||||
| 		return t, nil, nil |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	t, done, err := cc.blockingpicker.pick(ctx, failfast, balancer.PickOptions{}) | 	t, done, err := cc.blockingpicker.pick(ctx, failfast, balancer.PickOptions{}) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, nil, toRPCErr(err) | 		return nil, nil, toRPCErr(err) | ||||||
| @ -800,13 +812,18 @@ func (cc *ClientConn) Close() error { | |||||||
| 	conns := cc.conns | 	conns := cc.conns | ||||||
| 	cc.conns = nil | 	cc.conns = nil | ||||||
| 	cc.csMgr.updateState(connectivity.Shutdown) | 	cc.csMgr.updateState(connectivity.Shutdown) | ||||||
|  |  | ||||||
|  | 	rWrapper := cc.resolverWrapper | ||||||
|  | 	cc.resolverWrapper = nil | ||||||
|  | 	bWrapper := cc.balancerWrapper | ||||||
|  | 	cc.balancerWrapper = nil | ||||||
| 	cc.mu.Unlock() | 	cc.mu.Unlock() | ||||||
| 	cc.blockingpicker.close() | 	cc.blockingpicker.close() | ||||||
| 	if cc.resolverWrapper != nil { | 	if rWrapper != nil { | ||||||
| 		cc.resolverWrapper.close() | 		rWrapper.close() | ||||||
| 	} | 	} | ||||||
| 	if cc.balancerWrapper != nil { | 	if bWrapper != nil { | ||||||
| 		cc.balancerWrapper.close() | 		bWrapper.close() | ||||||
| 	} | 	} | ||||||
| 	for ac := range conns { | 	for ac := range conns { | ||||||
| 		ac.tearDown(ErrClientConnClosing) | 		ac.tearDown(ErrClientConnClosing) | ||||||
| @ -877,11 +894,7 @@ func (ac *addrConn) resetTransport() error { | |||||||
| 		return errConnClosing | 		return errConnClosing | ||||||
| 	} | 	} | ||||||
| 	ac.state = connectivity.TransientFailure | 	ac.state = connectivity.TransientFailure | ||||||
| 	if ac.cc.balancerWrapper != nil { | 	ac.cc.handleSubConnStateChange(ac.acbw, ac.state) | ||||||
| 		ac.cc.balancerWrapper.handleSubConnStateChange(ac.acbw, ac.state) |  | ||||||
| 	} else { |  | ||||||
| 		ac.cc.csMgr.updateState(ac.state) |  | ||||||
| 	} |  | ||||||
| 	if ac.ready != nil { | 	if ac.ready != nil { | ||||||
| 		close(ac.ready) | 		close(ac.ready) | ||||||
| 		ac.ready = nil | 		ac.ready = nil | ||||||
| @ -906,12 +919,7 @@ func (ac *addrConn) resetTransport() error { | |||||||
| 		} | 		} | ||||||
| 		ac.printf("connecting") | 		ac.printf("connecting") | ||||||
| 		ac.state = connectivity.Connecting | 		ac.state = connectivity.Connecting | ||||||
| 		// TODO(bar) remove condition once we always have a balancer. | 		ac.cc.handleSubConnStateChange(ac.acbw, ac.state) | ||||||
| 		if ac.cc.balancerWrapper != nil { |  | ||||||
| 			ac.cc.balancerWrapper.handleSubConnStateChange(ac.acbw, ac.state) |  | ||||||
| 		} else { |  | ||||||
| 			ac.cc.csMgr.updateState(ac.state) |  | ||||||
| 		} |  | ||||||
| 		// copy ac.addrs in case of race | 		// copy ac.addrs in case of race | ||||||
| 		addrsIter := make([]resolver.Address, len(ac.addrs)) | 		addrsIter := make([]resolver.Address, len(ac.addrs)) | ||||||
| 		copy(addrsIter, ac.addrs) | 		copy(addrsIter, ac.addrs) | ||||||
| @ -953,11 +961,7 @@ func (ac *addrConn) resetTransport() error { | |||||||
| 				return errConnClosing | 				return errConnClosing | ||||||
| 			} | 			} | ||||||
| 			ac.state = connectivity.Ready | 			ac.state = connectivity.Ready | ||||||
| 			if ac.cc.balancerWrapper != nil { | 			ac.cc.handleSubConnStateChange(ac.acbw, ac.state) | ||||||
| 				ac.cc.balancerWrapper.handleSubConnStateChange(ac.acbw, ac.state) |  | ||||||
| 			} else { |  | ||||||
| 				ac.cc.csMgr.updateState(ac.state) |  | ||||||
| 			} |  | ||||||
| 			t := ac.transport | 			t := ac.transport | ||||||
| 			ac.transport = newTransport | 			ac.transport = newTransport | ||||||
| 			if t != nil { | 			if t != nil { | ||||||
| @ -973,11 +977,7 @@ func (ac *addrConn) resetTransport() error { | |||||||
| 		} | 		} | ||||||
| 		ac.mu.Lock() | 		ac.mu.Lock() | ||||||
| 		ac.state = connectivity.TransientFailure | 		ac.state = connectivity.TransientFailure | ||||||
| 		if ac.cc.balancerWrapper != nil { | 		ac.cc.handleSubConnStateChange(ac.acbw, ac.state) | ||||||
| 			ac.cc.balancerWrapper.handleSubConnStateChange(ac.acbw, ac.state) |  | ||||||
| 		} else { |  | ||||||
| 			ac.cc.csMgr.updateState(ac.state) |  | ||||||
| 		} |  | ||||||
| 		if ac.ready != nil { | 		if ac.ready != nil { | ||||||
| 			close(ac.ready) | 			close(ac.ready) | ||||||
| 			ac.ready = nil | 			ac.ready = nil | ||||||
| @ -1111,11 +1111,7 @@ func (ac *addrConn) tearDown(err error) { | |||||||
| 	} | 	} | ||||||
| 	ac.state = connectivity.Shutdown | 	ac.state = connectivity.Shutdown | ||||||
| 	ac.tearDownErr = err | 	ac.tearDownErr = err | ||||||
| 	if ac.cc.balancerWrapper != nil { | 	ac.cc.handleSubConnStateChange(ac.acbw, ac.state) | ||||||
| 		ac.cc.balancerWrapper.handleSubConnStateChange(ac.acbw, ac.state) |  | ||||||
| 	} else { |  | ||||||
| 		ac.cc.csMgr.updateState(ac.state) |  | ||||||
| 	} |  | ||||||
| 	if ac.events != nil { | 	if ac.events != nil { | ||||||
| 		ac.events.Finish() | 		ac.events.Finish() | ||||||
| 		ac.events = nil | 		ac.events = nil | ||||||
|  | |||||||
| @ -30,6 +30,7 @@ import ( | |||||||
| 	"google.golang.org/grpc/credentials" | 	"google.golang.org/grpc/credentials" | ||||||
| 	"google.golang.org/grpc/keepalive" | 	"google.golang.org/grpc/keepalive" | ||||||
| 	"google.golang.org/grpc/naming" | 	"google.golang.org/grpc/naming" | ||||||
|  | 	_ "google.golang.org/grpc/resolver/passthrough" | ||||||
| 	"google.golang.org/grpc/test/leakcheck" | 	"google.golang.org/grpc/test/leakcheck" | ||||||
| 	"google.golang.org/grpc/testdata" | 	"google.golang.org/grpc/testdata" | ||||||
| ) | ) | ||||||
| @ -47,7 +48,7 @@ func TestConnectivityStates(t *testing.T) { | |||||||
| 	defer leakcheck.Check(t) | 	defer leakcheck.Check(t) | ||||||
| 	servers, resolver, cleanup := startServers(t, 2, math.MaxUint32) | 	servers, resolver, cleanup := startServers(t, 2, math.MaxUint32) | ||||||
| 	defer cleanup() | 	defer cleanup() | ||||||
| 	cc, err := Dial("foo.bar.com", WithBalancer(RoundRobin(resolver)), WithInsecure()) | 	cc, err := Dial("passthrough:///foo.bar.com", WithBalancer(RoundRobin(resolver)), WithInsecure()) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("Dial(\"foo.bar.com\", WithBalancer(_)) = _, %v, want _ <nil>", err) | 		t.Fatalf("Dial(\"foo.bar.com\", WithBalancer(_)) = _, %v, want _ <nil>", err) | ||||||
| 	} | 	} | ||||||
| @ -82,7 +83,7 @@ func TestConnectivityStates(t *testing.T) { | |||||||
|  |  | ||||||
| func TestDialTimeout(t *testing.T) { | func TestDialTimeout(t *testing.T) { | ||||||
| 	defer leakcheck.Check(t) | 	defer leakcheck.Check(t) | ||||||
| 	conn, err := Dial("Non-Existent.Server:80", WithTimeout(time.Millisecond), WithBlock(), WithInsecure()) | 	conn, err := Dial("passthrough:///Non-Existent.Server:80", WithTimeout(time.Millisecond), WithBlock(), WithInsecure()) | ||||||
| 	if err == nil { | 	if err == nil { | ||||||
| 		conn.Close() | 		conn.Close() | ||||||
| 	} | 	} | ||||||
| @ -97,7 +98,7 @@ func TestTLSDialTimeout(t *testing.T) { | |||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("Failed to create credentials %v", err) | 		t.Fatalf("Failed to create credentials %v", err) | ||||||
| 	} | 	} | ||||||
| 	conn, err := Dial("Non-Existent.Server:80", WithTransportCredentials(creds), WithTimeout(time.Millisecond), WithBlock()) | 	conn, err := Dial("passthrough:///Non-Existent.Server:80", WithTransportCredentials(creds), WithTimeout(time.Millisecond), WithBlock()) | ||||||
| 	if err == nil { | 	if err == nil { | ||||||
| 		conn.Close() | 		conn.Close() | ||||||
| 	} | 	} | ||||||
| @ -113,7 +114,7 @@ func TestDefaultAuthority(t *testing.T) { | |||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("Dial(_, _) = _, %v, want _, <nil>", err) | 		t.Fatalf("Dial(_, _) = _, %v, want _, <nil>", err) | ||||||
| 	} | 	} | ||||||
| 	conn.Close() | 	defer conn.Close() | ||||||
| 	if conn.authority != target { | 	if conn.authority != target { | ||||||
| 		t.Fatalf("%v.authority = %v, want %v", conn, conn.authority, target) | 		t.Fatalf("%v.authority = %v, want %v", conn, conn.authority, target) | ||||||
| 	} | 	} | ||||||
| @ -126,11 +127,11 @@ func TestTLSServerNameOverwrite(t *testing.T) { | |||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("Failed to create credentials %v", err) | 		t.Fatalf("Failed to create credentials %v", err) | ||||||
| 	} | 	} | ||||||
| 	conn, err := Dial("Non-Existent.Server:80", WithTransportCredentials(creds)) | 	conn, err := Dial("passthrough:///Non-Existent.Server:80", WithTransportCredentials(creds)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("Dial(_, _) = _, %v, want _, <nil>", err) | 		t.Fatalf("Dial(_, _) = _, %v, want _, <nil>", err) | ||||||
| 	} | 	} | ||||||
| 	conn.Close() | 	defer conn.Close() | ||||||
| 	if conn.authority != overwriteServerName { | 	if conn.authority != overwriteServerName { | ||||||
| 		t.Fatalf("%v.authority = %v, want %v", conn, conn.authority, overwriteServerName) | 		t.Fatalf("%v.authority = %v, want %v", conn, conn.authority, overwriteServerName) | ||||||
| 	} | 	} | ||||||
| @ -139,11 +140,11 @@ func TestTLSServerNameOverwrite(t *testing.T) { | |||||||
| func TestWithAuthority(t *testing.T) { | func TestWithAuthority(t *testing.T) { | ||||||
| 	defer leakcheck.Check(t) | 	defer leakcheck.Check(t) | ||||||
| 	overwriteServerName := "over.write.server.name" | 	overwriteServerName := "over.write.server.name" | ||||||
| 	conn, err := Dial("Non-Existent.Server:80", WithInsecure(), WithAuthority(overwriteServerName)) | 	conn, err := Dial("passthrough:///Non-Existent.Server:80", WithInsecure(), WithAuthority(overwriteServerName)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("Dial(_, _) = _, %v, want _, <nil>", err) | 		t.Fatalf("Dial(_, _) = _, %v, want _, <nil>", err) | ||||||
| 	} | 	} | ||||||
| 	conn.Close() | 	defer conn.Close() | ||||||
| 	if conn.authority != overwriteServerName { | 	if conn.authority != overwriteServerName { | ||||||
| 		t.Fatalf("%v.authority = %v, want %v", conn, conn.authority, overwriteServerName) | 		t.Fatalf("%v.authority = %v, want %v", conn, conn.authority, overwriteServerName) | ||||||
| 	} | 	} | ||||||
| @ -156,11 +157,11 @@ func TestWithAuthorityAndTLS(t *testing.T) { | |||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("Failed to create credentials %v", err) | 		t.Fatalf("Failed to create credentials %v", err) | ||||||
| 	} | 	} | ||||||
| 	conn, err := Dial("Non-Existent.Server:80", WithTransportCredentials(creds), WithAuthority("no.effect.authority")) | 	conn, err := Dial("passthrough:///Non-Existent.Server:80", WithTransportCredentials(creds), WithAuthority("no.effect.authority")) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("Dial(_, _) = _, %v, want _, <nil>", err) | 		t.Fatalf("Dial(_, _) = _, %v, want _, <nil>", err) | ||||||
| 	} | 	} | ||||||
| 	conn.Close() | 	defer conn.Close() | ||||||
| 	if conn.authority != overwriteServerName { | 	if conn.authority != overwriteServerName { | ||||||
| 		t.Fatalf("%v.authority = %v, want %v", conn, conn.authority, overwriteServerName) | 		t.Fatalf("%v.authority = %v, want %v", conn, conn.authority, overwriteServerName) | ||||||
| 	} | 	} | ||||||
| @ -231,11 +232,11 @@ func TestCredentialsMisuse(t *testing.T) { | |||||||
| 		t.Fatalf("Failed to create authenticator %v", err) | 		t.Fatalf("Failed to create authenticator %v", err) | ||||||
| 	} | 	} | ||||||
| 	// Two conflicting credential configurations | 	// Two conflicting credential configurations | ||||||
| 	if _, err := Dial("Non-Existent.Server:80", WithTransportCredentials(tlsCreds), WithBlock(), WithInsecure()); err != errCredentialsConflict { | 	if _, err := Dial("passthrough:///Non-Existent.Server:80", WithTransportCredentials(tlsCreds), WithBlock(), WithInsecure()); err != errCredentialsConflict { | ||||||
| 		t.Fatalf("Dial(_, _) = _, %v, want _, %v", err, errCredentialsConflict) | 		t.Fatalf("Dial(_, _) = _, %v, want _, %v", err, errCredentialsConflict) | ||||||
| 	} | 	} | ||||||
| 	// security info on insecure connection | 	// security info on insecure connection | ||||||
| 	if _, err := Dial("Non-Existent.Server:80", WithPerRPCCredentials(securePerRPCCredentials{}), WithBlock(), WithInsecure()); err != errTransportCredentialsMissing { | 	if _, err := Dial("passthrough:///Non-Existent.Server:80", WithPerRPCCredentials(securePerRPCCredentials{}), WithBlock(), WithInsecure()); err != errTransportCredentialsMissing { | ||||||
| 		t.Fatalf("Dial(_, _) = _, %v, want _, %v", err, errTransportCredentialsMissing) | 		t.Fatalf("Dial(_, _) = _, %v, want _, %v", err, errTransportCredentialsMissing) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| @ -263,10 +264,11 @@ func TestWithBackoffMaxDelay(t *testing.T) { | |||||||
|  |  | ||||||
| func testBackoffConfigSet(t *testing.T, expected *BackoffConfig, opts ...DialOption) { | func testBackoffConfigSet(t *testing.T, expected *BackoffConfig, opts ...DialOption) { | ||||||
| 	opts = append(opts, WithInsecure()) | 	opts = append(opts, WithInsecure()) | ||||||
| 	conn, err := Dial("foo:80", opts...) | 	conn, err := Dial("passthrough:///foo:80", opts...) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("unexpected error dialing connection: %v", err) | 		t.Fatalf("unexpected error dialing connection: %v", err) | ||||||
| 	} | 	} | ||||||
|  | 	defer conn.Close() | ||||||
|  |  | ||||||
| 	if conn.dopts.bs == nil { | 	if conn.dopts.bs == nil { | ||||||
| 		t.Fatalf("backoff config not set") | 		t.Fatalf("backoff config not set") | ||||||
| @ -280,39 +282,6 @@ func testBackoffConfigSet(t *testing.T, expected *BackoffConfig, opts ...DialOpt | |||||||
| 	if actual != *expected { | 	if actual != *expected { | ||||||
| 		t.Fatalf("unexpected backoff config on connection: %v, want %v", actual, expected) | 		t.Fatalf("unexpected backoff config on connection: %v, want %v", actual, expected) | ||||||
| 	} | 	} | ||||||
| 	conn.Close() |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type testErr struct { |  | ||||||
| 	temp bool |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (e *testErr) Error() string { |  | ||||||
| 	return "test error" |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (e *testErr) Temporary() bool { |  | ||||||
| 	return e.temp |  | ||||||
| } |  | ||||||
|  |  | ||||||
| var nonTemporaryError = &testErr{false} |  | ||||||
|  |  | ||||||
| func nonTemporaryErrorDialer(addr string, timeout time.Duration) (net.Conn, error) { |  | ||||||
| 	return nil, nonTemporaryError |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestDialWithBlockErrorOnNonTemporaryErrorDialer(t *testing.T) { |  | ||||||
| 	defer leakcheck.Check(t) |  | ||||||
| 	ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) |  | ||||||
| 	defer cancel() |  | ||||||
| 	if _, err := DialContext(ctx, "", WithInsecure(), WithDialer(nonTemporaryErrorDialer), WithBlock(), FailOnNonTempDialError(true)); err != nonTemporaryError { |  | ||||||
| 		t.Fatalf("Dial(%q) = %v, want %v", "", err, nonTemporaryError) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// Without FailOnNonTempDialError, gRPC will retry to connect, and dial should exit with time out error. |  | ||||||
| 	if _, err := DialContext(ctx, "", WithInsecure(), WithDialer(nonTemporaryErrorDialer), WithBlock()); err != context.DeadlineExceeded { |  | ||||||
| 		t.Fatalf("Dial(%q) = %v, want %v", "", err, context.DeadlineExceeded) |  | ||||||
| 	} |  | ||||||
| } | } | ||||||
|  |  | ||||||
| // emptyBalancer returns an empty set of servers. | // emptyBalancer returns an empty set of servers. | ||||||
|  | |||||||
							
								
								
									
										12
									
								
								pickfirst.go
									
									
									
									
									
								
							
							
						
						
									
										12
									
								
								pickfirst.go
									
									
									
									
									
								
							| @ -57,14 +57,20 @@ func (b *pickfirstBalancer) HandleResolvedAddrs(addrs []resolver.Address, err er | |||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 		b.cc.UpdateBalancerState(connectivity.Idle, &picker{sc: b.sc}) | 		b.cc.UpdateBalancerState(connectivity.Idle, &picker{sc: b.sc}) | ||||||
|  | 		b.sc.Connect() | ||||||
| 	} else { | 	} else { | ||||||
| 		b.sc.UpdateAddresses(addrs) | 		b.sc.UpdateAddresses(addrs) | ||||||
|  | 		b.sc.Connect() | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func (b *pickfirstBalancer) HandleSubConnStateChange(sc balancer.SubConn, s connectivity.State) { | func (b *pickfirstBalancer) HandleSubConnStateChange(sc balancer.SubConn, s connectivity.State) { | ||||||
| 	grpclog.Infof("pickfirstBalancer: HandleSubConnStateChange: %p, %v", sc, s) | 	grpclog.Infof("pickfirstBalancer: HandleSubConnStateChange: %p, %v", sc, s) | ||||||
| 	if b.sc != sc || s == connectivity.Shutdown { | 	if b.sc != sc { | ||||||
|  | 		grpclog.Infof("pickfirstBalancer: ignored state change because sc is not recognized") | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	if s == connectivity.Shutdown { | ||||||
| 		b.sc = nil | 		b.sc = nil | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| @ -93,3 +99,7 @@ func (p *picker) Pick(ctx context.Context, opts balancer.PickOptions) (balancer. | |||||||
| 	} | 	} | ||||||
| 	return p.sc, nil, nil | 	return p.sc, nil, nil | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func init() { | ||||||
|  | 	balancer.Register(newPickfirstBuilder()) | ||||||
|  | } | ||||||
|  | |||||||
| @ -24,7 +24,7 @@ var ( | |||||||
| 	// m is a map from scheme to resolver builder. | 	// m is a map from scheme to resolver builder. | ||||||
| 	m = make(map[string]Builder) | 	m = make(map[string]Builder) | ||||||
| 	// defaultScheme is the default scheme to use. | 	// defaultScheme is the default scheme to use. | ||||||
| 	defaultScheme string | 	defaultScheme = "dns" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| // TODO(bar) install dns resolver in init(){}. | // TODO(bar) install dns resolver in init(){}. | ||||||
|  | |||||||
| @ -19,6 +19,7 @@ | |||||||
| package grpc | package grpc | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"fmt" | ||||||
| 	"strings" | 	"strings" | ||||||
|  |  | ||||||
| 	"google.golang.org/grpc/grpclog" | 	"google.golang.org/grpc/grpclog" | ||||||
| @ -56,19 +57,13 @@ func parseTarget(target string) (ret resolver.Target) { | |||||||
| // newCCResolverWrapper parses cc.target for scheme and gets the resolver | // newCCResolverWrapper parses cc.target for scheme and gets the resolver | ||||||
| // builder for this scheme. It then builds the resolver and starts the | // builder for this scheme. It then builds the resolver and starts the | ||||||
| // monitoring goroutine for it. | // monitoring goroutine for it. | ||||||
| // |  | ||||||
| // This function could return nil, nil, in tests for old behaviors. |  | ||||||
| // TODO(bar) never return nil, nil when DNS becomes the default resolver. |  | ||||||
| func newCCResolverWrapper(cc *ClientConn) (*ccResolverWrapper, error) { | func newCCResolverWrapper(cc *ClientConn) (*ccResolverWrapper, error) { | ||||||
| 	target := parseTarget(cc.target) | 	target := parseTarget(cc.target) | ||||||
| 	grpclog.Infof("dialing to target with scheme: %q", target.Scheme) | 	grpclog.Infof("dialing to target with scheme: %q", target.Scheme) | ||||||
|  |  | ||||||
| 	rb := resolver.Get(target.Scheme) | 	rb := resolver.Get(target.Scheme) | ||||||
| 	if rb == nil { | 	if rb == nil { | ||||||
| 		// TODO(bar) return error when DNS becomes the default (implemented and | 		return nil, fmt.Errorf("could not get resolver for scheme: %q", target.Scheme) | ||||||
| 		// registered by DNS package). |  | ||||||
| 		grpclog.Infof("could not get resolver for scheme: %q", target.Scheme) |  | ||||||
| 		return nil, nil |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	ccr := &ccResolverWrapper{ | 	ccr := &ccResolverWrapper{ | ||||||
| @ -100,13 +95,19 @@ func (ccr *ccResolverWrapper) watcher() { | |||||||
|  |  | ||||||
| 		select { | 		select { | ||||||
| 		case addrs := <-ccr.addrCh: | 		case addrs := <-ccr.addrCh: | ||||||
| 			grpclog.Infof("ccResolverWrapper: sending new addresses to balancer wrapper: %v", addrs) | 			select { | ||||||
| 			// TODO(bar switching) this should never be nil. Pickfirst should be default. | 			case <-ccr.done: | ||||||
| 			if ccr.cc.balancerWrapper != nil { | 				return | ||||||
| 				// TODO(bar switching) create balancer if it's nil? | 			default: | ||||||
| 				ccr.cc.balancerWrapper.handleResolvedAddrs(addrs, nil) |  | ||||||
| 			} | 			} | ||||||
|  | 			grpclog.Infof("ccResolverWrapper: sending new addresses to cc: %v", addrs) | ||||||
|  | 			ccr.cc.handleResolvedAddrs(addrs, nil) | ||||||
| 		case sc := <-ccr.scCh: | 		case sc := <-ccr.scCh: | ||||||
|  | 			select { | ||||||
|  | 			case <-ccr.done: | ||||||
|  | 				return | ||||||
|  | 			default: | ||||||
|  | 			} | ||||||
| 			grpclog.Infof("ccResolverWrapper: got new service config: %v", sc) | 			grpclog.Infof("ccResolverWrapper: got new service config: %v", sc) | ||||||
| 		case <-ccr.done: | 		case <-ccr.done: | ||||||
| 			return | 			return | ||||||
|  | |||||||
| @ -559,8 +559,6 @@ func (te *test) startServer(ts testpb.TestServiceServer) { | |||||||
| 			te.t.Fatalf("Failed to generate credentials %v", err) | 			te.t.Fatalf("Failed to generate credentials %v", err) | ||||||
| 		} | 		} | ||||||
| 		sopts = append(sopts, grpc.Creds(creds)) | 		sopts = append(sopts, grpc.Creds(creds)) | ||||||
| 	case "clientAlwaysFailCred": |  | ||||||
| 		sopts = append(sopts, grpc.Creds(clientAlwaysFailCred{})) |  | ||||||
| 	case "clientTimeoutCreds": | 	case "clientTimeoutCreds": | ||||||
| 		sopts = append(sopts, grpc.Creds(&clientTimeoutCreds{})) | 		sopts = append(sopts, grpc.Creds(&clientTimeoutCreds{})) | ||||||
| 	} | 	} | ||||||
| @ -634,15 +632,13 @@ func (te *test) clientConn() *grpc.ClientConn { | |||||||
| 			te.t.Fatalf("Failed to load credentials: %v", err) | 			te.t.Fatalf("Failed to load credentials: %v", err) | ||||||
| 		} | 		} | ||||||
| 		opts = append(opts, grpc.WithTransportCredentials(creds)) | 		opts = append(opts, grpc.WithTransportCredentials(creds)) | ||||||
| 	case "clientAlwaysFailCred": |  | ||||||
| 		opts = append(opts, grpc.WithTransportCredentials(clientAlwaysFailCred{})) |  | ||||||
| 	case "clientTimeoutCreds": | 	case "clientTimeoutCreds": | ||||||
| 		opts = append(opts, grpc.WithTransportCredentials(&clientTimeoutCreds{})) | 		opts = append(opts, grpc.WithTransportCredentials(&clientTimeoutCreds{})) | ||||||
| 	default: | 	default: | ||||||
| 		opts = append(opts, grpc.WithInsecure()) | 		opts = append(opts, grpc.WithInsecure()) | ||||||
| 	} | 	} | ||||||
| 	// TODO(bar) switch balancer case "pickfirst". | 	// TODO(bar) switch balancer case "pickfirst". | ||||||
| 	var scheme string | 	scheme := "passthrough:///" | ||||||
| 	switch te.e.balancer { | 	switch te.e.balancer { | ||||||
| 	case "v1": | 	case "v1": | ||||||
| 		opts = append(opts, grpc.WithBalancer(grpc.RoundRobin(nil))) | 		opts = append(opts, grpc.WithBalancer(grpc.RoundRobin(nil))) | ||||||
| @ -652,7 +648,6 @@ func (te *test) clientConn() *grpc.ClientConn { | |||||||
| 			te.t.Fatalf("got nil when trying to get roundrobin balancer builder") | 			te.t.Fatalf("got nil when trying to get roundrobin balancer builder") | ||||||
| 		} | 		} | ||||||
| 		opts = append(opts, grpc.WithBalancerBuilder(rr)) | 		opts = append(opts, grpc.WithBalancerBuilder(rr)) | ||||||
| 		scheme = "passthrough:///" |  | ||||||
| 	} | 	} | ||||||
| 	if te.clientInitialWindowSize > 0 { | 	if te.clientInitialWindowSize > 0 { | ||||||
| 		opts = append(opts, grpc.WithInitialWindowSize(te.clientInitialWindowSize)) | 		opts = append(opts, grpc.WithInitialWindowSize(te.clientInitialWindowSize)) | ||||||
| @ -670,6 +665,9 @@ func (te *test) clientConn() *grpc.ClientConn { | |||||||
| 		// Only do a blocking dial if server is up. | 		// Only do a blocking dial if server is up. | ||||||
| 		opts = append(opts, grpc.WithBlock()) | 		opts = append(opts, grpc.WithBlock()) | ||||||
| 	} | 	} | ||||||
|  | 	if te.srvAddr == "" { | ||||||
|  | 		te.srvAddr = "client.side.only.test" | ||||||
|  | 	} | ||||||
| 	var err error | 	var err error | ||||||
| 	te.cc, err = grpc.Dial(scheme+te.srvAddr, opts...) | 	te.cc, err = grpc.Dial(scheme+te.srvAddr, opts...) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| @ -4068,44 +4066,6 @@ func testClientRequestBodyErrorCancelStreamingInput(t *testing.T, e env) { | |||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
|  |  | ||||||
| const clientAlwaysFailCredErrorMsg = "clientAlwaysFailCred always fails" |  | ||||||
|  |  | ||||||
| var errClientAlwaysFailCred = errors.New(clientAlwaysFailCredErrorMsg) |  | ||||||
|  |  | ||||||
| type clientAlwaysFailCred struct{} |  | ||||||
|  |  | ||||||
| func (c clientAlwaysFailCred) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { |  | ||||||
| 	return nil, nil, errClientAlwaysFailCred |  | ||||||
| } |  | ||||||
| func (c clientAlwaysFailCred) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { |  | ||||||
| 	return rawConn, nil, nil |  | ||||||
| } |  | ||||||
| func (c clientAlwaysFailCred) Info() credentials.ProtocolInfo { |  | ||||||
| 	return credentials.ProtocolInfo{} |  | ||||||
| } |  | ||||||
| func (c clientAlwaysFailCred) Clone() credentials.TransportCredentials { |  | ||||||
| 	return nil |  | ||||||
| } |  | ||||||
| func (c clientAlwaysFailCred) OverrideServerName(s string) error { |  | ||||||
| 	return nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestDialWithBlockErrorOnBadCertificates(t *testing.T) { |  | ||||||
| 	te := newTest(t, env{name: "bad-cred", network: "tcp", security: "clientAlwaysFailCred", balancer: "v1"}) |  | ||||||
| 	te.startServer(&testServer{security: te.e.security}) |  | ||||||
| 	defer te.tearDown() |  | ||||||
|  |  | ||||||
| 	var ( |  | ||||||
| 		err  error |  | ||||||
| 		opts []grpc.DialOption |  | ||||||
| 	) |  | ||||||
| 	opts = append(opts, grpc.WithTransportCredentials(clientAlwaysFailCred{}), grpc.WithBlock()) |  | ||||||
| 	te.cc, err = grpc.Dial(te.srvAddr, opts...) |  | ||||||
| 	if err != errClientAlwaysFailCred { |  | ||||||
| 		te.t.Fatalf("Dial(%q) = %v, want %v", te.srvAddr, err, errClientAlwaysFailCred) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type clientTimeoutCreds struct { | type clientTimeoutCreds struct { | ||||||
| 	timeoutReturned bool | 	timeoutReturned bool | ||||||
| } | } | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user
	 Menghan Li
					Menghan Li