From 3e71fb360ddc2df6701f6d6f4c11bb3351ab0730 Mon Sep 17 00:00:00 2001 From: iamqizhao Date: Mon, 27 Jun 2016 14:36:59 -0700 Subject: [PATCH] Support fail-fast mode and make it the default --- balancer.go | 22 +++++++++++++++----- balancer_test.go | 12 +++++------ call.go | 10 ++++++--- clientconn.go | 19 ++++++++++++++---- rpc_util.go | 14 +++++++++++++ stream.go | 22 +++++++++++++++----- test/end2end_test.go | 48 +++++++++++++++++++++++++++++++++++++++++--- 7 files changed, 121 insertions(+), 26 deletions(-) diff --git a/balancer.go b/balancer.go index c298ae91..307e3dc1 100644 --- a/balancer.go +++ b/balancer.go @@ -94,10 +94,10 @@ type Balancer interface { // instead of blocking. // // The function returns put which is called once the rpc has completed or failed. - // put can collect and report RPC stats to a remote load balancer. gRPC internals - // will try to call this again if err is non-nil (unless err is ErrClientConnClosing). + // put can collect and report RPC stats to a remote load balancer. // - // TODO: Add other non-recoverable errors? + // This function should only return the errors Balancer cannot recover by itself. + // gRPC internals will fail the RPC if an error is returned. Get(ctx context.Context, opts BalancerGetOptions) (addr Address, put func(), err error) // Notify returns a channel that is used by gRPC internals to watch the addresses // gRPC needs to connect. The addresses might be from a name resolver or remote @@ -298,8 +298,20 @@ func (rr *roundRobin) Get(ctx context.Context, opts BalancerGetOptions) (addr Ad } } } - // There is no address available. Wait on rr.waitCh. - // TODO(zhaoq): Handle the case when opts.BlockingWait is false. + // There is no address available. + if !opts.BlockingWait { + if len(rr.addrs) == 0 { + rr.mu.Unlock() + err = fmt.Errorf("there is no address available") + return + } + // Returns the next addr on rr.addrs for failfast RPCs. + addr = rr.addrs[rr.next].addr + rr.next++ + rr.mu.Unlock() + return + } + // Wait on rr.waitCh for non-failfast RPCs. if rr.waitCh == nil { ch = make(chan struct{}) rr.waitCh = ch diff --git a/balancer_test.go b/balancer_test.go index 9d8d2bcd..d0cf0611 100644 --- a/balancer_test.go +++ b/balancer_test.go @@ -239,7 +239,7 @@ func TestCloseWithPendingRPC(t *testing.T) { t.Fatalf("Failed to create ClientConn: %v", err) } var reply string - if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err != nil { + if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); err != nil { t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want %s", err, servers[0].port) } // Remove the server. @@ -251,7 +251,7 @@ func TestCloseWithPendingRPC(t *testing.T) { // Loop until the above update applies. for { ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond) - if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc); Code(err) == codes.DeadlineExceeded { + if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); Code(err) == codes.DeadlineExceeded { break } time.Sleep(10 * time.Millisecond) @@ -262,7 +262,7 @@ func TestCloseWithPendingRPC(t *testing.T) { go func() { defer wg.Done() var reply string - if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err == nil { + if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); err == nil { t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want not nil", err) } }() @@ -270,7 +270,7 @@ func TestCloseWithPendingRPC(t *testing.T) { defer wg.Done() var reply string time.Sleep(5 * time.Millisecond) - if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err == nil { + if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); err == nil { t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want not nil", err) } }() @@ -295,7 +295,7 @@ func TestGetOnWaitChannel(t *testing.T) { for { var reply string ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond) - if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc); Code(err) == codes.DeadlineExceeded { + if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); Code(err) == codes.DeadlineExceeded { break } time.Sleep(10 * time.Millisecond) @@ -305,7 +305,7 @@ func TestGetOnWaitChannel(t *testing.T) { go func() { defer wg.Done() var reply string - if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err != nil { + if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); err != nil { t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want ", err) } }() diff --git a/call.go b/call.go index d6d993b4..fb2144b2 100644 --- a/call.go +++ b/call.go @@ -35,6 +35,7 @@ package grpc import ( "bytes" + //"fmt" "io" "time" @@ -101,7 +102,7 @@ func sendRequest(ctx context.Context, codec Codec, compressor Compressor, callHd // Invoke is called by generated code. Also users can call Invoke directly when it // is really needed in their use cases. func Invoke(ctx context.Context, method string, args, reply interface{}, cc *ClientConn, opts ...CallOption) (err error) { - var c callInfo + c := defaultCallInfo for _, o := range opts { if err := o.before(&c); err != nil { return toRPCErr(err) @@ -165,9 +166,12 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli if c.failFast { return toRPCErr(err) } + continue } - // All the remaining cases are treated as retryable. - continue + // ALl the other errors are treated as Internal errors. + return Errorf(codes.Internal, "%v", err) + // All the remaining cases are treated as fatal. + //panic(fmt.Sprintf("ClientConn.getTransport got an unsupported error: %v", err)) } if c.traceInfo.tr != nil { c.traceInfo.tr.LazyLog(&payload{sent: true, msg: args}, true) diff --git a/clientconn.go b/clientconn.go index 9b9c78d0..79965ec3 100644 --- a/clientconn.go +++ b/clientconn.go @@ -424,7 +424,6 @@ func (cc *ClientConn) newAddrConn(addr Address, skipWait bool) error { } func (cc *ClientConn) getTransport(ctx context.Context, opts BalancerGetOptions) (transport.ClientTransport, func(), error) { - // TODO(zhaoq): Implement fail-fast logic. addr, put, err := cc.balancer.Get(ctx, opts) if err != nil { return nil, nil, err @@ -442,7 +441,7 @@ func (cc *ClientConn) getTransport(ctx context.Context, opts BalancerGetOptions) } return nil, nil, transport.StreamErrorf(codes.Internal, "grpc: failed to find the transport to send the rpc") } - t, err := ac.wait(ctx) + t, err := ac.wait(ctx, !opts.BlockingWait) if err != nil { if put != nil { put() @@ -649,8 +648,9 @@ func (ac *addrConn) transportMonitor() { } } -// wait blocks until i) the new transport is up or ii) ctx is done or iii) ac is closed. -func (ac *addrConn) wait(ctx context.Context) (transport.ClientTransport, error) { +// wait blocks until i) the new transport is up or ii) ctx is done or iii) ac is closed or +// iv) transport is in TransientFailure and the RPC is fail-fast. +func (ac *addrConn) wait(ctx context.Context, failFast bool) (transport.ClientTransport, error) { for { ac.mu.Lock() switch { @@ -662,6 +662,10 @@ func (ac *addrConn) wait(ctx context.Context) (transport.ClientTransport, error) ac.mu.Unlock() return ct, nil default: + if ac.state == TransientFailure && failFast { + ac.mu.Unlock() + return nil, transport.StreamErrorf(codes.Canceled, "grpc: RPC failed fast due to transport failure") + } ready := ac.ready if ready == nil { ready = make(chan struct{}) @@ -673,6 +677,13 @@ func (ac *addrConn) wait(ctx context.Context) (transport.ClientTransport, error) return nil, transport.ContextErr(ctx.Err()) // Wait until the new transport is ready or failed. case <-ready: + ac.mu.Lock() + if ac.state == TransientFailure && failFast { + ac.mu.Unlock() + return nil, transport.StreamErrorf(codes.Canceled, "grpc: RPC failed fast due to transport failure") + } + ac.mu.Unlock() + } } } diff --git a/rpc_util.go b/rpc_util.go index 080ebb14..36c60183 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -141,6 +141,8 @@ type callInfo struct { traceInfo traceInfo // in trace.go } +var defaultCallInfo = callInfo{failFast: true} + // CallOption configures a Call before it starts or extracts information from // a Call after it completes. type CallOption interface { @@ -179,6 +181,18 @@ func Trailer(md *metadata.MD) CallOption { }) } +// FailFast configures the action to take when an RPC is attempted on broken +// connections or unreachable servers. If failfast is true, the RPC will fail +// immediately. Otherwise, the RPC client will block the call until a +// connection is available (or the call is canceled or times out) and will retry +// the call if it fails due to a transient error. +func FailFast(failFast bool) CallOption { + return beforeCall(func(c *callInfo) error { + c.failFast = failFast + return nil + }) +} + // The format of the payload: compressed or not? type payloadFormat uint8 diff --git a/stream.go b/stream.go index 25be4b81..21b73a9c 100644 --- a/stream.go +++ b/stream.go @@ -105,9 +105,14 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth err error put func() ) - // TODO(zhaoq): CallOption is omitted. Add support when it is needed. + c := defaultCallInfo + for _, o := range opts { + if err := o.before(&c); err != nil { + return nil, toRPCErr(err) + } + } gopts := BalancerGetOptions{ - BlockingWait: false, + BlockingWait: !c.failFast, } t, put, err = cc.getTransport(ctx, gopts) if err != nil { @@ -122,6 +127,8 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth callHdr.SendCompress = cc.dopts.cp.Type() } cs := &clientStream{ + opts: opts, + c: c, desc: desc, put: put, codec: cc.dopts.codec, @@ -167,6 +174,8 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth // clientStream implements a client side Stream. type clientStream struct { + opts []CallOption + c callInfo t transport.ClientTransport s *transport.Stream p *parser @@ -312,15 +321,18 @@ func (cs *clientStream) closeTransportStream(err error) { } func (cs *clientStream) finish(err error) { - if !cs.tracing { - return - } cs.mu.Lock() defer cs.mu.Unlock() + for _, o := range cs.opts { + o.after(&cs.c) + } if cs.put != nil { cs.put() cs.put = nil } + if !cs.tracing { + return + } if cs.trInfo.tr != nil { if err == nil || err == io.EOF { cs.trInfo.tr.LazyPrintf("RPC: [OK]") diff --git a/test/end2end_test.go b/test/end2end_test.go index b539584b..5a0759f8 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -550,7 +550,7 @@ func testTimeoutOnDeadServer(t *testing.T, e env) { cc := te.clientConn() tc := testpb.NewTestServiceClient(cc) - if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); err != nil { + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil { t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, ", err) } te.srv.Stop() @@ -558,12 +558,54 @@ func testTimeoutOnDeadServer(t *testing.T, e env) { // notification in time the failure path of the 1st invoke of // ClientConn.wait hits the deadline exceeded error. ctx, _ := context.WithTimeout(context.Background(), -1) - if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); grpc.Code(err) != codes.DeadlineExceeded { + if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); grpc.Code(err) != codes.DeadlineExceeded { t.Fatalf("TestService/EmptyCall(%v, _) = _, %v, want _, error code: %d", ctx, err, codes.DeadlineExceeded) } awaitNewConnLogOutput() } +func TestFailFast(t *testing.T) { + defer leakCheck(t)() + for _, e := range listTestEnv() { + testFailFast(t, e) + } +} + +func testFailFast(t *testing.T, e env) { + te := newTest(t, e) + te.userAgent = testAppUA + te.declareLogNoise( + "transport: http2Client.notifyError got notified that the client transport was broken EOF", + "grpc: Conn.transportMonitor exits due to: grpc: the client connection is closing", + "grpc: Conn.resetTransport failed to create client transport: connection error", + "grpc: Conn.resetTransport failed to create client transport: connection error: desc = \"transport: dial unix", + ) + te.startServer() + defer te.tearDown() + + cc := te.clientConn() + tc := testpb.NewTestServiceClient(cc) + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); err != nil { + t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, ", err) + } + // Stop the server and tear down all the exisiting connections. + te.srv.Stop() + // Issue an RPC to make sure the server teardown is propagated to the client already. + ctx, _ := context.WithTimeout(context.Background(), time.Millisecond) + if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); grpc.Code(err) != codes.DeadlineExceeded { + t.Fatalf("TestService/EmptyCall(%v, _) = _, %v, want _, error code: %d", ctx, err, codes.DeadlineExceeded) + } + // The client keeps reconnecting and ongoing fail-fast RPCs should fail with code.Canceled. + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); grpc.Code(err) != codes.Canceled { + t.Fatalf("TestService/EmptyCall(_, _, _) = _, %v, want _, error code: %d", err, codes.Canceled) + } + if _, err := tc.StreamingInputCall(ctx); grpc.Code(err) != codes.Canceled { + t.Fatalf("TestService/StreamingInputCall(_) = _, %v, want _, error code: %d", err, codes.Canceled) + } + + awaitNewConnLogOutput() +} + func healthCheck(d time.Duration, cc *grpc.ClientConn, serviceName string) (*healthpb.HealthCheckResponse, error) { ctx, _ := context.WithTimeout(context.Background(), d) hc := healthpb.NewHealthClient(cc) @@ -879,7 +921,7 @@ func performOneRPC(t *testing.T, tc testpb.TestServiceClient, wg *sync.WaitGroup ResponseSize: proto.Int32(respSize), Payload: payload, } - reply, err := tc.UnaryCall(context.Background(), req) + reply, err := tc.UnaryCall(context.Background(), req, grpc.FailFast(false)) if err != nil { t.Errorf("TestService/UnaryCall(_, _) = _, %v, want _, ", err) return