Merge pull request #738 from iamqizhao/master

[Pre-1.0 Semantics Change] Support Fail-fast and make it the default setup
This commit is contained in:
Menghan Li
2016-07-08 10:28:17 -07:00
committed by GitHub
7 changed files with 177 additions and 52 deletions

View File

@ -40,7 +40,6 @@ import (
"golang.org/x/net/context" "golang.org/x/net/context"
"google.golang.org/grpc/grpclog" "google.golang.org/grpc/grpclog"
"google.golang.org/grpc/naming" "google.golang.org/grpc/naming"
"google.golang.org/grpc/transport"
) )
// Address represents a server the client connects to. // Address represents a server the client connects to.
@ -94,10 +93,10 @@ type Balancer interface {
// instead of blocking. // instead of blocking.
// //
// The function returns put which is called once the rpc has completed or failed. // The function returns put which is called once the rpc has completed or failed.
// put can collect and report RPC stats to a remote load balancer. gRPC internals // put can collect and report RPC stats to a remote load balancer.
// will try to call this again if err is non-nil (unless err is ErrClientConnClosing).
// //
// TODO: Add other non-recoverable errors? // This function should only return the errors Balancer cannot recover by itself.
// gRPC internals will fail the RPC if an error is returned.
Get(ctx context.Context, opts BalancerGetOptions) (addr Address, put func(), err error) Get(ctx context.Context, opts BalancerGetOptions) (addr Address, put func(), err error)
// Notify returns a channel that is used by gRPC internals to watch the addresses // Notify returns a channel that is used by gRPC internals to watch the addresses
// gRPC needs to connect. The addresses might be from a name resolver or remote // gRPC needs to connect. The addresses might be from a name resolver or remote
@ -299,8 +298,19 @@ func (rr *roundRobin) Get(ctx context.Context, opts BalancerGetOptions) (addr Ad
} }
} }
} }
// There is no address available. Wait on rr.waitCh. if !opts.BlockingWait {
// TODO(zhaoq): Handle the case when opts.BlockingWait is false. if len(rr.addrs) == 0 {
rr.mu.Unlock()
err = fmt.Errorf("there is no address available")
return
}
// Returns the next addr on rr.addrs for failfast RPCs.
addr = rr.addrs[rr.next].addr
rr.next++
rr.mu.Unlock()
return
}
// Wait on rr.waitCh for non-failfast RPCs.
if rr.waitCh == nil { if rr.waitCh == nil {
ch = make(chan struct{}) ch = make(chan struct{})
rr.waitCh = ch rr.waitCh = ch
@ -311,7 +321,7 @@ func (rr *roundRobin) Get(ctx context.Context, opts BalancerGetOptions) (addr Ad
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
err = transport.ContextErr(ctx.Err()) err = ctx.Err()
return return
case <-ch: case <-ch:
rr.mu.Lock() rr.mu.Lock()

View File

@ -239,7 +239,7 @@ func TestCloseWithPendingRPC(t *testing.T) {
t.Fatalf("Failed to create ClientConn: %v", err) t.Fatalf("Failed to create ClientConn: %v", err)
} }
var reply string var reply string
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err != nil { if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); err != nil {
t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want %s", err, servers[0].port) t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want %s", err, servers[0].port)
} }
// Remove the server. // Remove the server.
@ -251,7 +251,7 @@ func TestCloseWithPendingRPC(t *testing.T) {
// Loop until the above update applies. // Loop until the above update applies.
for { for {
ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond) ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond)
if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc); Code(err) == codes.DeadlineExceeded { if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); Code(err) == codes.DeadlineExceeded {
break break
} }
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
@ -262,7 +262,7 @@ func TestCloseWithPendingRPC(t *testing.T) {
go func() { go func() {
defer wg.Done() defer wg.Done()
var reply string var reply string
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err == nil { if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); err == nil {
t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want not nil", err) t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want not nil", err)
} }
}() }()
@ -270,7 +270,7 @@ func TestCloseWithPendingRPC(t *testing.T) {
defer wg.Done() defer wg.Done()
var reply string var reply string
time.Sleep(5 * time.Millisecond) time.Sleep(5 * time.Millisecond)
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err == nil { if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); err == nil {
t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want not nil", err) t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want not nil", err)
} }
}() }()
@ -295,7 +295,7 @@ func TestGetOnWaitChannel(t *testing.T) {
for { for {
var reply string var reply string
ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond) ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond)
if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc); Code(err) == codes.DeadlineExceeded { if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); Code(err) == codes.DeadlineExceeded {
break break
} }
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
@ -305,7 +305,7 @@ func TestGetOnWaitChannel(t *testing.T) {
go func() { go func() {
defer wg.Done() defer wg.Done()
var reply string var reply string
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err != nil { if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); err != nil {
t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want <nil>", err) t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want <nil>", err)
} }
}() }()

