From 1c2c309b25815bea38013789dfd6e6e3d02fd860 Mon Sep 17 00:00:00 2001 From: iamqizhao Date: Fri, 30 Oct 2015 15:52:41 -0700 Subject: [PATCH] Cancel all active streams when a server connection is closed --- transport/http2_server.go | 41 ++++++++++----------- transport/transport_test.go | 73 ++++++++++++++++++++++++++++++++++--- 2 files changed, 88 insertions(+), 26 deletions(-) diff --git a/transport/http2_server.go b/transport/http2_server.go index f3488f83..ceb30557 100644 --- a/transport/http2_server.go +++ b/transport/http2_server.go @@ -163,22 +163,6 @@ func (t *http2Server) operateHeaders(hDec *hpackDecoder, s *Stream, frame header if !endHeaders { return s } - t.mu.Lock() - if t.state != reachable { - t.mu.Unlock() - return nil - } - if uint32(len(t.activeStreams)) >= t.maxStreams { - t.mu.Unlock() - t.controlBuf.put(&resetStream{s.id, http2.ErrCodeRefusedStream}) - return nil - } - s.sendQuotaPool = newQuotaPool(int(t.streamSendQuota)) - t.activeStreams[s.id] = s - t.mu.Unlock() - s.windowHandler = func(n int) { - t.updateWindow(s, uint32(n)) - } if hDec.state.timeoutSet { s.ctx, s.cancel = context.WithTimeout(context.TODO(), hDec.state.timeout) } else { @@ -202,6 +186,22 @@ func (t *http2Server) operateHeaders(hDec *hpackDecoder, s *Stream, frame header recv: s.buf, } s.method = hDec.state.method + t.mu.Lock() + if t.state != reachable { + t.mu.Unlock() + return nil + } + if uint32(len(t.activeStreams)) >= t.maxStreams { + t.mu.Unlock() + t.controlBuf.put(&resetStream{s.id, http2.ErrCodeRefusedStream}) + return nil + } + s.sendQuotaPool = newQuotaPool(int(t.streamSendQuota)) + t.activeStreams[s.id] = s + t.mu.Unlock() + s.windowHandler = func(n int) { + t.updateWindow(s, uint32(n)) + } handle(s) return nil } @@ -660,9 +660,9 @@ func (t *http2Server) Close() (err error) { t.mu.Unlock() close(t.shutdownChan) err = t.conn.Close() - // Notify all active streams. + // Cancel all active streams. for _, s := range streams { - s.write(recvMsg{err: ErrConnClosing}) + s.cancel() } return } @@ -684,9 +684,8 @@ func (t *http2Server) closeStream(s *Stream) { s.state = streamDone s.mu.Unlock() // In case stream sending and receiving are invoked in separate - // goroutines (e.g., bi-directional streaming), the caller needs - // to call cancel on the stream to interrupt the blocking on - // other goroutines. + // goroutines (e.g., bi-directional streaming), cancel needs to be + // called to interrupt the potential blocking on other goroutines. s.cancel() } diff --git a/transport/transport_test.go b/transport/transport_test.go index 9bf3ed3c..06847cea 100644 --- a/transport/transport_test.go +++ b/transport/transport_test.go @@ -86,11 +86,11 @@ func (h *testStreamHandler) handleStream(t *testing.T, s *Stream) { } p := make([]byte, len(req)) _, err := io.ReadFull(s, p) - if err != nil || !bytes.Equal(p, req) { - if err == ErrConnClosing { - return - } - t.Fatalf("handleStream got error: %v, want ; result: %v, want %v", err, p, req) + if err != nil { + return + } + if !bytes.Equal(p, req) { + t.Fatalf("handleStream got %v, want %v", p, req) } // send a response back to the client. h.t.Write(s, resp, &Options{}) @@ -429,6 +429,69 @@ func TestMaxStreams(t *testing.T) { server.stop() } +func TestServerContextCanceledOnClosedConnection(t *testing.T) { + server, ct := setUp(t, 0, math.MaxUint32, suspended) + callHdr := &CallHdr{ + Host: "localhost", + Method: "foo", + } + var sc *http2Server + // Wait until the server transport is setup. + for { + server.mu.Lock() + if len(server.conns) == 0 { + server.mu.Unlock() + time.Sleep(time.Millisecond) + continue + } + for k := range server.conns { + var ok bool + sc, ok = k.(*http2Server) + if !ok { + t.Fatalf("Failed to convert %v to *http2Server", k) + } + } + server.mu.Unlock() + break + } + cc, ok := ct.(*http2Client) + if !ok { + t.Fatalf("Failed to convert %v to *http2Client", ct) + } + s, err := ct.NewStream(context.Background(), callHdr) + if err != nil { + t.Fatalf("Failed to open stream: %v", err) + } + // Make sure the headers frame is flushed out. + <-cc.writableChan + if err = cc.framer.writeData(true, s.id, false, make([]byte, http2MaxFrameLen)); err != nil { + t.Fatalf("Failed to write data: %v", err) + } + cc.writableChan <- 0 + // Loop until the server side stream is created. + var ss *Stream + for { + time.Sleep(time.Second) + sc.mu.Lock() + if len(sc.activeStreams) == 0 { + sc.mu.Unlock() + continue + } + ss = sc.activeStreams[s.id] + sc.mu.Unlock() + break + } + cc.Close() + select { + case <-ss.Context().Done(): + if ss.Context().Err() != context.Canceled { + t.Fatalf("ss.Context().Err() got %v, want %v", ss.Context().Err(), context.Canceled) + } + case <-time.After(5 * time.Second): + t.Fatalf("Failed to cancel the context of the sever side stream.") + } +} + func TestServerWithMisbehavedClient(t *testing.T) { server, ct := setUp(t, 0, math.MaxUint32, suspended) callHdr := &CallHdr{