Merge pull request #876 from menghanl/header_close
Close headerChan if processHeaderField sets error
This commit is contained in:
@ -852,6 +852,12 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
|
||||
state.processHeaderField(hf)
|
||||
}
|
||||
if state.err != nil {
|
||||
s.mu.Lock()
|
||||
if !s.headerDone {
|
||||
close(s.headerChan)
|
||||
s.headerDone = true
|
||||
}
|
||||
s.mu.Unlock()
|
||||
s.write(recvMsg{err: state.err})
|
||||
// Something wrong. Stops reading even when there is remaining.
|
||||
return
|
||||
|
@ -40,12 +40,14 @@ import (
|
||||
"math"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/hpack"
|
||||
"google.golang.org/grpc/codes"
|
||||
)
|
||||
|
||||
@ -58,14 +60,15 @@ type server struct {
|
||||
}
|
||||
|
||||
var (
|
||||
expectedRequest = []byte("ping")
|
||||
expectedResponse = []byte("pong")
|
||||
expectedRequestLarge = make([]byte, initialWindowSize*2)
|
||||
expectedResponseLarge = make([]byte, initialWindowSize*2)
|
||||
expectedRequest = []byte("ping")
|
||||
expectedResponse = []byte("pong")
|
||||
expectedRequestLarge = make([]byte, initialWindowSize*2)
|
||||
expectedResponseLarge = make([]byte, initialWindowSize*2)
|
||||
expectedInvalidHeaderField = "invalid/content-type"
|
||||
)
|
||||
|
||||
type testStreamHandler struct {
|
||||
t ServerTransport
|
||||
t *http2Server
|
||||
}
|
||||
|
||||
type hType int
|
||||
@ -75,6 +78,7 @@ const (
|
||||
suspended
|
||||
misbehaved
|
||||
encodingRequiredStatus
|
||||
invalidHeaderField
|
||||
)
|
||||
|
||||
func (h *testStreamHandler) handleStream(t *testing.T, s *Stream) {
|
||||
@ -140,6 +144,16 @@ func (h *testStreamHandler) handleStreamEncodingRequiredStatus(t *testing.T, s *
|
||||
h.t.WriteStatus(s, encodingTestStatusCode, encodingTestStatusDesc)
|
||||
}
|
||||
|
||||
func (h *testStreamHandler) handleStreamInvalidHeaderField(t *testing.T, s *Stream) {
|
||||
<-h.t.writableChan
|
||||
h.t.hBuf.Reset()
|
||||
h.t.hEnc.WriteField(hpack.HeaderField{Name: "content-type", Value: expectedInvalidHeaderField})
|
||||
if err := h.t.writeHeaders(s, h.t.hBuf, false); err != nil {
|
||||
t.Fatalf("Failed to write headers: %v", err)
|
||||
}
|
||||
h.t.writableChan <- 0
|
||||
}
|
||||
|
||||
// start starts server. Other goroutines should block on s.readyChan for further operations.
|
||||
func (s *server) start(t *testing.T, port int, maxStreams uint32, ht hType) {
|
||||
var err error
|
||||
@ -177,7 +191,7 @@ func (s *server) start(t *testing.T, port int, maxStreams uint32, ht hType) {
|
||||
}
|
||||
s.conns[transport] = true
|
||||
s.mu.Unlock()
|
||||
h := &testStreamHandler{transport}
|
||||
h := &testStreamHandler{transport.(*http2Server)}
|
||||
switch ht {
|
||||
case suspended:
|
||||
go transport.HandleStreams(h.handleStreamSuspension)
|
||||
@ -189,6 +203,10 @@ func (s *server) start(t *testing.T, port int, maxStreams uint32, ht hType) {
|
||||
go transport.HandleStreams(func(s *Stream) {
|
||||
go h.handleStreamEncodingRequiredStatus(t, s)
|
||||
})
|
||||
case invalidHeaderField:
|
||||
go transport.HandleStreams(func(s *Stream) {
|
||||
go h.handleStreamInvalidHeaderField(t, s)
|
||||
})
|
||||
default:
|
||||
go transport.HandleStreams(func(s *Stream) {
|
||||
go h.handleStream(t, s)
|
||||
@ -752,6 +770,32 @@ func TestEncodingRequiredStatus(t *testing.T) {
|
||||
server.stop()
|
||||
}
|
||||
|
||||
func TestInvalidHeaderField(t *testing.T) {
|
||||
server, ct := setUp(t, 0, math.MaxUint32, invalidHeaderField)
|
||||
callHdr := &CallHdr{
|
||||
Host: "localhost",
|
||||
Method: "foo",
|
||||
}
|
||||
s, err := ct.NewStream(context.Background(), callHdr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
opts := Options{
|
||||
Last: true,
|
||||
Delay: false,
|
||||
}
|
||||
if err := ct.Write(s, expectedRequest, &opts); err != nil && err != io.EOF {
|
||||
t.Fatalf("Failed to write the request: %v", err)
|
||||
}
|
||||
p := make([]byte, http2MaxFrameLen)
|
||||
_, err = s.dec.Read(p)
|
||||
if se, ok := err.(StreamError); !ok || se.Code != codes.FailedPrecondition || !strings.Contains(err.Error(), expectedInvalidHeaderField) {
|
||||
t.Fatalf("Read got error %v, want error with code %s and contains %q", err, codes.FailedPrecondition, expectedInvalidHeaderField)
|
||||
}
|
||||
ct.Close()
|
||||
server.stop()
|
||||
}
|
||||
|
||||
func TestStreamContext(t *testing.T) {
|
||||
expectedStream := &Stream{}
|
||||
ctx := newContextWithStream(context.Background(), expectedStream)
|
||||
|
Reference in New Issue
Block a user