Use the same hpack encoder on a transport and share it between RPCs. (#1536)

This commit is contained in:
mmukhi
2017-09-28 13:37:13 -07:00
committed by dfawley
parent eaf555a871
commit 6014154b60
6 changed files with 149 additions and 149 deletions

View File

@ -135,8 +135,6 @@ func (s *OutPayload) isRPCStats() {}
type OutHeader struct { type OutHeader struct {
// Client is true if this OutHeader is from client side. // Client is true if this OutHeader is from client side.
Client bool Client bool
// WireLength is the wire length of header.
WireLength int
// The following fields are valid only if Client is true. // The following fields are valid only if Client is true.
// FullMethod is the full RPC method string, i.e., /package.service/method. // FullMethod is the full RPC method string, i.e., /package.service/method.

View File

@ -444,10 +444,6 @@ func checkInHeader(t *testing.T, d *gotData, e *expectedData) {
if d.ctx == nil { if d.ctx == nil {
t.Fatalf("d.ctx = nil, want <non-nil>") t.Fatalf("d.ctx = nil, want <non-nil>")
} }
// TODO check real length, not just > 0.
if st.WireLength <= 0 {
t.Fatalf("st.Lenght = 0, want > 0")
}
if !d.client { if !d.client {
if st.FullMethod != e.method { if st.FullMethod != e.method {
t.Fatalf("st.FullMethod = %s, want %v", st.FullMethod, e.method) t.Fatalf("st.FullMethod = %s, want %v", st.FullMethod, e.method)
@ -530,18 +526,13 @@ func checkInPayload(t *testing.T, d *gotData, e *expectedData) {
func checkInTrailer(t *testing.T, d *gotData, e *expectedData) { func checkInTrailer(t *testing.T, d *gotData, e *expectedData) {
var ( var (
ok bool ok bool
st *stats.InTrailer
) )
if st, ok = d.s.(*stats.InTrailer); !ok { if _, ok = d.s.(*stats.InTrailer); !ok {
t.Fatalf("got %T, want InTrailer", d.s) t.Fatalf("got %T, want InTrailer", d.s)
} }
if d.ctx == nil { if d.ctx == nil {
t.Fatalf("d.ctx = nil, want <non-nil>") t.Fatalf("d.ctx = nil, want <non-nil>")
} }
// TODO check real length, not just > 0.
if st.WireLength <= 0 {
t.Fatalf("st.Lenght = 0, want > 0")
}
} }
func checkOutHeader(t *testing.T, d *gotData, e *expectedData) { func checkOutHeader(t *testing.T, d *gotData, e *expectedData) {
@ -555,10 +546,6 @@ func checkOutHeader(t *testing.T, d *gotData, e *expectedData) {
if d.ctx == nil { if d.ctx == nil {
t.Fatalf("d.ctx = nil, want <non-nil>") t.Fatalf("d.ctx = nil, want <non-nil>")
} }
// TODO check real length, not just > 0.
if st.WireLength <= 0 {
t.Fatalf("st.Lenght = 0, want > 0")
}
if d.client { if d.client {
if st.FullMethod != e.method { if st.FullMethod != e.method {
t.Fatalf("st.FullMethod = %s, want %v", st.FullMethod, e.method) t.Fatalf("st.FullMethod = %s, want %v", st.FullMethod, e.method)
@ -642,10 +629,6 @@ func checkOutTrailer(t *testing.T, d *gotData, e *expectedData) {
if st.Client { if st.Client {
t.Fatalf("st IsClient = true, want false") t.Fatalf("st IsClient = true, want false")
} }
// TODO check real length, not just > 0.
if st.WireLength <= 0 {
t.Fatalf("st.Lenght = 0, want > 0")
}
} }
func checkEnd(t *testing.T, d *gotData, e *expectedData) { func checkEnd(t *testing.T, d *gotData, e *expectedData) {

View File

@ -26,6 +26,7 @@ import (
"time" "time"
"golang.org/x/net/http2" "golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
) )
const ( const (
@ -56,7 +57,9 @@ const (
// control tasks, e.g., flow control, settings, streaming resetting, etc. // control tasks, e.g., flow control, settings, streaming resetting, etc.
type headerFrame struct { type headerFrame struct {
p http2.HeadersFrameParam streamID uint32
hf []hpack.HeaderField
endStream bool
} }
func (*headerFrame) item() {} func (*headerFrame) item() {}

View File

@ -193,6 +193,7 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) (
icwz = opts.InitialConnWindowSize icwz = opts.InitialConnWindowSize
dynamicWindow = false dynamicWindow = false
} }
var buf bytes.Buffer
t := &http2Client{ t := &http2Client{
ctx: ctx, ctx: ctx,
target: addr.Addr, target: addr.Addr,
@ -209,6 +210,8 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) (
goAway: make(chan struct{}), goAway: make(chan struct{}),
awakenKeepalive: make(chan struct{}, 1), awakenKeepalive: make(chan struct{}, 1),
framer: newFramer(conn), framer: newFramer(conn),
hBuf: &buf,
hEnc: hpack.NewEncoder(&buf),
controlBuf: newControlBuffer(), controlBuf: newControlBuffer(),
fc: &inFlow{limit: uint32(icwz)}, fc: &inFlow{limit: uint32(icwz)},
sendQuotaPool: newQuotaPool(defaultWindowSize), sendQuotaPool: newQuotaPool(defaultWindowSize),
@ -361,7 +364,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
authData[k] = v authData[k] = v
} }
} }
callAuthData := make(map[string]string) callAuthData := map[string]string{}
// Check if credentials.PerRPCCredentials were provided via call options. // Check if credentials.PerRPCCredentials were provided via call options.
// Note: if these credentials are provided both via dial options and call // Note: if these credentials are provided both via dial options and call
// options, then both sets of credentials will be applied. // options, then both sets of credentials will be applied.
@ -401,40 +404,40 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
if sq > 1 { if sq > 1 {
t.streamsQuota.add(sq - 1) t.streamsQuota.add(sq - 1)
} }
// HPACK encodes various headers. // TODO(mmukhi): Benchmark if the perfomance gets better if count the metadata and other header fields
hBuf := bytes.NewBuffer([]byte{}) // first and create a slice of that exact size.
hEnc := hpack.NewEncoder(hBuf) // Make the slice of certain predictable size to reduce allocations made by append.
hEnc.WriteField(hpack.HeaderField{Name: ":method", Value: "POST"}) hfLen := 7 // :method, :scheme, :path, :authority, content-type, user-agent, te
hEnc.WriteField(hpack.HeaderField{Name: ":scheme", Value: t.scheme}) hfLen += len(authData) + len(callAuthData)
hEnc.WriteField(hpack.HeaderField{Name: ":path", Value: callHdr.Method}) headerFields := make([]hpack.HeaderField, 0, hfLen)
hEnc.WriteField(hpack.HeaderField{Name: ":authority", Value: callHdr.Host}) headerFields = append(headerFields, hpack.HeaderField{Name: ":method", Value: "POST"})
hEnc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"}) headerFields = append(headerFields, hpack.HeaderField{Name: ":scheme", Value: t.scheme})
hEnc.WriteField(hpack.HeaderField{Name: "user-agent", Value: t.userAgent}) headerFields = append(headerFields, hpack.HeaderField{Name: ":path", Value: callHdr.Method})
hEnc.WriteField(hpack.HeaderField{Name: "te", Value: "trailers"}) headerFields = append(headerFields, hpack.HeaderField{Name: ":authority", Value: callHdr.Host})
headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: "application/grpc"})
headerFields = append(headerFields, hpack.HeaderField{Name: "user-agent", Value: t.userAgent})
headerFields = append(headerFields, hpack.HeaderField{Name: "te", Value: "trailers"})
if callHdr.SendCompress != "" { if callHdr.SendCompress != "" {
hEnc.WriteField(hpack.HeaderField{Name: "grpc-encoding", Value: callHdr.SendCompress}) headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-encoding", Value: callHdr.SendCompress})
} }
if dl, ok := ctx.Deadline(); ok { if dl, ok := ctx.Deadline(); ok {
// Send out timeout regardless its value. The server can detect timeout context by itself. // Send out timeout regardless its value. The server can detect timeout context by itself.
// TODO(mmukhi): Perhaps this field should be updated when actually writing out to the wire.
timeout := dl.Sub(time.Now()) timeout := dl.Sub(time.Now())
hEnc.WriteField(hpack.HeaderField{Name: "grpc-timeout", Value: encodeTimeout(timeout)}) headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-timeout", Value: encodeTimeout(timeout)})
} }
for k, v := range authData { for k, v := range authData {
hEnc.WriteField(hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)}) headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
} }
for k, v := range callAuthData { for k, v := range callAuthData {
hEnc.WriteField(hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)}) headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
} }
var (
endHeaders bool
)
if b := stats.OutgoingTags(ctx); b != nil { if b := stats.OutgoingTags(ctx); b != nil {
hEnc.WriteField(hpack.HeaderField{Name: "grpc-tags-bin", Value: encodeBinHeader(b)}) headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-tags-bin", Value: encodeBinHeader(b)})
} }
if b := stats.OutgoingTrace(ctx); b != nil { if b := stats.OutgoingTrace(ctx); b != nil {
hEnc.WriteField(hpack.HeaderField{Name: "grpc-trace-bin", Value: encodeBinHeader(b)}) headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-trace-bin", Value: encodeBinHeader(b)})
} }
if md, ok := metadata.FromOutgoingContext(ctx); ok { if md, ok := metadata.FromOutgoingContext(ctx); ok {
for k, vv := range md { for k, vv := range md {
@ -443,7 +446,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
continue continue
} }
for _, v := range vv { for _, v := range vv {
hEnc.WriteField(hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)}) headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
} }
} }
} }
@ -453,7 +456,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
continue continue
} }
for _, v := range vv { for _, v := range vv {
hEnc.WriteField(hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)}) headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
} }
} }
} }
@ -482,34 +485,11 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
default: default:
} }
} }
first := true t.controlBuf.put(&headerFrame{
bufLen := hBuf.Len() streamID: s.id,
// Sends the headers in a single batch even when they span multiple frames. hf: headerFields,
for !endHeaders { endStream: false,
size := hBuf.Len() })
if size > http2MaxFrameLen {
size = http2MaxFrameLen
} else {
endHeaders = true
}
if first {
// Sends a HeadersFrame to server to start a new stream.
p := http2.HeadersFrameParam{
StreamID: s.id,
BlockFragment: hBuf.Next(size),
EndStream: false,
EndHeaders: endHeaders,
}
// Do a force flush for the buffered frames iff it is the last headers frame
// and there is header metadata to be sent. Otherwise, there is flushing until
// the corresponding data frame is written.
t.controlBuf.put(&headerFrame{p})
first = false
} else {
// Sends Continuation frames for the leftover headers.
t.controlBuf.put(&continuationFrame{streamID: s.id, endHeaders: endHeaders, headerBlockFragment: hBuf.Next(size)})
}
}
t.mu.Unlock() t.mu.Unlock()
s.mu.Lock() s.mu.Lock()
@ -519,7 +499,6 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
if t.statsHandler != nil { if t.statsHandler != nil {
outHeader := &stats.OutHeader{ outHeader := &stats.OutHeader{
Client: true, Client: true,
WireLength: bufLen,
FullMethod: callHdr.Method, FullMethod: callHdr.Method,
RemoteAddr: t.remoteAddr, RemoteAddr: t.remoteAddr,
LocalAddr: t.localAddr, LocalAddr: t.localAddr,
@ -770,7 +749,7 @@ func (t *http2Client) adjustWindow(s *Stream, n uint32) {
return return
} }
if w := s.fc.maybeAdjust(n); w > 0 { if w := s.fc.maybeAdjust(n); w > 0 {
// Piggyback conneciton's window update along. // Piggyback connection's window update along.
if cw := t.fc.resetPendingUpdate(); cw > 0 { if cw := t.fc.resetPendingUpdate(); cw > 0 {
t.controlBuf.put(&windowUpdate{0, cw}) t.controlBuf.put(&windowUpdate{0, cw})
} }
@ -1200,6 +1179,9 @@ func (t *http2Client) applySettings(ss []http2.Setting) {
} }
} }
// TODO(mmukhi): A lot of this code(and code in other places in the tranpsort layer)
// is duplicated between the client and the server.
// The transport layer needs to be refactored to take care of this.
func (t *http2Client) itemHandler(i item) error { func (t *http2Client) itemHandler(i item) error {
var err error var err error
defer func() { defer func() {
@ -1214,9 +1196,38 @@ func (t *http2Client) itemHandler(i item) error {
i.f() i.f()
} }
case *headerFrame: case *headerFrame:
err = t.framer.fr.WriteHeaders(i.p) t.hBuf.Reset()
case *continuationFrame: for _, f := range i.hf {
err = t.framer.fr.WriteContinuation(i.streamID, i.endHeaders, i.headerBlockFragment) t.hEnc.WriteField(f)
}
endHeaders := false
first := true
for !endHeaders {
size := t.hBuf.Len()
if size > http2MaxFrameLen {
size = http2MaxFrameLen
} else {
endHeaders = true
}
if first {
first = false
err = t.framer.fr.WriteHeaders(http2.HeadersFrameParam{
StreamID: i.streamID,
BlockFragment: t.hBuf.Next(size),
EndStream: i.endStream,
EndHeaders: endHeaders,
})
} else {
err = t.framer.fr.WriteContinuation(
i.streamID,
endHeaders,
t.hBuf.Next(size),
)
}
if err != nil {
return err
}
}
case *windowUpdate: case *windowUpdate:
err = t.framer.fr.WriteWindowUpdate(i.streamID, i.increment) err = t.framer.fr.WriteWindowUpdate(i.streamID, i.increment)
case *settings: case *settings:

View File

@ -63,6 +63,8 @@ type http2Server struct {
// blocking forever after Close. // blocking forever after Close.
shutdownChan chan struct{} shutdownChan chan struct{}
framer *framer framer *framer
hBuf *bytes.Buffer // the buffer for HPACK encoding
hEnc *hpack.Encoder // HPACK encoder
// The max number of concurrent streams. // The max number of concurrent streams.
maxStreams uint32 maxStreams uint32
// controlBuf delivers all the control related tasks (e.g., window // controlBuf delivers all the control related tasks (e.g., window
@ -175,6 +177,7 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err
if kep.MinTime == 0 { if kep.MinTime == 0 {
kep.MinTime = defaultKeepalivePolicyMinTime kep.MinTime = defaultKeepalivePolicyMinTime
} }
var buf bytes.Buffer
t := &http2Server{ t := &http2Server{
ctx: context.Background(), ctx: context.Background(),
conn: conn, conn: conn,
@ -182,6 +185,8 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err
localAddr: conn.LocalAddr(), localAddr: conn.LocalAddr(),
authInfo: config.AuthInfo, authInfo: config.AuthInfo,
framer: framer, framer: framer,
hBuf: &buf,
hEnc: hpack.NewEncoder(&buf),
maxStreams: maxStreams, maxStreams: maxStreams,
inTapHandle: config.InTapHandle, inTapHandle: config.InTapHandle,
controlBuf: newControlBuffer(), controlBuf: newControlBuffer(),
@ -639,7 +644,7 @@ func (t *http2Server) handlePing(f *http2.PingFrame) {
t.mu.Unlock() t.mu.Unlock()
if ns < 1 && !t.kep.PermitWithoutStream { if ns < 1 && !t.kep.PermitWithoutStream {
// Keepalive shouldn't be active thus, this new ping should // Keepalive shouldn't be active thus, this new ping should
// have come after atleast defaultPingTimeout. // have come after at least defaultPingTimeout.
if t.lastPingAt.Add(defaultPingTimeout).After(now) { if t.lastPingAt.Add(defaultPingTimeout).After(now) {
t.pingStrikes++ t.pingStrikes++
} }
@ -669,34 +674,6 @@ func (t *http2Server) handleWindowUpdate(f *http2.WindowUpdateFrame) {
} }
} }
func (t *http2Server) writeHeaders(s *Stream, b *bytes.Buffer, endStream bool) error {
first := true
endHeaders := false
// Sends the headers in a single batch.
for !endHeaders {
size := b.Len()
if size > http2MaxFrameLen {
size = http2MaxFrameLen
} else {
endHeaders = true
}
if first {
p := http2.HeadersFrameParam{
StreamID: s.id,
BlockFragment: b.Next(size),
EndStream: endStream,
EndHeaders: endHeaders,
}
t.controlBuf.put(&headerFrame{p})
first = false
} else {
t.controlBuf.put(&continuationFrame{streamID: s.id, endHeaders: endHeaders, headerBlockFragment: b.Next(size)})
}
}
atomic.StoreUint32(&t.resetPingStrikes, 1)
return nil
}
// WriteHeader sends the header metedata md back to the client. // WriteHeader sends the header metedata md back to the client.
func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error { func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error {
select { select {
@ -722,13 +699,13 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error {
} }
md = s.header md = s.header
s.mu.Unlock() s.mu.Unlock()
// TODO(mmukhi): Benchmark if the perfomance gets better if count the metadata and other header fields
hBuf := bytes.NewBuffer([]byte{}) // TODO(mmukhi): Try and re-use this memory later. // first and create a slice of that exact size.
hEnc := hpack.NewEncoder(hBuf) headerFields := make([]hpack.HeaderField, 0, 2) // at least :status, content-type will be there if none else.
hEnc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) headerFields = append(headerFields, hpack.HeaderField{Name: ":status", Value: "200"})
hEnc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"}) headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: "application/grpc"})
if s.sendCompress != "" { if s.sendCompress != "" {
hEnc.WriteField(hpack.HeaderField{Name: "grpc-encoding", Value: s.sendCompress}) headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-encoding", Value: s.sendCompress})
} }
for k, vv := range md { for k, vv := range md {
if isReservedHeader(k) { if isReservedHeader(k) {
@ -736,16 +713,17 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error {
continue continue
} }
for _, v := range vv { for _, v := range vv {
hEnc.WriteField(hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)}) headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
} }
} }
bufLen := hBuf.Len() t.controlBuf.put(&headerFrame{
if err := t.writeHeaders(s, hBuf, false); err != nil { streamID: s.id,
return err hf: headerFields,
} endStream: false,
})
if t.stats != nil { if t.stats != nil {
outHeader := &stats.OutHeader{ outHeader := &stats.OutHeader{
WireLength: bufLen, //WireLength: // TODO(mmukhi): Revisit this later, if needed.
} }
t.stats.HandleRPC(s.Context(), outHeader) t.stats.HandleRPC(s.Context(), outHeader)
} }
@ -782,18 +760,15 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error {
headersSent = true headersSent = true
} }
hBuf := bytes.NewBuffer([]byte{}) // TODO(mmukhi): Try and re-use this memory. // TODO(mmukhi): Benchmark if the perfomance gets better if count the metadata and other header fields
hEnc := hpack.NewEncoder(hBuf) // first and create a slice of that exact size.
headerFields := make([]hpack.HeaderField, 0, 2) // grpc-status and grpc-message will be there if none else.
if !headersSent { if !headersSent {
hEnc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) headerFields = append(headerFields, hpack.HeaderField{Name: ":status", Value: "200"})
hEnc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"}) headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: "application/grpc"})
} }
hEnc.WriteField( headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-status", Value: strconv.Itoa(int(st.Code()))})
hpack.HeaderField{ headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-message", Value: encodeGrpcMessage(st.Message())})
Name: "grpc-status",
Value: strconv.Itoa(int(st.Code())),
})
hEnc.WriteField(hpack.HeaderField{Name: "grpc-message", Value: encodeGrpcMessage(st.Message())})
if p := st.Proto(); p != nil && len(p.Details) > 0 { if p := st.Proto(); p != nil && len(p.Details) > 0 {
stBytes, err := proto.Marshal(p) stBytes, err := proto.Marshal(p)
@ -802,7 +777,7 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error {
panic(err) panic(err)
} }
hEnc.WriteField(hpack.HeaderField{Name: "grpc-status-details-bin", Value: encodeBinHeader(stBytes)}) headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-status-details-bin", Value: encodeBinHeader(stBytes)})
} }
// Attach the trailer metadata. // Attach the trailer metadata.
@ -812,19 +787,16 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error {
continue continue
} }
for _, v := range vv { for _, v := range vv {
hEnc.WriteField(hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)}) headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
} }
} }
bufLen := hBuf.Len() t.controlBuf.put(&headerFrame{
if err := t.writeHeaders(s, hBuf, true); err != nil { streamID: s.id,
t.Close() hf: headerFields,
return err endStream: true,
} })
if t.stats != nil { if t.stats != nil {
outTrailer := &stats.OutTrailer{ t.stats.HandleRPC(s.Context(), &stats.OutTrailer{})
WireLength: bufLen,
}
t.stats.HandleRPC(s.Context(), outTrailer)
} }
t.closeStream(s) t.closeStream(s)
return nil return nil
@ -904,7 +876,6 @@ func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) (
atomic.StoreUint32(&t.resetPingStrikes, 1) atomic.StoreUint32(&t.resetPingStrikes, 1)
success := func() { success := func() {
t.controlBuf.put(&dataFrame{streamID: s.id, endStream: false, d: p, f: func() { t.controlBuf.put(&dataFrame{streamID: s.id, endStream: false, d: p, f: func() {
//fmt.Println("Adding quota back to localEendQuota", ps)
s.localSendQuota.add(ps) s.localSendQuota.add(ps)
}}) }})
if ps < sq { if ps < sq {
@ -1007,6 +978,9 @@ func (t *http2Server) keepalive() {
var goAwayPing = &ping{data: [8]byte{1, 6, 1, 8, 0, 3, 3, 9}} var goAwayPing = &ping{data: [8]byte{1, 6, 1, 8, 0, 3, 3, 9}}
// TODO(mmukhi): A lot of this code(and code in other places in the tranpsort layer)
// is duplicated between the client and the server.
// The transport layer needs to be refactored to take care of this.
func (t *http2Server) itemHandler(i item) error { func (t *http2Server) itemHandler(i item) error {
var err error var err error
defer func() { defer func() {
@ -1022,9 +996,39 @@ func (t *http2Server) itemHandler(i item) error {
i.f() i.f()
} }
case *headerFrame: case *headerFrame:
err = t.framer.fr.WriteHeaders(i.p) t.hBuf.Reset()
case *continuationFrame: for _, f := range i.hf {
err = t.framer.fr.WriteContinuation(i.streamID, i.endHeaders, i.headerBlockFragment) t.hEnc.WriteField(f)
}
first := true
endHeaders := false
for !endHeaders {
size := t.hBuf.Len()
if size > http2MaxFrameLen {
size = http2MaxFrameLen
} else {
endHeaders = true
}
if first {
first = false
err = t.framer.fr.WriteHeaders(http2.HeadersFrameParam{
StreamID: i.streamID,
BlockFragment: t.hBuf.Next(size),
EndStream: i.endStream,
EndHeaders: endHeaders,
})
} else {
err = t.framer.fr.WriteContinuation(
i.streamID,
endHeaders,
t.hBuf.Next(size),
)
}
if err != nil {
return err
}
}
atomic.StoreUint32(&t.resetPingStrikes, 1)
case *windowUpdate: case *windowUpdate:
err = t.framer.fr.WriteWindowUpdate(i.streamID, i.increment) err = t.framer.fr.WriteWindowUpdate(i.streamID, i.increment)
case *settings: case *settings:

View File

@ -163,12 +163,13 @@ func (h *testStreamHandler) handleStreamEncodingRequiredStatus(t *testing.T, s *
} }
func (h *testStreamHandler) handleStreamInvalidHeaderField(t *testing.T, s *Stream) { func (h *testStreamHandler) handleStreamInvalidHeaderField(t *testing.T, s *Stream) {
hBuf := bytes.NewBuffer([]byte{}) headerFields := []hpack.HeaderField{}
hEnc := hpack.NewEncoder(hBuf) headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: expectedInvalidHeaderField})
hEnc.WriteField(hpack.HeaderField{Name: "content-type", Value: expectedInvalidHeaderField}) h.t.controlBuf.put(&headerFrame{
if err := h.t.writeHeaders(s, hBuf, false); err != nil { streamID: s.id,
t.Fatalf("Failed to write headers: %v", err) hf: headerFields,
} endStream: false,
})
} }
func (h *testStreamHandler) handleStreamDelayRead(t *testing.T, s *Stream) { func (h *testStreamHandler) handleStreamDelayRead(t *testing.T, s *Stream) {