Cancel all active streams when a server connection is closed
This commit is contained in:
@ -163,22 +163,6 @@ func (t *http2Server) operateHeaders(hDec *hpackDecoder, s *Stream, frame header
|
|||||||
if !endHeaders {
|
if !endHeaders {
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
t.mu.Lock()
|
|
||||||
if t.state != reachable {
|
|
||||||
t.mu.Unlock()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if uint32(len(t.activeStreams)) >= t.maxStreams {
|
|
||||||
t.mu.Unlock()
|
|
||||||
t.controlBuf.put(&resetStream{s.id, http2.ErrCodeRefusedStream})
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
s.sendQuotaPool = newQuotaPool(int(t.streamSendQuota))
|
|
||||||
t.activeStreams[s.id] = s
|
|
||||||
t.mu.Unlock()
|
|
||||||
s.windowHandler = func(n int) {
|
|
||||||
t.updateWindow(s, uint32(n))
|
|
||||||
}
|
|
||||||
if hDec.state.timeoutSet {
|
if hDec.state.timeoutSet {
|
||||||
s.ctx, s.cancel = context.WithTimeout(context.TODO(), hDec.state.timeout)
|
s.ctx, s.cancel = context.WithTimeout(context.TODO(), hDec.state.timeout)
|
||||||
} else {
|
} else {
|
||||||
@ -202,6 +186,22 @@ func (t *http2Server) operateHeaders(hDec *hpackDecoder, s *Stream, frame header
|
|||||||
recv: s.buf,
|
recv: s.buf,
|
||||||
}
|
}
|
||||||
s.method = hDec.state.method
|
s.method = hDec.state.method
|
||||||
|
t.mu.Lock()
|
||||||
|
if t.state != reachable {
|
||||||
|
t.mu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if uint32(len(t.activeStreams)) >= t.maxStreams {
|
||||||
|
t.mu.Unlock()
|
||||||
|
t.controlBuf.put(&resetStream{s.id, http2.ErrCodeRefusedStream})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
s.sendQuotaPool = newQuotaPool(int(t.streamSendQuota))
|
||||||
|
t.activeStreams[s.id] = s
|
||||||
|
t.mu.Unlock()
|
||||||
|
s.windowHandler = func(n int) {
|
||||||
|
t.updateWindow(s, uint32(n))
|
||||||
|
}
|
||||||
handle(s)
|
handle(s)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -660,9 +660,9 @@ func (t *http2Server) Close() (err error) {
|
|||||||
t.mu.Unlock()
|
t.mu.Unlock()
|
||||||
close(t.shutdownChan)
|
close(t.shutdownChan)
|
||||||
err = t.conn.Close()
|
err = t.conn.Close()
|
||||||
// Notify all active streams.
|
// Cancel all active streams.
|
||||||
for _, s := range streams {
|
for _, s := range streams {
|
||||||
s.write(recvMsg{err: ErrConnClosing})
|
s.cancel()
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -684,9 +684,8 @@ func (t *http2Server) closeStream(s *Stream) {
|
|||||||
s.state = streamDone
|
s.state = streamDone
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
// In case stream sending and receiving are invoked in separate
|
// In case stream sending and receiving are invoked in separate
|
||||||
// goroutines (e.g., bi-directional streaming), the caller needs
|
// goroutines (e.g., bi-directional streaming), cancel needs to be
|
||||||
// to call cancel on the stream to interrupt the blocking on
|
// called to interrupt the potential blocking on other goroutines.
|
||||||
// other goroutines.
|
|
||||||
s.cancel()
|
s.cancel()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -86,11 +86,11 @@ func (h *testStreamHandler) handleStream(t *testing.T, s *Stream) {
|
|||||||
}
|
}
|
||||||
p := make([]byte, len(req))
|
p := make([]byte, len(req))
|
||||||
_, err := io.ReadFull(s, p)
|
_, err := io.ReadFull(s, p)
|
||||||
if err != nil || !bytes.Equal(p, req) {
|
if err != nil {
|
||||||
if err == ErrConnClosing {
|
return
|
||||||
return
|
}
|
||||||
}
|
if !bytes.Equal(p, req) {
|
||||||
t.Fatalf("handleStream got error: %v, want <nil>; result: %v, want %v", err, p, req)
|
t.Fatalf("handleStream got %v, want %v", p, req)
|
||||||
}
|
}
|
||||||
// send a response back to the client.
|
// send a response back to the client.
|
||||||
h.t.Write(s, resp, &Options{})
|
h.t.Write(s, resp, &Options{})
|
||||||
@ -429,6 +429,69 @@ func TestMaxStreams(t *testing.T) {
|
|||||||
server.stop()
|
server.stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestServerContextCanceledOnClosedConnection(t *testing.T) {
|
||||||
|
server, ct := setUp(t, 0, math.MaxUint32, suspended)
|
||||||
|
callHdr := &CallHdr{
|
||||||
|
Host: "localhost",
|
||||||
|
Method: "foo",
|
||||||
|
}
|
||||||
|
var sc *http2Server
|
||||||
|
// Wait until the server transport is setup.
|
||||||
|
for {
|
||||||
|
server.mu.Lock()
|
||||||
|
if len(server.conns) == 0 {
|
||||||
|
server.mu.Unlock()
|
||||||
|
time.Sleep(time.Millisecond)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for k := range server.conns {
|
||||||
|
var ok bool
|
||||||
|
sc, ok = k.(*http2Server)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("Failed to convert %v to *http2Server", k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
server.mu.Unlock()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
cc, ok := ct.(*http2Client)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("Failed to convert %v to *http2Client", ct)
|
||||||
|
}
|
||||||
|
s, err := ct.NewStream(context.Background(), callHdr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to open stream: %v", err)
|
||||||
|
}
|
||||||
|
// Make sure the headers frame is flushed out.
|
||||||
|
<-cc.writableChan
|
||||||
|
if err = cc.framer.writeData(true, s.id, false, make([]byte, http2MaxFrameLen)); err != nil {
|
||||||
|
t.Fatalf("Failed to write data: %v", err)
|
||||||
|
}
|
||||||
|
cc.writableChan <- 0
|
||||||
|
// Loop until the server side stream is created.
|
||||||
|
var ss *Stream
|
||||||
|
for {
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
sc.mu.Lock()
|
||||||
|
if len(sc.activeStreams) == 0 {
|
||||||
|
sc.mu.Unlock()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ss = sc.activeStreams[s.id]
|
||||||
|
sc.mu.Unlock()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
cc.Close()
|
||||||
|
select {
|
||||||
|
case <-ss.Context().Done():
|
||||||
|
if ss.Context().Err() != context.Canceled {
|
||||||
|
t.Fatalf("ss.Context().Err() got %v, want %v", ss.Context().Err(), context.Canceled)
|
||||||
|
}
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Fatalf("Failed to cancel the context of the sever side stream.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestServerWithMisbehavedClient(t *testing.T) {
|
func TestServerWithMisbehavedClient(t *testing.T) {
|
||||||
server, ct := setUp(t, 0, math.MaxUint32, suspended)
|
server, ct := setUp(t, 0, math.MaxUint32, suspended)
|
||||||
callHdr := &CallHdr{
|
callHdr := &CallHdr{
|
||||||
|
Reference in New Issue
Block a user