diff --git a/transport/control.go b/transport/control.go index c5fbe23a..464bf10a 100644 --- a/transport/control.go +++ b/transport/control.go @@ -104,8 +104,14 @@ type quotaPool struct { // newQuotaPool creates a quotaPool which has quota q available to consume. func newQuotaPool(q int) *quotaPool { - qb := "aPool{c: make(chan int, 1)} - qb.c <- q + qb := "aPool{ + c: make(chan int, 1), + } + if q > 0 { + qb.c <- q + } else { + qb.quota = q + } return qb } diff --git a/transport/http2_client.go b/transport/http2_client.go index f956b1ea..e9e85434 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -196,7 +196,7 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e return t, nil } -func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr, sq bool) *Stream { +func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream { fc := &inFlow{ limit: initialWindowSize, conn: t.fc, @@ -206,7 +206,6 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr, sq bool) id: t.nextID, method: callHdr.Method, buf: newRecvBuffer(), - updateStreams: sq, fc: fc, sendQuotaPool: newQuotaPool(int(t.streamSendQuota)), headerChan: make(chan struct{}), @@ -267,7 +266,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea return nil, err } t.mu.Lock() - s := t.newStream(ctx, callHdr, checkStreamsQuota) + s := t.newStream(ctx, callHdr) t.activeStreams[s.id] = s t.mu.Unlock() // HPACK encodes various headers. Note that once WriteField(...) is @@ -336,10 +335,14 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea // CloseStream clears the footprint of a stream when the stream is not needed any more. // This must not be executed in reader's goroutine. func (t *http2Client) CloseStream(s *Stream, err error) { + var updateStreams bool t.mu.Lock() + if t.streamsQuota != nil { + updateStreams = true + } delete(t.activeStreams, s.id) t.mu.Unlock() - if s.updateStreams { + if updateStreams { t.streamsQuota.add(1) } s.mu.Lock() @@ -737,7 +740,7 @@ func (t *http2Client) applySettings(ss []http2.Setting) { t.mu.Lock() reset := t.streamsQuota != nil if !reset { - t.streamsQuota = newQuotaPool(int(s.Val)) + t.streamsQuota = newQuotaPool(int(s.Val) - len(t.activeStreams)) } ms := t.maxStreams t.maxStreams = int(s.Val) diff --git a/transport/transport.go b/transport/transport.go index c2ac3f88..58436f01 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -169,17 +169,11 @@ type Stream struct { ctx context.Context cancel context.CancelFunc // method records the associated RPC method of the stream. - method string - buf *recvBuffer - dec io.Reader - - // updateStreams indicates whether the transport's streamsQuota needed - // to be updated when this stream is closed. It is false when the transport - // sticks to the initial infinite value of the number of concurrent streams. - // Ture otherwise. - updateStreams bool - fc *inFlow - recvQuota uint32 + method string + buf *recvBuffer + dec io.Reader + fc *inFlow + recvQuota uint32 // The accumulated inbound quota pending for window update. updateQuota uint32 // The handler to control the window update procedure for both this diff --git a/transport/transport_test.go b/transport/transport_test.go index 8529e2af..e79acec4 100644 --- a/transport/transport_test.go +++ b/transport/transport_test.go @@ -355,6 +355,75 @@ func TestLargeMessageSuspension(t *testing.T) { server.stop() } +func TestMaxStreams(t *testing.T) { + server, ct := setUp(t, 0, 1, suspended) + callHdr := &CallHdr{ + Host: "localhost", + Method: "foo.Large", + } + // Have a pending stream which takes all streams quota. + s, err := ct.NewStream(context.Background(), callHdr) + if err != nil { + t.Fatalf("Failed to open stream: %v", err) + } + cc, ok := ct.(*http2Client) + if !ok { + t.Fatalf("Failed to convert %v to *http2Client", ct) + } + done := make(chan struct{}) + ch := make(chan int) + go func() { + for { + select { + case <-time.After(5 * time.Millisecond): + ch <- 0 + case <-time.After(5 * time.Second): + close(done) + return + } + } + }() + for { + select { + case <-ch: + case <-done: + t.Fatalf("Client has not received the max stream setting in 5 seconds.") + } + cc.mu.Lock() + // cc.streamsQuota should be initialized once receiving the 1st setting frame from + // the server. + if cc.streamsQuota != nil { + cc.mu.Unlock() + select { + case <-cc.streamsQuota.acquire(): + t.Fatalf("streamsQuota.acquire() becomes readable mistakenly.") + default: + if cc.streamsQuota.quota != 0 { + t.Fatalf("streamsQuota.quota got non-zero quota mistakenly.") + } + } + break + } + cc.mu.Unlock() + } + // Close the pending stream so that the streams quota becomes available for the next new stream. + ct.CloseStream(s, nil) + select { + case i := <-cc.streamsQuota.acquire(): + if i != 1 { + t.Fatalf("streamsQuota.acquire() got %d quota, want 1.", i) + } + cc.streamsQuota.add(i) + default: + t.Fatalf("streamsQuota.acquire() is not readable.") + } + if _, err := ct.NewStream(context.Background(), callHdr); err != nil { + t.Fatalf("Failed to open stream: %v", err) + } + ct.Close() + server.stop() +} + func TestServerWithMisbehavedClient(t *testing.T) { server, ct := setUp(t, 0, math.MaxUint32, suspended) callHdr := &CallHdr{