diff --git a/call.go b/call.go index 368980d4..744514c3 100644 --- a/call.go +++ b/call.go @@ -29,7 +29,6 @@ import ( "google.golang.org/grpc/encoding" "google.golang.org/grpc/peer" "google.golang.org/grpc/stats" - "google.golang.org/grpc/status" "google.golang.org/grpc/transport" ) @@ -208,59 +207,48 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli Last: true, Delay: false, } + callHdr := &transport.CallHdr{ + Host: cc.authority, + Method: method, + } + if cc.dopts.cp != nil { + callHdr.SendCompress = cc.dopts.cp.Type() + } + if c.creds != nil { + callHdr.Creds = c.creds + } + if c.compressorType != "" { + callHdr.SendCompress = c.compressorType + } else if cc.dopts.cp != nil { + callHdr.SendCompress = cc.dopts.cp.Type() + } + firstAttempt := true + for { - var ( - err error - t transport.ClientTransport - stream *transport.Stream - // Record the done handler from Balancer.Get(...). It is called once the - // RPC has completed or failed. - done func(balancer.DoneInfo) - ) - // TODO(zhaoq): Need a formal spec of fail-fast. - callHdr := &transport.CallHdr{ - Host: cc.authority, - Method: method, - } - if c.compressorType != "" { - callHdr.SendCompress = c.compressorType - } else if cc.dopts.cp != nil { - callHdr.SendCompress = cc.dopts.cp.Type() - } - if c.creds != nil { - callHdr.Creds = c.creds + // Check to make sure the context has expired. This will prevent us from + // looping forever if an error occurs for wait-for-ready RPCs where no data + // is sent on the wire. + select { + case <-ctx.Done(): + return toRPCErr(ctx.Err()) + default: } - t, done, err = cc.getTransport(ctx, c.failFast) + // Record the done handler from Balancer.Get(...). It is called once the + // RPC has completed or failed. + t, done, err := cc.getTransport(ctx, c.failFast) if err != nil { - // TODO(zhaoq): Probably revisit the error handling. - if _, ok := status.FromError(err); ok { - return err - } - if err == errConnClosing || err == errConnUnavailable { - if c.failFast { - return Errorf(codes.Unavailable, "%v", err) - } - continue - } - // All the other errors are treated as Internal errors. - return Errorf(codes.Internal, "%v", err) + return err } - if c.traceInfo.tr != nil { - c.traceInfo.tr.LazyLog(&payload{sent: true, msg: args}, true) - } - stream, err = t.NewStream(ctx, callHdr) + stream, err := t.NewStream(ctx, callHdr) if err != nil { if done != nil { - if _, ok := err.(transport.ConnectionError); ok { - // If error is connection error, transport was sending data on wire, - // and we are not sure if anything has been sent on wire. - // If error is not connection error, we are sure nothing has been sent. - updateRPCInfoInContext(ctx, rpcInfo{bytesSent: true, bytesReceived: false}) - } done(balancer.DoneInfo{Err: err}) } - if _, ok := err.(transport.ConnectionError); (ok || err == transport.ErrStreamDrain) && !c.failFast { + // In the event of any error from NewStream, we never attempted to write + // anything to the wire, so we can retry indefinitely for non-fail-fast + // RPCs. + if !c.failFast { continue } return toRPCErr(err) @@ -268,20 +256,30 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli if peer, ok := peer.FromContext(stream.Context()); ok { c.peer = peer } + if c.traceInfo.tr != nil { + c.traceInfo.tr.LazyLog(&payload{sent: true, msg: args}, true) + } err = sendRequest(ctx, cc.dopts, cc.dopts.cp, c, callHdr, stream, t, args, topts) if err != nil { if done != nil { updateRPCInfoInContext(ctx, rpcInfo{ - bytesSent: stream.BytesSent(), + bytesSent: true, bytesReceived: stream.BytesReceived(), }) done(balancer.DoneInfo{Err: err}) } // Retry a non-failfast RPC when - // i) there is a connection error; or - // ii) the server started to drain before this RPC was initiated. - if _, ok := err.(transport.ConnectionError); (ok || err == transport.ErrStreamDrain) && !c.failFast { - continue + // i) the server started to drain before this RPC was initiated. + // ii) the server refused the stream. + if !c.failFast && stream.Unprocessed() { + // In this case, the server did not receive the data, but we still + // created wire traffic, so we should not retry indefinitely. + if firstAttempt { + // TODO: Add a field to header for grpc-transparent-retry-attempts + firstAttempt = false + continue + } + // Otherwise, give up and return an error anyway. } return toRPCErr(err) } @@ -289,13 +287,20 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli if err != nil { if done != nil { updateRPCInfoInContext(ctx, rpcInfo{ - bytesSent: stream.BytesSent(), + bytesSent: true, bytesReceived: stream.BytesReceived(), }) done(balancer.DoneInfo{Err: err}) } - if _, ok := err.(transport.ConnectionError); (ok || err == transport.ErrStreamDrain) && !c.failFast { - continue + if !c.failFast && stream.Unprocessed() { + // In these cases, the server did not receive the data, but we still + // created wire traffic, so we should not retry indefinitely. + if firstAttempt { + // TODO: Add a field to header for grpc-transparent-retry-attempts + firstAttempt = false + continue + } + // Otherwise, give up and return an error anyway. } return toRPCErr(err) } @@ -305,11 +310,20 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli t.CloseStream(stream, nil) if done != nil { updateRPCInfoInContext(ctx, rpcInfo{ - bytesSent: stream.BytesSent(), + bytesSent: true, bytesReceived: stream.BytesReceived(), }) done(balancer.DoneInfo{Err: err}) } + if !c.failFast && stream.Unprocessed() { + // In these cases, the server did not receive the data, but we still + // created wire traffic, so we should not retry indefinitely. + if firstAttempt { + // TODO: Add a field to header for grpc-transparent-retry-attempts + firstAttempt = false + continue + } + } return stream.Status().Err() } } diff --git a/clientconn.go b/clientconn.go index 8a4e8826..c374a595 100644 --- a/clientconn.go +++ b/clientconn.go @@ -51,7 +51,20 @@ var ( // underlying connections within the specified timeout. // DEPRECATED: Please use context.DeadlineExceeded instead. ErrClientConnTimeout = errors.New("grpc: timed out when dialing") + // errConnDrain indicates that the connection starts to be drained and does not accept any new RPCs. + errConnDrain = errors.New("grpc: the connection is drained") + // errConnClosing indicates that the connection is closing. + errConnClosing = errors.New("grpc: the connection is closing") + // errConnUnavailable indicates that the connection is unavailable. + errConnUnavailable = errors.New("grpc: the connection is unavailable") + // errBalancerClosed indicates that the balancer is closed. + errBalancerClosed = errors.New("grpc: balancer is closed") + // minimum time to give a connection to complete + minConnectTimeout = 20 * time.Second +) +// The following errors are returned from Dial and DialContext +var ( // errNoTransportSecurity indicates that there is no transport security // being set for ClientConn. Users should either set one or explicitly // call WithInsecure DialOption to disable security. @@ -65,16 +78,6 @@ var ( errCredentialsConflict = errors.New("grpc: transport credentials are set for an insecure connection (grpc.WithTransportCredentials() and grpc.WithInsecure() are both called)") // errNetworkIO indicates that the connection is down due to some network I/O error. errNetworkIO = errors.New("grpc: failed with network I/O error") - // errConnDrain indicates that the connection starts to be drained and does not accept any new RPCs. - errConnDrain = errors.New("grpc: the connection is drained") - // errConnClosing indicates that the connection is closing. - errConnClosing = errors.New("grpc: the connection is closing") - // errConnUnavailable indicates that the connection is unavailable. - errConnUnavailable = errors.New("grpc: the connection is unavailable") - // errBalancerClosed indicates that the balancer is closed. - errBalancerClosed = errors.New("grpc: balancer is closed") - // minimum time to give a connection to complete - minConnectTimeout = 20 * time.Second ) // dialOptions configure a Dial call. dialOptions are set by the DialOption @@ -879,7 +882,7 @@ type addrConn struct { // receiving a GoAway. func (ac *addrConn) adjustParams(r transport.GoAwayReason) { switch r { - case transport.TooManyPings: + case transport.GoAwayTooManyPings: v := 2 * ac.dopts.copts.KeepaliveParams.Time ac.cc.mu.Lock() if v > ac.cc.mkp.Time { diff --git a/internal/internal.go b/internal/internal.go index 07083832..53f17752 100644 --- a/internal/internal.go +++ b/internal/internal.go @@ -19,13 +19,6 @@ // the godoc of the top-level grpc package. package internal -// TestingCloseConns closes all existing transports but keeps -// grpcServer.lis accepting new connections. -// -// The provided grpcServer must be of type *grpc.Server. It is untyped -// for circular dependency reasons. -var TestingCloseConns func(grpcServer interface{}) - // TestingUseHandlerImpl enables the http.Handler-based server implementation. // It must be called before Serve and requires TLS credentials. // diff --git a/rpc_util.go b/rpc_util.go index eccf84de..7c39ed15 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -195,11 +195,15 @@ func Peer(peer *peer.Peer) CallOption { } // FailFast configures the action to take when an RPC is attempted on broken -// connections or unreachable servers. If failFast is true, the RPC will fail +// connections or unreachable servers. If failFast is true, the RPC will fail // immediately. Otherwise, the RPC client will block the call until a -// connection is available (or the call is canceled or times out). Please refer -// to https://github.com/grpc/grpc/blob/master/doc/wait-for-ready.md. -// The default behavior of RPCs is to fail fast. +// connection is available (or the call is canceled or times out) and will +// retry the call if it fails due to a transient error. gRPC will not retry if +// data was written to the wire unless the server indicates it did not process +// the data. Please refer to +// https://github.com/grpc/grpc/blob/master/doc/wait-for-ready.md. +// +// By default, RPCs are "Fail Fast". func FailFast(failFast bool) CallOption { return beforeCall(func(c *callInfo) error { c.failFast = failFast diff --git a/server.go b/server.go index 8cc29152..02316e8b 100644 --- a/server.go +++ b/server.go @@ -1184,25 +1184,11 @@ func (s *Server) GracefulStop() { } func init() { - internal.TestingCloseConns = func(arg interface{}) { - arg.(*Server).testingCloseConns() - } internal.TestingUseHandlerImpl = func(arg interface{}) { arg.(*Server).opts.useHandlerImpl = true } } -// testingCloseConns closes all existing transports but keeps s.lis -// accepting new connections. -func (s *Server) testingCloseConns() { - s.mu.Lock() - for c := range s.conns { - c.Close() - delete(s.conns, c) - } - s.mu.Unlock() -} - // SetHeader sets the header metadata. // When called multiple times, all the provided metadata will be merged. // All the metadata will be sent out when one of the following happens: diff --git a/stream.go b/stream.go index a659f14e..44547b79 100644 --- a/stream.go +++ b/stream.go @@ -199,42 +199,39 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth } }() } + for { + // Check to make sure the context has expired. This will prevent us from + // looping forever if an error occurs for wait-for-ready RPCs where no data + // is sent on the wire. + select { + case <-ctx.Done(): + return nil, toRPCErr(ctx.Err()) + default: + } + t, done, err = cc.getTransport(ctx, c.failFast) if err != nil { - // TODO(zhaoq): Probably revisit the error handling. - if _, ok := status.FromError(err); ok { - return nil, err - } - if err == errConnClosing || err == errConnUnavailable { - if c.failFast { - return nil, Errorf(codes.Unavailable, "%v", err) - } - continue - } - // All the other errors are treated as Internal errors. - return nil, Errorf(codes.Internal, "%v", err) + return nil, err } s, err = t.NewStream(ctx, callHdr) if err != nil { - if _, ok := err.(transport.ConnectionError); ok && done != nil { - // If error is connection error, transport was sending data on wire, - // and we are not sure if anything has been sent on wire. - // If error is not connection error, we are sure nothing has been sent. - updateRPCInfoInContext(ctx, rpcInfo{bytesSent: true, bytesReceived: false}) - } if done != nil { done(balancer.DoneInfo{Err: err}) done = nil } - if _, ok := err.(transport.ConnectionError); (ok || err == transport.ErrStreamDrain) && !c.failFast { + // In the event of any error from NewStream, we never attempted to write + // anything to the wire, so we can retry indefinitely for non-fail-fast + // RPCs. + if !c.failFast { continue } return nil, toRPCErr(err) } break } + // Set callInfo.peer object from stream's context. if peer, ok := peer.FromContext(s.Context()); ok { c.peer = peer @@ -260,8 +257,8 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth statsCtx: ctx, statsHandler: cc.dopts.copts.StatsHandler, } - // Listen on ctx.Done() to detect cancellation and s.Done() to detect normal termination - // when there is no pending I/O operations on this stream. + // Listen on s.Context().Done() to detect cancellation and s.Done() to detect + // normal termination when there is no pending I/O operations on this stream. go func() { select { case <-t.Error(): @@ -502,7 +499,7 @@ func (cs *clientStream) finish(err error) { } if cs.done != nil { updateRPCInfoInContext(cs.s.Context(), rpcInfo{ - bytesSent: cs.s.BytesSent(), + bytesSent: true, bytesReceived: cs.s.BytesReceived(), }) cs.done(balancer.DoneInfo{Err: err}) diff --git a/test/end2end_test.go b/test/end2end_test.go index 77945734..e2d1d2d6 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -818,7 +818,17 @@ func testTimeoutOnDeadServer(t *testing.T, e env) { t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, ", err) } te.srv.Stop() - ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + + // Wait for the client to notice the connection is gone. + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + state := cc.GetState() + for ; state == connectivity.Ready && cc.WaitForStateChange(ctx, state); state = cc.GetState() { + } + cancel() + if state == connectivity.Ready { + t.Fatalf("Timed out waiting for non-ready state") + } + ctx, cancel = context.WithTimeout(context.Background(), time.Millisecond) _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)) cancel() if e.balancer != "" && grpc.Code(err) != codes.DeadlineExceeded { @@ -932,7 +942,9 @@ func testServerGoAwayPendingRPC(t *testing.T, e env) { close(ch) }() // Loop until the server side GoAway signal is propagated to the client. - for { + abort := false + time.AfterFunc(time.Second, func() { abort = true }) + for !abort { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); err != nil { cancel() @@ -940,11 +952,11 @@ func testServerGoAwayPendingRPC(t *testing.T, e env) { } cancel() } - respParam := []*testpb.ResponseParameters{ - { - Size: 1, - }, + // Don't bother stopping the timer; it will have no effect past here. + if abort { + t.Fatalf("GoAway never received by client") } + respParam := []*testpb.ResponseParameters{{Size: 1}} payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, int32(100)) if err != nil { t.Fatal(err) @@ -961,6 +973,7 @@ func testServerGoAwayPendingRPC(t *testing.T, e env) { if _, err := stream.Recv(); err != nil { t.Fatalf("%v.Recv() = _, %v, want _, ", stream, err) } + // The RPC will run until canceled. cancel() <-ch awaitNewConnLogOutput() @@ -1217,18 +1230,22 @@ func testFailFast(t *testing.T, e env) { cc := te.clientConn() tc := testpb.NewTestServiceClient(cc) - if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); err != nil { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != nil { t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, ", err) } // Stop the server and tear down all the exisiting connections. te.srv.Stop() // Loop until the server teardown is propagated to the client. for { - _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + _, err := tc.EmptyCall(ctx, &testpb.Empty{}) + cancel() if grpc.Code(err) == codes.Unavailable { break } - fmt.Printf("%v.EmptyCall(_, _) = _, %v", tc, err) + t.Logf("%v.EmptyCall(_, _) = _, %v", tc, err) time.Sleep(10 * time.Millisecond) } // The client keeps reconnecting and ongoing fail-fast RPCs should fail with code.Unavailable. @@ -3006,35 +3023,6 @@ func testMalformedHTTP2Metadata(t *testing.T, e env) { } } -func performOneRPC(t *testing.T, tc testpb.TestServiceClient, wg *sync.WaitGroup) { - defer wg.Done() - const argSize = 2718 - const respSize = 314 - - payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, argSize) - if err != nil { - t.Error(err) - return - } - - req := &testpb.SimpleRequest{ - ResponseType: testpb.PayloadType_COMPRESSABLE, - ResponseSize: respSize, - Payload: payload, - } - reply, err := tc.UnaryCall(context.Background(), req, grpc.FailFast(false)) - if err != nil { - t.Errorf("TestService/UnaryCall(_, _) = _, %v, want _, ", err) - return - } - pt := reply.GetPayload().GetType() - ps := len(reply.GetPayload().GetBody()) - if pt != testpb.PayloadType_COMPRESSABLE || ps != respSize { - t.Errorf("Got reply with type %d len %d; want %d, %d", pt, ps, testpb.PayloadType_COMPRESSABLE, respSize) - return - } -} - func TestRetry(t *testing.T) { defer leakcheck.Check(t) for _, e := range listTestEnv() { @@ -3046,49 +3034,54 @@ func TestRetry(t *testing.T) { } } -// This test mimics a user who sends 1000 RPCs concurrently on a faulty transport. -// TODO(zhaoq): Refactor to make this clearer and add more cases to test racy -// and error-prone paths. +// This test make sure RPCs are retried times when they receive a RST_STREAM +// with the REFUSED_STREAM error code, which the InTapHandle provokes. func testRetry(t *testing.T, e env) { te := newTest(t, e) - te.declareLogNoise("transport: http2Client.notifyError got notified that the client transport was broken") + attempts := 0 + successAttempt := 2 + te.tapHandle = func(ctx context.Context, _ *tap.Info) (context.Context, error) { + attempts++ + if attempts < successAttempt { + return nil, errors.New("not now") + } + return ctx, nil + } te.startServer(&testServer{security: e.security}) defer te.tearDown() cc := te.clientConn() - tc := testpb.NewTestServiceClient(cc) - var wg sync.WaitGroup + tsc := testpb.NewTestServiceClient(cc) + testCases := []struct { + successAttempt int + failFast bool + errCode codes.Code + }{{ + successAttempt: 1, + }, { + successAttempt: 2, + }, { + successAttempt: 3, + errCode: codes.Unavailable, + }, { + successAttempt: 1, + failFast: true, + }, { + successAttempt: 2, + failFast: true, + errCode: codes.Unavailable, // We won't retry on fail fast. + }} + for _, tc := range testCases { + attempts = 0 + successAttempt = tc.successAttempt - numRPC := 1000 - rpcSpacing := 2 * time.Millisecond - if raceMode { - // The race detector has a limit on how many goroutines it can track. - // This test is near the upper limit, and goes over the limit - // depending on the environment (the http.Handler environment uses - // more goroutines) - t.Logf("Shortening test in race mode.") - numRPC /= 2 - rpcSpacing *= 2 + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + _, err := tsc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(tc.failFast)) + cancel() + if grpc.Code(err) != tc.errCode { + t.Errorf("%+v: tsc.EmptyCall(_, _) = _, %v, want _, Code=%v", tc, err, tc.errCode) + } } - - wg.Add(1) - go func() { - // Halfway through starting RPCs, kill all connections: - time.Sleep(time.Duration(numRPC/2) * rpcSpacing) - - // The server shuts down the network connection to make a - // transport error which will be detected by the client side - // code. - internal.TestingCloseConns(te.srv) - wg.Done() - }() - // All these RPCs should succeed eventually. - for i := 0; i < numRPC; i++ { - time.Sleep(rpcSpacing) - wg.Add(1) - go performOneRPC(t, tc, &wg) - } - wg.Wait() } func TestRPCTimeout(t *testing.T) { diff --git a/transport/control.go b/transport/control.go index 8aee96b7..8bfa6c3d 100644 --- a/transport/control.go +++ b/transport/control.go @@ -195,7 +195,7 @@ func (qb *quotaPool) get(v int, wc waiters) (int, uint32, error) { case <-wc.done: return 0, 0, io.EOF case <-wc.goAway: - return 0, 0, ErrStreamDrain + return 0, 0, errStreamDrain case <-qb.c: qb.mu.Lock() if qb.quota > 0 { diff --git a/transport/http2_client.go b/transport/http2_client.go index 74c1ee62..f6bd24a0 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -410,7 +410,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea } if t.state == draining { t.mu.Unlock() - return nil, ErrStreamDrain + return nil, errStreamDrain } if t.state != reachable { t.mu.Unlock() @@ -481,7 +481,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea if t.state == draining { t.mu.Unlock() t.streamsQuota.add(1) - return nil, ErrStreamDrain + return nil, errStreamDrain } if t.state != reachable { t.mu.Unlock() @@ -509,10 +509,6 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea }) t.mu.Unlock() - s.mu.Lock() - s.bytesSent = true - s.mu.Unlock() - if t.statsHandler != nil { outHeader := &stats.OutHeader{ Client: true, @@ -586,16 +582,16 @@ func (t *http2Client) CloseStream(s *Stream, err error) { // Close kicks off the shutdown process of the transport. This should be called // only once on a transport. Once it is called, the transport should not be // accessed any more. -func (t *http2Client) Close() (err error) { +func (t *http2Client) Close() error { t.mu.Lock() if t.state == closing { t.mu.Unlock() - return + return nil } t.state = closing t.mu.Unlock() t.cancel() - err = t.conn.Close() + err := t.conn.Close() t.mu.Lock() streams := t.activeStreams t.activeStreams = nil @@ -901,7 +897,13 @@ func (t *http2Client) handleRSTStream(f *http2.RSTStreamFrame) { close(s.headerChan) s.headerDone = true } - statusCode, ok := http2ErrConvTab[http2.ErrCode(f.ErrCode)] + + code := http2.ErrCode(f.ErrCode) + if code == http2.ErrCodeRefusedStream { + // The stream was unprocessed by the server. + s.unprocessed = true + } + statusCode, ok := http2ErrConvTab[code] if !ok { warningf("transport: http2Client.handleRSTStream found no mapped gRPC status for the received http2 error %v", f.ErrCode) statusCode = codes.Unknown @@ -983,12 +985,16 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { t.Close() return } - // A client can receive multiple GoAways from server (look at https://github.com/grpc/grpc-go/issues/1387). - // The idea is that the first GoAway will be sent with an ID of MaxInt32 and the second GoAway will be sent after an RTT delay - // with the ID of the last stream the server will process. - // Therefore, when we get the first GoAway we don't really close any streams. While in case of second GoAway we - // close all streams created after the second GoAwayId. This way streams that were in-flight while the GoAway from server - // was being sent don't get killed. + // A client can recieve multiple GoAways from the server (see + // https://github.com/grpc/grpc-go/issues/1387). The idea is that the first + // GoAway will be sent with an ID of MaxInt32 and the second GoAway will be + // sent after an RTT delay with the ID of the last stream the server will + // process. + // + // Therefore, when we get the first GoAway we don't necessarily close any + // streams. While in case of second GoAway we close all streams created after + // the GoAwayId. This way streams that were in-flight while the GoAway from + // server was being sent don't get killed. select { case <-t.goAway: // t.goAway has been closed (i.e.,multiple GoAways). // If there are multiple GoAways the first one should always have an ID greater than the following ones. @@ -1010,6 +1016,11 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { } for streamID, stream := range t.activeStreams { if streamID > id && streamID <= upperLimit { + // The stream was unprocessed by the server. + stream.mu.Lock() + stream.unprocessed = true + stream.finish(statusGoAway) + stream.mu.Unlock() close(stream.goAway) } } @@ -1026,11 +1037,11 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { // It expects a lock on transport's mutext to be held by // the caller. func (t *http2Client) setGoAwayReason(f *http2.GoAwayFrame) { - t.goAwayReason = NoReason + t.goAwayReason = GoAwayNoReason switch f.ErrCode { case http2.ErrCodeEnhanceYourCalm: if string(f.DebugData()) == "too_many_pings" { - t.goAwayReason = TooManyPings + t.goAwayReason = GoAwayTooManyPings } } } diff --git a/transport/transport.go b/transport/transport.go index 9a09aa9a..d48e0611 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -17,7 +17,8 @@ */ // Package transport defines and implements message oriented communication -// channel to complete various transactions (e.g., an RPC). +// channel to complete various transactions (e.g., an RPC). It is meant for +// grpc-internal usage and is not intended to be imported directly by users. package transport // import "google.golang.org/grpc/transport" import ( @@ -133,7 +134,7 @@ func (r *recvBufferReader) read(p []byte) (n int, err error) { case <-r.ctx.Done(): return 0, ContextErr(r.ctx.Err()) case <-r.goAway: - return 0, ErrStreamDrain + return 0, errStreamDrain case m := <-r.recv.get(): r.recv.load() if m.err != nil { @@ -210,19 +211,13 @@ const ( // Stream represents an RPC in the transport layer. type Stream struct { - id uint32 - // nil for client side Stream. - st ServerTransport - // ctx is the associated context of the stream. - ctx context.Context - // cancel is always nil for client side Stream. - cancel context.CancelFunc - // done is closed when the final status arrives. - done chan struct{} - // goAway is closed when the server sent GoAways signal before this stream was initiated. - goAway chan struct{} - // method records the associated RPC method of the stream. - method string + id uint32 + st ServerTransport // nil for client side Stream + ctx context.Context // the associated context of the stream + cancel context.CancelFunc // always nil for client side Stream + done chan struct{} // closed when the final status arrives + goAway chan struct{} // closed when a GOAWAY control message is received + method string // the associated RPC method of the stream recvCompress string sendCompress string buf *recvBuffer @@ -231,40 +226,27 @@ type Stream struct { recvQuota uint32 waiters waiters - // TODO: Remote this unused variable. - // The accumulated inbound quota pending for window update. - updateQuota uint32 - // Callback to state application's intentions to read data. This - // is used to adjust flow control, if need be. + // is used to adjust flow control, if needed. requestRead func(int) sendQuotaPool *quotaPool - // Close headerChan to indicate the end of reception of header metadata. - headerChan chan struct{} - // header caches the received header metadata. - header metadata.MD - // The key-value map of trailer metadata. - trailer metadata.MD + headerChan chan struct{} // closed to indicate the end of header metadata. + headerDone bool // set when headerChan is closed. Used to avoid closing headerChan multiple times. + header metadata.MD // the received header metadata. + trailer metadata.MD // the key-value map of trailer metadata. - mu sync.RWMutex // guard the following - // headerOK becomes true from the first header is about to send. - headerOk bool + mu sync.RWMutex // guard the following + headerOk bool // becomes true from the first header is about to send state streamState - // true iff headerChan is closed. Used to avoid closing headerChan - // multiple times. - headerDone bool - // the status error received from the server. - status *status.Status - // rstStream indicates whether a RST_STREAM frame needs to be sent - // to the server to signify that this stream is closing. - rstStream bool - // rstError is the error that needs to be sent along with the RST_STREAM frame. - rstError http2.ErrCode - // bytesSent and bytesReceived indicates whether any bytes have been sent or - // received on this stream. - bytesSent bool - bytesReceived bool + + status *status.Status // the status error received from the server + + rstStream bool // indicates whether a RST_STREAM frame needs to be sent + rstError http2.ErrCode // the error that needs to be sent along with the RST_STREAM frame + + bytesReceived bool // indicates whether any bytes have been received on this stream + unprocessed bool // set if the server sends a refused stream or GOAWAY including this stream } // RecvCompress returns the compression algorithm applied to the inbound @@ -299,7 +281,7 @@ func (s *Stream) Header() (metadata.MD, error) { case <-s.ctx.Done(): err = ContextErr(s.ctx.Err()) case <-s.goAway: - err = ErrStreamDrain + err = errStreamDrain case <-s.headerChan: return s.header.Copy(), nil } @@ -416,14 +398,6 @@ func (s *Stream) finish(st *status.Status) { close(s.done) } -// BytesSent indicates whether any bytes have been sent on this stream. -func (s *Stream) BytesSent() bool { - s.mu.Lock() - bs := s.bytesSent - s.mu.Unlock() - return bs -} - // BytesReceived indicates whether any bytes have been received on this stream. func (s *Stream) BytesReceived() bool { s.mu.Lock() @@ -432,6 +406,15 @@ func (s *Stream) BytesReceived() bool { return br } +// Unprocessed indicates whether the server did not process this stream -- +// i.e. it sent a refused stream or GOAWAY including this stream ID. +func (s *Stream) Unprocessed() bool { + s.mu.Lock() + br := s.unprocessed + s.mu.Unlock() + return br +} + // GoString is implemented by Stream so context.String() won't // race when printing %#v. func (s *Stream) GoString() string { @@ -686,9 +669,13 @@ func (e ConnectionError) Origin() error { var ( // ErrConnClosing indicates that the transport is closing. ErrConnClosing = connectionErrorf(true, nil, "transport is closing") - // ErrStreamDrain indicates that the stream is rejected by the server because + // errStreamDrain indicates that the stream is rejected by the server because // the server stops accepting new RPCs. - ErrStreamDrain = streamErrorf(codes.Unavailable, "the server stops accepting new RPCs") + // TODO: delete this error; it is no longer necessary. + errStreamDrain = streamErrorf(codes.Unavailable, "the server stops accepting new RPCs") + // StatusGoAway indicates that the server sent a GOAWAY that included this + // stream's ID in unprocessed RPCs. + statusGoAway = status.New(codes.Unavailable, "the server stopped accepting new RPCs") ) // TODO: See if we can replace StreamError with status package errors. @@ -716,13 +703,14 @@ type waiters struct { type GoAwayReason uint8 const ( - // Invalid indicates that no GoAway frame is received. - Invalid GoAwayReason = 0 - // NoReason is the default value when GoAway frame is received. - NoReason GoAwayReason = 1 - // TooManyPings indicates that a GoAway frame with ErrCodeEnhanceYourCalm - // was received and that the debug data said "too_many_pings". - TooManyPings GoAwayReason = 2 + // GoAwayInvalid indicates that no GoAway frame is received. + GoAwayInvalid GoAwayReason = 0 + // GoAwayNoReason is the default value when GoAway frame is received. + GoAwayNoReason GoAwayReason = 1 + // GoAwayTooManyPings indicates that a GoAway frame with + // ErrCodeEnhanceYourCalm was received and that the debug data said + // "too_many_pings". + GoAwayTooManyPings GoAwayReason = 2 ) // loopyWriter is run in a separate go routine. It is the single code path that will diff --git a/transport/transport_test.go b/transport/transport_test.go index c9b3e656..9ffdd486 100644 --- a/transport/transport_test.go +++ b/transport/transport_test.go @@ -1019,8 +1019,8 @@ func TestGracefulClose(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - if _, err := ct.NewStream(context.Background(), callHdr); err != ErrStreamDrain { - t.Errorf("%v.NewStream(_, _) = _, %v, want _, %v", ct, err, ErrStreamDrain) + if _, err := ct.NewStream(context.Background(), callHdr); err != errStreamDrain { + t.Errorf("%v.NewStream(_, _) = _, %v, want _, %v", ct, err, errStreamDrain) } }() }