Merge pull request #876 from menghanl/header_close
Close headerChan if processHeaderField sets error
This commit is contained in:
@ -852,6 +852,12 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
|
|||||||
state.processHeaderField(hf)
|
state.processHeaderField(hf)
|
||||||
}
|
}
|
||||||
if state.err != nil {
|
if state.err != nil {
|
||||||
|
s.mu.Lock()
|
||||||
|
if !s.headerDone {
|
||||||
|
close(s.headerChan)
|
||||||
|
s.headerDone = true
|
||||||
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
s.write(recvMsg{err: state.err})
|
s.write(recvMsg{err: state.err})
|
||||||
// Something wrong. Stops reading even when there is remaining.
|
// Something wrong. Stops reading even when there is remaining.
|
||||||
return
|
return
|
||||||
|
@ -40,12 +40,14 @@ import (
|
|||||||
"math"
|
"math"
|
||||||
"net"
|
"net"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/net/context"
|
"golang.org/x/net/context"
|
||||||
"golang.org/x/net/http2"
|
"golang.org/x/net/http2"
|
||||||
|
"golang.org/x/net/http2/hpack"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -58,14 +60,15 @@ type server struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
expectedRequest = []byte("ping")
|
expectedRequest = []byte("ping")
|
||||||
expectedResponse = []byte("pong")
|
expectedResponse = []byte("pong")
|
||||||
expectedRequestLarge = make([]byte, initialWindowSize*2)
|
expectedRequestLarge = make([]byte, initialWindowSize*2)
|
||||||
expectedResponseLarge = make([]byte, initialWindowSize*2)
|
expectedResponseLarge = make([]byte, initialWindowSize*2)
|
||||||
|
expectedInvalidHeaderField = "invalid/content-type"
|
||||||
)
|
)
|
||||||
|
|
||||||
type testStreamHandler struct {
|
type testStreamHandler struct {
|
||||||
t ServerTransport
|
t *http2Server
|
||||||
}
|
}
|
||||||
|
|
||||||
type hType int
|
type hType int
|
||||||
@ -75,6 +78,7 @@ const (
|
|||||||
suspended
|
suspended
|
||||||
misbehaved
|
misbehaved
|
||||||
encodingRequiredStatus
|
encodingRequiredStatus
|
||||||
|
invalidHeaderField
|
||||||
)
|
)
|
||||||
|
|
||||||
func (h *testStreamHandler) handleStream(t *testing.T, s *Stream) {
|
func (h *testStreamHandler) handleStream(t *testing.T, s *Stream) {
|
||||||
@ -140,6 +144,16 @@ func (h *testStreamHandler) handleStreamEncodingRequiredStatus(t *testing.T, s *
|
|||||||
h.t.WriteStatus(s, encodingTestStatusCode, encodingTestStatusDesc)
|
h.t.WriteStatus(s, encodingTestStatusCode, encodingTestStatusDesc)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *testStreamHandler) handleStreamInvalidHeaderField(t *testing.T, s *Stream) {
|
||||||
|
<-h.t.writableChan
|
||||||
|
h.t.hBuf.Reset()
|
||||||
|
h.t.hEnc.WriteField(hpack.HeaderField{Name: "content-type", Value: expectedInvalidHeaderField})
|
||||||
|
if err := h.t.writeHeaders(s, h.t.hBuf, false); err != nil {
|
||||||
|
t.Fatalf("Failed to write headers: %v", err)
|
||||||
|
}
|
||||||
|
h.t.writableChan <- 0
|
||||||
|
}
|
||||||
|
|
||||||
// start starts server. Other goroutines should block on s.readyChan for further operations.
|
// start starts server. Other goroutines should block on s.readyChan for further operations.
|
||||||
func (s *server) start(t *testing.T, port int, maxStreams uint32, ht hType) {
|
func (s *server) start(t *testing.T, port int, maxStreams uint32, ht hType) {
|
||||||
var err error
|
var err error
|
||||||
@ -177,7 +191,7 @@ func (s *server) start(t *testing.T, port int, maxStreams uint32, ht hType) {
|
|||||||
}
|
}
|
||||||
s.conns[transport] = true
|
s.conns[transport] = true
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
h := &testStreamHandler{transport}
|
h := &testStreamHandler{transport.(*http2Server)}
|
||||||
switch ht {
|
switch ht {
|
||||||
case suspended:
|
case suspended:
|
||||||
go transport.HandleStreams(h.handleStreamSuspension)
|
go transport.HandleStreams(h.handleStreamSuspension)
|
||||||
@ -189,6 +203,10 @@ func (s *server) start(t *testing.T, port int, maxStreams uint32, ht hType) {
|
|||||||
go transport.HandleStreams(func(s *Stream) {
|
go transport.HandleStreams(func(s *Stream) {
|
||||||
go h.handleStreamEncodingRequiredStatus(t, s)
|
go h.handleStreamEncodingRequiredStatus(t, s)
|
||||||
})
|
})
|
||||||
|
case invalidHeaderField:
|
||||||
|
go transport.HandleStreams(func(s *Stream) {
|
||||||
|
go h.handleStreamInvalidHeaderField(t, s)
|
||||||
|
})
|
||||||
default:
|
default:
|
||||||
go transport.HandleStreams(func(s *Stream) {
|
go transport.HandleStreams(func(s *Stream) {
|
||||||
go h.handleStream(t, s)
|
go h.handleStream(t, s)
|
||||||
@ -752,6 +770,32 @@ func TestEncodingRequiredStatus(t *testing.T) {
|
|||||||
server.stop()
|
server.stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestInvalidHeaderField(t *testing.T) {
|
||||||
|
server, ct := setUp(t, 0, math.MaxUint32, invalidHeaderField)
|
||||||
|
callHdr := &CallHdr{
|
||||||
|
Host: "localhost",
|
||||||
|
Method: "foo",
|
||||||
|
}
|
||||||
|
s, err := ct.NewStream(context.Background(), callHdr)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
opts := Options{
|
||||||
|
Last: true,
|
||||||
|
Delay: false,
|
||||||
|
}
|
||||||
|
if err := ct.Write(s, expectedRequest, &opts); err != nil && err != io.EOF {
|
||||||
|
t.Fatalf("Failed to write the request: %v", err)
|
||||||
|
}
|
||||||
|
p := make([]byte, http2MaxFrameLen)
|
||||||
|
_, err = s.dec.Read(p)
|
||||||
|
if se, ok := err.(StreamError); !ok || se.Code != codes.FailedPrecondition || !strings.Contains(err.Error(), expectedInvalidHeaderField) {
|
||||||
|
t.Fatalf("Read got error %v, want error with code %s and contains %q", err, codes.FailedPrecondition, expectedInvalidHeaderField)
|
||||||
|
}
|
||||||
|
ct.Close()
|
||||||
|
server.stop()
|
||||||
|
}
|
||||||
|
|
||||||
func TestStreamContext(t *testing.T) {
|
func TestStreamContext(t *testing.T) {
|
||||||
expectedStream := &Stream{}
|
expectedStream := &Stream{}
|
||||||
ctx := newContextWithStream(context.Background(), expectedStream)
|
ctx := newContextWithStream(context.Background(), expectedStream)
|
||||||
|
Reference in New Issue
Block a user