streams: Stop cleaning up after orphaned streams (#1854)

This change introduces some behavior changes that should not impact users that
are following the proper stream protocol. Specifically, one of the following
conditions must be satisfied:

1. The user calls Close on the ClientConn.
2. The user cancels the context provided to NewClientStream, or its deadline
    expires. (Note that it if the context is no longer needed before the deadline
    expires, it is still recommended to call cancel to prevent bloat.) It is always
    recommended to cancel contexts when they are no longer needed, and to
    never use the background context directly, so all users should always be
    doing this.
3. The user calls RecvMsg (or Recv in generated code) until a non-nil error is
    returned.
4. The user receives any error from Header or SendMsg (or Send in generated
    code) besides io.EOF.  If none of the above happen, this will leak a goroutine
    and a context, and grpc will not call the optionally-configured stats handler
    with a stats.End message.

Before this change, if a user created a stream and the server ended the stream,
the stats handler would be invoked with a stats.End containing the final status
of the stream. Subsequent calls to RecvMsg would then trigger the stats handler
with InPayloads, which may be unexpected by stats handlers.
This commit is contained in:
dfawley
2018-02-08 10:51:16 -08:00
committed by GitHub
parent 7646b5360d
commit 365770fcbd
6 changed files with 123 additions and 178 deletions

View File

@ -32,6 +32,7 @@ import (
"golang.org/x/net/trace" "golang.org/x/net/trace"
"google.golang.org/grpc/balancer" "google.golang.org/grpc/balancer"
_ "google.golang.org/grpc/balancer/roundrobin" // To register roundrobin. _ "google.golang.org/grpc/balancer/roundrobin" // To register roundrobin.
"google.golang.org/grpc/codes"
"google.golang.org/grpc/connectivity" "google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/grpclog" "google.golang.org/grpc/grpclog"
@ -40,17 +41,17 @@ import (
_ "google.golang.org/grpc/resolver/dns" // To register dns resolver. _ "google.golang.org/grpc/resolver/dns" // To register dns resolver.
_ "google.golang.org/grpc/resolver/passthrough" // To register passthrough resolver. _ "google.golang.org/grpc/resolver/passthrough" // To register passthrough resolver.
"google.golang.org/grpc/stats" "google.golang.org/grpc/stats"
"google.golang.org/grpc/status"
"google.golang.org/grpc/transport" "google.golang.org/grpc/transport"
) )
var ( var (
// ErrClientConnClosing indicates that the operation is illegal because // ErrClientConnClosing indicates that the operation is illegal because
// the ClientConn is closing. // the ClientConn is closing.
ErrClientConnClosing = errors.New("grpc: the client connection is closing") //
// ErrClientConnTimeout indicates that the ClientConn cannot establish the // Deprecated: this error should not be relied upon by users; use the status
// underlying connections within the specified timeout. // code of Canceled instead.
// DEPRECATED: Please use context.DeadlineExceeded instead. ErrClientConnClosing = status.Error(codes.Canceled, "grpc: the client connection is closing")
ErrClientConnTimeout = errors.New("grpc: timed out when dialing")
// errConnDrain indicates that the connection starts to be drained and does not accept any new RPCs. // errConnDrain indicates that the connection starts to be drained and does not accept any new RPCs.
errConnDrain = errors.New("grpc: the connection is drained") errConnDrain = errors.New("grpc: the connection is drained")
// errConnClosing indicates that the connection is closing. // errConnClosing indicates that the connection is closing.
@ -1374,3 +1375,10 @@ func (ac *addrConn) getState() connectivity.State {
defer ac.mu.Unlock() defer ac.mu.Unlock()
return ac.state return ac.state
} }
// ErrClientConnTimeout indicates that the ClientConn cannot establish the
// underlying connections within the specified timeout.
//
// Deprecated: This error is never returned by grpc and should not be
// referenced by users.
var ErrClientConnTimeout = errors.New("grpc: timed out when dialing")

View File

@ -48,6 +48,9 @@ func sendHTTPRequest(ctx context.Context, req *http.Request, conn net.Conn) erro
// toRPCErr converts an error into an error from the status package. // toRPCErr converts an error into an error from the status package.
func toRPCErr(err error) error { func toRPCErr(err error) error {
if err == nil || err == io.EOF {
return err
}
if _, ok := status.FromError(err); ok { if _, ok := status.FromError(err); ok {
return err return err
} }
@ -62,8 +65,6 @@ func toRPCErr(err error) error {
return status.Error(codes.DeadlineExceeded, err.Error()) return status.Error(codes.DeadlineExceeded, err.Error())
case context.Canceled: case context.Canceled:
return status.Error(codes.Canceled, err.Error()) return status.Error(codes.Canceled, err.Error())
case ErrClientConnClosing:
return status.Error(codes.FailedPrecondition, err.Error())
} }
} }
return status.Error(codes.Unknown, err.Error()) return status.Error(codes.Unknown, err.Error())

View File

@ -49,6 +49,9 @@ func sendHTTPRequest(ctx context.Context, req *http.Request, conn net.Conn) erro
// toRPCErr converts an error into an error from the status package. // toRPCErr converts an error into an error from the status package.
func toRPCErr(err error) error { func toRPCErr(err error) error {
if err == nil || err == io.EOF {
return err
}
if _, ok := status.FromError(err); ok { if _, ok := status.FromError(err); ok {
return err return err
} }
@ -63,8 +66,6 @@ func toRPCErr(err error) error {
return status.Error(codes.DeadlineExceeded, err.Error()) return status.Error(codes.DeadlineExceeded, err.Error())
case context.Canceled, netctx.Canceled: case context.Canceled, netctx.Canceled:
return status.Error(codes.Canceled, err.Error()) return status.Error(codes.Canceled, err.Error())
case ErrClientConnClosing:
return status.Error(codes.FailedPrecondition, err.Error())
} }
} }
return status.Error(codes.Unknown, err.Error()) return status.Error(codes.Unknown, err.Error())

