Merge pull request #876 from menghanl/header_close

Close headerChan if processHeaderField sets error
This commit is contained in:
Qi Zhao
2016-09-06 13:15:33 -07:00
committed by GitHub
2 changed files with 56 additions and 6 deletions

View File

@ -852,6 +852,12 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
state.processHeaderField(hf) state.processHeaderField(hf)
} }
if state.err != nil { if state.err != nil {
s.mu.Lock()
if !s.headerDone {
close(s.headerChan)
s.headerDone = true
}
s.mu.Unlock()
s.write(recvMsg{err: state.err}) s.write(recvMsg{err: state.err})
// Something wrong. Stops reading even when there is remaining. // Something wrong. Stops reading even when there is remaining.
return return

View File

@ -40,12 +40,14 @@ import (
"math" "math"
"net" "net"
"strconv" "strconv"
"strings"
"sync" "sync"
"testing" "testing"
"time" "time"
"golang.org/x/net/context" "golang.org/x/net/context"
"golang.org/x/net/http2" "golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
) )
@ -58,14 +60,15 @@ type server struct {
} }
var ( var (
expectedRequest = []byte("ping") expectedRequest = []byte("ping")
expectedResponse = []byte("pong") expectedResponse = []byte("pong")
expectedRequestLarge = make([]byte, initialWindowSize*2) expectedRequestLarge = make([]byte, initialWindowSize*2)
expectedResponseLarge = make([]byte, initialWindowSize*2) expectedResponseLarge = make([]byte, initialWindowSize*2)
expectedInvalidHeaderField = "invalid/content-type"
) )
type testStreamHandler struct { type testStreamHandler struct {
t ServerTransport t *http2Server
} }
type hType int type hType int
@ -75,6 +78,7 @@ const (
suspended suspended
misbehaved misbehaved
encodingRequiredStatus encodingRequiredStatus
invalidHeaderField
) )
func (h *testStreamHandler) handleStream(t *testing.T, s *Stream) { func (h *testStreamHandler) handleStream(t *testing.T, s *Stream) {
@ -140,6 +144,16 @@ func (h *testStreamHandler) handleStreamEncodingRequiredStatus(t *testing.T, s *
h.t.WriteStatus(s, encodingTestStatusCode, encodingTestStatusDesc) h.t.WriteStatus(s, encodingTestStatusCode, encodingTestStatusDesc)
} }
func (h *testStreamHandler) handleStreamInvalidHeaderField(t *testing.T, s *Stream) {
<-h.t.writableChan
h.t.hBuf.Reset()
h.t.hEnc.WriteField(hpack.HeaderField{Name: "content-type", Value: expectedInvalidHeaderField})
if err := h.t.writeHeaders(s, h.t.hBuf, false); err != nil {
t.Fatalf("Failed to write headers: %v", err)
}
h.t.writableChan <- 0
}
// start starts server. Other goroutines should block on s.readyChan for further operations. // start starts server. Other goroutines should block on s.readyChan for further operations.
func (s *server) start(t *testing.T, port int, maxStreams uint32, ht hType) { func (s *server) start(t *testing.T, port int, maxStreams uint32, ht hType) {
var err error var err error
@ -177,7 +191,7 @@ func (s *server) start(t *testing.T, port int, maxStreams uint32, ht hType) {
} }
s.conns[transport] = true s.conns[transport] = true
s.mu.Unlock() s.mu.Unlock()
h := &testStreamHandler{transport} h := &testStreamHandler{transport.(*http2Server)}
switch ht { switch ht {
case suspended: case suspended:
go transport.HandleStreams(h.handleStreamSuspension) go transport.HandleStreams(h.handleStreamSuspension)
@ -189,6 +203,10 @@ func (s *server) start(t *testing.T, port int, maxStreams uint32, ht hType) {
go transport.HandleStreams(func(s *Stream) { go transport.HandleStreams(func(s *Stream) {
go h.handleStreamEncodingRequiredStatus(t, s) go h.handleStreamEncodingRequiredStatus(t, s)
}) })
case invalidHeaderField:
go transport.HandleStreams(func(s *Stream) {
go h.handleStreamInvalidHeaderField(t, s)
})
default: default:
go transport.HandleStreams(func(s *Stream) { go transport.HandleStreams(func(s *Stream) {
go h.handleStream(t, s) go h.handleStream(t, s)
@ -752,6 +770,32 @@ func TestEncodingRequiredStatus(t *testing.T) {
server.stop() server.stop()
} }
func TestInvalidHeaderField(t *testing.T) {
server, ct := setUp(t, 0, math.MaxUint32, invalidHeaderField)
callHdr := &CallHdr{
Host: "localhost",
Method: "foo",
}
s, err := ct.NewStream(context.Background(), callHdr)
if err != nil {
return
}
opts := Options{
Last: true,
Delay: false,
}
if err := ct.Write(s, expectedRequest, &opts); err != nil && err != io.EOF {
t.Fatalf("Failed to write the request: %v", err)
}
p := make([]byte, http2MaxFrameLen)
_, err = s.dec.Read(p)
if se, ok := err.(StreamError); !ok || se.Code != codes.FailedPrecondition || !strings.Contains(err.Error(), expectedInvalidHeaderField) {
t.Fatalf("Read got error %v, want error with code %s and contains %q", err, codes.FailedPrecondition, expectedInvalidHeaderField)
}
ct.Close()
server.stop()
}
func TestStreamContext(t *testing.T) { func TestStreamContext(t *testing.T) {
expectedStream := &Stream{} expectedStream := &Stream{}
ctx := newContextWithStream(context.Background(), expectedStream) ctx := newContextWithStream(context.Background(), expectedStream)