diff --git a/stats/stats.go b/stats/stats.go index b64c4295..d5aa2f79 100644 --- a/stats/stats.go +++ b/stats/stats.go @@ -135,8 +135,6 @@ func (s *OutPayload) isRPCStats() {} type OutHeader struct { // Client is true if this OutHeader is from client side. Client bool - // WireLength is the wire length of header. - WireLength int // The following fields are valid only if Client is true. // FullMethod is the full RPC method string, i.e., /package.service/method. diff --git a/stats/stats_test.go b/stats/stats_test.go index 8865d3fb..d66485fa 100644 --- a/stats/stats_test.go +++ b/stats/stats_test.go @@ -444,10 +444,6 @@ func checkInHeader(t *testing.T, d *gotData, e *expectedData) { if d.ctx == nil { t.Fatalf("d.ctx = nil, want ") } - // TODO check real length, not just > 0. - if st.WireLength <= 0 { - t.Fatalf("st.Lenght = 0, want > 0") - } if !d.client { if st.FullMethod != e.method { t.Fatalf("st.FullMethod = %s, want %v", st.FullMethod, e.method) @@ -530,18 +526,13 @@ func checkInPayload(t *testing.T, d *gotData, e *expectedData) { func checkInTrailer(t *testing.T, d *gotData, e *expectedData) { var ( ok bool - st *stats.InTrailer ) - if st, ok = d.s.(*stats.InTrailer); !ok { + if _, ok = d.s.(*stats.InTrailer); !ok { t.Fatalf("got %T, want InTrailer", d.s) } if d.ctx == nil { t.Fatalf("d.ctx = nil, want ") } - // TODO check real length, not just > 0. - if st.WireLength <= 0 { - t.Fatalf("st.Lenght = 0, want > 0") - } } func checkOutHeader(t *testing.T, d *gotData, e *expectedData) { @@ -555,10 +546,6 @@ func checkOutHeader(t *testing.T, d *gotData, e *expectedData) { if d.ctx == nil { t.Fatalf("d.ctx = nil, want ") } - // TODO check real length, not just > 0. - if st.WireLength <= 0 { - t.Fatalf("st.Lenght = 0, want > 0") - } if d.client { if st.FullMethod != e.method { t.Fatalf("st.FullMethod = %s, want %v", st.FullMethod, e.method) @@ -642,10 +629,6 @@ func checkOutTrailer(t *testing.T, d *gotData, e *expectedData) { if st.Client { t.Fatalf("st IsClient = true, want false") } - // TODO check real length, not just > 0. - if st.WireLength <= 0 { - t.Fatalf("st.Lenght = 0, want > 0") - } } func checkEnd(t *testing.T, d *gotData, e *expectedData) { diff --git a/transport/control.go b/transport/control.go index 77914de1..dd1a8d42 100644 --- a/transport/control.go +++ b/transport/control.go @@ -26,6 +26,7 @@ import ( "time" "golang.org/x/net/http2" + "golang.org/x/net/http2/hpack" ) const ( @@ -56,7 +57,9 @@ const ( // control tasks, e.g., flow control, settings, streaming resetting, etc. type headerFrame struct { - p http2.HeadersFrameParam + streamID uint32 + hf []hpack.HeaderField + endStream bool } func (*headerFrame) item() {} diff --git a/transport/http2_client.go b/transport/http2_client.go index 31fed9e3..92ad868f 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -193,6 +193,7 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) ( icwz = opts.InitialConnWindowSize dynamicWindow = false } + var buf bytes.Buffer t := &http2Client{ ctx: ctx, target: addr.Addr, @@ -209,6 +210,8 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) ( goAway: make(chan struct{}), awakenKeepalive: make(chan struct{}, 1), framer: newFramer(conn), + hBuf: &buf, + hEnc: hpack.NewEncoder(&buf), controlBuf: newControlBuffer(), fc: &inFlow{limit: uint32(icwz)}, sendQuotaPool: newQuotaPool(defaultWindowSize), @@ -361,7 +364,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea authData[k] = v } } - callAuthData := make(map[string]string) + callAuthData := map[string]string{} // Check if credentials.PerRPCCredentials were provided via call options. // Note: if these credentials are provided both via dial options and call // options, then both sets of credentials will be applied. @@ -401,40 +404,40 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea if sq > 1 { t.streamsQuota.add(sq - 1) } - // HPACK encodes various headers. - hBuf := bytes.NewBuffer([]byte{}) - hEnc := hpack.NewEncoder(hBuf) - hEnc.WriteField(hpack.HeaderField{Name: ":method", Value: "POST"}) - hEnc.WriteField(hpack.HeaderField{Name: ":scheme", Value: t.scheme}) - hEnc.WriteField(hpack.HeaderField{Name: ":path", Value: callHdr.Method}) - hEnc.WriteField(hpack.HeaderField{Name: ":authority", Value: callHdr.Host}) - hEnc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"}) - hEnc.WriteField(hpack.HeaderField{Name: "user-agent", Value: t.userAgent}) - hEnc.WriteField(hpack.HeaderField{Name: "te", Value: "trailers"}) + // TODO(mmukhi): Benchmark if the perfomance gets better if count the metadata and other header fields + // first and create a slice of that exact size. + // Make the slice of certain predictable size to reduce allocations made by append. + hfLen := 7 // :method, :scheme, :path, :authority, content-type, user-agent, te + hfLen += len(authData) + len(callAuthData) + headerFields := make([]hpack.HeaderField, 0, hfLen) + headerFields = append(headerFields, hpack.HeaderField{Name: ":method", Value: "POST"}) + headerFields = append(headerFields, hpack.HeaderField{Name: ":scheme", Value: t.scheme}) + headerFields = append(headerFields, hpack.HeaderField{Name: ":path", Value: callHdr.Method}) + headerFields = append(headerFields, hpack.HeaderField{Name: ":authority", Value: callHdr.Host}) + headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: "application/grpc"}) + headerFields = append(headerFields, hpack.HeaderField{Name: "user-agent", Value: t.userAgent}) + headerFields = append(headerFields, hpack.HeaderField{Name: "te", Value: "trailers"}) if callHdr.SendCompress != "" { - hEnc.WriteField(hpack.HeaderField{Name: "grpc-encoding", Value: callHdr.SendCompress}) + headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-encoding", Value: callHdr.SendCompress}) } if dl, ok := ctx.Deadline(); ok { // Send out timeout regardless its value. The server can detect timeout context by itself. + // TODO(mmukhi): Perhaps this field should be updated when actually writing out to the wire. timeout := dl.Sub(time.Now()) - hEnc.WriteField(hpack.HeaderField{Name: "grpc-timeout", Value: encodeTimeout(timeout)}) + headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-timeout", Value: encodeTimeout(timeout)}) } - for k, v := range authData { - hEnc.WriteField(hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)}) + headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)}) } for k, v := range callAuthData { - hEnc.WriteField(hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)}) + headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)}) } - var ( - endHeaders bool - ) if b := stats.OutgoingTags(ctx); b != nil { - hEnc.WriteField(hpack.HeaderField{Name: "grpc-tags-bin", Value: encodeBinHeader(b)}) + headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-tags-bin", Value: encodeBinHeader(b)}) } if b := stats.OutgoingTrace(ctx); b != nil { - hEnc.WriteField(hpack.HeaderField{Name: "grpc-trace-bin", Value: encodeBinHeader(b)}) + headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-trace-bin", Value: encodeBinHeader(b)}) } if md, ok := metadata.FromOutgoingContext(ctx); ok { for k, vv := range md { @@ -443,7 +446,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea continue } for _, v := range vv { - hEnc.WriteField(hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)}) + headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)}) } } } @@ -453,7 +456,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea continue } for _, v := range vv { - hEnc.WriteField(hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)}) + headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)}) } } } @@ -482,34 +485,11 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea default: } } - first := true - bufLen := hBuf.Len() - // Sends the headers in a single batch even when they span multiple frames. - for !endHeaders { - size := hBuf.Len() - if size > http2MaxFrameLen { - size = http2MaxFrameLen - } else { - endHeaders = true - } - if first { - // Sends a HeadersFrame to server to start a new stream. - p := http2.HeadersFrameParam{ - StreamID: s.id, - BlockFragment: hBuf.Next(size), - EndStream: false, - EndHeaders: endHeaders, - } - // Do a force flush for the buffered frames iff it is the last headers frame - // and there is header metadata to be sent. Otherwise, there is flushing until - // the corresponding data frame is written. - t.controlBuf.put(&headerFrame{p}) - first = false - } else { - // Sends Continuation frames for the leftover headers. - t.controlBuf.put(&continuationFrame{streamID: s.id, endHeaders: endHeaders, headerBlockFragment: hBuf.Next(size)}) - } - } + t.controlBuf.put(&headerFrame{ + streamID: s.id, + hf: headerFields, + endStream: false, + }) t.mu.Unlock() s.mu.Lock() @@ -519,7 +499,6 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea if t.statsHandler != nil { outHeader := &stats.OutHeader{ Client: true, - WireLength: bufLen, FullMethod: callHdr.Method, RemoteAddr: t.remoteAddr, LocalAddr: t.localAddr, @@ -770,7 +749,7 @@ func (t *http2Client) adjustWindow(s *Stream, n uint32) { return } if w := s.fc.maybeAdjust(n); w > 0 { - // Piggyback conneciton's window update along. + // Piggyback connection's window update along. if cw := t.fc.resetPendingUpdate(); cw > 0 { t.controlBuf.put(&windowUpdate{0, cw}) } @@ -1200,6 +1179,9 @@ func (t *http2Client) applySettings(ss []http2.Setting) { } } +// TODO(mmukhi): A lot of this code(and code in other places in the tranpsort layer) +// is duplicated between the client and the server. +// The transport layer needs to be refactored to take care of this. func (t *http2Client) itemHandler(i item) error { var err error defer func() { @@ -1214,9 +1196,38 @@ func (t *http2Client) itemHandler(i item) error { i.f() } case *headerFrame: - err = t.framer.fr.WriteHeaders(i.p) - case *continuationFrame: - err = t.framer.fr.WriteContinuation(i.streamID, i.endHeaders, i.headerBlockFragment) + t.hBuf.Reset() + for _, f := range i.hf { + t.hEnc.WriteField(f) + } + endHeaders := false + first := true + for !endHeaders { + size := t.hBuf.Len() + if size > http2MaxFrameLen { + size = http2MaxFrameLen + } else { + endHeaders = true + } + if first { + first = false + err = t.framer.fr.WriteHeaders(http2.HeadersFrameParam{ + StreamID: i.streamID, + BlockFragment: t.hBuf.Next(size), + EndStream: i.endStream, + EndHeaders: endHeaders, + }) + } else { + err = t.framer.fr.WriteContinuation( + i.streamID, + endHeaders, + t.hBuf.Next(size), + ) + } + if err != nil { + return err + } + } case *windowUpdate: err = t.framer.fr.WriteWindowUpdate(i.streamID, i.increment) case *settings: diff --git a/transport/http2_server.go b/transport/http2_server.go index 4f62cba5..0f0e7599 100644 --- a/transport/http2_server.go +++ b/transport/http2_server.go @@ -63,6 +63,8 @@ type http2Server struct { // blocking forever after Close. shutdownChan chan struct{} framer *framer + hBuf *bytes.Buffer // the buffer for HPACK encoding + hEnc *hpack.Encoder // HPACK encoder // The max number of concurrent streams. maxStreams uint32 // controlBuf delivers all the control related tasks (e.g., window @@ -175,6 +177,7 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err if kep.MinTime == 0 { kep.MinTime = defaultKeepalivePolicyMinTime } + var buf bytes.Buffer t := &http2Server{ ctx: context.Background(), conn: conn, @@ -182,6 +185,8 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err localAddr: conn.LocalAddr(), authInfo: config.AuthInfo, framer: framer, + hBuf: &buf, + hEnc: hpack.NewEncoder(&buf), maxStreams: maxStreams, inTapHandle: config.InTapHandle, controlBuf: newControlBuffer(), @@ -639,7 +644,7 @@ func (t *http2Server) handlePing(f *http2.PingFrame) { t.mu.Unlock() if ns < 1 && !t.kep.PermitWithoutStream { // Keepalive shouldn't be active thus, this new ping should - // have come after atleast defaultPingTimeout. + // have come after at least defaultPingTimeout. if t.lastPingAt.Add(defaultPingTimeout).After(now) { t.pingStrikes++ } @@ -669,34 +674,6 @@ func (t *http2Server) handleWindowUpdate(f *http2.WindowUpdateFrame) { } } -func (t *http2Server) writeHeaders(s *Stream, b *bytes.Buffer, endStream bool) error { - first := true - endHeaders := false - // Sends the headers in a single batch. - for !endHeaders { - size := b.Len() - if size > http2MaxFrameLen { - size = http2MaxFrameLen - } else { - endHeaders = true - } - if first { - p := http2.HeadersFrameParam{ - StreamID: s.id, - BlockFragment: b.Next(size), - EndStream: endStream, - EndHeaders: endHeaders, - } - t.controlBuf.put(&headerFrame{p}) - first = false - } else { - t.controlBuf.put(&continuationFrame{streamID: s.id, endHeaders: endHeaders, headerBlockFragment: b.Next(size)}) - } - } - atomic.StoreUint32(&t.resetPingStrikes, 1) - return nil -} - // WriteHeader sends the header metedata md back to the client. func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error { select { @@ -722,13 +699,13 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error { } md = s.header s.mu.Unlock() - - hBuf := bytes.NewBuffer([]byte{}) // TODO(mmukhi): Try and re-use this memory later. - hEnc := hpack.NewEncoder(hBuf) - hEnc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - hEnc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"}) + // TODO(mmukhi): Benchmark if the perfomance gets better if count the metadata and other header fields + // first and create a slice of that exact size. + headerFields := make([]hpack.HeaderField, 0, 2) // at least :status, content-type will be there if none else. + headerFields = append(headerFields, hpack.HeaderField{Name: ":status", Value: "200"}) + headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: "application/grpc"}) if s.sendCompress != "" { - hEnc.WriteField(hpack.HeaderField{Name: "grpc-encoding", Value: s.sendCompress}) + headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-encoding", Value: s.sendCompress}) } for k, vv := range md { if isReservedHeader(k) { @@ -736,16 +713,17 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error { continue } for _, v := range vv { - hEnc.WriteField(hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)}) + headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)}) } } - bufLen := hBuf.Len() - if err := t.writeHeaders(s, hBuf, false); err != nil { - return err - } + t.controlBuf.put(&headerFrame{ + streamID: s.id, + hf: headerFields, + endStream: false, + }) if t.stats != nil { outHeader := &stats.OutHeader{ - WireLength: bufLen, + //WireLength: // TODO(mmukhi): Revisit this later, if needed. } t.stats.HandleRPC(s.Context(), outHeader) } @@ -782,18 +760,15 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error { headersSent = true } - hBuf := bytes.NewBuffer([]byte{}) // TODO(mmukhi): Try and re-use this memory. - hEnc := hpack.NewEncoder(hBuf) + // TODO(mmukhi): Benchmark if the perfomance gets better if count the metadata and other header fields + // first and create a slice of that exact size. + headerFields := make([]hpack.HeaderField, 0, 2) // grpc-status and grpc-message will be there if none else. if !headersSent { - hEnc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - hEnc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"}) + headerFields = append(headerFields, hpack.HeaderField{Name: ":status", Value: "200"}) + headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: "application/grpc"}) } - hEnc.WriteField( - hpack.HeaderField{ - Name: "grpc-status", - Value: strconv.Itoa(int(st.Code())), - }) - hEnc.WriteField(hpack.HeaderField{Name: "grpc-message", Value: encodeGrpcMessage(st.Message())}) + headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-status", Value: strconv.Itoa(int(st.Code()))}) + headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-message", Value: encodeGrpcMessage(st.Message())}) if p := st.Proto(); p != nil && len(p.Details) > 0 { stBytes, err := proto.Marshal(p) @@ -802,7 +777,7 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error { panic(err) } - hEnc.WriteField(hpack.HeaderField{Name: "grpc-status-details-bin", Value: encodeBinHeader(stBytes)}) + headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-status-details-bin", Value: encodeBinHeader(stBytes)}) } // Attach the trailer metadata. @@ -812,19 +787,16 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error { continue } for _, v := range vv { - hEnc.WriteField(hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)}) + headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)}) } } - bufLen := hBuf.Len() - if err := t.writeHeaders(s, hBuf, true); err != nil { - t.Close() - return err - } + t.controlBuf.put(&headerFrame{ + streamID: s.id, + hf: headerFields, + endStream: true, + }) if t.stats != nil { - outTrailer := &stats.OutTrailer{ - WireLength: bufLen, - } - t.stats.HandleRPC(s.Context(), outTrailer) + t.stats.HandleRPC(s.Context(), &stats.OutTrailer{}) } t.closeStream(s) return nil @@ -904,7 +876,6 @@ func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) ( atomic.StoreUint32(&t.resetPingStrikes, 1) success := func() { t.controlBuf.put(&dataFrame{streamID: s.id, endStream: false, d: p, f: func() { - //fmt.Println("Adding quota back to localEendQuota", ps) s.localSendQuota.add(ps) }}) if ps < sq { @@ -1007,6 +978,9 @@ func (t *http2Server) keepalive() { var goAwayPing = &ping{data: [8]byte{1, 6, 1, 8, 0, 3, 3, 9}} +// TODO(mmukhi): A lot of this code(and code in other places in the tranpsort layer) +// is duplicated between the client and the server. +// The transport layer needs to be refactored to take care of this. func (t *http2Server) itemHandler(i item) error { var err error defer func() { @@ -1022,9 +996,39 @@ func (t *http2Server) itemHandler(i item) error { i.f() } case *headerFrame: - err = t.framer.fr.WriteHeaders(i.p) - case *continuationFrame: - err = t.framer.fr.WriteContinuation(i.streamID, i.endHeaders, i.headerBlockFragment) + t.hBuf.Reset() + for _, f := range i.hf { + t.hEnc.WriteField(f) + } + first := true + endHeaders := false + for !endHeaders { + size := t.hBuf.Len() + if size > http2MaxFrameLen { + size = http2MaxFrameLen + } else { + endHeaders = true + } + if first { + first = false + err = t.framer.fr.WriteHeaders(http2.HeadersFrameParam{ + StreamID: i.streamID, + BlockFragment: t.hBuf.Next(size), + EndStream: i.endStream, + EndHeaders: endHeaders, + }) + } else { + err = t.framer.fr.WriteContinuation( + i.streamID, + endHeaders, + t.hBuf.Next(size), + ) + } + if err != nil { + return err + } + } + atomic.StoreUint32(&t.resetPingStrikes, 1) case *windowUpdate: err = t.framer.fr.WriteWindowUpdate(i.streamID, i.increment) case *settings: diff --git a/transport/transport_test.go b/transport/transport_test.go index 1d12b172..f30ebc6d 100644 --- a/transport/transport_test.go +++ b/transport/transport_test.go @@ -163,12 +163,13 @@ func (h *testStreamHandler) handleStreamEncodingRequiredStatus(t *testing.T, s * } func (h *testStreamHandler) handleStreamInvalidHeaderField(t *testing.T, s *Stream) { - hBuf := bytes.NewBuffer([]byte{}) - hEnc := hpack.NewEncoder(hBuf) - hEnc.WriteField(hpack.HeaderField{Name: "content-type", Value: expectedInvalidHeaderField}) - if err := h.t.writeHeaders(s, hBuf, false); err != nil { - t.Fatalf("Failed to write headers: %v", err) - } + headerFields := []hpack.HeaderField{} + headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: expectedInvalidHeaderField}) + h.t.controlBuf.put(&headerFrame{ + streamID: s.id, + hf: headerFields, + endStream: false, + }) } func (h *testStreamHandler) handleStreamDelayRead(t *testing.T, s *Stream) {