View File

@ -260,7 +260,6 @@ type rpcConfig struct {
success bool // Whether the RPC should succeed or return error. success bool // Whether the RPC should succeed or return error.
failfast bool failfast bool
callType rpcType // Type of RPC. callType rpcType // Type of RPC.
noLastRecv bool // Whether to call recv for io.EOF. When true, last recv won't be called. Only valid for streaming RPCs.
} }
func (te *test) doUnaryCall(c *rpcConfig) (*testpb.SimpleRequest, *testpb.SimpleResponse, error) { func (te *test) doUnaryCall(c *rpcConfig) (*testpb.SimpleRequest, *testpb.SimpleResponse, error) {
@ -313,15 +312,9 @@ func (te *test) doFullDuplexCallRoundtrip(c *rpcConfig) ([]*testpb.SimpleRequest
if err = stream.CloseSend(); err != nil && err != io.EOF { if err = stream.CloseSend(); err != nil && err != io.EOF {
return reqs, resps, err return reqs, resps, err
} }
if !c.noLastRecv {
if _, err = stream.Recv(); err != io.EOF { if _, err = stream.Recv(); err != io.EOF {
return reqs, resps, err return reqs, resps, err
} }
} else {
// In the case of not calling the last recv, sleep to avoid
// returning too fast to miss the remaining stats (InTrailer and End).
time.Sleep(time.Second)
}
return reqs, resps, nil return reqs, resps, nil
} }
@ -651,7 +644,7 @@ func checkEnd(t *testing.T, d *gotData, e *expectedData) {
actual, ok := status.FromError(st.Error) actual, ok := status.FromError(st.Error)
if !ok { if !ok {
t.Fatalf("expected st.Error to be a statusError, got %T", st.Error) t.Fatalf("expected st.Error to be a statusError, got %v (type %T)", st.Error, st.Error)
} }
expectedStatus, _ := status.FromError(e.err) expectedStatus, _ := status.FromError(e.err)
@ -1222,20 +1215,6 @@ func TestClientStatsFullDuplexRPCError(t *testing.T) {
}) })
} }
// If the user doesn't call the last recv() on clientStream.
func TestClientStatsFullDuplexRPCNotCallingLastRecv(t *testing.T) {
count := 1
testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, failfast: false, callType: fullDuplexStreamRPC, noLastRecv: true}, map[int]*checkFuncWithCount{
begin: {checkBegin, 1},
outHeader: {checkOutHeader, 1},
outPayload: {checkOutPayload, count},
inHeader: {checkInHeader, 1},
inPayload: {checkInPayload, count},
inTrailer: {checkInTrailer, 1},
end: {checkEnd, 1},
})
}
func TestTags(t *testing.T) { func TestTags(t *testing.T) {
b := []byte{5, 2, 4, 3, 1} b := []byte{5, 2, 4, 3, 1}
ctx := stats.SetTags(context.Background(), b) ctx := stats.SetTags(context.Background(), b)

160
stream.go
View File

@ -114,34 +114,28 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
} }
func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (_ ClientStream, err error) { func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (_ ClientStream, err error) {
var (
t transport.ClientTransport
s *transport.Stream
done func(balancer.DoneInfo)
cancel context.CancelFunc
)
c := defaultCallInfo() c := defaultCallInfo()
mc := cc.GetMethodConfig(method) mc := cc.GetMethodConfig(method)
if mc.WaitForReady != nil { if mc.WaitForReady != nil {
c.failFast = !*mc.WaitForReady c.failFast = !*mc.WaitForReady
} }
if mc.Timeout != nil && *mc.Timeout >= 0 {
// The cancel function for this context will only be called when RecvMsg
// returns non-nil error, which means the stream finishes with error or
// io.EOF. https://github.com/grpc/grpc-go/issues/1818.
//
// Possible context leak: // Possible context leak:
// - If user provided context is Background, and the user doesn't call // The cancel function for the child context we create will only be called
// RecvMsg() for the final status, this ctx will be leaked after the // when RecvMsg returns a non-nil error, if the ClientConn is closed, or if
// stream is done, until the service config timeout happens. // an error is generated by SendMsg.
// https://github.com/grpc/grpc-go/issues/1818.
var cancel context.CancelFunc
if mc.Timeout != nil && *mc.Timeout >= 0 {
ctx, cancel = context.WithTimeout(ctx, *mc.Timeout) ctx, cancel = context.WithTimeout(ctx, *mc.Timeout)
} else {
ctx, cancel = context.WithCancel(ctx)
}
defer func() { defer func() {
if err != nil { if err != nil {
cancel() cancel()
} }
}() }()
}
opts = append(cc.dopts.callOptions, opts...) opts = append(cc.dopts.callOptions, opts...)
for _, o := range opts { for _, o := range opts {
@ -228,6 +222,11 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
}() }()
} }
var (
t transport.ClientTransport
s *transport.Stream
done func(balancer.DoneInfo)
)
for { for {
// Check to make sure the context has expired. This will prevent us from // Check to make sure the context has expired. This will prevent us from
// looping forever if an error occurs for wait-for-ready RPCs where no data // looping forever if an error occurs for wait-for-ready RPCs where no data
@ -283,29 +282,17 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
statsHandler: cc.dopts.copts.StatsHandler, statsHandler: cc.dopts.copts.StatsHandler,
} }
if desc != unaryStreamDesc { if desc != unaryStreamDesc {
// Listen on s.Context().Done() to detect cancellation and s.Done() to // Listen on cc and stream contexts to cleanup when the user closes the
// detect normal termination when there is no pending I/O operations on // ClientConn or cancels the stream context. In all other cases, an error
// this stream. Not necessary for "unary" streams, since we are guaranteed // should already be injected into the recv buffer by the transport, which
// to always be in stream functions. // the client will eventually receive, and then we will cancel the stream's
// context in clientStream.finish.
go func() { go func() {
select { select {
case <-t.Error():
// Incur transport error, simply exit.
case <-cc.ctx.Done(): case <-cc.ctx.Done():
cs.finish(ErrClientConnClosing) cs.finish(ErrClientConnClosing)
cs.closeTransportStream(ErrClientConnClosing) case <-ctx.Done():
case <-s.Done(): cs.finish(toRPCErr(s.Context().Err()))
// TODO: The trace of the RPC is terminated here when there is no pending
// I/O, which is probably not the optimal solution.
cs.finish(s.Status().Err())
cs.closeTransportStream(nil)
case <-s.GoAway():
cs.finish(errConnDrain)
cs.closeTransportStream(errConnDrain)
case <-s.Context().Done():
err := s.Context().Err()
cs.finish(err)
cs.closeTransportStream(transport.ContextErr(err))
} }
}() }()
} }
@ -337,7 +324,6 @@ type clientStream struct {
mu sync.Mutex mu sync.Mutex
done func(balancer.DoneInfo) done func(balancer.DoneInfo)
sentLast bool // sent an end stream sentLast bool // sent an end stream
closed bool
finished bool finished bool
// trInfo.tr is set when the clientStream is created (if EnableTracing is true), // trInfo.tr is set when the clientStream is created (if EnableTracing is true),
// and is set to nil when the clientStream's finish method is called. // and is set to nil when the clientStream's finish method is called.
@ -357,9 +343,8 @@ func (cs *clientStream) Context() context.Context {
func (cs *clientStream) Header() (metadata.MD, error) { func (cs *clientStream) Header() (metadata.MD, error) {
m, err := cs.s.Header() m, err := cs.s.Header()
if err != nil { if err != nil {
if _, ok := err.(transport.ConnectionError); !ok { err = toRPCErr(err)
cs.closeTransportStream(err) cs.finish(err)
}
} }
return m, err return m, err
} }
@ -380,23 +365,18 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {
// TODO Investigate how to signal the stats handling party. // TODO Investigate how to signal the stats handling party.
// generate error stats if err != nil && err != io.EOF? // generate error stats if err != nil && err != io.EOF?
defer func() { defer func() {
if err == nil { // For non-client-streaming RPCs, we return nil instead of EOF on success
return // because the generated code requires it. finish is not called; RecvMsg()
} // will call it with the stream's status independently.
cs.finish(err) if err == io.EOF && !cs.desc.ClientStreams {
if err == io.EOF {
// SendMsg is only called once for non-client-streams. io.EOF needs to be
// skipped when the rpc is early finished (before the stream object is
// created.).
if !cs.desc.ClientStreams {
err = nil err = nil
} }
return if err != nil && err != io.EOF {
// Call finish for errors generated by this SendMsg call. (Transport
// errors are converted to an io.EOF error below; the real error will be
// returned from RecvMsg eventually in that case.)
cs.finish(err)
} }
if _, ok := err.(transport.ConnectionError); !ok {
cs.closeTransportStream(err)
}
err = toRPCErr(err)
}() }()
var outPayload *stats.OutPayload var outPayload *stats.OutPayload
if cs.statsHandler != nil { if cs.statsHandler != nil {
@ -408,9 +388,6 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {
if err != nil { if err != nil {
return err return err
} }
if cs.c.maxSendMessageSize == nil {
return status.Errorf(codes.Internal, "callInfo maxSendMessageSize field uninitialized(nil)")
}
if len(data) > *cs.c.maxSendMessageSize { if len(data) > *cs.c.maxSendMessageSize {
return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(data), *cs.c.maxSendMessageSize) return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(data), *cs.c.maxSendMessageSize)
} }
@ -418,21 +395,21 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {
cs.sentLast = true cs.sentLast = true
} }
err = cs.t.Write(cs.s, hdr, data, &transport.Options{Last: !cs.desc.ClientStreams}) err = cs.t.Write(cs.s, hdr, data, &transport.Options{Last: !cs.desc.ClientStreams})
if err == nil && outPayload != nil { if err == nil {
if outPayload != nil {
outPayload.SentTime = time.Now() outPayload.SentTime = time.Now()
cs.statsHandler.HandleRPC(cs.statsCtx, outPayload) cs.statsHandler.HandleRPC(cs.statsCtx, outPayload)
} }
return err return nil
}
return io.EOF
} }
func (cs *clientStream) RecvMsg(m interface{}) (err error) { func (cs *clientStream) RecvMsg(m interface{}) (err error) {
defer func() { defer func() {
// err != nil indicates the termination of the stream. if err != nil || !cs.desc.ServerStreams {
if err != nil { // err != nil or non-server-streaming indicates end of stream.
cs.finish(err) cs.finish(err)
if cs.cancel != nil {
cs.cancel()
}
} }
}() }()
var inPayload *stats.InPayload var inPayload *stats.InPayload
@ -441,9 +418,6 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) {
Client: true, Client: true,
} }
} }
if cs.c.maxReceiveMessageSize == nil {
return status.Errorf(codes.Internal, "callInfo maxReceiveMessageSize field uninitialized(nil)")
}
if !cs.decompSet { if !cs.decompSet {
// Block until we receive headers containing received message encoding. // Block until we receive headers containing received message encoding.
if ct := cs.s.RecvCompress(); ct != "" && ct != encoding.Identity { if ct := cs.s.RecvCompress(); ct != "" && ct != encoding.Identity {
@ -462,14 +436,11 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) {
} }
err = recv(cs.p, cs.codec, cs.s, cs.dc, m, *cs.c.maxReceiveMessageSize, inPayload, cs.decomp) err = recv(cs.p, cs.codec, cs.s, cs.dc, m, *cs.c.maxReceiveMessageSize, inPayload, cs.decomp)
if err != nil { if err != nil {
if _, ok := err.(transport.ConnectionError); !ok {
cs.closeTransportStream(err)
}
if err == io.EOF { if err == io.EOF {
if statusErr := cs.s.Status().Err(); statusErr != nil { if statusErr := cs.s.Status().Err(); statusErr != nil {
return statusErr return statusErr
} }
return io.EOF // indicates end of stream. return io.EOF // indicates successful end of stream.
} }
return toRPCErr(err) return toRPCErr(err)
} }
@ -484,22 +455,18 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) {
cs.statsHandler.HandleRPC(cs.statsCtx, inPayload) cs.statsHandler.HandleRPC(cs.statsCtx, inPayload)
} }
if cs.desc.ServerStreams { if cs.desc.ServerStreams {
// Subsequent messages should be received by subsequent RecvMsg calls.
return nil return nil
} }
// Special handling for non-server-stream rpcs. // Special handling for non-server-stream rpcs.
// This recv expects EOF or errors, so we don't collect inPayload. // This recv expects EOF or errors, so we don't collect inPayload.
err = recv(cs.p, cs.codec, cs.s, cs.dc, m, *cs.c.maxReceiveMessageSize, nil, cs.decomp) err = recv(cs.p, cs.codec, cs.s, cs.dc, m, *cs.c.maxReceiveMessageSize, nil, cs.decomp)
cs.closeTransportStream(err)
if err == nil { if err == nil {
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>")) return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
} }
if err == io.EOF { if err == io.EOF {
if se := cs.s.Status().Err(); se != nil { return cs.s.Status().Err() // non-server streaming Recv returns nil on success
return se
}
cs.finish(err)
return nil
} }
return toRPCErr(err) return toRPCErr(err)
} }
@ -509,41 +476,26 @@ func (cs *clientStream) CloseSend() error {
return nil return nil
} }
cs.sentLast = true cs.sentLast = true
err := cs.t.Write(cs.s, nil, nil, &transport.Options{Last: true}) cs.t.Write(cs.s, nil, nil, &transport.Options{Last: true})
if err == nil || err == io.EOF { // We ignore errors from Write and always return nil here. Any error it
// would return would also be returned by a subsequent RecvMsg call, and the
// user is supposed to always finish the stream by calling RecvMsg until it
// returns err != nil.
return nil return nil
} }
if _, ok := err.(transport.ConnectionError); !ok {
cs.closeTransportStream(err)
}
err = toRPCErr(err)
if err != nil {
cs.finish(err)
}
return err
}
func (cs *clientStream) closeTransportStream(err error) {
cs.mu.Lock()
if cs.closed {
cs.mu.Unlock()
return
}
cs.closed = true
cs.mu.Unlock()
cs.t.CloseStream(cs.s, err)
}
func (cs *clientStream) finish(err error) { func (cs *clientStream) finish(err error) {
// Do not call cs.cancel in this function. Only call it when RecvMag() if err == io.EOF {
// returns non-nil error because of // Ending a stream with EOF indicates a success.
// https://github.com/grpc/grpc-go/issues/1818. err = nil
}
cs.mu.Lock() cs.mu.Lock()
defer cs.mu.Unlock() defer cs.mu.Unlock()
if cs.finished { if cs.finished {
return return
} }
cs.finished = true cs.finished = true
cs.t.CloseStream(cs.s, err)
for _, o := range cs.opts { for _, o := range cs.opts {
o.after(cs.c) o.after(cs.c)
} }
@ -559,18 +511,16 @@ func (cs *clientStream) finish(err error) {
end := &stats.End{ end := &stats.End{
Client: true, Client: true,
EndTime: time.Now(), EndTime: time.Now(),
} Error: err,
if err != io.EOF {
// end.Error is nil if the RPC finished successfully.
end.Error = toRPCErr(err)
} }
cs.statsHandler.HandleRPC(cs.statsCtx, end) cs.statsHandler.HandleRPC(cs.statsCtx, end)
} }
cs.cancel()
if !cs.tracing { if !cs.tracing {
return return
} }
if cs.trInfo.tr != nil { if cs.trInfo.tr != nil {
if err == nil || err == io.EOF { if err == nil {
cs.trInfo.tr.LazyPrintf("RPC: [OK]") cs.trInfo.tr.LazyPrintf("RPC: [OK]")
} else { } else {
cs.trInfo.tr.LazyPrintf("RPC: [%v]", err) cs.trInfo.tr.LazyPrintf("RPC: [%v]", err)

View File

@ -1164,10 +1164,22 @@ func testConcurrentServerStopAndGoAway(t *testing.T, e env) {
ResponseParameters: respParam, ResponseParameters: respParam,
Payload: payload, Payload: payload,
} }
if err := stream.Send(req); err == nil { sendStart := time.Now()
if _, err := stream.Recv(); err == nil { for {
t.Fatalf("%v.Recv() = _, %v, want _, <nil>", stream, err) if err := stream.Send(req); err == io.EOF {
// stream.Send should eventually send io.EOF
break
} else if err != nil {
// Send should never return a transport-level error.
t.Fatalf("stream.Send(%v) = %v; want <nil or io.EOF>", req, err)
} }
if time.Since(sendStart) > 2*time.Second {
t.Fatalf("stream.Send(_) did not return io.EOF after 2s")
}
time.Sleep(time.Millisecond)
}
if _, err := stream.Recv(); err == nil || err == io.EOF {
t.Fatalf("%v.Recv() = _, %v, want _, <non-nil, non-EOF>", stream, err)
} }
<-ch <-ch
awaitNewConnLogOutput() awaitNewConnLogOutput()
@ -1190,7 +1202,9 @@ func testClientConnCloseAfterGoAwayWithActiveStream(t *testing.T, e env) {
cc := te.clientConn() cc := te.clientConn()
tc := testpb.NewTestServiceClient(cc) tc := testpb.NewTestServiceClient(cc)
if _, err := tc.FullDuplexCall(context.Background()); err != nil { ctx, cancel := context.WithCancel(context.Background())
defer cancel()
if _, err := tc.FullDuplexCall(ctx); err != nil {
t.Fatalf("%v.FullDuplexCall(_) = _, %v, want _, <nil>", tc, err) t.Fatalf("%v.FullDuplexCall(_) = _, %v, want _, <nil>", tc, err)
} }
done := make(chan struct{}) done := make(chan struct{})
@ -2334,24 +2348,6 @@ func testHealthCheckServingStatus(t *testing.T, e env) {
} }
func TestErrorChanNoIO(t *testing.T) {
defer leakcheck.Check(t)
for _, e := range listTestEnv() {
testErrorChanNoIO(t, e)
}
}
func testErrorChanNoIO(t *testing.T, e env) {
te := newTest(t, e)
te.startServer(&testServer{security: e.security})
defer te.tearDown()
tc := testpb.NewTestServiceClient(te.clientConn())
if _, err := tc.FullDuplexCall(context.Background()); err != nil {
t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
}
}
func TestEmptyUnaryWithUserAgent(t *testing.T) { func TestEmptyUnaryWithUserAgent(t *testing.T) {
defer leakcheck.Check(t) defer leakcheck.Check(t)
for _, e := range listTestEnv() { for _, e := range listTestEnv() {
@ -3811,22 +3807,24 @@ func testStreamsQuotaRecovery(t *testing.T, e env) {
cc := te.clientConn() cc := te.clientConn()
tc := testpb.NewTestServiceClient(cc) tc := testpb.NewTestServiceClient(cc)
if _, err := tc.StreamingInputCall(context.Background()); err != nil { ctx, cancel := context.WithCancel(context.Background())
t.Fatalf("%v.StreamingInputCall(_) = _, %v, want _, <nil>", tc, err) defer cancel()
if _, err := tc.StreamingInputCall(ctx); err != nil {
t.Fatalf("tc.StreamingInputCall(_) = _, %v, want _, <nil>", err)
} }
// Loop until the new max stream setting is effective. // Loop until the new max stream setting is effective.
for { for {
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()
_, err := tc.StreamingInputCall(ctx) _, err := tc.StreamingInputCall(ctx)
cancel()
if err == nil { if err == nil {
time.Sleep(50 * time.Millisecond) time.Sleep(5 * time.Millisecond)
continue continue
} }
if status.Code(err) == codes.DeadlineExceeded { if status.Code(err) == codes.DeadlineExceeded {
break break
} }
t.Fatalf("%v.StreamingInputCall(_) = _, %v, want _, %s", tc, err, codes.DeadlineExceeded) t.Fatalf("tc.StreamingInputCall(_) = _, %v, want _, %s", err, codes.DeadlineExceeded)
} }
var wg sync.WaitGroup var wg sync.WaitGroup
@ -3848,11 +3846,19 @@ func testStreamsQuotaRecovery(t *testing.T, e env) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer cancel() defer cancel()
if _, err := tc.UnaryCall(ctx, req, grpc.FailFast(false)); status.Code(err) != codes.DeadlineExceeded { if _, err := tc.UnaryCall(ctx, req, grpc.FailFast(false)); status.Code(err) != codes.DeadlineExceeded {
t.Errorf("TestService/UnaryCall(_, _) = _, %v, want _, %s", err, codes.DeadlineExceeded) t.Errorf("tc.UnaryCall(_, _) = _, %v, want _, %s", err, codes.DeadlineExceeded)
} }
}() }()
} }
wg.Wait() wg.Wait()
cancel()
// A new stream should be allowed after canceling the first one.
ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if _, err := tc.StreamingInputCall(ctx); err != nil {
t.Fatalf("tc.StreamingInputCall(_) = _, %v, want _, %v", err, nil)
}
} }
func TestCompressServerHasNoSupport(t *testing.T) { func TestCompressServerHasNoSupport(t *testing.T) {