diff --git a/balancer.go b/balancer.go index e217a207..9d943fba 100644 --- a/balancer.go +++ b/balancer.go @@ -38,6 +38,7 @@ import ( "sync" "golang.org/x/net/context" + "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" "google.golang.org/grpc/grpclog" "google.golang.org/grpc/naming" @@ -315,7 +316,7 @@ func (rr *roundRobin) Get(ctx context.Context, opts BalancerGetOptions) (addr Ad if !opts.BlockingWait { if len(rr.addrs) == 0 { rr.mu.Unlock() - err = fmt.Errorf("there is no address available") + err = Errorf(codes.Unavailable, "there is no address available") return } // Returns the next addr on rr.addrs for failfast RPCs. diff --git a/clientconn.go b/clientconn.go index c81e3892..61674729 100644 --- a/clientconn.go +++ b/clientconn.go @@ -807,7 +807,7 @@ func (ac *addrConn) transportMonitor() { } // wait blocks until i) the new transport is up or ii) ctx is done or iii) ac is closed or -// iv) transport is in TransientFailure and there's no balancer/failfast is true. +// iv) transport is in TransientFailure and there is a balancer/failfast is true. func (ac *addrConn) wait(ctx context.Context, hasBalancer, failfast bool) (transport.ClientTransport, error) { for { ac.mu.Lock() diff --git a/grpclb/grpclb.go b/grpclb/grpclb.go index 78137453..996d27ae 100644 --- a/grpclb/grpclb.go +++ b/grpclb/grpclb.go @@ -43,6 +43,7 @@ import ( "golang.org/x/net/context" "google.golang.org/grpc" + "google.golang.org/grpc/codes" lbpb "google.golang.org/grpc/grpclb/grpc_lb_v1" "google.golang.org/grpc/grpclog" "google.golang.org/grpc/metadata" @@ -84,8 +85,10 @@ type remoteBalancerInfo struct { // addrInfo consists of the information of a backend server. type addrInfo struct { - addr grpc.Address - connected bool + addr grpc.Address + connected bool + // dropRequest indicates whether a particular RPC which chooses this address + // should be dropped. dropRequest bool } @@ -96,7 +99,7 @@ type balancer struct { w naming.Watcher addrCh chan []grpc.Address rbs []remoteBalancerInfo - addrs []addrInfo + addrs []*addrInfo next int waitCh chan struct{} done bool @@ -180,7 +183,7 @@ func (b *balancer) watchAddrUpdates(w naming.Watcher, ch chan remoteBalancerInfo func (b *balancer) processServerList(l *lbpb.ServerList, seq int) { servers := l.GetServers() var ( - sl []addrInfo + sl []*addrInfo addrs []grpc.Address ) for _, s := range servers { @@ -190,9 +193,9 @@ func (b *balancer) processServerList(l *lbpb.ServerList, seq int) { Addr: fmt.Sprintf("%s:%d", s.IpAddress, s.Port), Metadata: &md, } - sl = append(sl, addrInfo{ - addr: addr, - // TODO: Support dropRequest feature. + sl = append(sl, &addrInfo{ + addr: addr, + dropRequest: s.DropRequest, }) addrs = append(addrs, addr) } @@ -306,8 +309,6 @@ func (b *balancer) Start(target string, config grpc.BalancerConfig) error { return } // Talk to the remote load balancer to get the server list. - // - // TODO: override the server name in creds using Metadata in addr. var err error creds := config.DialCreds if creds == nil { @@ -364,7 +365,7 @@ func (b *balancer) Up(addr grpc.Address) func(error) { } a.connected = true } - if a.connected { + if a.connected && !a.dropRequest { cnt++ } } @@ -396,10 +397,18 @@ func (b *balancer) Get(ctx context.Context, opts grpc.BalancerGetOptions) (addr a := b.addrs[next] next = (next + 1) % len(b.addrs) if a.connected { - addr = a.addr - b.next = next - b.mu.Unlock() - return + if !a.dropRequest { + addr = a.addr + b.next = next + b.mu.Unlock() + return + } + if !opts.BlockingWait { + b.next = next + b.mu.Unlock() + err = grpc.Errorf(codes.Unavailable, "%s drops requests", a.addr.Addr) + return + } } if next == b.next { // Has iterated all the possible address but none is connected. @@ -410,7 +419,7 @@ func (b *balancer) Get(ctx context.Context, opts grpc.BalancerGetOptions) (addr if !opts.BlockingWait { if len(b.addrs) == 0 { b.mu.Unlock() - err = fmt.Errorf("there is no address available") + err = grpc.Errorf(codes.Unavailable, "there is no address available") return } // Returns the next addr on b.addrs for a failfast RPC. @@ -449,10 +458,18 @@ func (b *balancer) Get(ctx context.Context, opts grpc.BalancerGetOptions) (addr a := b.addrs[next] next = (next + 1) % len(b.addrs) if a.connected { - addr = a.addr - b.next = next - b.mu.Unlock() - return + if !a.dropRequest { + addr = a.addr + b.next = next + b.mu.Unlock() + return + } + if !opts.BlockingWait { + b.next = next + b.mu.Unlock() + err = grpc.Errorf(codes.Unavailable, "drop requests for the addreess %s", a.addr.Addr) + return + } } if next == b.next { // Has iterated all the possible address but none is connected. diff --git a/grpclb/grpclb_test.go b/grpclb/grpclb_test.go index 658e7225..3215beaf 100644 --- a/grpclb/grpclb_test.go +++ b/grpclb/grpclb_test.go @@ -293,3 +293,149 @@ func TestGRPCLB(t *testing.T) { } cc.Close() } + +func TestDropRequest(t *testing.T) { + // Start 2 backends. + beLis1, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("Failed to listen %v", err) + } + beAddr1 := strings.Split(beLis1.Addr().String(), ":") + bePort1, err := strconv.Atoi(beAddr1[1]) + + beLis2, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("Failed to listen %v", err) + } + beAddr2 := strings.Split(beLis2.Addr().String(), ":") + bePort2, err := strconv.Atoi(beAddr2[1]) + + backends := startBackends(t, besn, beLis1, beLis2) + defer stopBackends(backends) + + // Start a load balancer. + lbLis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("Failed to create the listener for the load balancer %v", err) + } + lbCreds := &serverNameCheckCreds{ + sn: lbsn, + } + lb := grpc.NewServer(grpc.Creds(lbCreds)) + if err != nil { + t.Fatalf("Failed to generate the port number %v", err) + } + var bes []*lbpb.Server + be := &lbpb.Server{ + IpAddress: []byte(beAddr1[0]), + Port: int32(bePort1), + LoadBalanceToken: lbToken, + DropRequest: true, + } + bes = append(bes, be) + be = &lbpb.Server{ + IpAddress: []byte(beAddr2[0]), + Port: int32(bePort2), + LoadBalanceToken: lbToken, + DropRequest: false, + } + bes = append(bes, be) + sl := &lbpb.ServerList{ + Servers: bes, + } + ls := newRemoteBalancer(sl) + lbpb.RegisterLoadBalancerServer(lb, ls) + go func() { + lb.Serve(lbLis) + }() + defer func() { + ls.stop() + lb.Stop() + }() + creds := serverNameCheckCreds{ + expected: besn, + } + ctx, _ := context.WithTimeout(context.Background(), 10*time.Second) + cc, err := grpc.DialContext(ctx, besn, grpc.WithBalancer(Balancer(&testNameResolver{ + addr: lbLis.Addr().String(), + })), grpc.WithBlock(), grpc.WithTransportCredentials(&creds)) + if err != nil { + t.Fatalf("Failed to dial to the backend %v", err) + } + // The 1st fail-fast RPC should fail because the 1st backend has DropRequest set to true. + helloC := hwpb.NewGreeterClient(cc) + if _, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}); grpc.Code(err) != codes.Unavailable { + t.Fatalf("%v.SayHello(_, _) = _, %v, want _, %s", helloC, err, codes.Unavailable) + } + // The 2nd fail-fast RPC should succeed since it chooses the non-drop-request backend according + // to the round robin policy. + if _, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}); err != nil { + t.Fatalf("%v.SayHello(_, _) = _, %v, want _, ", helloC, err) + } + // The 3nd non-fail-fast RPC should succeed. + if _, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}, grpc.FailFast(false)); err != nil { + t.Fatalf("%v.SayHello(_, _) = _, %v, want _, ", helloC, err) + } + cc.Close() +} + +func TestDropRequestFailedNonFailFast(t *testing.T) { + // Start a backend. + beLis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("Failed to listen %v", err) + } + beAddr := strings.Split(beLis.Addr().String(), ":") + bePort, err := strconv.Atoi(beAddr[1]) + backends := startBackends(t, besn, beLis) + defer stopBackends(backends) + + // Start a load balancer. + lbLis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("Failed to create the listener for the load balancer %v", err) + } + lbCreds := &serverNameCheckCreds{ + sn: lbsn, + } + lb := grpc.NewServer(grpc.Creds(lbCreds)) + if err != nil { + t.Fatalf("Failed to generate the port number %v", err) + } + be := &lbpb.Server{ + IpAddress: []byte(beAddr[0]), + Port: int32(bePort), + LoadBalanceToken: lbToken, + DropRequest: true, + } + var bes []*lbpb.Server + bes = append(bes, be) + sl := &lbpb.ServerList{ + Servers: bes, + } + ls := newRemoteBalancer(sl) + lbpb.RegisterLoadBalancerServer(lb, ls) + go func() { + lb.Serve(lbLis) + }() + defer func() { + ls.stop() + lb.Stop() + }() + creds := serverNameCheckCreds{ + expected: besn, + } + ctx, _ := context.WithTimeout(context.Background(), 10*time.Second) + cc, err := grpc.DialContext(ctx, besn, grpc.WithBalancer(Balancer(&testNameResolver{ + addr: lbLis.Addr().String(), + })), grpc.WithBlock(), grpc.WithTransportCredentials(&creds)) + if err != nil { + t.Fatalf("Failed to dial to the backend %v", err) + } + helloC := hwpb.NewGreeterClient(cc) + ctx, _ = context.WithTimeout(context.Background(), 10*time.Millisecond) + if _, err := helloC.SayHello(ctx, &hwpb.HelloRequest{Name: "grpc"}, grpc.FailFast(false)); grpc.Code(err) != codes.DeadlineExceeded { + t.Fatalf("%v.SayHello(_, _) = _, %v, want _, %s", helloC, err, codes.DeadlineExceeded) + } + cc.Close() +}