streams: Stop cleaning up after orphaned streams (#1854)

This change introduces some behavior changes that should not impact users that
are following the proper stream protocol. Specifically, one of the following
conditions must be satisfied:

1. The user calls Close on the ClientConn.
2. The user cancels the context provided to NewClientStream, or its deadline
    expires. (Note that it if the context is no longer needed before the deadline
    expires, it is still recommended to call cancel to prevent bloat.) It is always
    recommended to cancel contexts when they are no longer needed, and to
    never use the background context directly, so all users should always be
    doing this.
3. The user calls RecvMsg (or Recv in generated code) until a non-nil error is
    returned.
4. The user receives any error from Header or SendMsg (or Send in generated
    code) besides io.EOF.  If none of the above happen, this will leak a goroutine
    and a context, and grpc will not call the optionally-configured stats handler
    with a stats.End message.

Before this change, if a user created a stream and the server ended the stream,
the stats handler would be invoked with a stats.End containing the final status
of the stream. Subsequent calls to RecvMsg would then trigger the stats handler
with InPayloads, which may be unexpected by stats handlers.
This commit is contained in:
dfawley
2018-02-08 10:51:16 -08:00
committed by GitHub
parent 7646b5360d
commit 365770fcbd
6 changed files with 123 additions and 178 deletions

View File

@ -32,6 +32,7 @@ import (
"golang.org/x/net/trace"
"google.golang.org/grpc/balancer"
_ "google.golang.org/grpc/balancer/roundrobin" // To register roundrobin.
"google.golang.org/grpc/codes"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/grpclog"
@ -40,17 +41,17 @@ import (
_ "google.golang.org/grpc/resolver/dns" // To register dns resolver.
_ "google.golang.org/grpc/resolver/passthrough" // To register passthrough resolver.
"google.golang.org/grpc/stats"
"google.golang.org/grpc/status"
"google.golang.org/grpc/transport"
)
var (
// ErrClientConnClosing indicates that the operation is illegal because
// the ClientConn is closing.
ErrClientConnClosing = errors.New("grpc: the client connection is closing")
// ErrClientConnTimeout indicates that the ClientConn cannot establish the
// underlying connections within the specified timeout.
// DEPRECATED: Please use context.DeadlineExceeded instead.
ErrClientConnTimeout = errors.New("grpc: timed out when dialing")
//
// Deprecated: this error should not be relied upon by users; use the status
// code of Canceled instead.
ErrClientConnClosing = status.Error(codes.Canceled, "grpc: the client connection is closing")
// errConnDrain indicates that the connection starts to be drained and does not accept any new RPCs.
errConnDrain = errors.New("grpc: the connection is drained")
// errConnClosing indicates that the connection is closing.
@ -1374,3 +1375,10 @@ func (ac *addrConn) getState() connectivity.State {
defer ac.mu.Unlock()
return ac.state
}
// ErrClientConnTimeout indicates that the ClientConn cannot establish the
// underlying connections within the specified timeout.
//
// Deprecated: This error is never returned by grpc and should not be
// referenced by users.
var ErrClientConnTimeout = errors.New("grpc: timed out when dialing")

View File

@ -48,6 +48,9 @@ func sendHTTPRequest(ctx context.Context, req *http.Request, conn net.Conn) erro
// toRPCErr converts an error into an error from the status package.
func toRPCErr(err error) error {
if err == nil || err == io.EOF {
return err
}
if _, ok := status.FromError(err); ok {
return err
}
@ -62,8 +65,6 @@ func toRPCErr(err error) error {
return status.Error(codes.DeadlineExceeded, err.Error())
case context.Canceled:
return status.Error(codes.Canceled, err.Error())
case ErrClientConnClosing:
return status.Error(codes.FailedPrecondition, err.Error())
}
}
return status.Error(codes.Unknown, err.Error())

View File

@ -49,6 +49,9 @@ func sendHTTPRequest(ctx context.Context, req *http.Request, conn net.Conn) erro
// toRPCErr converts an error into an error from the status package.
func toRPCErr(err error) error {
if err == nil || err == io.EOF {
return err
}
if _, ok := status.FromError(err); ok {
return err
}
@ -63,8 +66,6 @@ func toRPCErr(err error) error {
return status.Error(codes.DeadlineExceeded, err.Error())
case context.Canceled, netctx.Canceled:
return status.Error(codes.Canceled, err.Error())
case ErrClientConnClosing:
return status.Error(codes.FailedPrecondition, err.Error())
}
}
return status.Error(codes.Unknown, err.Error())

View File

@ -260,7 +260,6 @@ type rpcConfig struct {
success bool // Whether the RPC should succeed or return error.
failfast bool
callType rpcType // Type of RPC.
noLastRecv bool // Whether to call recv for io.EOF. When true, last recv won't be called. Only valid for streaming RPCs.
}
func (te *test) doUnaryCall(c *rpcConfig) (*testpb.SimpleRequest, *testpb.SimpleResponse, error) {
@ -313,15 +312,9 @@ func (te *test) doFullDuplexCallRoundtrip(c *rpcConfig) ([]*testpb.SimpleRequest
if err = stream.CloseSend(); err != nil && err != io.EOF {
return reqs, resps, err
}
if !c.noLastRecv {
if _, err = stream.Recv(); err != io.EOF {
return reqs, resps, err
}
} else {
// In the case of not calling the last recv, sleep to avoid
// returning too fast to miss the remaining stats (InTrailer and End).
time.Sleep(time.Second)
}
return reqs, resps, nil
}
@ -651,7 +644,7 @@ func checkEnd(t *testing.T, d *gotData, e *expectedData) {
actual, ok := status.FromError(st.Error)
if !ok {
t.Fatalf("expected st.Error to be a statusError, got %T", st.Error)
t.Fatalf("expected st.Error to be a statusError, got %v (type %T)", st.Error, st.Error)
}
expectedStatus, _ := status.FromError(e.err)
@ -1222,20 +1215,6 @@ func TestClientStatsFullDuplexRPCError(t *testing.T) {
})
}
// If the user doesn't call the last recv() on clientStream.
func TestClientStatsFullDuplexRPCNotCallingLastRecv(t *testing.T) {
count := 1
testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, failfast: false, callType: fullDuplexStreamRPC, noLastRecv: true}, map[int]*checkFuncWithCount{
begin: {checkBegin, 1},
outHeader: {checkOutHeader, 1},
outPayload: {checkOutPayload, count},
inHeader: {checkInHeader, 1},
inPayload: {checkInPayload, count},
inTrailer: {checkInTrailer, 1},
end: {checkEnd, 1},
})
}
func TestTags(t *testing.T) {
b := []byte{5, 2, 4, 3, 1}
ctx := stats.SetTags(context.Background(), b)

160
stream.go
View File

@ -114,34 +114,28 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
}
func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (_ ClientStream, err error) {
var (
t transport.ClientTransport
s *transport.Stream
done func(balancer.DoneInfo)
cancel context.CancelFunc
)
c := defaultCallInfo()
mc := cc.GetMethodConfig(method)
if mc.WaitForReady != nil {
c.failFast = !*mc.WaitForReady
}
if mc.Timeout != nil && *mc.Timeout >= 0 {
// The cancel function for this context will only be called when RecvMsg
// returns non-nil error, which means the stream finishes with error or
// io.EOF. https://github.com/grpc/grpc-go/issues/1818.
//
// Possible context leak:
// - If user provided context is Background, and the user doesn't call
// RecvMsg() for the final status, this ctx will be leaked after the
// stream is done, until the service config timeout happens.
// The cancel function for the child context we create will only be called
// when RecvMsg returns a non-nil error, if the ClientConn is closed, or if
// an error is generated by SendMsg.
// https://github.com/grpc/grpc-go/issues/1818.
var cancel context.CancelFunc
if mc.Timeout != nil && *mc.Timeout >= 0 {
ctx, cancel = context.WithTimeout(ctx, *mc.Timeout)
} else {
ctx, cancel = context.WithCancel(ctx)
}
defer func() {
if err != nil {
cancel()
}
}()
}
opts = append(cc.dopts.callOptions, opts...)
for _, o := range opts {
@ -228,6 +222,11 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
}()
}
var (
t transport.ClientTransport
s *transport.Stream
done func(balancer.DoneInfo)
)
for {
// Check to make sure the context has expired. This will prevent us from
// looping forever if an error occurs for wait-for-ready RPCs where no data
@ -283,29 +282,17 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
statsHandler: cc.dopts.copts.StatsHandler,
}
if desc != unaryStreamDesc {
// Listen on s.Context().Done() to detect cancellation and s.Done() to
// detect normal termination when there is no pending I/O operations on
// this stream. Not necessary for "unary" streams, since we are guaranteed
// to always be in stream functions.
// Listen on cc and stream contexts to cleanup when the user closes the
// ClientConn or cancels the stream context. In all other cases, an error
// should already be injected into the recv buffer by the transport, which
// the client will eventually receive, and then we will cancel the stream's
// context in clientStream.finish.
go func() {
select {
case <-t.Error():
// Incur transport error, simply exit.
case <-cc.ctx.Done():
cs.finish(ErrClientConnClosing)
cs.closeTransportStream(ErrClientConnClosing)
case <-s.Done():
// TODO: The trace of the RPC is terminated here when there is no pending
// I/O, which is probably not the optimal solution.
cs.finish(s.Status().Err())
cs.closeTransportStream(nil)
case <-s.GoAway():
cs.finish(errConnDrain)
cs.closeTransportStream(errConnDrain)
case <-s.Context().Done():
err := s.Context().Err()
cs.finish(err)
cs.closeTransportStream(transport.ContextErr(err))
case <-ctx.Done():
cs.finish(toRPCErr(s.Context().Err()))
}
}()
}
@ -337,7 +324,6 @@ type clientStream struct {
mu sync.Mutex
done func(balancer.DoneInfo)
sentLast bool // sent an end stream
closed bool
finished bool
// trInfo.tr is set when the clientStream is created (if EnableTracing is true),
// and is set to nil when the clientStream's finish method is called.
@ -357,9 +343,8 @@ func (cs *clientStream) Context() context.Context {
func (cs *clientStream) Header() (metadata.MD, error) {
m, err := cs.s.Header()
if err != nil {
if _, ok := err.(transport.ConnectionError); !ok {
cs.closeTransportStream(err)
}
err = toRPCErr(err)
cs.finish(err)
}
return m, err
}
@ -380,23 +365,18 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {
// TODO Investigate how to signal the stats handling party.
// generate error stats if err != nil && err != io.EOF?
defer func() {
if err == nil {
return
}
cs.finish(err)
if err == io.EOF {
// SendMsg is only called once for non-client-streams. io.EOF needs to be
// skipped when the rpc is early finished (before the stream object is
// created.).
if !cs.desc.ClientStreams {
// For non-client-streaming RPCs, we return nil instead of EOF on success
// because the generated code requires it. finish is not called; RecvMsg()
// will call it with the stream's status independently.
if err == io.EOF && !cs.desc.ClientStreams {
err = nil
}
return
if err != nil && err != io.EOF {
// Call finish for errors generated by this SendMsg call. (Transport
// errors are converted to an io.EOF error below; the real error will be
// returned from RecvMsg eventually in that case.)
cs.finish(err)
}
if _, ok := err.(transport.ConnectionError); !ok {
cs.closeTransportStream(err)
}
err = toRPCErr(err)
}()
var outPayload *stats.OutPayload
if cs.statsHandler != nil {
@ -408,9 +388,6 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {
if err != nil {
return err
}
if cs.c.maxSendMessageSize == nil {
return status.Errorf(codes.Internal, "callInfo maxSendMessageSize field uninitialized(nil)")
}
if len(data) > *cs.c.maxSendMessageSize {
return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(data), *cs.c.maxSendMessageSize)
}
@ -418,21 +395,21 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {
cs.sentLast = true
}
err = cs.t.Write(cs.s, hdr, data, &transport.Options{Last: !cs.desc.ClientStreams})
if err == nil && outPayload != nil {
if err == nil {
if outPayload != nil {
outPayload.SentTime = time.Now()
cs.statsHandler.HandleRPC(cs.statsCtx, outPayload)
}
return err
return nil
}
return io.EOF
}
func (cs *clientStream) RecvMsg(m interface{}) (err error) {
defer func() {
// err != nil indicates the termination of the stream.
if err != nil {
if err != nil || !cs.desc.ServerStreams {
// err != nil or non-server-streaming indicates end of stream.
cs.finish(err)
if cs.cancel != nil {
cs.cancel()
}
}
}()
var inPayload *stats.InPayload
@ -441,9 +418,6 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) {
Client: true,
}
}
if cs.c.maxReceiveMessageSize == nil {
return status.Errorf(codes.Internal, "callInfo maxReceiveMessageSize field uninitialized(nil)")
}
if !cs.decompSet {
// Block until we receive headers containing received message encoding.
if ct := cs.s.RecvCompress(); ct != "" && ct != encoding.Identity {
@ -462,14 +436,11 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) {
}
err = recv(cs.p, cs.codec, cs.s, cs.dc, m, *cs.c.maxReceiveMessageSize, inPayload, cs.decomp)
if err != nil {
if _, ok := err.(transport.ConnectionError); !ok {
cs.closeTransportStream(err)
}
if err == io.EOF {
if statusErr := cs.s.Status().Err(); statusErr != nil {
return statusErr
}
return io.EOF // indicates end of stream.
return io.EOF // indicates successful end of stream.
}
return toRPCErr(err)
}
@ -484,22 +455,18 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) {
cs.statsHandler.HandleRPC(cs.statsCtx, inPayload)
}
if cs.desc.ServerStreams {
// Subsequent messages should be received by subsequent RecvMsg calls.
return nil
}
// Special handling for non-server-stream rpcs.
// This recv expects EOF or errors, so we don't collect inPayload.
err = recv(cs.p, cs.codec, cs.s, cs.dc, m, *cs.c.maxReceiveMessageSize, nil, cs.decomp)
cs.closeTransportStream(err)
if err == nil {
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
}
if err == io.EOF {
if se := cs.s.Status().Err(); se != nil {
return se
}
cs.finish(err)
return nil
return cs.s.Status().Err() // non-server streaming Recv returns nil on success
}
return toRPCErr(err)
}
@ -509,41 +476,26 @@ func (cs *clientStream) CloseSend() error {
return nil
}
cs.sentLast = true
err := cs.t.Write(cs.s, nil, nil, &transport.Options{Last: true})
if err == nil || err == io.EOF {
cs.t.Write(cs.s, nil, nil, &transport.Options{Last: true})
// We ignore errors from Write and always return nil here. Any error it
// would return would also be returned by a subsequent RecvMsg call, and the
// user is supposed to always finish the stream by calling RecvMsg until it
// returns err != nil.
return nil
}
if _, ok := err.(transport.ConnectionError); !ok {
cs.closeTransportStream(err)
}
err = toRPCErr(err)
if err != nil {
cs.finish(err)
}
return err
}
func (cs *clientStream) closeTransportStream(err error) {
cs.mu.Lock()
if cs.closed {
cs.mu.Unlock()
return
}
cs.closed = true
cs.mu.Unlock()
cs.t.CloseStream(cs.s, err)
}
func (cs *clientStream) finish(err error) {
// Do not call cs.cancel in this function. Only call it when RecvMag()
// returns non-nil error because of
// https://github.com/grpc/grpc-go/issues/1818.
if err == io.EOF {
// Ending a stream with EOF indicates a success.
err = nil
}
cs.mu.Lock()
defer cs.mu.Unlock()
if cs.finished {
return
}
cs.finished = true
cs.t.CloseStream(cs.s, err)
for _, o := range cs.opts {
o.after(cs.c)
}
@ -559,18 +511,16 @@ func (cs *clientStream) finish(err error) {
end := &stats.End{
Client: true,
EndTime: time.Now(),
}
if err != io.EOF {
// end.Error is nil if the RPC finished successfully.
end.Error = toRPCErr(err)
Error: err,
}
cs.statsHandler.HandleRPC(cs.statsCtx, end)
}
cs.cancel()
if !cs.tracing {
return
}
if cs.trInfo.tr != nil {
if err == nil || err == io.EOF {
if err == nil {
cs.trInfo.tr.LazyPrintf("RPC: [OK]")
} else {
cs.trInfo.tr.LazyPrintf("RPC: [%v]", err)

View File

@ -1164,10 +1164,22 @@ func testConcurrentServerStopAndGoAway(t *testing.T, e env) {
ResponseParameters: respParam,
Payload: payload,
}
if err := stream.Send(req); err == nil {
if _, err := stream.Recv(); err == nil {
t.Fatalf("%v.Recv() = _, %v, want _, <nil>", stream, err)
sendStart := time.Now()
for {
if err := stream.Send(req); err == io.EOF {
// stream.Send should eventually send io.EOF
break
} else if err != nil {
// Send should never return a transport-level error.
t.Fatalf("stream.Send(%v) = %v; want <nil or io.EOF>", req, err)
}
if time.Since(sendStart) > 2*time.Second {
t.Fatalf("stream.Send(_) did not return io.EOF after 2s")
}
time.Sleep(time.Millisecond)
}
if _, err := stream.Recv(); err == nil || err == io.EOF {
t.Fatalf("%v.Recv() = _, %v, want _, <non-nil, non-EOF>", stream, err)
}
<-ch
awaitNewConnLogOutput()
@ -1190,7 +1202,9 @@ func testClientConnCloseAfterGoAwayWithActiveStream(t *testing.T, e env) {
cc := te.clientConn()
tc := testpb.NewTestServiceClient(cc)
if _, err := tc.FullDuplexCall(context.Background()); err != nil {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
if _, err := tc.FullDuplexCall(ctx); err != nil {
t.Fatalf("%v.FullDuplexCall(_) = _, %v, want _, <nil>", tc, err)
}
done := make(chan struct{})
@ -2334,24 +2348,6 @@ func testHealthCheckServingStatus(t *testing.T, e env) {
}
func TestErrorChanNoIO(t *testing.T) {
defer leakcheck.Check(t)
for _, e := range listTestEnv() {
testErrorChanNoIO(t, e)
}
}
func testErrorChanNoIO(t *testing.T, e env) {
te := newTest(t, e)
te.startServer(&testServer{security: e.security})
defer te.tearDown()
tc := testpb.NewTestServiceClient(te.clientConn())
if _, err := tc.FullDuplexCall(context.Background()); err != nil {
t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
}
}
func TestEmptyUnaryWithUserAgent(t *testing.T) {
defer leakcheck.Check(t)
for _, e := range listTestEnv() {
@ -3811,22 +3807,24 @@ func testStreamsQuotaRecovery(t *testing.T, e env) {
cc := te.clientConn()
tc := testpb.NewTestServiceClient(cc)
if _, err := tc.StreamingInputCall(context.Background()); err != nil {
t.Fatalf("%v.StreamingInputCall(_) = _, %v, want _, <nil>", tc, err)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
if _, err := tc.StreamingInputCall(ctx); err != nil {
t.Fatalf("tc.StreamingInputCall(_) = _, %v, want _, <nil>", err)
}
// Loop until the new max stream setting is effective.
for {
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()
_, err := tc.StreamingInputCall(ctx)
cancel()
if err == nil {
time.Sleep(50 * time.Millisecond)
time.Sleep(5 * time.Millisecond)
continue
}
if status.Code(err) == codes.DeadlineExceeded {
break
}
t.Fatalf("%v.StreamingInputCall(_) = _, %v, want _, %s", tc, err, codes.DeadlineExceeded)
t.Fatalf("tc.StreamingInputCall(_) = _, %v, want _, %s", err, codes.DeadlineExceeded)
}
var wg sync.WaitGroup
@ -3848,11 +3846,19 @@ func testStreamsQuotaRecovery(t *testing.T, e env) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer cancel()
if _, err := tc.UnaryCall(ctx, req, grpc.FailFast(false)); status.Code(err) != codes.DeadlineExceeded {
t.Errorf("TestService/UnaryCall(_, _) = _, %v, want _, %s", err, codes.DeadlineExceeded)
t.Errorf("tc.UnaryCall(_, _) = _, %v, want _, %s", err, codes.DeadlineExceeded)
}
}()
}
wg.Wait()
cancel()
// A new stream should be allowed after canceling the first one.
ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if _, err := tc.StreamingInputCall(ctx); err != nil {
t.Fatalf("tc.StreamingInputCall(_) = _, %v, want _, %v", err, nil)
}
}
func TestCompressServerHasNoSupport(t *testing.T) {