diff --git a/transport/control.go b/transport/control.go index e734ed30..8dd9a0ac 100644 --- a/transport/control.go +++ b/transport/control.go @@ -78,6 +78,13 @@ func (resetStream) isItem() bool { return true } +type flushIO struct { +} + +func (flushIO) isItem() bool { + return true +} + // quotaPool is a pool which accumulates the quota and sends it to acquire() // when it is available. type quotaPool struct { diff --git a/transport/http2_client.go b/transport/http2_client.go index dbd78a87..b4b6ef8b 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -69,7 +69,7 @@ type http2Client struct { // errorChan is closed to notify the I/O error to the caller. errorChan chan struct{} - framer *http2.Framer + framer *framer hBuf *bytes.Buffer // the buffer for HPACK encoding hEnc *hpack.Encoder // HPACK encoder @@ -132,8 +132,8 @@ func newHTTP2Client(addr string, opts *DialOptions) (_ ClientTransport, err erro if n != len(clientPreface) { return nil, ConnectionErrorf("transport: preface mismatch, wrote %d bytes; want %d", n, len(clientPreface)) } - framer := http2.NewFramer(conn, conn) - if err := framer.WriteSettings(); err != nil { + framer := newFramer(conn) + if err := framer.writeSettings(true); err != nil { return nil, ConnectionErrorf("transport: %v", err) } var buf bytes.Buffer @@ -229,13 +229,17 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea for k, v := range authData { t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: v}) } + var ( + hasMD bool + endHeaders bool + ) if md, ok := metadata.FromContext(ctx); ok { + hasMD = true for k, v := range md { t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: v}) } } first := true - endHeaders := false streamID := t.nextID t.nextID += 2 // Sends the headers in a single batch even when they span multiple frames. @@ -254,11 +258,14 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea EndStream: false, EndHeaders: endHeaders, } - err = t.framer.WriteHeaders(p) + // Do a force flush for the buffered frames iff it is the last headers frame + // and there is header metadata to be sent. Otherwise, there is flushing until + // the corresponding data frame is written. + err = t.framer.writeHeaders(hasMD && endHeaders, p) first = false } else { // Sends Continuation frames for the leftover headers. - err = t.framer.WriteContinuation(t.nextID, endHeaders, t.hBuf.Next(size)) + err = t.framer.writeContinuation(hasMD && endHeaders, t.nextID, endHeaders, t.hBuf.Next(size)) } if err != nil { t.notifyError(err) @@ -380,21 +387,41 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error { t.sendQuotaPool.add(tq - ps) } } - var endStream bool + var ( + endStream bool + forceFlush bool + ) if opts.Last && r.Len() == 0 { endStream = true } + // Indicate there is a writer who is about to write a data frame. + t.framer.adjustNumWriters(1) // Got some quota. Try to acquire writing privilege on the transport. if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil { + if t.framer.adjustNumWriters(-1) == 0 { + // This writer is the last one in this batch and has the + // responsibility to flush the buffered frames. It queues + // a flush request to controlBuf instead of flushing directly + // in order to avoid the race with other writing or flushing. + t.controlBuf.put(&flushIO{}) + } return err } + if r.Len() == 0 && t.framer.adjustNumWriters(0) == 1 { + // Do a force flush iff this is last frame for the entire gRPC message + // and the caller is the only writer at this moment. + forceFlush = true + } // If WriteData fails, all the pending streams will be handled // by http2Client.Close(). No explicit CloseStream() needs to be // invoked. - if err := t.framer.WriteData(s.id, endStream, p); err != nil { + if err := t.framer.writeData(forceFlush, s.id, endStream, p); err != nil { t.notifyError(err) return ConnectionErrorf("transport: %v", err) } + if t.framer.adjustNumWriters(-1) == 0 { + t.framer.flushWrite() + } t.writableChan <- 0 if r.Len() == 0 { break @@ -560,7 +587,7 @@ func (t *http2Client) operateHeaders(hDec *hpackDecoder, s *Stream, frame header // TODO(zhaoq): Check the validity of the incoming frame sequence. func (t *http2Client) reader() { // Check the validity of server preface. - frame, err := t.framer.ReadFrame() + frame, err := t.framer.readFrame() if err != nil { t.notifyError(err) return @@ -576,7 +603,7 @@ func (t *http2Client) reader() { var curStream *Stream // loop to keep reading incoming messages on this transport. for { - frame, err := t.framer.ReadFrame() + frame, err := t.framer.readFrame() if err != nil { t.notifyError(err) return @@ -623,11 +650,13 @@ func (t *http2Client) controller() { case <-t.writableChan: switch i := i.(type) { case *windowUpdate: - t.framer.WriteWindowUpdate(i.streamID, i.increment) + t.framer.writeWindowUpdate(true, i.streamID, i.increment) case *settings: - t.framer.WriteSettings(http2.Setting{i.id, i.val}) + t.framer.writeSettings(true, http2.Setting{i.id, i.val}) case *resetStream: - t.framer.WriteRSTStream(i.streamID, i.code) + t.framer.writeRSTStream(true, i.streamID, i.code) + case *flushIO: + t.framer.flushWrite() default: log.Printf("transport: http2Client.controller got unexpected item type %v\n", i) } diff --git a/transport/http2_server.go b/transport/http2_server.go index d9f833d7..d9f5c811 100644 --- a/transport/http2_server.go +++ b/transport/http2_server.go @@ -66,7 +66,7 @@ type http2Server struct { // Blocking operations should select on shutdownChan to avoid // blocking forever after Close. shutdownChan chan struct{} - framer *http2.Framer + framer *framer hBuf *bytes.Buffer // the buffer for HPACK encoding hEnc *hpack.Encoder // HPACK encoder @@ -88,15 +88,15 @@ type http2Server struct { // newHTTP2Server constructs a ServerTransport based on HTTP2. ConnectionError is // returned if something goes wrong. func newHTTP2Server(conn net.Conn, maxStreams uint32) (_ ServerTransport, err error) { - framer := http2.NewFramer(conn, conn) + framer := newFramer(conn) // Send initial settings as connection preface to client. // TODO(zhaoq): Have a better way to signal "no limit" because 0 is // permitted in the HTTP2 spec. if maxStreams == 0 { - err = framer.WriteSettings() + err = framer.writeSettings(true) maxStreams = math.MaxUint32 } else { - err = framer.WriteSettings(http2.Setting{http2.SettingMaxConcurrentStreams, maxStreams}) + err = framer.writeSettings(true, http2.Setting{http2.SettingMaxConcurrentStreams, maxStreams}) } if err != nil { return @@ -203,7 +203,7 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) { return } - frame, err := t.framer.ReadFrame() + frame, err := t.framer.readFrame() if err != nil { log.Printf("transport: http2Server.HandleStreams failed to read frame: %v", err) t.Close() @@ -222,7 +222,7 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) { var wg sync.WaitGroup defer wg.Wait() for { - frame, err := t.framer.ReadFrame() + frame, err := t.framer.readFrame() if err != nil { t.Close() return @@ -380,10 +380,10 @@ func (t *http2Server) writeHeaders(s *Stream, b *bytes.Buffer, endStream bool) e EndStream: endStream, EndHeaders: endHeaders, } - err = t.framer.WriteHeaders(p) + err = t.framer.writeHeaders(endHeaders, p) first = false } else { - err = t.framer.WriteContinuation(s.id, endHeaders, b.Next(size)) + err = t.framer.writeContinuation(endHeaders, s.id, endHeaders, b.Next(size)) } if err != nil { t.Close() @@ -475,7 +475,7 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error { BlockFragment: t.hBuf.Bytes(), EndHeaders: true, } - if err := t.framer.WriteHeaders(p); err != nil { + if err := t.framer.writeHeaders(false, p); err != nil { t.Close() return ConnectionErrorf("transport: %v", err) } @@ -518,15 +518,30 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error { // Overbooked transport quota. Return it back. t.sendQuotaPool.add(tq - ps) } + t.framer.adjustNumWriters(1) // Got some quota. Try to acquire writing privilege on the // transport. if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil { + if t.framer.adjustNumWriters(-1) == 0 { + // This writer is the last one in this batch and has the + // responsibility to flush the buffered frames. It queues + // a flush request to controlBuf instead of flushing directly + // in order to avoid the race with other writing or flushing. + t.controlBuf.put(&flushIO{}) + } return err } - if err := t.framer.WriteData(s.id, false, p); err != nil { + var forceFlush bool + if r.Len() == 0 && t.framer.adjustNumWriters(0) == 1 && !opts.Last { + forceFlush = true + } + if err := t.framer.writeData(forceFlush, s.id, false, p); err != nil { t.Close() return ConnectionErrorf("transport: %v", err) } + if t.framer.adjustNumWriters(-1) == 0 { + t.framer.flushWrite() + } t.writableChan <- 0 } @@ -543,11 +558,13 @@ func (t *http2Server) controller() { case <-t.writableChan: switch i := i.(type) { case *windowUpdate: - t.framer.WriteWindowUpdate(i.streamID, i.increment) + t.framer.writeWindowUpdate(true, i.streamID, i.increment) case *settings: - t.framer.WriteSettings(http2.Setting{i.id, i.val}) + t.framer.writeSettings(true, http2.Setting{i.id, i.val}) case *resetStream: - t.framer.WriteRSTStream(i.streamID, i.code) + t.framer.writeRSTStream(true, i.streamID, i.code) + case *flushIO: + t.framer.flushWrite() default: log.Printf("transport: http2Server.controller got unexpected item type %v\n", i) } diff --git a/transport/http_util.go b/transport/http_util.go index 30fbd076..f97c16ae 100644 --- a/transport/http_util.go +++ b/transport/http_util.go @@ -34,9 +34,13 @@ package transport import ( + "bufio" "fmt" + "io" "log" + "net" "strconv" + "sync/atomic" "time" "github.com/bradfitz/http2" @@ -50,6 +54,8 @@ const ( http2MaxFrameLen = 16384 // 16KB frame // http://http2.github.io/http2-spec/#SettingValues http2InitHeaderTableSize = 4096 + // http2IOBufSize specifies the buffer size for sending frames. + http2IOBufSize = 32 * 1024 ) var ( @@ -288,3 +294,144 @@ func timeoutDecode(s string) (time.Duration, error) { } return d * time.Duration(t), nil } + +type framer struct { + numWriters int32 + reader io.Reader + writer *bufio.Writer + fr *http2.Framer +} + +func newFramer(conn net.Conn) *framer { + f := &framer{ + reader: conn, + writer: bufio.NewWriterSize(conn, http2IOBufSize), + } + f.fr = http2.NewFramer(f.writer, f.reader) + return f +} + +func (f *framer) adjustNumWriters(i int32) int32 { + return atomic.AddInt32(&f.numWriters, i) +} + +// The following writeXXX functions can only be called when the caller gets +// unblocked from writableChan channel (i.e., owns the privilege to write). + +func (f *framer) writeContinuation(forceFlush bool, streamID uint32, endHeaders bool, headerBlockFragment []byte) error { + if err := f.fr.WriteContinuation(streamID, endHeaders, headerBlockFragment); err != nil { + return err + } + if forceFlush { + return f.writer.Flush() + } + return nil +} + +func (f *framer) writeData(forceFlush bool, streamID uint32, endStream bool, data []byte) error { + if err := f.fr.WriteData(streamID, endStream, data); err != nil { + return err + } + if forceFlush { + return f.writer.Flush() + } + return nil +} + +func (f *framer) writeGoAway(forceFlush bool, maxStreamID uint32, code http2.ErrCode, debugData []byte) error { + if err := f.fr.WriteGoAway(maxStreamID, code, debugData); err != nil { + return err + } + if forceFlush { + return f.writer.Flush() + } + return nil +} + +func (f *framer) writeHeaders(forceFlush bool, p http2.HeadersFrameParam) error { + if err := f.fr.WriteHeaders(p); err != nil { + return err + } + if forceFlush { + return f.writer.Flush() + } + return nil +} + +func (f *framer) writePing(forceFlush, ack bool, data [8]byte) error { + if err := f.fr.WritePing(ack, data); err != nil { + return err + } + if forceFlush { + return f.writer.Flush() + } + return nil +} + +func (f *framer) writePriority(forceFlush bool, streamID uint32, p http2.PriorityParam) error { + if err := f.fr.WritePriority(streamID, p); err != nil { + return err + } + if forceFlush { + return f.writer.Flush() + } + return nil +} + +func (f *framer) writePushPromise(forceFlush bool, p http2.PushPromiseParam) error { + if err := f.fr.WritePushPromise(p); err != nil { + return err + } + if forceFlush { + return f.writer.Flush() + } + return nil +} + +func (f *framer) writeRSTStream(forceFlush bool, streamID uint32, code http2.ErrCode) error { + if err := f.fr.WriteRSTStream(streamID, code); err != nil { + return err + } + if forceFlush { + return f.writer.Flush() + } + return nil +} + +func (f *framer) writeSettings(forceFlush bool, settings ...http2.Setting) error { + if err := f.fr.WriteSettings(settings...); err != nil { + return err + } + if forceFlush { + return f.writer.Flush() + } + return nil +} + +func (f *framer) writeSettingsAck(forceFlush bool) error { + if err := f.fr.WriteSettingsAck(); err != nil { + return err + } + if forceFlush { + return f.writer.Flush() + } + return nil +} + +func (f *framer) writeWindowUpdate(forceFlush bool, streamID, incr uint32) error { + if err := f.fr.WriteWindowUpdate(streamID, incr); err != nil { + return err + } + if forceFlush { + return f.writer.Flush() + } + return nil +} + +func (f *framer) flushWrite() error { + return f.writer.Flush() +} + +func (f *framer) readFrame() (http2.Frame, error) { + return f.fr.ReadFrame() +}