Client load report for grpclb. (#1200)

This commit is contained in:
Menghan Li
2017-04-27 10:43:38 -07:00
committed by GitHub
parent a7fee9febf
commit 277e90a432
8 changed files with 437 additions and 66 deletions

56
call.go
View File

@ -93,11 +93,7 @@ func recvResponse(ctx context.Context, dopts dialOptions, t transport.ClientTran
}
// sendRequest writes out various information of an RPC such as Context and Message.
func sendRequest(ctx context.Context, dopts dialOptions, compressor Compressor, callHdr *transport.CallHdr, t transport.ClientTransport, args interface{}, opts *transport.Options) (_ *transport.Stream, err error) {
stream, err := t.NewStream(ctx, callHdr)
if err != nil {
return nil, err
}
func sendRequest(ctx context.Context, dopts dialOptions, compressor Compressor, callHdr *transport.CallHdr, stream *transport.Stream, t transport.ClientTransport, args interface{}, opts *transport.Options) (err error) {
defer func() {
if err != nil {
// If err is connection error, t will be closed, no need to close stream here.
@ -120,7 +116,7 @@ func sendRequest(ctx context.Context, dopts dialOptions, compressor Compressor,
}
outBuf, err := encode(dopts.codec, args, compressor, cbuf, outPayload)
if err != nil {
return nil, Errorf(codes.Internal, "grpc: %v", err)
return Errorf(codes.Internal, "grpc: %v", err)
}
err = t.Write(stream, outBuf, opts)
if err == nil && outPayload != nil {
@ -131,10 +127,10 @@ func sendRequest(ctx context.Context, dopts dialOptions, compressor Compressor,
// does not exist.) so that t.Write could get io.EOF from wait(...). Leave the following
// recvResponse to get the final status.
if err != nil && err != io.EOF {
return nil, err
return err
}
// Sent successfully.
return stream, nil
return nil
}
// Invoke sends the RPC request on the wire and returns after response is received.
@ -183,6 +179,7 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
}
}()
}
ctx = newContextWithRPCInfo(ctx)
sh := cc.dopts.copts.StatsHandler
if sh != nil {
ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method})
@ -246,19 +243,35 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
if c.traceInfo.tr != nil {
c.traceInfo.tr.LazyLog(&payload{sent: true, msg: args}, true)
}
stream, err = sendRequest(ctx, cc.dopts, cc.dopts.cp, callHdr, t, args, topts)
stream, err = t.NewStream(ctx, callHdr)
if err != nil {
if put != nil {
if _, ok := err.(transport.ConnectionError); ok {
// If error is connection error, transport was sending data on wire,
// and we are not sure if anything has been sent on wire.
// If error is not connection error, we are sure nothing has been sent.
updateRPCInfoInContext(ctx, rpcInfo{bytesSent: true, bytesReceived: false})
}
put()
}
if _, ok := err.(transport.ConnectionError); (ok || err == transport.ErrStreamDrain) && !c.failFast {
continue
}
return toRPCErr(err)
}
err = sendRequest(ctx, cc.dopts, cc.dopts.cp, callHdr, stream, t, args, topts)
if err != nil {
if put != nil {
updateRPCInfoInContext(ctx, rpcInfo{
bytesSent: stream.BytesSent(),
bytesReceived: stream.BytesReceived(),
})
put()
put = nil
}
// Retry a non-failfast RPC when
// i) there is a connection error; or
// ii) the server started to drain before this RPC was initiated.
if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain {
if c.failFast {
return toRPCErr(err)
}
if _, ok := err.(transport.ConnectionError); (ok || err == transport.ErrStreamDrain) && !c.failFast {
continue
}
return toRPCErr(err)
@ -266,13 +279,13 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
err = recvResponse(ctx, cc.dopts, t, &c, stream, reply)
if err != nil {
if put != nil {
updateRPCInfoInContext(ctx, rpcInfo{
bytesSent: stream.BytesSent(),
bytesReceived: stream.BytesReceived(),
})
put()
put = nil
}
if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain {
if c.failFast {
return toRPCErr(err)
}
if _, ok := err.(transport.ConnectionError); (ok || err == transport.ErrStreamDrain) && !c.failFast {
continue
}
return toRPCErr(err)
@ -282,8 +295,11 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
}
t.CloseStream(stream, nil)
if put != nil {
updateRPCInfoInContext(ctx, rpcInfo{
bytesSent: stream.BytesSent(),
bytesReceived: stream.BytesReceived(),
})
put()
put = nil
}
return stream.Status().Err()
}

View File

@ -669,6 +669,7 @@ func (cc *ClientConn) getTransport(ctx context.Context, opts BalancerGetOptions)
}
if !ok {
if put != nil {
updateRPCInfoInContext(ctx, rpcInfo{bytesSent: false, bytesReceived: false})
put()
}
return nil, nil, errConnClosing
@ -676,6 +677,7 @@ func (cc *ClientConn) getTransport(ctx context.Context, opts BalancerGetOptions)
t, err := ac.wait(ctx, cc.dopts.balancer != nil, !opts.BlockingWait)
if err != nil {
if put != nil {
updateRPCInfoInContext(ctx, rpcInfo{bytesSent: false, bytesReceived: false})
put()
}
return nil, nil, err

View File

@ -145,6 +145,8 @@ type balancer struct {
done bool
expTimer *time.Timer
rand *rand.Rand
clientStats lbpb.ClientStats
}
func (b *balancer) watchAddrUpdates(w naming.Watcher, ch chan []remoteBalancerInfo) error {
@ -281,6 +283,34 @@ func (b *balancer) processServerList(l *lbpb.ServerList, seq int) {
return
}
func (b *balancer) sendLoadReport(s *balanceLoadClientStream, interval time.Duration, done <-chan struct{}) {
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
case <-done:
return
}
b.mu.Lock()
stats := b.clientStats
b.clientStats = lbpb.ClientStats{} // Clear the stats.
b.mu.Unlock()
t := time.Now()
stats.Timestamp = &lbpb.Timestamp{
Seconds: t.Unix(),
Nanos: int32(t.Nanosecond()),
}
if err := s.Send(&lbpb.LoadBalanceRequest{
LoadBalanceRequestType: &lbpb.LoadBalanceRequest_ClientStats{
ClientStats: &stats,
},
}); err != nil {
return
}
}
}
func (b *balancer) callRemoteBalancer(lbc *loadBalancerClient, seq int) (retry bool) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
@ -322,6 +352,14 @@ func (b *balancer) callRemoteBalancer(lbc *loadBalancerClient, seq int) (retry b
grpclog.Println("TODO: Delegation is not supported yet.")
return
}
streamDone := make(chan struct{})
defer close(streamDone)
b.mu.Lock()
b.clientStats = lbpb.ClientStats{} // Clear client stats.
b.mu.Unlock()
if d := convertDuration(initResp.ClientStatsReportInterval); d > 0 {
go b.sendLoadReport(stream, d, streamDone)
}
// Retrieve the server list.
for {
reply, err := stream.Recv()
@ -538,7 +576,32 @@ func (b *balancer) Get(ctx context.Context, opts BalancerGetOptions) (addr Addre
err = ErrClientConnClosing
return
}
seq := b.seq
defer func() {
if err != nil {
return
}
put = func() {
s, ok := rpcInfoFromContext(ctx)
if !ok {
return
}
b.mu.Lock()
defer b.mu.Unlock()
if b.done || seq < b.seq {
return
}
b.clientStats.NumCallsFinished++
if !s.bytesSent {
b.clientStats.NumCallsFinishedWithClientFailedToSend++
} else if s.bytesReceived {
b.clientStats.NumCallsFinishedKnownReceived++
}
}
}()
b.clientStats.NumCallsStarted++
if len(b.addrs) > 0 {
if b.next >= len(b.addrs) {
b.next = 0
@ -556,6 +619,13 @@ func (b *balancer) Get(ctx context.Context, opts BalancerGetOptions) (addr Addre
}
if !opts.BlockingWait {
b.next = next
if a.dropForLoadBalancing {
b.clientStats.NumCallsFinished++
b.clientStats.NumCallsFinishedWithDropForLoadBalancing++
} else if a.dropForRateLimiting {
b.clientStats.NumCallsFinished++
b.clientStats.NumCallsFinishedWithDropForRateLimiting++
}
b.mu.Unlock()
err = Errorf(codes.Unavailable, "%s drops requests", a.addr.Addr)
return
@ -569,6 +639,8 @@ func (b *balancer) Get(ctx context.Context, opts BalancerGetOptions) (addr Addre
}
if !opts.BlockingWait {
if len(b.addrs) == 0 {
b.clientStats.NumCallsFinished++
b.clientStats.NumCallsFinishedWithClientFailedToSend++
b.mu.Unlock()
err = Errorf(codes.Unavailable, "there is no address available")
return
@ -590,11 +662,17 @@ func (b *balancer) Get(ctx context.Context, opts BalancerGetOptions) (addr Addre
for {
select {
case <-ctx.Done():
b.mu.Lock()
b.clientStats.NumCallsFinished++
b.clientStats.NumCallsFinishedWithClientFailedToSend++
b.mu.Unlock()
err = ctx.Err()
return
case <-ch:
b.mu.Lock()
if b.done {
b.clientStats.NumCallsFinished++
b.clientStats.NumCallsFinishedWithClientFailedToSend++
b.mu.Unlock()
err = ErrClientConnClosing
return
@ -617,6 +695,13 @@ func (b *balancer) Get(ctx context.Context, opts BalancerGetOptions) (addr Addre
}
if !opts.BlockingWait {
b.next = next
if a.dropForLoadBalancing {
b.clientStats.NumCallsFinished++
b.clientStats.NumCallsFinishedWithDropForLoadBalancing++
} else if a.dropForRateLimiting {
b.clientStats.NumCallsFinished++
b.clientStats.NumCallsFinishedWithDropForRateLimiting++
}
b.mu.Unlock()
err = Errorf(codes.Unavailable, "drop requests for the addreess %s", a.addr.Addr)
return

View File

@ -40,6 +40,7 @@ import (
"net"
"strconv"
"strings"
"sync"
"testing"
"time"
@ -47,10 +48,10 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
hwpb "google.golang.org/grpc/examples/helloworld/helloworld"
lbpb "google.golang.org/grpc/grpclb/grpc_lb_v1"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/naming"
testpb "google.golang.org/grpc/test/grpc_testing"
)
var (
@ -172,7 +173,10 @@ func (c *serverNameCheckCreds) OverrideServerName(s string) error {
type remoteBalancer struct {
sls []*lbpb.ServerList
intervals []time.Duration
statsDura time.Duration
done chan struct{}
mu sync.Mutex
stats lbpb.ClientStats
}
func newRemoteBalancer(sls []*lbpb.ServerList, intervals []time.Duration) *remoteBalancer {
@ -198,12 +202,36 @@ func (b *remoteBalancer) BalanceLoad(stream *loadBalancerBalanceLoadServer) erro
}
resp := &lbpb.LoadBalanceResponse{
LoadBalanceResponseType: &lbpb.LoadBalanceResponse_InitialResponse{
InitialResponse: new(lbpb.InitialLoadBalanceResponse),
InitialResponse: &lbpb.InitialLoadBalanceResponse{
ClientStatsReportInterval: &lbpb.Duration{
Seconds: int64(b.statsDura.Seconds()),
Nanos: int32(b.statsDura.Nanoseconds() - int64(b.statsDura.Seconds())*1e9),
},
},
},
}
if err := stream.Send(resp); err != nil {
return err
}
go func() {
for {
var (
req *lbpb.LoadBalanceRequest
err error
)
if req, err = stream.Recv(); err != nil {
return
}
b.mu.Lock()
b.stats.NumCallsStarted += req.GetClientStats().NumCallsStarted
b.stats.NumCallsFinished += req.GetClientStats().NumCallsFinished
b.stats.NumCallsFinishedWithDropForRateLimiting += req.GetClientStats().NumCallsFinishedWithDropForRateLimiting
b.stats.NumCallsFinishedWithDropForLoadBalancing += req.GetClientStats().NumCallsFinishedWithDropForLoadBalancing
b.stats.NumCallsFinishedWithClientFailedToSend += req.GetClientStats().NumCallsFinishedWithClientFailedToSend
b.stats.NumCallsFinishedKnownReceived += req.GetClientStats().NumCallsFinishedKnownReceived
b.mu.Unlock()
}
}()
for k, v := range b.sls {
time.Sleep(b.intervals[k])
resp = &lbpb.LoadBalanceResponse{
@ -219,11 +247,15 @@ func (b *remoteBalancer) BalanceLoad(stream *loadBalancerBalanceLoadServer) erro
return nil
}
type helloServer struct {
type testServer struct {
testpb.TestServiceServer
addr string
}
func (s *helloServer) SayHello(ctx context.Context, in *hwpb.HelloRequest) (*hwpb.HelloReply, error) {
const testmdkey = "testmd"
func (s *testServer) EmptyCall(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return nil, grpc.Errorf(codes.Internal, "failed to receive metadata")
@ -231,9 +263,12 @@ func (s *helloServer) SayHello(ctx context.Context, in *hwpb.HelloRequest) (*hwp
if md == nil || md["lb-token"][0] != lbToken {
return nil, grpc.Errorf(codes.Internal, "received unexpected metadata: %v", md)
}
return &hwpb.HelloReply{
Message: "Hello " + in.Name + " for " + s.addr,
}, nil
grpc.SetTrailer(ctx, metadata.Pairs(testmdkey, s.addr))
return &testpb.Empty{}, nil
}
func (s *testServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServer) error {
return nil
}
func startBackends(sn string, lis ...net.Listener) (servers []*grpc.Server) {
@ -242,7 +277,7 @@ func startBackends(sn string, lis ...net.Listener) (servers []*grpc.Server) {
sn: sn,
}
s := grpc.NewServer(grpc.Creds(creds))
hwpb.RegisterGreeterServer(s, &helloServer{addr: l.Addr().String()})
testpb.RegisterTestServiceServer(s, &testServer{addr: l.Addr().String()})
servers = append(servers, s)
go func(s *grpc.Server, l net.Listener) {
s.Serve(l)
@ -356,9 +391,9 @@ func TestGRPCLB(t *testing.T) {
if err != nil {
t.Fatalf("Failed to dial to the backend %v", err)
}
helloC := hwpb.NewGreeterClient(cc)
if _, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}); err != nil {
t.Fatalf("%v.SayHello(_, _) = _, %v, want _, <nil>", helloC, err)
testC := testpb.NewTestServiceClient(cc)
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
}
cc.Close()
}
@ -393,22 +428,22 @@ func TestDropRequest(t *testing.T) {
if err != nil {
t.Fatalf("Failed to dial to the backend %v", err)
}
helloC := hwpb.NewGreeterClient(cc)
testC := testpb.NewTestServiceClient(cc)
// The 1st, non-fail-fast RPC should succeed. This ensures both server
// connections are made, because the first one has DropForLoadBalancing set to true.
if _, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}, grpc.FailFast(false)); err != nil {
t.Fatalf("%v.SayHello(_, _) = _, %v, want _, <nil>", helloC, err)
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil {
t.Fatalf("%v.SayHello(_, _) = _, %v, want _, <nil>", testC, err)
}
for i := 0; i < 3; i++ {
// Odd fail-fast RPCs should fail, because the 1st backend has DropForLoadBalancing
// set to true.
if _, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}); grpc.Code(err) != codes.Unavailable {
t.Fatalf("%v.SayHello(_, _) = _, %v, want _, %s", helloC, err, codes.Unavailable)
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); grpc.Code(err) != codes.Unavailable {
t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, %s", testC, err, codes.Unavailable)
}
// Even fail-fast RPCs should succeed since they choose the
// non-drop-request backend according to the round robin policy.
if _, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}); err != nil {
t.Fatalf("%v.SayHello(_, _) = _, %v, want _, <nil>", helloC, err)
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
}
}
cc.Close()
@ -443,10 +478,10 @@ func TestDropRequestFailedNonFailFast(t *testing.T) {
if err != nil {
t.Fatalf("Failed to dial to the backend %v", err)
}
helloC := hwpb.NewGreeterClient(cc)
testC := testpb.NewTestServiceClient(cc)
ctx, _ = context.WithTimeout(context.Background(), 10*time.Millisecond)
if _, err := helloC.SayHello(ctx, &hwpb.HelloRequest{Name: "grpc"}, grpc.FailFast(false)); grpc.Code(err) != codes.DeadlineExceeded {
t.Fatalf("%v.SayHello(_, _) = _, %v, want _, %s", helloC, err, codes.DeadlineExceeded)
if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); grpc.Code(err) != codes.DeadlineExceeded {
t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, %s", testC, err, codes.DeadlineExceeded)
}
cc.Close()
}
@ -493,19 +528,19 @@ func TestServerExpiration(t *testing.T) {
if err != nil {
t.Fatalf("Failed to dial to the backend %v", err)
}
helloC := hwpb.NewGreeterClient(cc)
if _, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}); err != nil {
t.Fatalf("%v.SayHello(_, _) = _, %v, want _, <nil>", helloC, err)
testC := testpb.NewTestServiceClient(cc)
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
}
// Sleep and wake up when the first server list gets expired.
time.Sleep(150 * time.Millisecond)
if _, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}); grpc.Code(err) != codes.Unavailable {
t.Fatalf("%v.SayHello(_, _) = _, %v, want _, %s", helloC, err, codes.Unavailable)
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); grpc.Code(err) != codes.Unavailable {
t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, %s", testC, err, codes.Unavailable)
}
// A non-failfast rpc should be succeeded after the second server list is received from
// the remote load balancer.
if _, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}, grpc.FailFast(false)); err != nil {
t.Fatalf("%v.SayHello(_, _) = _, %v, want _, <nil>", helloC, err)
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil {
t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
}
cc.Close()
}
@ -551,23 +586,24 @@ func TestBalancerDisconnects(t *testing.T) {
if err != nil {
t.Fatalf("Failed to dial to the backend %v", err)
}
helloC := hwpb.NewGreeterClient(cc)
var message string
if resp, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}); err != nil {
t.Fatalf("%v.SayHello(_, _) = _, %v, want _, <nil>", helloC, err)
testC := testpb.NewTestServiceClient(cc)
var previousTrailer string
trailer := metadata.MD{}
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Trailer(&trailer)); err != nil {
t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
} else {
message = resp.Message
previousTrailer = trailer[testmdkey][0]
}
// The initial resolver update contains lbs[0] and lbs[1].
// When lbs[0] is stopped, lbs[1] should be used.
lbs[0].Stop()
for {
if resp, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}); err != nil {
t.Fatalf("%v.SayHello(_, _) = _, %v, want _, <nil>", helloC, err)
} else if resp.Message != message {
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Trailer(&trailer)); err != nil {
t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
} else if trailer[testmdkey][0] != previousTrailer {
// A new backend server should receive the request.
// The response contains the backend address, so the message should be different from the previous one.
message = resp.Message
// The trailer contains the backend address, so the trailer should be different from the previous one.
previousTrailer = trailer[testmdkey][0]
break
}
time.Sleep(100 * time.Millisecond)
@ -585,14 +621,194 @@ func TestBalancerDisconnects(t *testing.T) {
// Stop lbs[1]. Now lbs[0] and lbs[1] are all stopped. lbs[2] should be used.
lbs[1].Stop()
for {
if resp, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}); err != nil {
t.Fatalf("%v.SayHello(_, _) = _, %v, want _, <nil>", helloC, err)
} else if resp.Message != message {
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Trailer(&trailer)); err != nil {
t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
} else if trailer[testmdkey][0] != previousTrailer {
// A new backend server should receive the request.
// The response contains the backend address, so the message should be different from the previous one.
// The trailer contains the backend address, so the trailer should be different from the previous one.
break
}
time.Sleep(100 * time.Millisecond)
}
cc.Close()
}
type failPreRPCCred struct{}
func (failPreRPCCred) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
if strings.Contains(uri[0], "failtosend") {
return nil, fmt.Errorf("rpc should fail to send")
}
return nil, nil
}
func (failPreRPCCred) RequireTransportSecurity() bool {
return false
}
func TestGRPCLBStatsUnary(t *testing.T) {
var (
countNormalRPC = 66 // 1/3 succeeds, 1/3 dropped load balancing, 1/3 dropped rate limiting.
countFailedToSend = 30 // 1/3 fail to send, 1/3 dropped load balancing, 1/3 dropped rate limiting.
)
tss, cleanup, err := newLoadBalancer(3)
if err != nil {
t.Fatalf("failed to create new load balancer: %v", err)
}
defer cleanup()
tss.ls.sls = []*lbpb.ServerList{{
Servers: []*lbpb.Server{{
IpAddress: tss.beIPs[0],
Port: int32(tss.bePorts[0]),
LoadBalanceToken: lbToken,
DropForLoadBalancing: true,
}, {
IpAddress: tss.beIPs[1],
Port: int32(tss.bePorts[1]),
LoadBalanceToken: lbToken,
DropForRateLimiting: true,
}, {
IpAddress: tss.beIPs[2],
Port: int32(tss.bePorts[2]),
LoadBalanceToken: lbToken,
DropForLoadBalancing: false,
}},
}}
tss.ls.intervals = []time.Duration{0}
tss.ls.statsDura = 100 * time.Millisecond
creds := serverNameCheckCreds{
expected: besn,
}
ctx, _ := context.WithTimeout(context.Background(), 10*time.Second)
cc, err := grpc.DialContext(ctx, besn, grpc.WithBalancer(grpc.NewGRPCLBBalancer(&testNameResolver{
addrs: []string{tss.lbAddr},
})), grpc.WithBlock(), grpc.WithTransportCredentials(&creds), grpc.WithPerRPCCredentials(failPreRPCCred{}))
if err != nil {
t.Fatalf("Failed to dial to the backend %v", err)
}
testC := testpb.NewTestServiceClient(cc)
// The first non-failfast RPC succeeds, all connections are up.
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil {
t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
}
for i := 0; i < countNormalRPC-1; i++ {
testC.EmptyCall(context.Background(), &testpb.Empty{})
}
for i := 0; i < countFailedToSend; i++ {
grpc.Invoke(context.Background(), "failtosend", &testpb.Empty{}, nil, cc)
}
cc.Close()
time.Sleep(1 * time.Second)
tss.ls.mu.Lock()
if tss.ls.stats.NumCallsStarted != int64(countNormalRPC+countFailedToSend) {
t.Errorf("num calls started = %v, want %v+%v", tss.ls.stats.NumCallsStarted, countNormalRPC, countFailedToSend)
}
if tss.ls.stats.NumCallsFinished != int64(countNormalRPC+countFailedToSend) {
t.Errorf("num calls finished = %v, want %v+%v", tss.ls.stats.NumCallsFinished, countNormalRPC, countFailedToSend)
}
if tss.ls.stats.NumCallsFinishedWithDropForRateLimiting != int64(countNormalRPC+countFailedToSend)/3 {
t.Errorf("num calls drop rate limiting = %v, want (%v+%v)/3", tss.ls.stats.NumCallsFinishedWithDropForRateLimiting, countNormalRPC, countFailedToSend)
}
if tss.ls.stats.NumCallsFinishedWithDropForLoadBalancing != int64(countNormalRPC+countFailedToSend)/3 {
t.Errorf("num calls drop load balancing = %v, want (%v+%v)/3", tss.ls.stats.NumCallsFinishedWithDropForLoadBalancing, countNormalRPC, countFailedToSend)
}
if tss.ls.stats.NumCallsFinishedWithClientFailedToSend != int64(countFailedToSend)/3 {
t.Errorf("num calls failed to send = %v, want %v/3", tss.ls.stats.NumCallsFinishedWithClientFailedToSend, countFailedToSend)
}
if tss.ls.stats.NumCallsFinishedKnownReceived != int64(countNormalRPC)/3 {
t.Errorf("num calls known received = %v, want %v/3", tss.ls.stats.NumCallsFinishedKnownReceived, countNormalRPC)
}
tss.ls.mu.Unlock()
}
func TestGRPCLBStatsStreaming(t *testing.T) {
var (
countNormalRPC = 66 // 1/3 succeeds, 1/3 dropped load balancing, 1/3 dropped rate limiting.
countFailedToSend = 30 // 1/3 fail to send, 1/3 dropped load balancing, 1/3 dropped rate limiting.
)
tss, cleanup, err := newLoadBalancer(3)
if err != nil {
t.Fatalf("failed to create new load balancer: %v", err)
}
defer cleanup()
tss.ls.sls = []*lbpb.ServerList{{
Servers: []*lbpb.Server{{
IpAddress: tss.beIPs[0],
Port: int32(tss.bePorts[0]),
LoadBalanceToken: lbToken,
DropForLoadBalancing: true,
}, {
IpAddress: tss.beIPs[1],
Port: int32(tss.bePorts[1]),
LoadBalanceToken: lbToken,
DropForRateLimiting: true,
}, {
IpAddress: tss.beIPs[2],
Port: int32(tss.bePorts[2]),
LoadBalanceToken: lbToken,
DropForLoadBalancing: false,
}},
}}
tss.ls.intervals = []time.Duration{0}
tss.ls.statsDura = 100 * time.Millisecond
creds := serverNameCheckCreds{
expected: besn,
}
ctx, _ := context.WithTimeout(context.Background(), 10*time.Second)
cc, err := grpc.DialContext(ctx, besn, grpc.WithBalancer(grpc.NewGRPCLBBalancer(&testNameResolver{
addrs: []string{tss.lbAddr},
})), grpc.WithBlock(), grpc.WithTransportCredentials(&creds), grpc.WithPerRPCCredentials(failPreRPCCred{}))
if err != nil {
t.Fatalf("Failed to dial to the backend %v", err)
}
testC := testpb.NewTestServiceClient(cc)
// The first non-failfast RPC succeeds, all connections are up.
var stream testpb.TestService_FullDuplexCallClient
stream, err = testC.FullDuplexCall(context.Background(), grpc.FailFast(false))
if err != nil {
t.Fatalf("%v.FullDuplexCall(_, _) = _, %v, want _, <nil>", testC, err)
}
for {
if _, err = stream.Recv(); err == io.EOF {
break
}
}
for i := 0; i < countNormalRPC-1; i++ {
stream, err = testC.FullDuplexCall(context.Background())
if err == nil {
// Wait for stream to end if err is nil.
for {
if _, err = stream.Recv(); err == io.EOF {
break
}
}
}
}
for i := 0; i < countFailedToSend; i++ {
grpc.NewClientStream(context.Background(), &grpc.StreamDesc{}, cc, "failtosend")
}
cc.Close()
time.Sleep(1 * time.Second)
tss.ls.mu.Lock()
if tss.ls.stats.NumCallsStarted != int64(countNormalRPC+countFailedToSend) {
t.Errorf("num calls started = %v, want %v+%v", tss.ls.stats.NumCallsStarted, countNormalRPC, countFailedToSend)
}
if tss.ls.stats.NumCallsFinished != int64(countNormalRPC+countFailedToSend) {
t.Errorf("num calls finished = %v, want %v+%v", tss.ls.stats.NumCallsFinished, countNormalRPC, countFailedToSend)
}
if tss.ls.stats.NumCallsFinishedWithDropForRateLimiting != int64(countNormalRPC+countFailedToSend)/3 {
t.Errorf("num calls drop rate limiting = %v, want (%v+%v)/3", tss.ls.stats.NumCallsFinishedWithDropForRateLimiting, countNormalRPC, countFailedToSend)
}
if tss.ls.stats.NumCallsFinishedWithDropForLoadBalancing != int64(countNormalRPC+countFailedToSend)/3 {
t.Errorf("num calls drop load balancing = %v, want (%v+%v)/3", tss.ls.stats.NumCallsFinishedWithDropForLoadBalancing, countNormalRPC, countFailedToSend)
}
if tss.ls.stats.NumCallsFinishedWithClientFailedToSend != int64(countFailedToSend)/3 {
t.Errorf("num calls failed to send = %v, want %v/3", tss.ls.stats.NumCallsFinishedWithClientFailedToSend, countFailedToSend)
}
if tss.ls.stats.NumCallsFinishedKnownReceived != int64(countNormalRPC)/3 {
t.Errorf("num calls known received = %v, want %v/3", tss.ls.stats.NumCallsFinishedKnownReceived, countNormalRPC)
}
tss.ls.mu.Unlock()
}

View File

@ -345,6 +345,29 @@ func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{
return nil
}
type rpcInfo struct {
bytesSent bool
bytesReceived bool
}
type rpcInfoContextKey struct{}
func newContextWithRPCInfo(ctx context.Context) context.Context {
return context.WithValue(ctx, rpcInfoContextKey{}, &rpcInfo{})
}
func rpcInfoFromContext(ctx context.Context) (s *rpcInfo, ok bool) {
s, ok = ctx.Value(rpcInfoContextKey{}).(*rpcInfo)
return
}
func updateRPCInfoInContext(ctx context.Context, s rpcInfo) {
if ss, ok := rpcInfoFromContext(ctx); ok {
*ss = s
}
return
}
// Code returns the error code for err if it was produced by the rpc system.
// Otherwise, it returns codes.Unknown.
//

View File

@ -151,6 +151,7 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
}
}()
}
ctx = newContextWithRPCInfo(ctx)
sh := cc.dopts.copts.StatsHandler
if sh != nil {
ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method})
@ -193,14 +194,17 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
s, err = t.NewStream(ctx, callHdr)
if err != nil {
if _, ok := err.(transport.ConnectionError); ok && put != nil {
// If error is connection error, transport was sending data on wire,
// and we are not sure if anything has been sent on wire.
// If error is not connection error, we are sure nothing has been sent.
updateRPCInfoInContext(ctx, rpcInfo{bytesSent: true, bytesReceived: false})
}
if put != nil {
put()
put = nil
}
if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain {
if c.failFast {
return nil, toRPCErr(err)
}
if _, ok := err.(transport.ConnectionError); (ok || err == transport.ErrStreamDrain) && !c.failFast {
continue
}
return nil, toRPCErr(err)
@ -463,6 +467,10 @@ func (cs *clientStream) finish(err error) {
o.after(&cs.c)
}
if cs.put != nil {
updateRPCInfoInContext(cs.s.Context(), rpcInfo{
bytesSent: cs.s.BytesSent(),
bytesReceived: cs.s.BytesReceived(),
})
cs.put()
cs.put = nil
}

View File

@ -493,6 +493,8 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
return nil, connectionErrorf(true, err, "transport: %v", err)
}
}
s.bytesSent = true
if t.statsHandler != nil {
outHeader := &stats.OutHeader{
Client: true,
@ -958,6 +960,7 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
if !ok {
return
}
s.bytesReceived = true
var state decodeState
for _, hf := range frame.Fields {
if err := state.processHeaderField(hf); err != nil {

View File

@ -220,6 +220,10 @@ type Stream struct {
rstStream bool
// rstError is the error that needs to be sent along with the RST_STREAM frame.
rstError http2.ErrCode
// bytesSent and bytesReceived indicates whether any bytes have been sent or
// received on this stream.
bytesSent bool
bytesReceived bool
}
// RecvCompress returns the compression algorithm applied to the inbound
@ -341,6 +345,20 @@ func (s *Stream) finish(st *status.Status) {
close(s.done)
}
// BytesSent indicates whether any bytes have been sent on this stream.
func (s *Stream) BytesSent() bool {
s.mu.Lock()
defer s.mu.Unlock()
return s.bytesSent
}
// BytesReceived indicates whether any bytes have been received on this stream.
func (s *Stream) BytesReceived() bool {
s.mu.Lock()
defer s.mu.Unlock()
return s.bytesReceived
}
// GoString is implemented by Stream so context.String() won't
// race when printing %#v.
func (s *Stream) GoString() string {