diff --git a/balancer.go b/balancer.go index e561dbaf..419e2146 100644 --- a/balancer.go +++ b/balancer.go @@ -40,7 +40,6 @@ import ( "golang.org/x/net/context" "google.golang.org/grpc/grpclog" "google.golang.org/grpc/naming" - "google.golang.org/grpc/transport" ) // Address represents a server the client connects to. @@ -94,10 +93,10 @@ type Balancer interface { // instead of blocking. // // The function returns put which is called once the rpc has completed or failed. - // put can collect and report RPC stats to a remote load balancer. gRPC internals - // will try to call this again if err is non-nil (unless err is ErrClientConnClosing). + // put can collect and report RPC stats to a remote load balancer. // - // TODO: Add other non-recoverable errors? + // This function should only return the errors Balancer cannot recover by itself. + // gRPC internals will fail the RPC if an error is returned. Get(ctx context.Context, opts BalancerGetOptions) (addr Address, put func(), err error) // Notify returns a channel that is used by gRPC internals to watch the addresses // gRPC needs to connect. The addresses might be from a name resolver or remote @@ -299,8 +298,19 @@ func (rr *roundRobin) Get(ctx context.Context, opts BalancerGetOptions) (addr Ad } } } - // There is no address available. Wait on rr.waitCh. - // TODO(zhaoq): Handle the case when opts.BlockingWait is false. + if !opts.BlockingWait { + if len(rr.addrs) == 0 { + rr.mu.Unlock() + err = fmt.Errorf("there is no address available") + return + } + // Returns the next addr on rr.addrs for failfast RPCs. + addr = rr.addrs[rr.next].addr + rr.next++ + rr.mu.Unlock() + return + } + // Wait on rr.waitCh for non-failfast RPCs. if rr.waitCh == nil { ch = make(chan struct{}) rr.waitCh = ch @@ -311,7 +321,7 @@ func (rr *roundRobin) Get(ctx context.Context, opts BalancerGetOptions) (addr Ad for { select { case <-ctx.Done(): - err = transport.ContextErr(ctx.Err()) + err = ctx.Err() return case <-ch: rr.mu.Lock() diff --git a/balancer_test.go b/balancer_test.go index 9d8d2bcd..d0cf0611 100644 --- a/balancer_test.go +++ b/balancer_test.go @@ -239,7 +239,7 @@ func TestCloseWithPendingRPC(t *testing.T) { t.Fatalf("Failed to create ClientConn: %v", err) } var reply string - if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err != nil { + if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); err != nil { t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want %s", err, servers[0].port) } // Remove the server. @@ -251,7 +251,7 @@ func TestCloseWithPendingRPC(t *testing.T) { // Loop until the above update applies. for { ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond) - if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc); Code(err) == codes.DeadlineExceeded { + if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); Code(err) == codes.DeadlineExceeded { break } time.Sleep(10 * time.Millisecond) @@ -262,7 +262,7 @@ func TestCloseWithPendingRPC(t *testing.T) { go func() { defer wg.Done() var reply string - if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err == nil { + if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); err == nil { t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want not nil", err) } }() @@ -270,7 +270,7 @@ func TestCloseWithPendingRPC(t *testing.T) { defer wg.Done() var reply string time.Sleep(5 * time.Millisecond) - if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err == nil { + if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); err == nil { t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want not nil", err) } }() @@ -295,7 +295,7 @@ func TestGetOnWaitChannel(t *testing.T) { for { var reply string ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond) - if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc); Code(err) == codes.DeadlineExceeded { + if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); Code(err) == codes.DeadlineExceeded { break } time.Sleep(10 * time.Millisecond) @@ -305,7 +305,7 @@ func TestGetOnWaitChannel(t *testing.T) { go func() { defer wg.Done() var reply string - if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err != nil { + if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); err != nil { t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want ", err) } }() diff --git a/call.go b/call.go index d6d993b4..93262360 100644 --- a/call.go +++ b/call.go @@ -101,7 +101,7 @@ func sendRequest(ctx context.Context, codec Codec, compressor Compressor, callHd // Invoke is called by generated code. Also users can call Invoke directly when it // is really needed in their use cases. func Invoke(ctx context.Context, method string, args, reply interface{}, cc *ClientConn, opts ...CallOption) (err error) { - var c callInfo + c := defaultCallInfo for _, o := range opts { if err := o.before(&c); err != nil { return toRPCErr(err) @@ -155,19 +155,17 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli t, put, err = cc.getTransport(ctx, gopts) if err != nil { // TODO(zhaoq): Probably revisit the error handling. - if err == ErrClientConnClosing { - return Errorf(codes.FailedPrecondition, "%v", err) + if _, ok := err.(rpcError); ok { + return err } - if _, ok := err.(transport.StreamError); ok { - return toRPCErr(err) - } - if _, ok := err.(transport.ConnectionError); ok { + if err == errConnClosing { if c.failFast { - return toRPCErr(err) + return Errorf(codes.Unavailable, "%v", errConnClosing) } + continue } - // All the remaining cases are treated as retryable. - continue + // All the other errors are treated as Internal errors. + return Errorf(codes.Internal, "%v", err) } if c.traceInfo.tr != nil { c.traceInfo.tr.LazyLog(&payload{sent: true, msg: args}, true) diff --git a/clientconn.go b/clientconn.go index d374b3c9..c3c7691d 100644 --- a/clientconn.go +++ b/clientconn.go @@ -422,15 +422,14 @@ func (cc *ClientConn) newAddrConn(addr Address, skipWait bool) error { } func (cc *ClientConn) getTransport(ctx context.Context, opts BalancerGetOptions) (transport.ClientTransport, func(), error) { - // TODO(zhaoq): Implement fail-fast logic. addr, put, err := cc.dopts.balancer.Get(ctx, opts) if err != nil { - return nil, nil, err + return nil, nil, toRPCErr(err) } cc.mu.RLock() if cc.conns == nil { cc.mu.RUnlock() - return nil, nil, ErrClientConnClosing + return nil, nil, toRPCErr(ErrClientConnClosing) } ac, ok := cc.conns[addr] cc.mu.RUnlock() @@ -438,9 +437,9 @@ func (cc *ClientConn) getTransport(ctx context.Context, opts BalancerGetOptions) if put != nil { put() } - return nil, nil, transport.StreamErrorf(codes.Internal, "grpc: failed to find the transport to send the rpc") + return nil, nil, Errorf(codes.Internal, "grpc: failed to find the transport to send the rpc") } - t, err := ac.wait(ctx) + t, err := ac.wait(ctx, !opts.BlockingWait) if err != nil { if put != nil { put() @@ -647,8 +646,9 @@ func (ac *addrConn) transportMonitor() { } } -// wait blocks until i) the new transport is up or ii) ctx is done or iii) ac is closed. -func (ac *addrConn) wait(ctx context.Context) (transport.ClientTransport, error) { +// wait blocks until i) the new transport is up or ii) ctx is done or iii) ac is closed or +// iv) transport is in TransientFailure and the RPC is fail-fast. +func (ac *addrConn) wait(ctx context.Context, failFast bool) (transport.ClientTransport, error) { for { ac.mu.Lock() switch { @@ -659,6 +659,9 @@ func (ac *addrConn) wait(ctx context.Context) (transport.ClientTransport, error) ct := ac.transport ac.mu.Unlock() return ct, nil + case ac.state == TransientFailure && failFast: + ac.mu.Unlock() + return nil, Errorf(codes.Unavailable, "grpc: RPC failed fast due to transport failure") default: ready := ac.ready if ready == nil { @@ -668,7 +671,7 @@ func (ac *addrConn) wait(ctx context.Context) (transport.ClientTransport, error) ac.mu.Unlock() select { case <-ctx.Done(): - return nil, transport.ContextErr(ctx.Err()) + return nil, toRPCErr(ctx.Err()) // Wait until the new transport is ready or failed. case <-ready: } diff --git a/rpc_util.go b/rpc_util.go index 080ebb14..91342bd8 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -141,6 +141,8 @@ type callInfo struct { traceInfo traceInfo // in trace.go } +var defaultCallInfo = callInfo{failFast: true} + // CallOption configures a Call before it starts or extracts information from // a Call after it completes. type CallOption interface { @@ -179,6 +181,19 @@ func Trailer(md *metadata.MD) CallOption { }) } +// FailFast configures the action to take when an RPC is attempted on broken +// connections or unreachable servers. If failfast is true, the RPC will fail +// immediately. Otherwise, the RPC client will block the call until a +// connection is available (or the call is canceled or times out) and will retry +// the call if it fails due to a transient error. Please refer to +// https://github.com/grpc/grpc/blob/master/doc/fail_fast.md +func FailFast(failFast bool) CallOption { + return beforeCall(func(c *callInfo) error { + c.failFast = failFast + return nil + }) +} + // The format of the payload: compressed or not? type payloadFormat uint8 @@ -374,6 +389,25 @@ func toRPCErr(err error) error { code: codes.Internal, desc: e.Desc, } + default: + switch err { + case context.DeadlineExceeded: + return rpcError{ + code: codes.DeadlineExceeded, + desc: err.Error(), + } + case context.Canceled: + return rpcError{ + code: codes.Canceled, + desc: err.Error(), + } + case ErrClientConnClosing: + return rpcError{ + code: codes.FailedPrecondition, + desc: err.Error(), + } + } + } return Errorf(codes.Unknown, "%v", err) } diff --git a/stream.go b/stream.go index 25be4b81..73d1da23 100644 --- a/stream.go +++ b/stream.go @@ -102,16 +102,15 @@ type ClientStream interface { func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (ClientStream, error) { var ( t transport.ClientTransport + s *transport.Stream err error put func() ) - // TODO(zhaoq): CallOption is omitted. Add support when it is needed. - gopts := BalancerGetOptions{ - BlockingWait: false, - } - t, put, err = cc.getTransport(ctx, gopts) - if err != nil { - return nil, toRPCErr(err) + c := defaultCallInfo + for _, o := range opts { + if err := o.before(&c); err != nil { + return nil, toRPCErr(err) + } } callHdr := &transport.CallHdr{ Host: cc.authority, @@ -122,8 +121,9 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth callHdr.SendCompress = cc.dopts.cp.Type() } cs := &clientStream{ + opts: opts, + c: c, desc: desc, - put: put, codec: cc.dopts.codec, cp: cc.dopts.cp, dc: cc.dopts.dc, @@ -142,11 +142,44 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth cs.trInfo.tr.LazyLog(&cs.trInfo.firstLine, false) ctx = trace.NewContext(ctx, cs.trInfo.tr) } - s, err := t.NewStream(ctx, callHdr) - if err != nil { - cs.finish(err) - return nil, toRPCErr(err) + gopts := BalancerGetOptions{ + BlockingWait: !c.failFast, } + for { + t, put, err = cc.getTransport(ctx, gopts) + if err != nil { + // TODO(zhaoq): Probably revisit the error handling. + if _, ok := err.(rpcError); ok { + return nil, err + } + if err == errConnClosing { + if c.failFast { + return nil, Errorf(codes.Unavailable, "%v", errConnClosing) + } + continue + } + // All the other errors are treated as Internal errors. + return nil, Errorf(codes.Internal, "%v", err) + } + + s, err = t.NewStream(ctx, callHdr) + if err != nil { + if put != nil { + put() + put = nil + } + if _, ok := err.(transport.ConnectionError); ok { + if c.failFast { + cs.finish(err) + return nil, toRPCErr(err) + } + continue + } + return nil, toRPCErr(err) + } + break + } + cs.put = put cs.t = t cs.s = s cs.p = &parser{r: s} @@ -167,6 +200,8 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth // clientStream implements a client side Stream. type clientStream struct { + opts []CallOption + c callInfo t transport.ClientTransport s *transport.Stream p *parser @@ -312,15 +347,18 @@ func (cs *clientStream) closeTransportStream(err error) { } func (cs *clientStream) finish(err error) { - if !cs.tracing { - return - } cs.mu.Lock() defer cs.mu.Unlock() + for _, o := range cs.opts { + o.after(&cs.c) + } if cs.put != nil { cs.put() cs.put = nil } + if !cs.tracing { + return + } if cs.trInfo.tr != nil { if err == nil || err == io.EOF { cs.trInfo.tr.LazyPrintf("RPC: [OK]") diff --git a/test/end2end_test.go b/test/end2end_test.go index cb71db9c..245b8c0f 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -550,7 +550,7 @@ func testTimeoutOnDeadServer(t *testing.T, e env) { cc := te.clientConn() tc := testpb.NewTestServiceClient(cc) - if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); err != nil { + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil { t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, ", err) } te.srv.Stop() @@ -558,12 +558,56 @@ func testTimeoutOnDeadServer(t *testing.T, e env) { // notification in time the failure path of the 1st invoke of // ClientConn.wait hits the deadline exceeded error. ctx, _ := context.WithTimeout(context.Background(), -1) - if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); grpc.Code(err) != codes.DeadlineExceeded { + if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); grpc.Code(err) != codes.DeadlineExceeded { t.Fatalf("TestService/EmptyCall(%v, _) = _, %v, want _, error code: %d", ctx, err, codes.DeadlineExceeded) } awaitNewConnLogOutput() } +func TestFailFast(t *testing.T) { + defer leakCheck(t)() + for _, e := range listTestEnv() { + testFailFast(t, e) + } +} + +func testFailFast(t *testing.T, e env) { + te := newTest(t, e) + te.userAgent = testAppUA + te.declareLogNoise( + "transport: http2Client.notifyError got notified that the client transport was broken EOF", + "grpc: Conn.transportMonitor exits due to: grpc: the client connection is closing", + "grpc: Conn.resetTransport failed to create client transport: connection error", + "grpc: Conn.resetTransport failed to create client transport: connection error: desc = \"transport: dial unix", + ) + te.startServer() + defer te.tearDown() + + cc := te.clientConn() + tc := testpb.NewTestServiceClient(cc) + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); err != nil { + t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, ", err) + } + // Stop the server and tear down all the exisiting connections. + te.srv.Stop() + // Loop until the server teardown is propagated to the client. + for { + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); grpc.Code(err) == codes.Unavailable { + break + } + time.Sleep(10 * time.Millisecond) + } + // The client keeps reconnecting and ongoing fail-fast RPCs should fail with code.Unavailable. + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); grpc.Code(err) != codes.Unavailable { + t.Fatalf("TestService/EmptyCall(_, _, _) = _, %v, want _, error code: %d", err, codes.Unavailable) + } + if _, err := tc.StreamingInputCall(context.Background()); grpc.Code(err) != codes.Unavailable { + t.Fatalf("TestService/StreamingInputCall(_) = _, %v, want _, error code: %d", err, codes.Unavailable) + } + + awaitNewConnLogOutput() +} + func healthCheck(d time.Duration, cc *grpc.ClientConn, serviceName string) (*healthpb.HealthCheckResponse, error) { ctx, _ := context.WithTimeout(context.Background(), d) hc := healthpb.NewHealthClient(cc) @@ -879,7 +923,7 @@ func performOneRPC(t *testing.T, tc testpb.TestServiceClient, wg *sync.WaitGroup ResponseSize: proto.Int32(respSize), Payload: payload, } - reply, err := tc.UnaryCall(context.Background(), req) + reply, err := tc.UnaryCall(context.Background(), req, grpc.FailFast(false)) if err != nil { t.Errorf("TestService/UnaryCall(_, _) = _, %v, want _, ", err) return @@ -1106,9 +1150,7 @@ func testNoService(t *testing.T, e env) { cc := te.clientConn() tc := testpb.NewTestServiceClient(cc) - // Make sure setting ack has been sent. - time.Sleep(20 * time.Millisecond) - stream, err := tc.FullDuplexCall(te.ctx) + stream, err := tc.FullDuplexCall(te.ctx, grpc.FailFast(false)) if err != nil { t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) }