diff --git a/clientconn.go b/clientconn.go index 19ee8e01..2c22d628 100644 --- a/clientconn.go +++ b/clientconn.go @@ -32,6 +32,7 @@ import ( "golang.org/x/net/trace" "google.golang.org/grpc/balancer" _ "google.golang.org/grpc/balancer/roundrobin" // To register roundrobin. + "google.golang.org/grpc/codes" "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials" "google.golang.org/grpc/grpclog" @@ -40,17 +41,17 @@ import ( _ "google.golang.org/grpc/resolver/dns" // To register dns resolver. _ "google.golang.org/grpc/resolver/passthrough" // To register passthrough resolver. "google.golang.org/grpc/stats" + "google.golang.org/grpc/status" "google.golang.org/grpc/transport" ) var ( // ErrClientConnClosing indicates that the operation is illegal because // the ClientConn is closing. - ErrClientConnClosing = errors.New("grpc: the client connection is closing") - // ErrClientConnTimeout indicates that the ClientConn cannot establish the - // underlying connections within the specified timeout. - // DEPRECATED: Please use context.DeadlineExceeded instead. - ErrClientConnTimeout = errors.New("grpc: timed out when dialing") + // + // Deprecated: this error should not be relied upon by users; use the status + // code of Canceled instead. + ErrClientConnClosing = status.Error(codes.Canceled, "grpc: the client connection is closing") // errConnDrain indicates that the connection starts to be drained and does not accept any new RPCs. errConnDrain = errors.New("grpc: the connection is drained") // errConnClosing indicates that the connection is closing. @@ -1374,3 +1375,10 @@ func (ac *addrConn) getState() connectivity.State { defer ac.mu.Unlock() return ac.state } + +// ErrClientConnTimeout indicates that the ClientConn cannot establish the +// underlying connections within the specified timeout. +// +// Deprecated: This error is never returned by grpc and should not be +// referenced by users. +var ErrClientConnTimeout = errors.New("grpc: timed out when dialing") diff --git a/go16.go b/go16.go index f3dbf217..0ae4dbda 100644 --- a/go16.go +++ b/go16.go @@ -48,6 +48,9 @@ func sendHTTPRequest(ctx context.Context, req *http.Request, conn net.Conn) erro // toRPCErr converts an error into an error from the status package. func toRPCErr(err error) error { + if err == nil || err == io.EOF { + return err + } if _, ok := status.FromError(err); ok { return err } @@ -62,8 +65,6 @@ func toRPCErr(err error) error { return status.Error(codes.DeadlineExceeded, err.Error()) case context.Canceled: return status.Error(codes.Canceled, err.Error()) - case ErrClientConnClosing: - return status.Error(codes.FailedPrecondition, err.Error()) } } return status.Error(codes.Unknown, err.Error()) diff --git a/go17.go b/go17.go index de23098e..53908828 100644 --- a/go17.go +++ b/go17.go @@ -49,6 +49,9 @@ func sendHTTPRequest(ctx context.Context, req *http.Request, conn net.Conn) erro // toRPCErr converts an error into an error from the status package. func toRPCErr(err error) error { + if err == nil || err == io.EOF { + return err + } if _, ok := status.FromError(err); ok { return err } @@ -63,8 +66,6 @@ func toRPCErr(err error) error { return status.Error(codes.DeadlineExceeded, err.Error()) case context.Canceled, netctx.Canceled: return status.Error(codes.Canceled, err.Error()) - case ErrClientConnClosing: - return status.Error(codes.FailedPrecondition, err.Error()) } } return status.Error(codes.Unknown, err.Error()) diff --git a/stats/stats_test.go b/stats/stats_test.go index fef0d7c6..b6c7b998 100644 --- a/stats/stats_test.go +++ b/stats/stats_test.go @@ -256,11 +256,10 @@ const ( ) type rpcConfig struct { - count int // Number of requests and responses for streaming RPCs. - success bool // Whether the RPC should succeed or return error. - failfast bool - callType rpcType // Type of RPC. - noLastRecv bool // Whether to call recv for io.EOF. When true, last recv won't be called. Only valid for streaming RPCs. + count int // Number of requests and responses for streaming RPCs. + success bool // Whether the RPC should succeed or return error. + failfast bool + callType rpcType // Type of RPC. } func (te *test) doUnaryCall(c *rpcConfig) (*testpb.SimpleRequest, *testpb.SimpleResponse, error) { @@ -313,14 +312,8 @@ func (te *test) doFullDuplexCallRoundtrip(c *rpcConfig) ([]*testpb.SimpleRequest if err = stream.CloseSend(); err != nil && err != io.EOF { return reqs, resps, err } - if !c.noLastRecv { - if _, err = stream.Recv(); err != io.EOF { - return reqs, resps, err - } - } else { - // In the case of not calling the last recv, sleep to avoid - // returning too fast to miss the remaining stats (InTrailer and End). - time.Sleep(time.Second) + if _, err = stream.Recv(); err != io.EOF { + return reqs, resps, err } return reqs, resps, nil @@ -651,7 +644,7 @@ func checkEnd(t *testing.T, d *gotData, e *expectedData) { actual, ok := status.FromError(st.Error) if !ok { - t.Fatalf("expected st.Error to be a statusError, got %T", st.Error) + t.Fatalf("expected st.Error to be a statusError, got %v (type %T)", st.Error, st.Error) } expectedStatus, _ := status.FromError(e.err) @@ -1222,20 +1215,6 @@ func TestClientStatsFullDuplexRPCError(t *testing.T) { }) } -// If the user doesn't call the last recv() on clientStream. -func TestClientStatsFullDuplexRPCNotCallingLastRecv(t *testing.T) { - count := 1 - testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, failfast: false, callType: fullDuplexStreamRPC, noLastRecv: true}, map[int]*checkFuncWithCount{ - begin: {checkBegin, 1}, - outHeader: {checkOutHeader, 1}, - outPayload: {checkOutPayload, count}, - inHeader: {checkInHeader, 1}, - inPayload: {checkInPayload, count}, - inTrailer: {checkInTrailer, 1}, - end: {checkEnd, 1}, - }) -} - func TestTags(t *testing.T) { b := []byte{5, 2, 4, 3, 1} ctx := stats.SetTags(context.Background(), b) diff --git a/stream.go b/stream.go index 50ae3ea1..deb73592 100644 --- a/stream.go +++ b/stream.go @@ -114,34 +114,28 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth } func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (_ ClientStream, err error) { - var ( - t transport.ClientTransport - s *transport.Stream - done func(balancer.DoneInfo) - cancel context.CancelFunc - ) c := defaultCallInfo() mc := cc.GetMethodConfig(method) if mc.WaitForReady != nil { c.failFast = !*mc.WaitForReady } + // Possible context leak: + // The cancel function for the child context we create will only be called + // when RecvMsg returns a non-nil error, if the ClientConn is closed, or if + // an error is generated by SendMsg. + // https://github.com/grpc/grpc-go/issues/1818. + var cancel context.CancelFunc if mc.Timeout != nil && *mc.Timeout >= 0 { - // The cancel function for this context will only be called when RecvMsg - // returns non-nil error, which means the stream finishes with error or - // io.EOF. https://github.com/grpc/grpc-go/issues/1818. - // - // Possible context leak: - // - If user provided context is Background, and the user doesn't call - // RecvMsg() for the final status, this ctx will be leaked after the - // stream is done, until the service config timeout happens. ctx, cancel = context.WithTimeout(ctx, *mc.Timeout) - defer func() { - if err != nil { - cancel() - } - }() + } else { + ctx, cancel = context.WithCancel(ctx) } + defer func() { + if err != nil { + cancel() + } + }() opts = append(cc.dopts.callOptions, opts...) for _, o := range opts { @@ -228,6 +222,11 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth }() } + var ( + t transport.ClientTransport + s *transport.Stream + done func(balancer.DoneInfo) + ) for { // Check to make sure the context has expired. This will prevent us from // looping forever if an error occurs for wait-for-ready RPCs where no data @@ -283,29 +282,17 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth statsHandler: cc.dopts.copts.StatsHandler, } if desc != unaryStreamDesc { - // Listen on s.Context().Done() to detect cancellation and s.Done() to - // detect normal termination when there is no pending I/O operations on - // this stream. Not necessary for "unary" streams, since we are guaranteed - // to always be in stream functions. + // Listen on cc and stream contexts to cleanup when the user closes the + // ClientConn or cancels the stream context. In all other cases, an error + // should already be injected into the recv buffer by the transport, which + // the client will eventually receive, and then we will cancel the stream's + // context in clientStream.finish. go func() { select { - case <-t.Error(): - // Incur transport error, simply exit. case <-cc.ctx.Done(): cs.finish(ErrClientConnClosing) - cs.closeTransportStream(ErrClientConnClosing) - case <-s.Done(): - // TODO: The trace of the RPC is terminated here when there is no pending - // I/O, which is probably not the optimal solution. - cs.finish(s.Status().Err()) - cs.closeTransportStream(nil) - case <-s.GoAway(): - cs.finish(errConnDrain) - cs.closeTransportStream(errConnDrain) - case <-s.Context().Done(): - err := s.Context().Err() - cs.finish(err) - cs.closeTransportStream(transport.ContextErr(err)) + case <-ctx.Done(): + cs.finish(toRPCErr(s.Context().Err())) } }() } @@ -337,7 +324,6 @@ type clientStream struct { mu sync.Mutex done func(balancer.DoneInfo) sentLast bool // sent an end stream - closed bool finished bool // trInfo.tr is set when the clientStream is created (if EnableTracing is true), // and is set to nil when the clientStream's finish method is called. @@ -357,9 +343,8 @@ func (cs *clientStream) Context() context.Context { func (cs *clientStream) Header() (metadata.MD, error) { m, err := cs.s.Header() if err != nil { - if _, ok := err.(transport.ConnectionError); !ok { - cs.closeTransportStream(err) - } + err = toRPCErr(err) + cs.finish(err) } return m, err } @@ -380,23 +365,18 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) { // TODO Investigate how to signal the stats handling party. // generate error stats if err != nil && err != io.EOF? defer func() { - if err == nil { - return + // For non-client-streaming RPCs, we return nil instead of EOF on success + // because the generated code requires it. finish is not called; RecvMsg() + // will call it with the stream's status independently. + if err == io.EOF && !cs.desc.ClientStreams { + err = nil } - cs.finish(err) - if err == io.EOF { - // SendMsg is only called once for non-client-streams. io.EOF needs to be - // skipped when the rpc is early finished (before the stream object is - // created.). - if !cs.desc.ClientStreams { - err = nil - } - return + if err != nil && err != io.EOF { + // Call finish for errors generated by this SendMsg call. (Transport + // errors are converted to an io.EOF error below; the real error will be + // returned from RecvMsg eventually in that case.) + cs.finish(err) } - if _, ok := err.(transport.ConnectionError); !ok { - cs.closeTransportStream(err) - } - err = toRPCErr(err) }() var outPayload *stats.OutPayload if cs.statsHandler != nil { @@ -408,9 +388,6 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) { if err != nil { return err } - if cs.c.maxSendMessageSize == nil { - return status.Errorf(codes.Internal, "callInfo maxSendMessageSize field uninitialized(nil)") - } if len(data) > *cs.c.maxSendMessageSize { return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(data), *cs.c.maxSendMessageSize) } @@ -418,21 +395,21 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) { cs.sentLast = true } err = cs.t.Write(cs.s, hdr, data, &transport.Options{Last: !cs.desc.ClientStreams}) - if err == nil && outPayload != nil { - outPayload.SentTime = time.Now() - cs.statsHandler.HandleRPC(cs.statsCtx, outPayload) + if err == nil { + if outPayload != nil { + outPayload.SentTime = time.Now() + cs.statsHandler.HandleRPC(cs.statsCtx, outPayload) + } + return nil } - return err + return io.EOF } func (cs *clientStream) RecvMsg(m interface{}) (err error) { defer func() { - // err != nil indicates the termination of the stream. - if err != nil { + if err != nil || !cs.desc.ServerStreams { + // err != nil or non-server-streaming indicates end of stream. cs.finish(err) - if cs.cancel != nil { - cs.cancel() - } } }() var inPayload *stats.InPayload @@ -441,9 +418,6 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) { Client: true, } } - if cs.c.maxReceiveMessageSize == nil { - return status.Errorf(codes.Internal, "callInfo maxReceiveMessageSize field uninitialized(nil)") - } if !cs.decompSet { // Block until we receive headers containing received message encoding. if ct := cs.s.RecvCompress(); ct != "" && ct != encoding.Identity { @@ -462,14 +436,11 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) { } err = recv(cs.p, cs.codec, cs.s, cs.dc, m, *cs.c.maxReceiveMessageSize, inPayload, cs.decomp) if err != nil { - if _, ok := err.(transport.ConnectionError); !ok { - cs.closeTransportStream(err) - } if err == io.EOF { if statusErr := cs.s.Status().Err(); statusErr != nil { return statusErr } - return io.EOF // indicates end of stream. + return io.EOF // indicates successful end of stream. } return toRPCErr(err) } @@ -484,22 +455,18 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) { cs.statsHandler.HandleRPC(cs.statsCtx, inPayload) } if cs.desc.ServerStreams { + // Subsequent messages should be received by subsequent RecvMsg calls. return nil } // Special handling for non-server-stream rpcs. // This recv expects EOF or errors, so we don't collect inPayload. err = recv(cs.p, cs.codec, cs.s, cs.dc, m, *cs.c.maxReceiveMessageSize, nil, cs.decomp) - cs.closeTransportStream(err) if err == nil { return toRPCErr(errors.New("grpc: client streaming protocol violation: get , want ")) } if err == io.EOF { - if se := cs.s.Status().Err(); se != nil { - return se - } - cs.finish(err) - return nil + return cs.s.Status().Err() // non-server streaming Recv returns nil on success } return toRPCErr(err) } @@ -509,41 +476,26 @@ func (cs *clientStream) CloseSend() error { return nil } cs.sentLast = true - err := cs.t.Write(cs.s, nil, nil, &transport.Options{Last: true}) - if err == nil || err == io.EOF { - return nil - } - if _, ok := err.(transport.ConnectionError); !ok { - cs.closeTransportStream(err) - } - err = toRPCErr(err) - if err != nil { - cs.finish(err) - } - return err -} - -func (cs *clientStream) closeTransportStream(err error) { - cs.mu.Lock() - if cs.closed { - cs.mu.Unlock() - return - } - cs.closed = true - cs.mu.Unlock() - cs.t.CloseStream(cs.s, err) + cs.t.Write(cs.s, nil, nil, &transport.Options{Last: true}) + // We ignore errors from Write and always return nil here. Any error it + // would return would also be returned by a subsequent RecvMsg call, and the + // user is supposed to always finish the stream by calling RecvMsg until it + // returns err != nil. + return nil } func (cs *clientStream) finish(err error) { - // Do not call cs.cancel in this function. Only call it when RecvMag() - // returns non-nil error because of - // https://github.com/grpc/grpc-go/issues/1818. + if err == io.EOF { + // Ending a stream with EOF indicates a success. + err = nil + } cs.mu.Lock() defer cs.mu.Unlock() if cs.finished { return } cs.finished = true + cs.t.CloseStream(cs.s, err) for _, o := range cs.opts { o.after(cs.c) } @@ -559,18 +511,16 @@ func (cs *clientStream) finish(err error) { end := &stats.End{ Client: true, EndTime: time.Now(), - } - if err != io.EOF { - // end.Error is nil if the RPC finished successfully. - end.Error = toRPCErr(err) + Error: err, } cs.statsHandler.HandleRPC(cs.statsCtx, end) } + cs.cancel() if !cs.tracing { return } if cs.trInfo.tr != nil { - if err == nil || err == io.EOF { + if err == nil { cs.trInfo.tr.LazyPrintf("RPC: [OK]") } else { cs.trInfo.tr.LazyPrintf("RPC: [%v]", err) diff --git a/test/end2end_test.go b/test/end2end_test.go index 772f6e60..93620695 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -1164,10 +1164,22 @@ func testConcurrentServerStopAndGoAway(t *testing.T, e env) { ResponseParameters: respParam, Payload: payload, } - if err := stream.Send(req); err == nil { - if _, err := stream.Recv(); err == nil { - t.Fatalf("%v.Recv() = _, %v, want _, ", stream, err) + sendStart := time.Now() + for { + if err := stream.Send(req); err == io.EOF { + // stream.Send should eventually send io.EOF + break + } else if err != nil { + // Send should never return a transport-level error. + t.Fatalf("stream.Send(%v) = %v; want ", req, err) } + if time.Since(sendStart) > 2*time.Second { + t.Fatalf("stream.Send(_) did not return io.EOF after 2s") + } + time.Sleep(time.Millisecond) + } + if _, err := stream.Recv(); err == nil || err == io.EOF { + t.Fatalf("%v.Recv() = _, %v, want _, ", stream, err) } <-ch awaitNewConnLogOutput() @@ -1190,7 +1202,9 @@ func testClientConnCloseAfterGoAwayWithActiveStream(t *testing.T, e env) { cc := te.clientConn() tc := testpb.NewTestServiceClient(cc) - if _, err := tc.FullDuplexCall(context.Background()); err != nil { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + if _, err := tc.FullDuplexCall(ctx); err != nil { t.Fatalf("%v.FullDuplexCall(_) = _, %v, want _, ", tc, err) } done := make(chan struct{}) @@ -2334,24 +2348,6 @@ func testHealthCheckServingStatus(t *testing.T, e env) { } -func TestErrorChanNoIO(t *testing.T) { - defer leakcheck.Check(t) - for _, e := range listTestEnv() { - testErrorChanNoIO(t, e) - } -} - -func testErrorChanNoIO(t *testing.T, e env) { - te := newTest(t, e) - te.startServer(&testServer{security: e.security}) - defer te.tearDown() - - tc := testpb.NewTestServiceClient(te.clientConn()) - if _, err := tc.FullDuplexCall(context.Background()); err != nil { - t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) - } -} - func TestEmptyUnaryWithUserAgent(t *testing.T) { defer leakcheck.Check(t) for _, e := range listTestEnv() { @@ -3811,22 +3807,24 @@ func testStreamsQuotaRecovery(t *testing.T, e env) { cc := te.clientConn() tc := testpb.NewTestServiceClient(cc) - if _, err := tc.StreamingInputCall(context.Background()); err != nil { - t.Fatalf("%v.StreamingInputCall(_) = _, %v, want _, ", tc, err) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + if _, err := tc.StreamingInputCall(ctx); err != nil { + t.Fatalf("tc.StreamingInputCall(_) = _, %v, want _, ", err) } // Loop until the new max stream setting is effective. for { ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) - defer cancel() _, err := tc.StreamingInputCall(ctx) + cancel() if err == nil { - time.Sleep(50 * time.Millisecond) + time.Sleep(5 * time.Millisecond) continue } if status.Code(err) == codes.DeadlineExceeded { break } - t.Fatalf("%v.StreamingInputCall(_) = _, %v, want _, %s", tc, err, codes.DeadlineExceeded) + t.Fatalf("tc.StreamingInputCall(_) = _, %v, want _, %s", err, codes.DeadlineExceeded) } var wg sync.WaitGroup @@ -3848,11 +3846,19 @@ func testStreamsQuotaRecovery(t *testing.T, e env) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) defer cancel() if _, err := tc.UnaryCall(ctx, req, grpc.FailFast(false)); status.Code(err) != codes.DeadlineExceeded { - t.Errorf("TestService/UnaryCall(_, _) = _, %v, want _, %s", err, codes.DeadlineExceeded) + t.Errorf("tc.UnaryCall(_, _) = _, %v, want _, %s", err, codes.DeadlineExceeded) } }() } wg.Wait() + + cancel() + // A new stream should be allowed after canceling the first one. + ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if _, err := tc.StreamingInputCall(ctx); err != nil { + t.Fatalf("tc.StreamingInputCall(_) = _, %v, want _, %v", err, nil) + } } func TestCompressServerHasNoSupport(t *testing.T) {