Cancel all active streams when a server connection is closed

This commit is contained in:
iamqizhao
2015-10-30 15:52:41 -07:00
parent 174192fc93
commit 1c2c309b25
2 changed files with 88 additions and 26 deletions

View File

@ -163,22 +163,6 @@ func (t *http2Server) operateHeaders(hDec *hpackDecoder, s *Stream, frame header
if !endHeaders { if !endHeaders {
return s return s
} }
t.mu.Lock()
if t.state != reachable {
t.mu.Unlock()
return nil
}
if uint32(len(t.activeStreams)) >= t.maxStreams {
t.mu.Unlock()
t.controlBuf.put(&resetStream{s.id, http2.ErrCodeRefusedStream})
return nil
}
s.sendQuotaPool = newQuotaPool(int(t.streamSendQuota))
t.activeStreams[s.id] = s
t.mu.Unlock()
s.windowHandler = func(n int) {
t.updateWindow(s, uint32(n))
}
if hDec.state.timeoutSet { if hDec.state.timeoutSet {
s.ctx, s.cancel = context.WithTimeout(context.TODO(), hDec.state.timeout) s.ctx, s.cancel = context.WithTimeout(context.TODO(), hDec.state.timeout)
} else { } else {
@ -202,6 +186,22 @@ func (t *http2Server) operateHeaders(hDec *hpackDecoder, s *Stream, frame header
recv: s.buf, recv: s.buf,
} }
s.method = hDec.state.method s.method = hDec.state.method
t.mu.Lock()
if t.state != reachable {
t.mu.Unlock()
return nil
}
if uint32(len(t.activeStreams)) >= t.maxStreams {
t.mu.Unlock()
t.controlBuf.put(&resetStream{s.id, http2.ErrCodeRefusedStream})
return nil
}
s.sendQuotaPool = newQuotaPool(int(t.streamSendQuota))
t.activeStreams[s.id] = s
t.mu.Unlock()
s.windowHandler = func(n int) {
t.updateWindow(s, uint32(n))
}
handle(s) handle(s)
return nil return nil
} }
@ -660,9 +660,9 @@ func (t *http2Server) Close() (err error) {
t.mu.Unlock() t.mu.Unlock()
close(t.shutdownChan) close(t.shutdownChan)
err = t.conn.Close() err = t.conn.Close()
// Notify all active streams. // Cancel all active streams.
for _, s := range streams { for _, s := range streams {
s.write(recvMsg{err: ErrConnClosing}) s.cancel()
} }
return return
} }
@ -684,9 +684,8 @@ func (t *http2Server) closeStream(s *Stream) {
s.state = streamDone s.state = streamDone
s.mu.Unlock() s.mu.Unlock()
// In case stream sending and receiving are invoked in separate // In case stream sending and receiving are invoked in separate
// goroutines (e.g., bi-directional streaming), the caller needs // goroutines (e.g., bi-directional streaming), cancel needs to be
// to call cancel on the stream to interrupt the blocking on // called to interrupt the potential blocking on other goroutines.
// other goroutines.
s.cancel() s.cancel()
} }

View File

@ -86,11 +86,11 @@ func (h *testStreamHandler) handleStream(t *testing.T, s *Stream) {
} }
p := make([]byte, len(req)) p := make([]byte, len(req))
_, err := io.ReadFull(s, p) _, err := io.ReadFull(s, p)
if err != nil || !bytes.Equal(p, req) { if err != nil {
if err == ErrConnClosing { return
return }
} if !bytes.Equal(p, req) {
t.Fatalf("handleStream got error: %v, want <nil>; result: %v, want %v", err, p, req) t.Fatalf("handleStream got %v, want %v", p, req)
} }
// send a response back to the client. // send a response back to the client.
h.t.Write(s, resp, &Options{}) h.t.Write(s, resp, &Options{})
@ -429,6 +429,69 @@ func TestMaxStreams(t *testing.T) {
server.stop() server.stop()
} }
func TestServerContextCanceledOnClosedConnection(t *testing.T) {
server, ct := setUp(t, 0, math.MaxUint32, suspended)
callHdr := &CallHdr{
Host: "localhost",
Method: "foo",
}
var sc *http2Server
// Wait until the server transport is setup.
for {
server.mu.Lock()
if len(server.conns) == 0 {
server.mu.Unlock()
time.Sleep(time.Millisecond)
continue
}
for k := range server.conns {
var ok bool
sc, ok = k.(*http2Server)
if !ok {
t.Fatalf("Failed to convert %v to *http2Server", k)
}
}
server.mu.Unlock()
break
}
cc, ok := ct.(*http2Client)
if !ok {
t.Fatalf("Failed to convert %v to *http2Client", ct)
}
s, err := ct.NewStream(context.Background(), callHdr)
if err != nil {
t.Fatalf("Failed to open stream: %v", err)
}
// Make sure the headers frame is flushed out.
<-cc.writableChan
if err = cc.framer.writeData(true, s.id, false, make([]byte, http2MaxFrameLen)); err != nil {
t.Fatalf("Failed to write data: %v", err)
}
cc.writableChan <- 0
// Loop until the server side stream is created.
var ss *Stream
for {
time.Sleep(time.Second)
sc.mu.Lock()
if len(sc.activeStreams) == 0 {
sc.mu.Unlock()
continue
}
ss = sc.activeStreams[s.id]
sc.mu.Unlock()
break
}
cc.Close()
select {
case <-ss.Context().Done():
if ss.Context().Err() != context.Canceled {
t.Fatalf("ss.Context().Err() got %v, want %v", ss.Context().Err(), context.Canceled)
}
case <-time.After(5 * time.Second):
t.Fatalf("Failed to cancel the context of the sever side stream.")
}
}
func TestServerWithMisbehavedClient(t *testing.T) { func TestServerWithMisbehavedClient(t *testing.T) {
server, ct := setUp(t, 0, math.MaxUint32, suspended) server, ct := setUp(t, 0, math.MaxUint32, suspended)
callHdr := &CallHdr{ callHdr := &CallHdr{