Merge pull request #570 from bradfitz/meta

transport: use http2.Framer's MetaHeadersFrame functionality
This commit is contained in:
Qi Zhao
2016-02-26 13:37:15 -08:00
4 changed files with 126 additions and 199 deletions

View File

@ -39,6 +39,7 @@ import (
"math" "math"
"net" "net"
"strconv" "strconv"
"strings"
"sync" "sync"
"testing" "testing"
"time" "time"
@ -95,7 +96,7 @@ func (h *testStreamHandler) handleStream(t *testing.T, s *transport.Stream) {
return return
} }
if v != expectedRequest { if v != expectedRequest {
h.t.WriteStatus(s, codes.Internal, string(make([]byte, sizeLargeErr))) h.t.WriteStatus(s, codes.Internal, strings.Repeat("A", sizeLargeErr))
return return
} }
} }

View File

@ -550,14 +550,8 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error {
func (t *http2Client) getStream(f http2.Frame) (*Stream, bool) { func (t *http2Client) getStream(f http2.Frame) (*Stream, bool) {
t.mu.Lock() t.mu.Lock()
defer t.mu.Unlock() defer t.mu.Unlock()
if t.activeStreams == nil { s, ok := t.activeStreams[f.Header().StreamID]
// The transport is closing. return s, ok
return nil, false
}
if s, ok := t.activeStreams[f.Header().StreamID]; ok {
return s, true
}
return nil, false
} }
// updateWindow adjusts the inbound quota for the stream and the transport. // updateWindow adjusts the inbound quota for the stream and the transport.
@ -680,54 +674,49 @@ func (t *http2Client) handleWindowUpdate(f *http2.WindowUpdateFrame) {
} }
} }
// operateHeader takes action on the decoded headers. It returns the current // operateHeaders takes action on the decoded headers.
// stream if there are remaining headers on the wire (in the following func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
// Continuation frame). s, ok := t.getStream(frame)
func (t *http2Client) operateHeaders(hDec *hpackDecoder, s *Stream, frame headerFrame, endStream bool) (pendingStream *Stream) { if !ok {
defer func() { return
if pendingStream == nil {
hDec.state = decodeState{}
} }
}() var state decodeState
endHeaders, err := hDec.decodeClientHTTP2Headers(frame) for _, hf := range frame.Fields {
if s == nil { state.processHeaderField(hf)
// s has been closed.
return nil
} }
if err != nil { if state.err != nil {
s.write(recvMsg{err: err}) s.write(recvMsg{err: state.err})
// Something wrong. Stops reading even when there is remaining. // Something wrong. Stops reading even when there is remaining.
return nil return
}
if !endHeaders {
return s
} }
endStream := frame.StreamEnded()
s.mu.Lock() s.mu.Lock()
if !endStream { if !endStream {
s.recvCompress = hDec.state.encoding s.recvCompress = state.encoding
} }
if !s.headerDone { if !s.headerDone {
if !endStream && len(hDec.state.mdata) > 0 { if !endStream && len(state.mdata) > 0 {
s.header = hDec.state.mdata s.header = state.mdata
} }
close(s.headerChan) close(s.headerChan)
s.headerDone = true s.headerDone = true
} }
if !endStream || s.state == streamDone { if !endStream || s.state == streamDone {
s.mu.Unlock() s.mu.Unlock()
return nil return
} }
if len(hDec.state.mdata) > 0 { if len(state.mdata) > 0 {
s.trailer = hDec.state.mdata s.trailer = state.mdata
} }
s.state = streamDone s.state = streamDone
s.statusCode = hDec.state.statusCode s.statusCode = state.statusCode
s.statusDesc = hDec.state.statusDesc s.statusDesc = state.statusDesc
s.mu.Unlock() s.mu.Unlock()
s.write(recvMsg{err: io.EOF}) s.write(recvMsg{err: io.EOF})
return nil
} }
// reader runs as a separate goroutine in charge of reading data from network // reader runs as a separate goroutine in charge of reading data from network
@ -750,8 +739,6 @@ func (t *http2Client) reader() {
} }
t.handleSettings(sf) t.handleSettings(sf)
hDec := newHPACKDecoder()
var curStream *Stream
// loop to keep reading incoming messages on this transport. // loop to keep reading incoming messages on this transport.
for { for {
frame, err := t.framer.readFrame() frame, err := t.framer.readFrame()
@ -760,15 +747,8 @@ func (t *http2Client) reader() {
return return
} }
switch frame := frame.(type) { switch frame := frame.(type) {
case *http2.HeadersFrame: case *http2.MetaHeadersFrame:
// operateHeaders has to be invoked regardless the value of curStream t.operateHeaders(frame)
// because the HPACK decoder needs to be updated using the received
// headers.
curStream, _ = t.getStream(frame)
endStream := frame.Header().Flags.Has(http2.FlagHeadersEndStream)
curStream = t.operateHeaders(hDec, curStream, frame, endStream)
case *http2.ContinuationFrame:
curStream = t.operateHeaders(hDec, curStream, frame, frame.HeadersEnded())
case *http2.DataFrame: case *http2.DataFrame:
t.handleData(frame) t.handleData(frame)
case *http2.RSTStreamFrame: case *http2.RSTStreamFrame:
@ -866,6 +846,17 @@ func (t *http2Client) Error() <-chan struct{} {
func (t *http2Client) notifyError(err error) { func (t *http2Client) notifyError(err error) {
t.mu.Lock() t.mu.Lock()
defer t.mu.Unlock() defer t.mu.Unlock()
// Abort an active stream if the http2.Framer returns a
// http2.StreamError. This can happen only if the server's response
// is malformed http2.
if se, ok := err.(http2.StreamError); ok {
if s, ok := t.activeStreams[se.StreamID]; ok {
s.write(recvMsg{err: StreamErrorf(http2ErrConvTab[se.Code], "%v", err)})
return
}
}
// make sure t.errorChan is closed only once. // make sure t.errorChan is closed only once.
if t.state == reachable { if t.state == reachable {
t.state = unreachable t.state = unreachable

View File

@ -136,37 +136,38 @@ func newHTTP2Server(conn net.Conn, maxStreams uint32, authInfo credentials.AuthI
return t, nil return t, nil
} }
// operateHeader takes action on the decoded headers. It returns the current // operateHeader takes action on the decoded headers.
// stream if there are remaining headers on the wire (in the following func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(*Stream)) {
// Continuation frame). buf := newRecvBuffer()
func (t *http2Server) operateHeaders(hDec *hpackDecoder, s *Stream, frame headerFrame, endStream bool, handle func(*Stream)) (pendingStream *Stream) { fc := &inFlow{
defer func() { limit: initialWindowSize,
if pendingStream == nil { conn: t.fc,
hDec.state = decodeState{}
} }
}() s := &Stream{
endHeaders, err := hDec.decodeServerHTTP2Headers(frame) id: frame.Header().StreamID,
if s == nil { st: t,
// s has been closed. buf: buf,
return nil fc: fc,
} }
if err != nil {
grpclog.Printf("transport: http2Server.operateHeader found %v", err) var state decodeState
for _, hf := range frame.Fields {
state.processHeaderField(hf)
}
if err := state.err; err != nil {
if se, ok := err.(StreamError); ok { if se, ok := err.(StreamError); ok {
t.controlBuf.put(&resetStream{s.id, statusCodeConvTab[se.Code]}) t.controlBuf.put(&resetStream{s.id, statusCodeConvTab[se.Code]})
} }
return nil return
} }
if endStream {
if frame.StreamEnded() {
// s is just created by the caller. No lock needed. // s is just created by the caller. No lock needed.
s.state = streamReadDone s.state = streamReadDone
} }
if !endHeaders { s.recvCompress = state.encoding
return s if state.timeoutSet {
} s.ctx, s.cancel = context.WithTimeout(context.TODO(), state.timeout)
s.recvCompress = hDec.state.encoding
if hDec.state.timeoutSet {
s.ctx, s.cancel = context.WithTimeout(context.TODO(), hDec.state.timeout)
} else { } else {
s.ctx, s.cancel = context.WithCancel(context.TODO()) s.ctx, s.cancel = context.WithCancel(context.TODO())
} }
@ -183,25 +184,25 @@ func (t *http2Server) operateHeaders(hDec *hpackDecoder, s *Stream, frame header
// back to the client (unary call only). // back to the client (unary call only).
s.ctx = newContextWithStream(s.ctx, s) s.ctx = newContextWithStream(s.ctx, s)
// Attach the received metadata to the context. // Attach the received metadata to the context.
if len(hDec.state.mdata) > 0 { if len(state.mdata) > 0 {
s.ctx = metadata.NewContext(s.ctx, hDec.state.mdata) s.ctx = metadata.NewContext(s.ctx, state.mdata)
} }
s.dec = &recvBufferReader{ s.dec = &recvBufferReader{
ctx: s.ctx, ctx: s.ctx,
recv: s.buf, recv: s.buf,
} }
s.recvCompress = hDec.state.encoding s.recvCompress = state.encoding
s.method = hDec.state.method s.method = state.method
t.mu.Lock() t.mu.Lock()
if t.state != reachable { if t.state != reachable {
t.mu.Unlock() t.mu.Unlock()
return nil return
} }
if uint32(len(t.activeStreams)) >= t.maxStreams { if uint32(len(t.activeStreams)) >= t.maxStreams {
t.mu.Unlock() t.mu.Unlock()
t.controlBuf.put(&resetStream{s.id, http2.ErrCodeRefusedStream}) t.controlBuf.put(&resetStream{s.id, http2.ErrCodeRefusedStream})
return nil return
} }
s.sendQuotaPool = newQuotaPool(int(t.streamSendQuota)) s.sendQuotaPool = newQuotaPool(int(t.streamSendQuota))
t.activeStreams[s.id] = s t.activeStreams[s.id] = s
@ -210,7 +211,6 @@ func (t *http2Server) operateHeaders(hDec *hpackDecoder, s *Stream, frame header
t.updateWindow(s, uint32(n)) t.updateWindow(s, uint32(n))
} }
handle(s) handle(s)
return nil
} }
// HandleStreams receives incoming streams using the given handler. This is // HandleStreams receives incoming streams using the given handler. This is
@ -243,8 +243,6 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) {
} }
t.handleSettings(sf) t.handleSettings(sf)
hDec := newHPACKDecoder()
var curStream *Stream
for { for {
frame, err := t.framer.readFrame() frame, err := t.framer.readFrame()
if err != nil { if err != nil {
@ -252,7 +250,7 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) {
return return
} }
switch frame := frame.(type) { switch frame := frame.(type) {
case *http2.HeadersFrame: case *http2.MetaHeadersFrame:
id := frame.Header().StreamID id := frame.Header().StreamID
if id%2 != 1 || id <= t.maxStreamID { if id%2 != 1 || id <= t.maxStreamID {
// illegal gRPC stream id. // illegal gRPC stream id.
@ -261,21 +259,7 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) {
break break
} }
t.maxStreamID = id t.maxStreamID = id
buf := newRecvBuffer() t.operateHeaders(frame, handle)
fc := &inFlow{
limit: initialWindowSize,
conn: t.fc,
}
curStream = &Stream{
id: frame.Header().StreamID,
st: t,
buf: buf,
fc: fc,
}
endStream := frame.Header().Flags.Has(http2.FlagHeadersEndStream)
curStream = t.operateHeaders(hDec, curStream, frame, endStream, handle)
case *http2.ContinuationFrame:
curStream = t.operateHeaders(hDec, curStream, frame, frame.HeadersEnded(), handle)
case *http2.DataFrame: case *http2.DataFrame:
t.handleData(frame) t.handleData(frame)
case *http2.RSTStreamFrame: case *http2.RSTStreamFrame:

View File

@ -90,6 +90,8 @@ var (
// Records the states during HPACK decoding. Must be reset once the // Records the states during HPACK decoding. Must be reset once the
// decoding of the entire headers are finished. // decoding of the entire headers are finished.
type decodeState struct { type decodeState struct {
err error // first error encountered decoding
encoding string encoding string
// statusCode caches the stream status received from the trailer // statusCode caches the stream status received from the trailer
// the server sent. Client side only. // the server sent. Client side only.
@ -103,20 +105,6 @@ type decodeState struct {
mdata map[string][]string mdata map[string][]string
} }
// An hpackDecoder decodes HTTP2 headers which may span multiple frames.
type hpackDecoder struct {
h *hpack.Decoder
state decodeState
err error // The err when decoding
}
// A headerFrame is either a http2.HeaderFrame or http2.ContinuationFrame.
type headerFrame interface {
Header() http2.FrameHeader
HeaderBlockFragment() []byte
HeadersEnded() bool
}
// isReservedHeader checks whether hdr belongs to HTTP2 headers // isReservedHeader checks whether hdr belongs to HTTP2 headers
// reserved by gRPC protocol. Any other headers are classified as the // reserved by gRPC protocol. Any other headers are classified as the
// user-specified metadata. // user-specified metadata.
@ -138,36 +126,40 @@ func isReservedHeader(hdr string) bool {
} }
} }
func newHPACKDecoder() *hpackDecoder { func (d *decodeState) setErr(err error) {
d := &hpackDecoder{} if d.err == nil {
d.h = hpack.NewDecoder(http2InitHeaderTableSize, func(f hpack.HeaderField) { d.err = err
}
}
func (d *decodeState) processHeaderField(f hpack.HeaderField) {
switch f.Name { switch f.Name {
case "content-type": case "content-type":
if !strings.Contains(f.Value, "application/grpc") { if !strings.Contains(f.Value, "application/grpc") {
d.err = StreamErrorf(codes.FailedPrecondition, "transport: received the unexpected content-type %q", f.Value) d.setErr(StreamErrorf(codes.FailedPrecondition, "transport: received the unexpected content-type %q", f.Value))
return return
} }
case "grpc-encoding": case "grpc-encoding":
d.state.encoding = f.Value d.encoding = f.Value
case "grpc-status": case "grpc-status":
code, err := strconv.Atoi(f.Value) code, err := strconv.Atoi(f.Value)
if err != nil { if err != nil {
d.err = StreamErrorf(codes.Internal, "transport: malformed grpc-status: %v", err) d.setErr(StreamErrorf(codes.Internal, "transport: malformed grpc-status: %v", err))
return return
} }
d.state.statusCode = codes.Code(code) d.statusCode = codes.Code(code)
case "grpc-message": case "grpc-message":
d.state.statusDesc = f.Value d.statusDesc = f.Value
case "grpc-timeout": case "grpc-timeout":
d.state.timeoutSet = true d.timeoutSet = true
var err error var err error
d.state.timeout, err = timeoutDecode(f.Value) d.timeout, err = timeoutDecode(f.Value)
if err != nil { if err != nil {
d.err = StreamErrorf(codes.Internal, "transport: malformed time-out: %v", err) d.setErr(StreamErrorf(codes.Internal, "transport: malformed time-out: %v", err))
return return
} }
case ":path": case ":path":
d.state.method = f.Value d.method = f.Value
default: default:
if !isReservedHeader(f.Name) { if !isReservedHeader(f.Name) {
if f.Name == "user-agent" { if f.Name == "user-agent" {
@ -179,59 +171,17 @@ func newHPACKDecoder() *hpackDecoder {
// Extract the application user agent string. // Extract the application user agent string.
f.Value = f.Value[:i] f.Value = f.Value[:i]
} }
if d.state.mdata == nil { if d.mdata == nil {
d.state.mdata = make(map[string][]string) d.mdata = make(map[string][]string)
} }
k, v, err := metadata.DecodeKeyValue(f.Name, f.Value) k, v, err := metadata.DecodeKeyValue(f.Name, f.Value)
if err != nil { if err != nil {
grpclog.Printf("Failed to decode (%q, %q): %v", f.Name, f.Value, err) grpclog.Printf("Failed to decode (%q, %q): %v", f.Name, f.Value, err)
return return
} }
d.state.mdata[k] = append(d.state.mdata[k], v) d.mdata[k] = append(d.mdata[k], v)
} }
} }
})
return d
}
func (d *hpackDecoder) decodeClientHTTP2Headers(frame headerFrame) (endHeaders bool, err error) {
d.err = nil
_, err = d.h.Write(frame.HeaderBlockFragment())
if err != nil {
err = StreamErrorf(codes.Internal, "transport: HPACK header decode error: %v", err)
}
if frame.HeadersEnded() {
if closeErr := d.h.Close(); closeErr != nil && err == nil {
err = StreamErrorf(codes.Internal, "transport: HPACK decoder close error: %v", closeErr)
}
endHeaders = true
}
if err == nil && d.err != nil {
err = d.err
}
return
}
func (d *hpackDecoder) decodeServerHTTP2Headers(frame headerFrame) (endHeaders bool, err error) {
d.err = nil
_, err = d.h.Write(frame.HeaderBlockFragment())
if err != nil {
err = StreamErrorf(codes.Internal, "transport: HPACK header decode error: %v", err)
}
if frame.HeadersEnded() {
if closeErr := d.h.Close(); closeErr != nil && err == nil {
err = StreamErrorf(codes.Internal, "transport: HPACK decoder close error: %v", closeErr)
}
endHeaders = true
}
if err == nil && d.err != nil {
err = d.err
}
return
} }
type timeoutUnit uint8 type timeoutUnit uint8
@ -326,6 +276,7 @@ func newFramer(conn net.Conn) *framer {
writer: bufio.NewWriterSize(conn, http2IOBufSize), writer: bufio.NewWriterSize(conn, http2IOBufSize),
} }
f.fr = http2.NewFramer(f.writer, f.reader) f.fr = http2.NewFramer(f.writer, f.reader)
f.fr.ReadMetaHeaders = hpack.NewDecoder(http2InitHeaderTableSize, nil)
return f return f
} }