18
call.go
View File

@ -101,7 +101,7 @@ func sendRequest(ctx context.Context, codec Codec, compressor Compressor, callHd
// Invoke is called by generated code. Also users can call Invoke directly when it // Invoke is called by generated code. Also users can call Invoke directly when it
// is really needed in their use cases. // is really needed in their use cases.
func Invoke(ctx context.Context, method string, args, reply interface{}, cc *ClientConn, opts ...CallOption) (err error) { func Invoke(ctx context.Context, method string, args, reply interface{}, cc *ClientConn, opts ...CallOption) (err error) {
var c callInfo c := defaultCallInfo
for _, o := range opts { for _, o := range opts {
if err := o.before(&c); err != nil { if err := o.before(&c); err != nil {
return toRPCErr(err) return toRPCErr(err)
@ -155,19 +155,17 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
t, put, err = cc.getTransport(ctx, gopts) t, put, err = cc.getTransport(ctx, gopts)
if err != nil { if err != nil {
// TODO(zhaoq): Probably revisit the error handling. // TODO(zhaoq): Probably revisit the error handling.
if err == ErrClientConnClosing { if _, ok := err.(rpcError); ok {
return Errorf(codes.FailedPrecondition, "%v", err) return err
} }
if _, ok := err.(transport.StreamError); ok { if err == errConnClosing {
return toRPCErr(err)
}
if _, ok := err.(transport.ConnectionError); ok {
if c.failFast { if c.failFast {
return toRPCErr(err) return Errorf(codes.Unavailable, "%v", errConnClosing)
} }
continue
} }
// All the remaining cases are treated as retryable. // All the other errors are treated as Internal errors.
continue return Errorf(codes.Internal, "%v", err)
} }
if c.traceInfo.tr != nil { if c.traceInfo.tr != nil {
c.traceInfo.tr.LazyLog(&payload{sent: true, msg: args}, true) c.traceInfo.tr.LazyLog(&payload{sent: true, msg: args}, true)

View File

@ -422,15 +422,14 @@ func (cc *ClientConn) newAddrConn(addr Address, skipWait bool) error {
} }
func (cc *ClientConn) getTransport(ctx context.Context, opts BalancerGetOptions) (transport.ClientTransport, func(), error) { func (cc *ClientConn) getTransport(ctx context.Context, opts BalancerGetOptions) (transport.ClientTransport, func(), error) {
// TODO(zhaoq): Implement fail-fast logic.
addr, put, err := cc.dopts.balancer.Get(ctx, opts) addr, put, err := cc.dopts.balancer.Get(ctx, opts)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, toRPCErr(err)
} }
cc.mu.RLock() cc.mu.RLock()
if cc.conns == nil { if cc.conns == nil {
cc.mu.RUnlock() cc.mu.RUnlock()
return nil, nil, ErrClientConnClosing return nil, nil, toRPCErr(ErrClientConnClosing)
} }
ac, ok := cc.conns[addr] ac, ok := cc.conns[addr]
cc.mu.RUnlock() cc.mu.RUnlock()
@ -438,9 +437,9 @@ func (cc *ClientConn) getTransport(ctx context.Context, opts BalancerGetOptions)
if put != nil { if put != nil {
put() put()
} }
return nil, nil, transport.StreamErrorf(codes.Internal, "grpc: failed to find the transport to send the rpc") return nil, nil, Errorf(codes.Internal, "grpc: failed to find the transport to send the rpc")
} }
t, err := ac.wait(ctx) t, err := ac.wait(ctx, !opts.BlockingWait)
if err != nil { if err != nil {
if put != nil { if put != nil {
put() put()
@ -647,8 +646,9 @@ func (ac *addrConn) transportMonitor() {
} }
} }
// wait blocks until i) the new transport is up or ii) ctx is done or iii) ac is closed. // wait blocks until i) the new transport is up or ii) ctx is done or iii) ac is closed or
func (ac *addrConn) wait(ctx context.Context) (transport.ClientTransport, error) { // iv) transport is in TransientFailure and the RPC is fail-fast.
func (ac *addrConn) wait(ctx context.Context, failFast bool) (transport.ClientTransport, error) {
for { for {
ac.mu.Lock() ac.mu.Lock()
switch { switch {
@ -659,6 +659,9 @@ func (ac *addrConn) wait(ctx context.Context) (transport.ClientTransport, error)
ct := ac.transport ct := ac.transport
ac.mu.Unlock() ac.mu.Unlock()
return ct, nil return ct, nil
case ac.state == TransientFailure && failFast:
ac.mu.Unlock()
return nil, Errorf(codes.Unavailable, "grpc: RPC failed fast due to transport failure")
default: default:
ready := ac.ready ready := ac.ready
if ready == nil { if ready == nil {
@ -668,7 +671,7 @@ func (ac *addrConn) wait(ctx context.Context) (transport.ClientTransport, error)
ac.mu.Unlock() ac.mu.Unlock()
select { select {
case <-ctx.Done(): case <-ctx.Done():
return nil, transport.ContextErr(ctx.Err()) return nil, toRPCErr(ctx.Err())
// Wait until the new transport is ready or failed. // Wait until the new transport is ready or failed.
case <-ready: case <-ready:
} }

View File

@ -141,6 +141,8 @@ type callInfo struct {
traceInfo traceInfo // in trace.go traceInfo traceInfo // in trace.go
} }
var defaultCallInfo = callInfo{failFast: true}
// CallOption configures a Call before it starts or extracts information from // CallOption configures a Call before it starts or extracts information from
// a Call after it completes. // a Call after it completes.
type CallOption interface { type CallOption interface {
@ -179,6 +181,19 @@ func Trailer(md *metadata.MD) CallOption {
}) })
} }
// FailFast configures the action to take when an RPC is attempted on broken
// connections or unreachable servers. If failfast is true, the RPC will fail
// immediately. Otherwise, the RPC client will block the call until a
// connection is available (or the call is canceled or times out) and will retry
// the call if it fails due to a transient error. Please refer to
// https://github.com/grpc/grpc/blob/master/doc/fail_fast.md
func FailFast(failFast bool) CallOption {
return beforeCall(func(c *callInfo) error {
c.failFast = failFast
return nil
})
}
// The format of the payload: compressed or not? // The format of the payload: compressed or not?
type payloadFormat uint8 type payloadFormat uint8
@ -374,6 +389,25 @@ func toRPCErr(err error) error {
code: codes.Internal, code: codes.Internal,
desc: e.Desc, desc: e.Desc,
} }
default:
switch err {
case context.DeadlineExceeded:
return rpcError{
code: codes.DeadlineExceeded,
desc: err.Error(),
}
case context.Canceled:
return rpcError{
code: codes.Canceled,
desc: err.Error(),
}
case ErrClientConnClosing:
return rpcError{
code: codes.FailedPrecondition,
desc: err.Error(),
}
}
} }
return Errorf(codes.Unknown, "%v", err) return Errorf(codes.Unknown, "%v", err)
} }

View File

@ -102,16 +102,15 @@ type ClientStream interface {
func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (ClientStream, error) { func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (ClientStream, error) {
var ( var (
t transport.ClientTransport t transport.ClientTransport
s *transport.Stream
err error err error
put func() put func()
) )
// TODO(zhaoq): CallOption is omitted. Add support when it is needed. c := defaultCallInfo
gopts := BalancerGetOptions{ for _, o := range opts {
BlockingWait: false, if err := o.before(&c); err != nil {
} return nil, toRPCErr(err)
t, put, err = cc.getTransport(ctx, gopts) }
if err != nil {
return nil, toRPCErr(err)
} }
callHdr := &transport.CallHdr{ callHdr := &transport.CallHdr{
Host: cc.authority, Host: cc.authority,
@ -122,8 +121,9 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
callHdr.SendCompress = cc.dopts.cp.Type() callHdr.SendCompress = cc.dopts.cp.Type()
} }
cs := &clientStream{ cs := &clientStream{
opts: opts,
c: c,
desc: desc, desc: desc,
put: put,
codec: cc.dopts.codec, codec: cc.dopts.codec,
cp: cc.dopts.cp, cp: cc.dopts.cp,
dc: cc.dopts.dc, dc: cc.dopts.dc,
@ -142,11 +142,44 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
cs.trInfo.tr.LazyLog(&cs.trInfo.firstLine, false) cs.trInfo.tr.LazyLog(&cs.trInfo.firstLine, false)
ctx = trace.NewContext(ctx, cs.trInfo.tr) ctx = trace.NewContext(ctx, cs.trInfo.tr)
} }
s, err := t.NewStream(ctx, callHdr) gopts := BalancerGetOptions{
if err != nil { BlockingWait: !c.failFast,
cs.finish(err)
return nil, toRPCErr(err)
} }
for {
t, put, err = cc.getTransport(ctx, gopts)
if err != nil {
// TODO(zhaoq): Probably revisit the error handling.
if _, ok := err.(rpcError); ok {
return nil, err
}
if err == errConnClosing {
if c.failFast {
return nil, Errorf(codes.Unavailable, "%v", errConnClosing)
}
continue
}
// All the other errors are treated as Internal errors.
return nil, Errorf(codes.Internal, "%v", err)
}
s, err = t.NewStream(ctx, callHdr)
if err != nil {
if put != nil {
put()
put = nil
}
if _, ok := err.(transport.ConnectionError); ok {
if c.failFast {
cs.finish(err)
return nil, toRPCErr(err)
}
continue
}
return nil, toRPCErr(err)
}
break
}
cs.put = put
cs.t = t cs.t = t
cs.s = s cs.s = s
cs.p = &parser{r: s} cs.p = &parser{r: s}
@ -167,6 +200,8 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
// clientStream implements a client side Stream. // clientStream implements a client side Stream.
type clientStream struct { type clientStream struct {
opts []CallOption
c callInfo
t transport.ClientTransport t transport.ClientTransport
s *transport.Stream s *transport.Stream
p *parser p *parser
@ -312,15 +347,18 @@ func (cs *clientStream) closeTransportStream(err error) {
} }
func (cs *clientStream) finish(err error) { func (cs *clientStream) finish(err error) {
if !cs.tracing {
return
}
cs.mu.Lock() cs.mu.Lock()
defer cs.mu.Unlock() defer cs.mu.Unlock()
for _, o := range cs.opts {
o.after(&cs.c)
}
if cs.put != nil { if cs.put != nil {
cs.put() cs.put()
cs.put = nil cs.put = nil
} }
if !cs.tracing {
return
}
if cs.trInfo.tr != nil { if cs.trInfo.tr != nil {
if err == nil || err == io.EOF { if err == nil || err == io.EOF {
cs.trInfo.tr.LazyPrintf("RPC: [OK]") cs.trInfo.tr.LazyPrintf("RPC: [OK]")

View File

@ -550,7 +550,7 @@ func testTimeoutOnDeadServer(t *testing.T, e env) {
cc := te.clientConn() cc := te.clientConn()
tc := testpb.NewTestServiceClient(cc) tc := testpb.NewTestServiceClient(cc)
if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); err != nil { if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil {
t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, <nil>", err) t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, <nil>", err)
} }
te.srv.Stop() te.srv.Stop()
@ -558,12 +558,56 @@ func testTimeoutOnDeadServer(t *testing.T, e env) {
// notification in time the failure path of the 1st invoke of // notification in time the failure path of the 1st invoke of
// ClientConn.wait hits the deadline exceeded error. // ClientConn.wait hits the deadline exceeded error.
ctx, _ := context.WithTimeout(context.Background(), -1) ctx, _ := context.WithTimeout(context.Background(), -1)
if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); grpc.Code(err) != codes.DeadlineExceeded { if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); grpc.Code(err) != codes.DeadlineExceeded {
t.Fatalf("TestService/EmptyCall(%v, _) = _, %v, want _, error code: %d", ctx, err, codes.DeadlineExceeded) t.Fatalf("TestService/EmptyCall(%v, _) = _, %v, want _, error code: %d", ctx, err, codes.DeadlineExceeded)
} }
awaitNewConnLogOutput() awaitNewConnLogOutput()
} }
func TestFailFast(t *testing.T) {
defer leakCheck(t)()
for _, e := range listTestEnv() {
testFailFast(t, e)
}
}
func testFailFast(t *testing.T, e env) {
te := newTest(t, e)
te.userAgent = testAppUA
te.declareLogNoise(
"transport: http2Client.notifyError got notified that the client transport was broken EOF",
"grpc: Conn.transportMonitor exits due to: grpc: the client connection is closing",
"grpc: Conn.resetTransport failed to create client transport: connection error",
"grpc: Conn.resetTransport failed to create client transport: connection error: desc = \"transport: dial unix",
)
te.startServer()
defer te.tearDown()
cc := te.clientConn()
tc := testpb.NewTestServiceClient(cc)
if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, <nil>", err)
}
// Stop the server and tear down all the exisiting connections.
te.srv.Stop()
// Loop until the server teardown is propagated to the client.
for {
if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); grpc.Code(err) == codes.Unavailable {
break
}
time.Sleep(10 * time.Millisecond)
}
// The client keeps reconnecting and ongoing fail-fast RPCs should fail with code.Unavailable.
if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); grpc.Code(err) != codes.Unavailable {
t.Fatalf("TestService/EmptyCall(_, _, _) = _, %v, want _, error code: %d", err, codes.Unavailable)
}
if _, err := tc.StreamingInputCall(context.Background()); grpc.Code(err) != codes.Unavailable {
t.Fatalf("TestService/StreamingInputCall(_) = _, %v, want _, error code: %d", err, codes.Unavailable)
}
awaitNewConnLogOutput()
}
func healthCheck(d time.Duration, cc *grpc.ClientConn, serviceName string) (*healthpb.HealthCheckResponse, error) { func healthCheck(d time.Duration, cc *grpc.ClientConn, serviceName string) (*healthpb.HealthCheckResponse, error) {
ctx, _ := context.WithTimeout(context.Background(), d) ctx, _ := context.WithTimeout(context.Background(), d)
hc := healthpb.NewHealthClient(cc) hc := healthpb.NewHealthClient(cc)
@ -879,7 +923,7 @@ func performOneRPC(t *testing.T, tc testpb.TestServiceClient, wg *sync.WaitGroup
ResponseSize: proto.Int32(respSize), ResponseSize: proto.Int32(respSize),
Payload: payload, Payload: payload,
} }
reply, err := tc.UnaryCall(context.Background(), req) reply, err := tc.UnaryCall(context.Background(), req, grpc.FailFast(false))
if err != nil { if err != nil {
t.Errorf("TestService/UnaryCall(_, _) = _, %v, want _, <nil>", err) t.Errorf("TestService/UnaryCall(_, _) = _, %v, want _, <nil>", err)
return return
@ -1106,9 +1150,7 @@ func testNoService(t *testing.T, e env) {
cc := te.clientConn() cc := te.clientConn()
tc := testpb.NewTestServiceClient(cc) tc := testpb.NewTestServiceClient(cc)
// Make sure setting ack has been sent. stream, err := tc.FullDuplexCall(te.ctx, grpc.FailFast(false))
time.Sleep(20 * time.Millisecond)
stream, err := tc.FullDuplexCall(te.ctx)
if err != nil { if err != nil {
t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err) t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
} }