diff --git a/transport/http2_client.go b/transport/http2_client.go index 4892faab..e9cf6bec 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -852,6 +852,12 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { state.processHeaderField(hf) } if state.err != nil { + s.mu.Lock() + if !s.headerDone { + close(s.headerChan) + s.headerDone = true + } + s.mu.Unlock() s.write(recvMsg{err: state.err}) // Something wrong. Stops reading even when there is remaining. return diff --git a/transport/transport_test.go b/transport/transport_test.go index ac38c7a2..01d95e47 100644 --- a/transport/transport_test.go +++ b/transport/transport_test.go @@ -40,12 +40,14 @@ import ( "math" "net" "strconv" + "strings" "sync" "testing" "time" "golang.org/x/net/context" "golang.org/x/net/http2" + "golang.org/x/net/http2/hpack" "google.golang.org/grpc/codes" ) @@ -58,14 +60,15 @@ type server struct { } var ( - expectedRequest = []byte("ping") - expectedResponse = []byte("pong") - expectedRequestLarge = make([]byte, initialWindowSize*2) - expectedResponseLarge = make([]byte, initialWindowSize*2) + expectedRequest = []byte("ping") + expectedResponse = []byte("pong") + expectedRequestLarge = make([]byte, initialWindowSize*2) + expectedResponseLarge = make([]byte, initialWindowSize*2) + expectedInvalidHeaderField = "invalid/content-type" ) type testStreamHandler struct { - t ServerTransport + t *http2Server } type hType int @@ -75,6 +78,7 @@ const ( suspended misbehaved encodingRequiredStatus + invalidHeaderField ) func (h *testStreamHandler) handleStream(t *testing.T, s *Stream) { @@ -140,6 +144,16 @@ func (h *testStreamHandler) handleStreamEncodingRequiredStatus(t *testing.T, s * h.t.WriteStatus(s, encodingTestStatusCode, encodingTestStatusDesc) } +func (h *testStreamHandler) handleStreamInvalidHeaderField(t *testing.T, s *Stream) { + <-h.t.writableChan + h.t.hBuf.Reset() + h.t.hEnc.WriteField(hpack.HeaderField{Name: "content-type", Value: expectedInvalidHeaderField}) + if err := h.t.writeHeaders(s, h.t.hBuf, false); err != nil { + t.Fatalf("Failed to write headers: %v", err) + } + h.t.writableChan <- 0 +} + // start starts server. Other goroutines should block on s.readyChan for further operations. func (s *server) start(t *testing.T, port int, maxStreams uint32, ht hType) { var err error @@ -177,7 +191,7 @@ func (s *server) start(t *testing.T, port int, maxStreams uint32, ht hType) { } s.conns[transport] = true s.mu.Unlock() - h := &testStreamHandler{transport} + h := &testStreamHandler{transport.(*http2Server)} switch ht { case suspended: go transport.HandleStreams(h.handleStreamSuspension) @@ -189,6 +203,10 @@ func (s *server) start(t *testing.T, port int, maxStreams uint32, ht hType) { go transport.HandleStreams(func(s *Stream) { go h.handleStreamEncodingRequiredStatus(t, s) }) + case invalidHeaderField: + go transport.HandleStreams(func(s *Stream) { + go h.handleStreamInvalidHeaderField(t, s) + }) default: go transport.HandleStreams(func(s *Stream) { go h.handleStream(t, s) @@ -752,6 +770,32 @@ func TestEncodingRequiredStatus(t *testing.T) { server.stop() } +func TestInvalidHeaderField(t *testing.T) { + server, ct := setUp(t, 0, math.MaxUint32, invalidHeaderField) + callHdr := &CallHdr{ + Host: "localhost", + Method: "foo", + } + s, err := ct.NewStream(context.Background(), callHdr) + if err != nil { + return + } + opts := Options{ + Last: true, + Delay: false, + } + if err := ct.Write(s, expectedRequest, &opts); err != nil && err != io.EOF { + t.Fatalf("Failed to write the request: %v", err) + } + p := make([]byte, http2MaxFrameLen) + _, err = s.dec.Read(p) + if se, ok := err.(StreamError); !ok || se.Code != codes.FailedPrecondition || !strings.Contains(err.Error(), expectedInvalidHeaderField) { + t.Fatalf("Read got error %v, want error with code %s and contains %q", err, codes.FailedPrecondition, expectedInvalidHeaderField) + } + ct.Close() + server.stop() +} + func TestStreamContext(t *testing.T) { expectedStream := &Stream{} ctx := newContextWithStream(context.Background(), expectedStream)