diff --git a/transport/http2_client.go b/transport/http2_client.go index d16f8dd8..d01fb892 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -593,7 +593,11 @@ func (t *http2Client) operateHeaders(hDec *hpackDecoder, s *Stream, frame header hDec.state = decodeState{} } }() - endHeaders, err := hDec.decodeClientHTTP2Headers(s, frame) + endHeaders, err := hDec.decodeClientHTTP2Headers(frame) + if s == nil { + // s has been closed. + return nil + } if err != nil { s.write(recvMsg{err: err}) // Something wrong. Stops reading even when there is remaining. @@ -659,16 +663,13 @@ func (t *http2Client) reader() { } switch frame := frame.(type) { case *http2.HeadersFrame: - var ok bool - if curStream, ok = t.getStream(frame); !ok { - continue - } + // operateHeaders has to be invoked regardless the value of curStream + // because the HPACK decoder needs to be updated using the received + // headers. + curStream, _ = t.getStream(frame) endStream := frame.Header().Flags.Has(http2.FlagHeadersEndStream) curStream = t.operateHeaders(hDec, curStream, frame, endStream) case *http2.ContinuationFrame: - if curStream == nil { - continue - } curStream = t.operateHeaders(hDec, curStream, frame, false) case *http2.DataFrame: t.handleData(frame) diff --git a/transport/http2_server.go b/transport/http2_server.go index a1d606f7..0399c3b0 100644 --- a/transport/http2_server.go +++ b/transport/http2_server.go @@ -143,7 +143,11 @@ func (t *http2Server) operateHeaders(hDec *hpackDecoder, s *Stream, frame header hDec.state = decodeState{} } }() - endHeaders, err := hDec.decodeServerHTTP2Headers(s, frame) + endHeaders, err := hDec.decodeServerHTTP2Headers(frame) + if s == nil { + // s has been closed. + return nil + } if err != nil { log.Printf("transport: http2Server.operateHeader found %v", err) if se, ok := err.(StreamError); ok { @@ -266,9 +270,6 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) { endStream := frame.Header().Flags.Has(http2.FlagHeadersEndStream) curStream = t.operateHeaders(hDec, curStream, frame, endStream, handle, &wg) case *http2.ContinuationFrame: - if curStream == nil { - continue - } curStream = t.operateHeaders(hDec, curStream, frame, false, handle, &wg) case *http2.DataFrame: t.handleData(frame) @@ -483,6 +484,7 @@ func (t *http2Server) WriteStatus(s *Stream, statusCode codes.Code, statusDesc s t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: v}) } if err := t.writeHeaders(s, t.hBuf, true); err != nil { + t.Close() return err } t.closeStream(s) diff --git a/transport/http_util.go b/transport/http_util.go index 263bc5d1..fb67a0d2 100644 --- a/transport/http_util.go +++ b/transport/http_util.go @@ -176,7 +176,7 @@ func newHPACKDecoder() *hpackDecoder { return d } -func (d *hpackDecoder) decodeClientHTTP2Headers(s *Stream, frame headerFrame) (endHeaders bool, err error) { +func (d *hpackDecoder) decodeClientHTTP2Headers(frame headerFrame) (endHeaders bool, err error) { d.err = nil _, err = d.h.Write(frame.HeaderBlockFragment()) if err != nil { @@ -196,7 +196,7 @@ func (d *hpackDecoder) decodeClientHTTP2Headers(s *Stream, frame headerFrame) (e return } -func (d *hpackDecoder) decodeServerHTTP2Headers(s *Stream, frame headerFrame) (endHeaders bool, err error) { +func (d *hpackDecoder) decodeServerHTTP2Headers(frame headerFrame) (endHeaders bool, err error) { d.err = nil _, err = d.h.Write(frame.HeaderBlockFragment()) if err != nil {