From f3c6dc545fecad27bd1c8c2301345f3e6ae72c01 Mon Sep 17 00:00:00 2001 From: iamqizhao Date: Mon, 29 Feb 2016 17:45:54 -0800 Subject: [PATCH] Fix err handling of malformed http2 polish the commit --- transport/http2_client.go | 39 ++++++++++++++++++++++++------------- transport/transport_test.go | 37 +++++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 13 deletions(-) diff --git a/transport/http2_client.go b/transport/http2_client.go index 19780539..66fabbba 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -719,6 +719,16 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { s.write(recvMsg{err: io.EOF}) } +func handleMalformedHTTP2(s *Stream, err http2.StreamError) { + s.mu.Lock() + if !s.headerDone { + close(s.headerChan) + s.headerDone = true + } + s.mu.Unlock() + s.write(recvMsg{err: StreamErrorf(http2ErrConvTab[err.Code], "%v", err)}) +} + // reader runs as a separate goroutine in charge of reading data from network // connection. // @@ -743,8 +753,22 @@ func (t *http2Client) reader() { for { frame, err := t.framer.readFrame() if err != nil { - t.notifyError(err) - return + // Abort an active stream if the http2.Framer returns a + // http2.StreamError. This can happen only if the server's response + // is malformed http2. + if se, ok := err.(http2.StreamError); ok { + t.mu.Lock() + s := t.activeStreams[se.StreamID] + t.mu.Unlock() + if s != nil { + handleMalformedHTTP2(s, se) + } + continue + } else { + // Transport error. + t.notifyError(err) + return + } } switch frame := frame.(type) { case *http2.MetaHeadersFrame: @@ -846,17 +870,6 @@ func (t *http2Client) Error() <-chan struct{} { func (t *http2Client) notifyError(err error) { t.mu.Lock() defer t.mu.Unlock() - - // Abort an active stream if the http2.Framer returns a - // http2.StreamError. This can happen only if the server's response - // is malformed http2. - if se, ok := err.(http2.StreamError); ok { - if s, ok := t.activeStreams[se.StreamID]; ok { - s.write(recvMsg{err: StreamErrorf(http2ErrConvTab[se.Code], "%v", err)}) - return - } - } - // make sure t.errorChan is closed only once. if t.state == reachable { t.state = unreachable diff --git a/transport/transport_test.go b/transport/transport_test.go index cb11ab2f..b8655c02 100644 --- a/transport/transport_test.go +++ b/transport/transport_test.go @@ -75,6 +75,7 @@ const ( normal hType = iota suspended misbehaved + malformedStatus ) func (h *testStreamHandler) handleStream(t *testing.T, s *Stream) { @@ -127,6 +128,12 @@ func (h *testStreamHandler) handleStreamMisbehave(t *testing.T, s *Stream) { } } +func (h *testStreamHandler) handleStreamMalformedStatus(t *testing.T, s *Stream) { + // raw newline is not accepted by http2 framer and a http2.StreamError is + // generated. + h.t.WriteStatus(s, codes.Internal, "\n") +} + // start starts server. Other goroutines should block on s.readyChan for futher operations. func (s *server) start(t *testing.T, port int, maxStreams uint32, ht hType) { var err error @@ -172,6 +179,10 @@ func (s *server) start(t *testing.T, port int, maxStreams uint32, ht hType) { go transport.HandleStreams(func(s *Stream) { go h.handleStreamMisbehave(t, s) }) + case malformedStatus: + go transport.HandleStreams(func(s *Stream) { + go h.handleStreamMalformedStatus(t, s) + }) default: go transport.HandleStreams(func(s *Stream) { go h.handleStream(t, s) @@ -652,6 +663,32 @@ func TestClientWithMisbehavedServer(t *testing.T) { server.stop() } +func TestMalformedStatus(t *testing.T) { + server, ct := setUp(t, 0, math.MaxUint32, malformedStatus) + callHdr := &CallHdr{ + Host: "localhost", + Method: "foo", + } + s, err := ct.NewStream(context.Background(), callHdr) + if err != nil { + return + } + opts := Options{ + Last: true, + Delay: false, + } + if err := ct.Write(s, expectedRequest, &opts); err != nil { + t.Fatalf("Failed to write the request: %v", err) + } + p := make([]byte, http2MaxFrameLen) + expectedErr := StreamErrorf(codes.Internal, "stream error: stream ID 1; PROTOCOL_ERROR") + if _, err = s.dec.Read(p); err != expectedErr { + t.Fatalf("Read the err %v, want %v", err, expectedErr) + } + ct.Close() + server.stop() +} + func TestStreamContext(t *testing.T) { expectedStream := Stream{} ctx := newContextWithStream(context.Background(), &expectedStream)