Client load report for grpclb. (#1200)
This commit is contained in:
56
call.go
56
call.go
@ -93,11 +93,7 @@ func recvResponse(ctx context.Context, dopts dialOptions, t transport.ClientTran
|
||||
}
|
||||
|
||||
// sendRequest writes out various information of an RPC such as Context and Message.
|
||||
func sendRequest(ctx context.Context, dopts dialOptions, compressor Compressor, callHdr *transport.CallHdr, t transport.ClientTransport, args interface{}, opts *transport.Options) (_ *transport.Stream, err error) {
|
||||
stream, err := t.NewStream(ctx, callHdr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
func sendRequest(ctx context.Context, dopts dialOptions, compressor Compressor, callHdr *transport.CallHdr, stream *transport.Stream, t transport.ClientTransport, args interface{}, opts *transport.Options) (err error) {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
// If err is connection error, t will be closed, no need to close stream here.
|
||||
@ -120,7 +116,7 @@ func sendRequest(ctx context.Context, dopts dialOptions, compressor Compressor,
|
||||
}
|
||||
outBuf, err := encode(dopts.codec, args, compressor, cbuf, outPayload)
|
||||
if err != nil {
|
||||
return nil, Errorf(codes.Internal, "grpc: %v", err)
|
||||
return Errorf(codes.Internal, "grpc: %v", err)
|
||||
}
|
||||
err = t.Write(stream, outBuf, opts)
|
||||
if err == nil && outPayload != nil {
|
||||
@ -131,10 +127,10 @@ func sendRequest(ctx context.Context, dopts dialOptions, compressor Compressor,
|
||||
// does not exist.) so that t.Write could get io.EOF from wait(...). Leave the following
|
||||
// recvResponse to get the final status.
|
||||
if err != nil && err != io.EOF {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
// Sent successfully.
|
||||
return stream, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// Invoke sends the RPC request on the wire and returns after response is received.
|
||||
@ -183,6 +179,7 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
|
||||
}
|
||||
}()
|
||||
}
|
||||
ctx = newContextWithRPCInfo(ctx)
|
||||
sh := cc.dopts.copts.StatsHandler
|
||||
if sh != nil {
|
||||
ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method})
|
||||
@ -246,19 +243,35 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
|
||||
if c.traceInfo.tr != nil {
|
||||
c.traceInfo.tr.LazyLog(&payload{sent: true, msg: args}, true)
|
||||
}
|
||||
stream, err = sendRequest(ctx, cc.dopts, cc.dopts.cp, callHdr, t, args, topts)
|
||||
stream, err = t.NewStream(ctx, callHdr)
|
||||
if err != nil {
|
||||
if put != nil {
|
||||
if _, ok := err.(transport.ConnectionError); ok {
|
||||
// If error is connection error, transport was sending data on wire,
|
||||
// and we are not sure if anything has been sent on wire.
|
||||
// If error is not connection error, we are sure nothing has been sent.
|
||||
updateRPCInfoInContext(ctx, rpcInfo{bytesSent: true, bytesReceived: false})
|
||||
}
|
||||
put()
|
||||
}
|
||||
if _, ok := err.(transport.ConnectionError); (ok || err == transport.ErrStreamDrain) && !c.failFast {
|
||||
continue
|
||||
}
|
||||
return toRPCErr(err)
|
||||
}
|
||||
err = sendRequest(ctx, cc.dopts, cc.dopts.cp, callHdr, stream, t, args, topts)
|
||||
if err != nil {
|
||||
if put != nil {
|
||||
updateRPCInfoInContext(ctx, rpcInfo{
|
||||
bytesSent: stream.BytesSent(),
|
||||
bytesReceived: stream.BytesReceived(),
|
||||
})
|
||||
put()
|
||||
put = nil
|
||||
}
|
||||
// Retry a non-failfast RPC when
|
||||
// i) there is a connection error; or
|
||||
// ii) the server started to drain before this RPC was initiated.
|
||||
if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain {
|
||||
if c.failFast {
|
||||
return toRPCErr(err)
|
||||
}
|
||||
if _, ok := err.(transport.ConnectionError); (ok || err == transport.ErrStreamDrain) && !c.failFast {
|
||||
continue
|
||||
}
|
||||
return toRPCErr(err)
|
||||
@ -266,13 +279,13 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
|
||||
err = recvResponse(ctx, cc.dopts, t, &c, stream, reply)
|
||||
if err != nil {
|
||||
if put != nil {
|
||||
updateRPCInfoInContext(ctx, rpcInfo{
|
||||
bytesSent: stream.BytesSent(),
|
||||
bytesReceived: stream.BytesReceived(),
|
||||
})
|
||||
put()
|
||||
put = nil
|
||||
}
|
||||
if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain {
|
||||
if c.failFast {
|
||||
return toRPCErr(err)
|
||||
}
|
||||
if _, ok := err.(transport.ConnectionError); (ok || err == transport.ErrStreamDrain) && !c.failFast {
|
||||
continue
|
||||
}
|
||||
return toRPCErr(err)
|
||||
@ -282,8 +295,11 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
|
||||
}
|
||||
t.CloseStream(stream, nil)
|
||||
if put != nil {
|
||||
updateRPCInfoInContext(ctx, rpcInfo{
|
||||
bytesSent: stream.BytesSent(),
|
||||
bytesReceived: stream.BytesReceived(),
|
||||
})
|
||||
put()
|
||||
put = nil
|
||||
}
|
||||
return stream.Status().Err()
|
||||
}
|
||||
|
@ -669,6 +669,7 @@ func (cc *ClientConn) getTransport(ctx context.Context, opts BalancerGetOptions)
|
||||
}
|
||||
if !ok {
|
||||
if put != nil {
|
||||
updateRPCInfoInContext(ctx, rpcInfo{bytesSent: false, bytesReceived: false})
|
||||
put()
|
||||
}
|
||||
return nil, nil, errConnClosing
|
||||
@ -676,6 +677,7 @@ func (cc *ClientConn) getTransport(ctx context.Context, opts BalancerGetOptions)
|
||||
t, err := ac.wait(ctx, cc.dopts.balancer != nil, !opts.BlockingWait)
|
||||
if err != nil {
|
||||
if put != nil {
|
||||
updateRPCInfoInContext(ctx, rpcInfo{bytesSent: false, bytesReceived: false})
|
||||
put()
|
||||
}
|
||||
return nil, nil, err
|
||||
|
85
grpclb.go
85
grpclb.go
@ -145,6 +145,8 @@ type balancer struct {
|
||||
done bool
|
||||
expTimer *time.Timer
|
||||
rand *rand.Rand
|
||||
|
||||
clientStats lbpb.ClientStats
|
||||
}
|
||||
|
||||
func (b *balancer) watchAddrUpdates(w naming.Watcher, ch chan []remoteBalancerInfo) error {
|
||||
@ -281,6 +283,34 @@ func (b *balancer) processServerList(l *lbpb.ServerList, seq int) {
|
||||
return
|
||||
}
|
||||
|
||||
func (b *balancer) sendLoadReport(s *balanceLoadClientStream, interval time.Duration, done <-chan struct{}) {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
case <-done:
|
||||
return
|
||||
}
|
||||
b.mu.Lock()
|
||||
stats := b.clientStats
|
||||
b.clientStats = lbpb.ClientStats{} // Clear the stats.
|
||||
b.mu.Unlock()
|
||||
t := time.Now()
|
||||
stats.Timestamp = &lbpb.Timestamp{
|
||||
Seconds: t.Unix(),
|
||||
Nanos: int32(t.Nanosecond()),
|
||||
}
|
||||
if err := s.Send(&lbpb.LoadBalanceRequest{
|
||||
LoadBalanceRequestType: &lbpb.LoadBalanceRequest_ClientStats{
|
||||
ClientStats: &stats,
|
||||
},
|
||||
}); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *balancer) callRemoteBalancer(lbc *loadBalancerClient, seq int) (retry bool) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
@ -322,6 +352,14 @@ func (b *balancer) callRemoteBalancer(lbc *loadBalancerClient, seq int) (retry b
|
||||
grpclog.Println("TODO: Delegation is not supported yet.")
|
||||
return
|
||||
}
|
||||
streamDone := make(chan struct{})
|
||||
defer close(streamDone)
|
||||
b.mu.Lock()
|
||||
b.clientStats = lbpb.ClientStats{} // Clear client stats.
|
||||
b.mu.Unlock()
|
||||
if d := convertDuration(initResp.ClientStatsReportInterval); d > 0 {
|
||||
go b.sendLoadReport(stream, d, streamDone)
|
||||
}
|
||||
// Retrieve the server list.
|
||||
for {
|
||||
reply, err := stream.Recv()
|
||||
@ -538,7 +576,32 @@ func (b *balancer) Get(ctx context.Context, opts BalancerGetOptions) (addr Addre
|
||||
err = ErrClientConnClosing
|
||||
return
|
||||
}
|
||||
seq := b.seq
|
||||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
put = func() {
|
||||
s, ok := rpcInfoFromContext(ctx)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
if b.done || seq < b.seq {
|
||||
return
|
||||
}
|
||||
b.clientStats.NumCallsFinished++
|
||||
if !s.bytesSent {
|
||||
b.clientStats.NumCallsFinishedWithClientFailedToSend++
|
||||
} else if s.bytesReceived {
|
||||
b.clientStats.NumCallsFinishedKnownReceived++
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
b.clientStats.NumCallsStarted++
|
||||
if len(b.addrs) > 0 {
|
||||
if b.next >= len(b.addrs) {
|
||||
b.next = 0
|
||||
@ -556,6 +619,13 @@ func (b *balancer) Get(ctx context.Context, opts BalancerGetOptions) (addr Addre
|
||||
}
|
||||
if !opts.BlockingWait {
|
||||
b.next = next
|
||||
if a.dropForLoadBalancing {
|
||||
b.clientStats.NumCallsFinished++
|
||||
b.clientStats.NumCallsFinishedWithDropForLoadBalancing++
|
||||
} else if a.dropForRateLimiting {
|
||||
b.clientStats.NumCallsFinished++
|
||||
b.clientStats.NumCallsFinishedWithDropForRateLimiting++
|
||||
}
|
||||
b.mu.Unlock()
|
||||
err = Errorf(codes.Unavailable, "%s drops requests", a.addr.Addr)
|
||||
return
|
||||
@ -569,6 +639,8 @@ func (b *balancer) Get(ctx context.Context, opts BalancerGetOptions) (addr Addre
|
||||
}
|
||||
if !opts.BlockingWait {
|
||||
if len(b.addrs) == 0 {
|
||||
b.clientStats.NumCallsFinished++
|
||||
b.clientStats.NumCallsFinishedWithClientFailedToSend++
|
||||
b.mu.Unlock()
|
||||
err = Errorf(codes.Unavailable, "there is no address available")
|
||||
return
|
||||
@ -590,11 +662,17 @@ func (b *balancer) Get(ctx context.Context, opts BalancerGetOptions) (addr Addre
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
b.mu.Lock()
|
||||
b.clientStats.NumCallsFinished++
|
||||
b.clientStats.NumCallsFinishedWithClientFailedToSend++
|
||||
b.mu.Unlock()
|
||||
err = ctx.Err()
|
||||
return
|
||||
case <-ch:
|
||||
b.mu.Lock()
|
||||
if b.done {
|
||||
b.clientStats.NumCallsFinished++
|
||||
b.clientStats.NumCallsFinishedWithClientFailedToSend++
|
||||
b.mu.Unlock()
|
||||
err = ErrClientConnClosing
|
||||
return
|
||||
@ -617,6 +695,13 @@ func (b *balancer) Get(ctx context.Context, opts BalancerGetOptions) (addr Addre
|
||||
}
|
||||
if !opts.BlockingWait {
|
||||
b.next = next
|
||||
if a.dropForLoadBalancing {
|
||||
b.clientStats.NumCallsFinished++
|
||||
b.clientStats.NumCallsFinishedWithDropForLoadBalancing++
|
||||
} else if a.dropForRateLimiting {
|
||||
b.clientStats.NumCallsFinished++
|
||||
b.clientStats.NumCallsFinishedWithDropForRateLimiting++
|
||||
}
|
||||
b.mu.Unlock()
|
||||
err = Errorf(codes.Unavailable, "drop requests for the addreess %s", a.addr.Addr)
|
||||
return
|
||||
|
@ -40,6 +40,7 @@ import (
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -47,10 +48,10 @@ import (
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials"
|
||||
hwpb "google.golang.org/grpc/examples/helloworld/helloworld"
|
||||
lbpb "google.golang.org/grpc/grpclb/grpc_lb_v1"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/naming"
|
||||
testpb "google.golang.org/grpc/test/grpc_testing"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -172,7 +173,10 @@ func (c *serverNameCheckCreds) OverrideServerName(s string) error {
|
||||
type remoteBalancer struct {
|
||||
sls []*lbpb.ServerList
|
||||
intervals []time.Duration
|
||||
statsDura time.Duration
|
||||
done chan struct{}
|
||||
mu sync.Mutex
|
||||
stats lbpb.ClientStats
|
||||
}
|
||||
|
||||
func newRemoteBalancer(sls []*lbpb.ServerList, intervals []time.Duration) *remoteBalancer {
|
||||
@ -198,12 +202,36 @@ func (b *remoteBalancer) BalanceLoad(stream *loadBalancerBalanceLoadServer) erro
|
||||
}
|
||||
resp := &lbpb.LoadBalanceResponse{
|
||||
LoadBalanceResponseType: &lbpb.LoadBalanceResponse_InitialResponse{
|
||||
InitialResponse: new(lbpb.InitialLoadBalanceResponse),
|
||||
InitialResponse: &lbpb.InitialLoadBalanceResponse{
|
||||
ClientStatsReportInterval: &lbpb.Duration{
|
||||
Seconds: int64(b.statsDura.Seconds()),
|
||||
Nanos: int32(b.statsDura.Nanoseconds() - int64(b.statsDura.Seconds())*1e9),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
if err := stream.Send(resp); err != nil {
|
||||
return err
|
||||
}
|
||||
go func() {
|
||||
for {
|
||||
var (
|
||||
req *lbpb.LoadBalanceRequest
|
||||
err error
|
||||
)
|
||||
if req, err = stream.Recv(); err != nil {
|
||||
return
|
||||
}
|
||||
b.mu.Lock()
|
||||
b.stats.NumCallsStarted += req.GetClientStats().NumCallsStarted
|
||||
b.stats.NumCallsFinished += req.GetClientStats().NumCallsFinished
|
||||
b.stats.NumCallsFinishedWithDropForRateLimiting += req.GetClientStats().NumCallsFinishedWithDropForRateLimiting
|
||||
b.stats.NumCallsFinishedWithDropForLoadBalancing += req.GetClientStats().NumCallsFinishedWithDropForLoadBalancing
|
||||
b.stats.NumCallsFinishedWithClientFailedToSend += req.GetClientStats().NumCallsFinishedWithClientFailedToSend
|
||||
b.stats.NumCallsFinishedKnownReceived += req.GetClientStats().NumCallsFinishedKnownReceived
|
||||
b.mu.Unlock()
|
||||
}
|
||||
}()
|
||||
for k, v := range b.sls {
|
||||
time.Sleep(b.intervals[k])
|
||||
resp = &lbpb.LoadBalanceResponse{
|
||||
@ -219,11 +247,15 @@ func (b *remoteBalancer) BalanceLoad(stream *loadBalancerBalanceLoadServer) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
type helloServer struct {
|
||||
type testServer struct {
|
||||
testpb.TestServiceServer
|
||||
|
||||
addr string
|
||||
}
|
||||
|
||||
func (s *helloServer) SayHello(ctx context.Context, in *hwpb.HelloRequest) (*hwpb.HelloReply, error) {
|
||||
const testmdkey = "testmd"
|
||||
|
||||
func (s *testServer) EmptyCall(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
|
||||
md, ok := metadata.FromIncomingContext(ctx)
|
||||
if !ok {
|
||||
return nil, grpc.Errorf(codes.Internal, "failed to receive metadata")
|
||||
@ -231,9 +263,12 @@ func (s *helloServer) SayHello(ctx context.Context, in *hwpb.HelloRequest) (*hwp
|
||||
if md == nil || md["lb-token"][0] != lbToken {
|
||||
return nil, grpc.Errorf(codes.Internal, "received unexpected metadata: %v", md)
|
||||
}
|
||||
return &hwpb.HelloReply{
|
||||
Message: "Hello " + in.Name + " for " + s.addr,
|
||||
}, nil
|
||||
grpc.SetTrailer(ctx, metadata.Pairs(testmdkey, s.addr))
|
||||
return &testpb.Empty{}, nil
|
||||
}
|
||||
|
||||
func (s *testServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServer) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func startBackends(sn string, lis ...net.Listener) (servers []*grpc.Server) {
|
||||
@ -242,7 +277,7 @@ func startBackends(sn string, lis ...net.Listener) (servers []*grpc.Server) {
|
||||
sn: sn,
|
||||
}
|
||||
s := grpc.NewServer(grpc.Creds(creds))
|
||||
hwpb.RegisterGreeterServer(s, &helloServer{addr: l.Addr().String()})
|
||||
testpb.RegisterTestServiceServer(s, &testServer{addr: l.Addr().String()})
|
||||
servers = append(servers, s)
|
||||
go func(s *grpc.Server, l net.Listener) {
|
||||
s.Serve(l)
|
||||
@ -356,9 +391,9 @@ func TestGRPCLB(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to dial to the backend %v", err)
|
||||
}
|
||||
helloC := hwpb.NewGreeterClient(cc)
|
||||
if _, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}); err != nil {
|
||||
t.Fatalf("%v.SayHello(_, _) = _, %v, want _, <nil>", helloC, err)
|
||||
testC := testpb.NewTestServiceClient(cc)
|
||||
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
|
||||
t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
|
||||
}
|
||||
cc.Close()
|
||||
}
|
||||
@ -393,22 +428,22 @@ func TestDropRequest(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to dial to the backend %v", err)
|
||||
}
|
||||
helloC := hwpb.NewGreeterClient(cc)
|
||||
testC := testpb.NewTestServiceClient(cc)
|
||||
// The 1st, non-fail-fast RPC should succeed. This ensures both server
|
||||
// connections are made, because the first one has DropForLoadBalancing set to true.
|
||||
if _, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}, grpc.FailFast(false)); err != nil {
|
||||
t.Fatalf("%v.SayHello(_, _) = _, %v, want _, <nil>", helloC, err)
|
||||
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil {
|
||||
t.Fatalf("%v.SayHello(_, _) = _, %v, want _, <nil>", testC, err)
|
||||
}
|
||||
for i := 0; i < 3; i++ {
|
||||
// Odd fail-fast RPCs should fail, because the 1st backend has DropForLoadBalancing
|
||||
// set to true.
|
||||
if _, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}); grpc.Code(err) != codes.Unavailable {
|
||||
t.Fatalf("%v.SayHello(_, _) = _, %v, want _, %s", helloC, err, codes.Unavailable)
|
||||
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); grpc.Code(err) != codes.Unavailable {
|
||||
t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, %s", testC, err, codes.Unavailable)
|
||||
}
|
||||
// Even fail-fast RPCs should succeed since they choose the
|
||||
// non-drop-request backend according to the round robin policy.
|
||||
if _, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}); err != nil {
|
||||
t.Fatalf("%v.SayHello(_, _) = _, %v, want _, <nil>", helloC, err)
|
||||
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
|
||||
t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
|
||||
}
|
||||
}
|
||||
cc.Close()
|
||||
@ -443,10 +478,10 @@ func TestDropRequestFailedNonFailFast(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to dial to the backend %v", err)
|
||||
}
|
||||
helloC := hwpb.NewGreeterClient(cc)
|
||||
testC := testpb.NewTestServiceClient(cc)
|
||||
ctx, _ = context.WithTimeout(context.Background(), 10*time.Millisecond)
|
||||
if _, err := helloC.SayHello(ctx, &hwpb.HelloRequest{Name: "grpc"}, grpc.FailFast(false)); grpc.Code(err) != codes.DeadlineExceeded {
|
||||
t.Fatalf("%v.SayHello(_, _) = _, %v, want _, %s", helloC, err, codes.DeadlineExceeded)
|
||||
if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); grpc.Code(err) != codes.DeadlineExceeded {
|
||||
t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, %s", testC, err, codes.DeadlineExceeded)
|
||||
}
|
||||
cc.Close()
|
||||
}
|
||||
@ -493,19 +528,19 @@ func TestServerExpiration(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to dial to the backend %v", err)
|
||||
}
|
||||
helloC := hwpb.NewGreeterClient(cc)
|
||||
if _, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}); err != nil {
|
||||
t.Fatalf("%v.SayHello(_, _) = _, %v, want _, <nil>", helloC, err)
|
||||
testC := testpb.NewTestServiceClient(cc)
|
||||
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
|
||||
t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
|
||||
}
|
||||
// Sleep and wake up when the first server list gets expired.
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
if _, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}); grpc.Code(err) != codes.Unavailable {
|
||||
t.Fatalf("%v.SayHello(_, _) = _, %v, want _, %s", helloC, err, codes.Unavailable)
|
||||
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); grpc.Code(err) != codes.Unavailable {
|
||||
t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, %s", testC, err, codes.Unavailable)
|
||||
}
|
||||
// A non-failfast rpc should be succeeded after the second server list is received from
|
||||
// the remote load balancer.
|
||||
if _, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}, grpc.FailFast(false)); err != nil {
|
||||
t.Fatalf("%v.SayHello(_, _) = _, %v, want _, <nil>", helloC, err)
|
||||
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil {
|
||||
t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
|
||||
}
|
||||
cc.Close()
|
||||
}
|
||||
@ -551,23 +586,24 @@ func TestBalancerDisconnects(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to dial to the backend %v", err)
|
||||
}
|
||||
helloC := hwpb.NewGreeterClient(cc)
|
||||
var message string
|
||||
if resp, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}); err != nil {
|
||||
t.Fatalf("%v.SayHello(_, _) = _, %v, want _, <nil>", helloC, err)
|
||||
testC := testpb.NewTestServiceClient(cc)
|
||||
var previousTrailer string
|
||||
trailer := metadata.MD{}
|
||||
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Trailer(&trailer)); err != nil {
|
||||
t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
|
||||
} else {
|
||||
message = resp.Message
|
||||
previousTrailer = trailer[testmdkey][0]
|
||||
}
|
||||
// The initial resolver update contains lbs[0] and lbs[1].
|
||||
// When lbs[0] is stopped, lbs[1] should be used.
|
||||
lbs[0].Stop()
|
||||
for {
|
||||
if resp, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}); err != nil {
|
||||
t.Fatalf("%v.SayHello(_, _) = _, %v, want _, <nil>", helloC, err)
|
||||
} else if resp.Message != message {
|
||||
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Trailer(&trailer)); err != nil {
|
||||
t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
|
||||
} else if trailer[testmdkey][0] != previousTrailer {
|
||||
// A new backend server should receive the request.
|
||||
// The response contains the backend address, so the message should be different from the previous one.
|
||||
message = resp.Message
|
||||
// The trailer contains the backend address, so the trailer should be different from the previous one.
|
||||
previousTrailer = trailer[testmdkey][0]
|
||||
break
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
@ -585,14 +621,194 @@ func TestBalancerDisconnects(t *testing.T) {
|
||||
// Stop lbs[1]. Now lbs[0] and lbs[1] are all stopped. lbs[2] should be used.
|
||||
lbs[1].Stop()
|
||||
for {
|
||||
if resp, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}); err != nil {
|
||||
t.Fatalf("%v.SayHello(_, _) = _, %v, want _, <nil>", helloC, err)
|
||||
} else if resp.Message != message {
|
||||
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Trailer(&trailer)); err != nil {
|
||||
t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
|
||||
} else if trailer[testmdkey][0] != previousTrailer {
|
||||
// A new backend server should receive the request.
|
||||
// The response contains the backend address, so the message should be different from the previous one.
|
||||
// The trailer contains the backend address, so the trailer should be different from the previous one.
|
||||
break
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
cc.Close()
|
||||
}
|
||||
|
||||
type failPreRPCCred struct{}
|
||||
|
||||
func (failPreRPCCred) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
|
||||
if strings.Contains(uri[0], "failtosend") {
|
||||
return nil, fmt.Errorf("rpc should fail to send")
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (failPreRPCCred) RequireTransportSecurity() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func TestGRPCLBStatsUnary(t *testing.T) {
|
||||
var (
|
||||
countNormalRPC = 66 // 1/3 succeeds, 1/3 dropped load balancing, 1/3 dropped rate limiting.
|
||||
countFailedToSend = 30 // 1/3 fail to send, 1/3 dropped load balancing, 1/3 dropped rate limiting.
|
||||
)
|
||||
tss, cleanup, err := newLoadBalancer(3)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create new load balancer: %v", err)
|
||||
}
|
||||
defer cleanup()
|
||||
tss.ls.sls = []*lbpb.ServerList{{
|
||||
Servers: []*lbpb.Server{{
|
||||
IpAddress: tss.beIPs[0],
|
||||
Port: int32(tss.bePorts[0]),
|
||||
LoadBalanceToken: lbToken,
|
||||
DropForLoadBalancing: true,
|
||||
}, {
|
||||
IpAddress: tss.beIPs[1],
|
||||
Port: int32(tss.bePorts[1]),
|
||||
LoadBalanceToken: lbToken,
|
||||
DropForRateLimiting: true,
|
||||
}, {
|
||||
IpAddress: tss.beIPs[2],
|
||||
Port: int32(tss.bePorts[2]),
|
||||
LoadBalanceToken: lbToken,
|
||||
DropForLoadBalancing: false,
|
||||
}},
|
||||
}}
|
||||
tss.ls.intervals = []time.Duration{0}
|
||||
tss.ls.statsDura = 100 * time.Millisecond
|
||||
creds := serverNameCheckCreds{
|
||||
expected: besn,
|
||||
}
|
||||
ctx, _ := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
cc, err := grpc.DialContext(ctx, besn, grpc.WithBalancer(grpc.NewGRPCLBBalancer(&testNameResolver{
|
||||
addrs: []string{tss.lbAddr},
|
||||
})), grpc.WithBlock(), grpc.WithTransportCredentials(&creds), grpc.WithPerRPCCredentials(failPreRPCCred{}))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to dial to the backend %v", err)
|
||||
}
|
||||
testC := testpb.NewTestServiceClient(cc)
|
||||
// The first non-failfast RPC succeeds, all connections are up.
|
||||
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil {
|
||||
t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
|
||||
}
|
||||
for i := 0; i < countNormalRPC-1; i++ {
|
||||
testC.EmptyCall(context.Background(), &testpb.Empty{})
|
||||
}
|
||||
for i := 0; i < countFailedToSend; i++ {
|
||||
grpc.Invoke(context.Background(), "failtosend", &testpb.Empty{}, nil, cc)
|
||||
}
|
||||
cc.Close()
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
tss.ls.mu.Lock()
|
||||
if tss.ls.stats.NumCallsStarted != int64(countNormalRPC+countFailedToSend) {
|
||||
t.Errorf("num calls started = %v, want %v+%v", tss.ls.stats.NumCallsStarted, countNormalRPC, countFailedToSend)
|
||||
}
|
||||
if tss.ls.stats.NumCallsFinished != int64(countNormalRPC+countFailedToSend) {
|
||||
t.Errorf("num calls finished = %v, want %v+%v", tss.ls.stats.NumCallsFinished, countNormalRPC, countFailedToSend)
|
||||
}
|
||||
if tss.ls.stats.NumCallsFinishedWithDropForRateLimiting != int64(countNormalRPC+countFailedToSend)/3 {
|
||||
t.Errorf("num calls drop rate limiting = %v, want (%v+%v)/3", tss.ls.stats.NumCallsFinishedWithDropForRateLimiting, countNormalRPC, countFailedToSend)
|
||||
}
|
||||
if tss.ls.stats.NumCallsFinishedWithDropForLoadBalancing != int64(countNormalRPC+countFailedToSend)/3 {
|
||||
t.Errorf("num calls drop load balancing = %v, want (%v+%v)/3", tss.ls.stats.NumCallsFinishedWithDropForLoadBalancing, countNormalRPC, countFailedToSend)
|
||||
}
|
||||
if tss.ls.stats.NumCallsFinishedWithClientFailedToSend != int64(countFailedToSend)/3 {
|
||||
t.Errorf("num calls failed to send = %v, want %v/3", tss.ls.stats.NumCallsFinishedWithClientFailedToSend, countFailedToSend)
|
||||
}
|
||||
if tss.ls.stats.NumCallsFinishedKnownReceived != int64(countNormalRPC)/3 {
|
||||
t.Errorf("num calls known received = %v, want %v/3", tss.ls.stats.NumCallsFinishedKnownReceived, countNormalRPC)
|
||||
}
|
||||
tss.ls.mu.Unlock()
|
||||
}
|
||||
|
||||
func TestGRPCLBStatsStreaming(t *testing.T) {
|
||||
var (
|
||||
countNormalRPC = 66 // 1/3 succeeds, 1/3 dropped load balancing, 1/3 dropped rate limiting.
|
||||
countFailedToSend = 30 // 1/3 fail to send, 1/3 dropped load balancing, 1/3 dropped rate limiting.
|
||||
)
|
||||
tss, cleanup, err := newLoadBalancer(3)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create new load balancer: %v", err)
|
||||
}
|
||||
defer cleanup()
|
||||
tss.ls.sls = []*lbpb.ServerList{{
|
||||
Servers: []*lbpb.Server{{
|
||||
IpAddress: tss.beIPs[0],
|
||||
Port: int32(tss.bePorts[0]),
|
||||
LoadBalanceToken: lbToken,
|
||||
DropForLoadBalancing: true,
|
||||
}, {
|
||||
IpAddress: tss.beIPs[1],
|
||||
Port: int32(tss.bePorts[1]),
|
||||
LoadBalanceToken: lbToken,
|
||||
DropForRateLimiting: true,
|
||||
}, {
|
||||
IpAddress: tss.beIPs[2],
|
||||
Port: int32(tss.bePorts[2]),
|
||||
LoadBalanceToken: lbToken,
|
||||
DropForLoadBalancing: false,
|
||||
}},
|
||||
}}
|
||||
tss.ls.intervals = []time.Duration{0}
|
||||
tss.ls.statsDura = 100 * time.Millisecond
|
||||
creds := serverNameCheckCreds{
|
||||
expected: besn,
|
||||
}
|
||||
ctx, _ := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
cc, err := grpc.DialContext(ctx, besn, grpc.WithBalancer(grpc.NewGRPCLBBalancer(&testNameResolver{
|
||||
addrs: []string{tss.lbAddr},
|
||||
})), grpc.WithBlock(), grpc.WithTransportCredentials(&creds), grpc.WithPerRPCCredentials(failPreRPCCred{}))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to dial to the backend %v", err)
|
||||
}
|
||||
testC := testpb.NewTestServiceClient(cc)
|
||||
// The first non-failfast RPC succeeds, all connections are up.
|
||||
var stream testpb.TestService_FullDuplexCallClient
|
||||
stream, err = testC.FullDuplexCall(context.Background(), grpc.FailFast(false))
|
||||
if err != nil {
|
||||
t.Fatalf("%v.FullDuplexCall(_, _) = _, %v, want _, <nil>", testC, err)
|
||||
}
|
||||
for {
|
||||
if _, err = stream.Recv(); err == io.EOF {
|
||||
break
|
||||
}
|
||||
}
|
||||
for i := 0; i < countNormalRPC-1; i++ {
|
||||
stream, err = testC.FullDuplexCall(context.Background())
|
||||
if err == nil {
|
||||
// Wait for stream to end if err is nil.
|
||||
for {
|
||||
if _, err = stream.Recv(); err == io.EOF {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for i := 0; i < countFailedToSend; i++ {
|
||||
grpc.NewClientStream(context.Background(), &grpc.StreamDesc{}, cc, "failtosend")
|
||||
}
|
||||
cc.Close()
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
tss.ls.mu.Lock()
|
||||
if tss.ls.stats.NumCallsStarted != int64(countNormalRPC+countFailedToSend) {
|
||||
t.Errorf("num calls started = %v, want %v+%v", tss.ls.stats.NumCallsStarted, countNormalRPC, countFailedToSend)
|
||||
}
|
||||
if tss.ls.stats.NumCallsFinished != int64(countNormalRPC+countFailedToSend) {
|
||||
t.Errorf("num calls finished = %v, want %v+%v", tss.ls.stats.NumCallsFinished, countNormalRPC, countFailedToSend)
|
||||
}
|
||||
if tss.ls.stats.NumCallsFinishedWithDropForRateLimiting != int64(countNormalRPC+countFailedToSend)/3 {
|
||||
t.Errorf("num calls drop rate limiting = %v, want (%v+%v)/3", tss.ls.stats.NumCallsFinishedWithDropForRateLimiting, countNormalRPC, countFailedToSend)
|
||||
}
|
||||
if tss.ls.stats.NumCallsFinishedWithDropForLoadBalancing != int64(countNormalRPC+countFailedToSend)/3 {
|
||||
t.Errorf("num calls drop load balancing = %v, want (%v+%v)/3", tss.ls.stats.NumCallsFinishedWithDropForLoadBalancing, countNormalRPC, countFailedToSend)
|
||||
}
|
||||
if tss.ls.stats.NumCallsFinishedWithClientFailedToSend != int64(countFailedToSend)/3 {
|
||||
t.Errorf("num calls failed to send = %v, want %v/3", tss.ls.stats.NumCallsFinishedWithClientFailedToSend, countFailedToSend)
|
||||
}
|
||||
if tss.ls.stats.NumCallsFinishedKnownReceived != int64(countNormalRPC)/3 {
|
||||
t.Errorf("num calls known received = %v, want %v/3", tss.ls.stats.NumCallsFinishedKnownReceived, countNormalRPC)
|
||||
}
|
||||
tss.ls.mu.Unlock()
|
||||
}
|
||||
|
23
rpc_util.go
23
rpc_util.go
@ -345,6 +345,29 @@ func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{
|
||||
return nil
|
||||
}
|
||||
|
||||
type rpcInfo struct {
|
||||
bytesSent bool
|
||||
bytesReceived bool
|
||||
}
|
||||
|
||||
type rpcInfoContextKey struct{}
|
||||
|
||||
func newContextWithRPCInfo(ctx context.Context) context.Context {
|
||||
return context.WithValue(ctx, rpcInfoContextKey{}, &rpcInfo{})
|
||||
}
|
||||
|
||||
func rpcInfoFromContext(ctx context.Context) (s *rpcInfo, ok bool) {
|
||||
s, ok = ctx.Value(rpcInfoContextKey{}).(*rpcInfo)
|
||||
return
|
||||
}
|
||||
|
||||
func updateRPCInfoInContext(ctx context.Context, s rpcInfo) {
|
||||
if ss, ok := rpcInfoFromContext(ctx); ok {
|
||||
*ss = s
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Code returns the error code for err if it was produced by the rpc system.
|
||||
// Otherwise, it returns codes.Unknown.
|
||||
//
|
||||
|
16
stream.go
16
stream.go
@ -151,6 +151,7 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
|
||||
}
|
||||
}()
|
||||
}
|
||||
ctx = newContextWithRPCInfo(ctx)
|
||||
sh := cc.dopts.copts.StatsHandler
|
||||
if sh != nil {
|
||||
ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method})
|
||||
@ -193,14 +194,17 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
|
||||
|
||||
s, err = t.NewStream(ctx, callHdr)
|
||||
if err != nil {
|
||||
if _, ok := err.(transport.ConnectionError); ok && put != nil {
|
||||
// If error is connection error, transport was sending data on wire,
|
||||
// and we are not sure if anything has been sent on wire.
|
||||
// If error is not connection error, we are sure nothing has been sent.
|
||||
updateRPCInfoInContext(ctx, rpcInfo{bytesSent: true, bytesReceived: false})
|
||||
}
|
||||
if put != nil {
|
||||
put()
|
||||
put = nil
|
||||
}
|
||||
if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain {
|
||||
if c.failFast {
|
||||
return nil, toRPCErr(err)
|
||||
}
|
||||
if _, ok := err.(transport.ConnectionError); (ok || err == transport.ErrStreamDrain) && !c.failFast {
|
||||
continue
|
||||
}
|
||||
return nil, toRPCErr(err)
|
||||
@ -463,6 +467,10 @@ func (cs *clientStream) finish(err error) {
|
||||
o.after(&cs.c)
|
||||
}
|
||||
if cs.put != nil {
|
||||
updateRPCInfoInContext(cs.s.Context(), rpcInfo{
|
||||
bytesSent: cs.s.BytesSent(),
|
||||
bytesReceived: cs.s.BytesReceived(),
|
||||
})
|
||||
cs.put()
|
||||
cs.put = nil
|
||||
}
|
||||
|
@ -493,6 +493,8 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
|
||||
return nil, connectionErrorf(true, err, "transport: %v", err)
|
||||
}
|
||||
}
|
||||
s.bytesSent = true
|
||||
|
||||
if t.statsHandler != nil {
|
||||
outHeader := &stats.OutHeader{
|
||||
Client: true,
|
||||
@ -958,6 +960,7 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
s.bytesReceived = true
|
||||
var state decodeState
|
||||
for _, hf := range frame.Fields {
|
||||
if err := state.processHeaderField(hf); err != nil {
|
||||
|
@ -220,6 +220,10 @@ type Stream struct {
|
||||
rstStream bool
|
||||
// rstError is the error that needs to be sent along with the RST_STREAM frame.
|
||||
rstError http2.ErrCode
|
||||
// bytesSent and bytesReceived indicates whether any bytes have been sent or
|
||||
// received on this stream.
|
||||
bytesSent bool
|
||||
bytesReceived bool
|
||||
}
|
||||
|
||||
// RecvCompress returns the compression algorithm applied to the inbound
|
||||
@ -341,6 +345,20 @@ func (s *Stream) finish(st *status.Status) {
|
||||
close(s.done)
|
||||
}
|
||||
|
||||
// BytesSent indicates whether any bytes have been sent on this stream.
|
||||
func (s *Stream) BytesSent() bool {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.bytesSent
|
||||
}
|
||||
|
||||
// BytesReceived indicates whether any bytes have been received on this stream.
|
||||
func (s *Stream) BytesReceived() bool {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.bytesReceived
|
||||
}
|
||||
|
||||
// GoString is implemented by Stream so context.String() won't
|
||||
// race when printing %#v.
|
||||
func (s *Stream) GoString() string {
|
||||
|
Reference in New Issue
Block a user