Refactored the inbound flow control and fixed a couple of issues.

This commit is contained in:
iamqizhao
2016-04-14 14:16:39 -07:00
parent c1db6d8439
commit 963ee99c99
4 changed files with 106 additions and 141 deletions

View File

@ -162,10 +162,6 @@ func (qb *quotaPool) acquire() <-chan int {
type inFlow struct { type inFlow struct {
// The inbound flow control limit for pending data. // The inbound flow control limit for pending data.
limit uint32 limit uint32
// conn points to the shared connection-level inFlow that is shared
// by all streams on that conn. It is nil for the inFlow on the conn
// directly.
conn *inFlow
mu sync.Mutex mu sync.Mutex
// pendingData is the overall data which have been received but not been // pendingData is the overall data which have been received but not been
@ -176,97 +172,39 @@ type inFlow struct {
pendingUpdate uint32 pendingUpdate uint32
} }
// onData is invoked when some data frame is received. It increments not only its // onData is invoked when some data frame is received. It updates pendingData.
// own pendingData but also that of the associated connection-level flow.
func (f *inFlow) onData(n uint32) error { func (f *inFlow) onData(n uint32) error {
if n == 0 {
return nil
}
f.mu.Lock() f.mu.Lock()
defer f.mu.Unlock() defer f.mu.Unlock()
if f.pendingData+f.pendingUpdate+n > f.limit {
return fmt.Errorf("received %d-bytes data exceeding the limit %d bytes", f.pendingData+f.pendingUpdate+n, f.limit)
}
if f.conn != nil {
if err := f.conn.onData(n); err != nil {
return ConnectionErrorf("%v", err)
}
}
f.pendingData += n f.pendingData += n
if f.pendingData+f.pendingUpdate > f.limit {
return fmt.Errorf("received %d-bytes data exceeding the limit %d bytes", f.pendingData+f.pendingUpdate, f.limit)
}
return nil return nil
} }
// adjustConnPendingUpdate increments the connection level pending updates by n. // onRead is invoked when the application reads the data. It returns the window size
// This is called to make the proper connection level window updates when // to be sent to the peer.
// receiving data frame targeting the canceled RPCs. func (f *inFlow) onRead(n uint32) uint32 {
func (f *inFlow) adjustConnPendingUpdate(n uint32) (uint32, error) {
if n == 0 || f.conn != nil {
return 0, nil
}
f.mu.Lock() f.mu.Lock()
defer f.mu.Unlock() defer f.mu.Unlock()
if f.pendingData+f.pendingUpdate+n > f.limit { if f.pendingData == 0 {
return 0, ConnectionErrorf("received %d-bytes data exceeding the limit %d bytes", f.pendingData+f.pendingUpdate+n, f.limit)
}
f.pendingUpdate += n
if f.pendingUpdate >= f.limit/4 {
ret := f.pendingUpdate
f.pendingUpdate = 0
return ret, nil
}
return 0, nil
}
// connOnRead updates the connection level states when the application consumes data.
func (f *inFlow) connOnRead(n uint32) uint32 {
if n == 0 || f.conn != nil {
return 0 return 0
} }
f.mu.Lock()
defer f.mu.Unlock()
f.pendingData -= n f.pendingData -= n
f.pendingUpdate += n f.pendingUpdate += n
if f.pendingUpdate >= f.limit/4 { if f.pendingUpdate >= f.limit/4 {
ret := f.pendingUpdate wu := f.pendingUpdate
f.pendingUpdate = 0 f.pendingUpdate = 0
return ret return wu
} }
return 0 return 0
} }
// onRead is invoked when the application reads the data. It returns the window updates func (f *inFlow) getPendingData() uint32 {
// for both stream and connection level.
func (f *inFlow) onRead(n uint32) (swu, cwu uint32) {
if n == 0 {
return
}
f.mu.Lock()
defer f.mu.Unlock()
if f.pendingData == 0 {
// pendingData has been adjusted by restoreConn.
return
}
f.pendingData -= n
f.pendingUpdate += n
if f.pendingUpdate >= f.limit/4 {
swu = f.pendingUpdate
f.pendingUpdate = 0
}
cwu = f.conn.connOnRead(n)
return
}
// restoreConn is invoked when a stream is terminated. It removes its stake in
// the connection-level flow and resets its own state.
func (f *inFlow) restoreConn() uint32 {
if f.conn == nil {
return 0
}
f.mu.Lock() f.mu.Lock()
defer f.mu.Unlock() defer f.mu.Unlock()
n := f.pendingData n := f.pendingData
f.pendingData = 0 f.pendingData = 0
f.pendingUpdate = 0 return n
return f.conn.connOnRead(n)
} }

View File

@ -202,17 +202,13 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
} }
func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream { func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream {
fc := &inFlow{
limit: initialWindowSize,
conn: t.fc,
}
// TODO(zhaoq): Handle uint32 overflow of Stream.id. // TODO(zhaoq): Handle uint32 overflow of Stream.id.
s := &Stream{ s := &Stream{
id: t.nextID, id: t.nextID,
method: callHdr.Method, method: callHdr.Method,
sendCompress: callHdr.SendCompress, sendCompress: callHdr.SendCompress,
buf: newRecvBuffer(), buf: newRecvBuffer(),
fc: fc, fc: &inFlow{limit: initialWindowSize},
sendQuotaPool: newQuotaPool(int(t.streamSendQuota)), sendQuotaPool: newQuotaPool(int(t.streamSendQuota)),
headerChan: make(chan struct{}), headerChan: make(chan struct{}),
} }
@ -236,9 +232,9 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
var timeout time.Duration var timeout time.Duration
if dl, ok := ctx.Deadline(); ok { if dl, ok := ctx.Deadline(); ok {
timeout = dl.Sub(time.Now()) timeout = dl.Sub(time.Now())
if timeout <= 0 { }
return nil, ContextErr(context.DeadlineExceeded) if err := ctx.Err(); err != nil {
} return nil, ContextErr(err)
} }
pr := &peer.Peer{ pr := &peer.Peer{
Addr: t.conn.RemoteAddr(), Addr: t.conn.RemoteAddr(),
@ -404,8 +400,10 @@ func (t *http2Client) CloseStream(s *Stream, err error) {
// other goroutines. // other goroutines.
s.cancel() s.cancel()
s.mu.Lock() s.mu.Lock()
if q := s.fc.restoreConn(); q > 0 { if q := s.fc.getPendingData(); q > 0 {
t.controlBuf.put(&windowUpdate{0, q}) if n := t.fc.onRead(q); n > 0 {
t.controlBuf.put(&windowUpdate{0, n})
}
} }
if s.state == streamDone { if s.state == streamDone {
s.mu.Unlock() s.mu.Unlock()
@ -505,6 +503,10 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error {
t.framer.adjustNumWriters(1) t.framer.adjustNumWriters(1)
// Got some quota. Try to acquire writing privilege on the transport. // Got some quota. Try to acquire writing privilege on the transport.
if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil { if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil {
if _, ok := err.(StreamError); ok {
// Return the connection quota back.
t.sendQuotaPool.add(len(p))
}
if t.framer.adjustNumWriters(-1) == 0 { if t.framer.adjustNumWriters(-1) == 0 {
// This writer is the last one in this batch and has the // This writer is the last one in this batch and has the
// responsibility to flush the buffered frames. It queues // responsibility to flush the buffered frames. It queues
@ -514,6 +516,14 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error {
} }
return err return err
} }
if s.ctx.Err() != nil {
t.sendQuotaPool.add(len(p))
if t.framer.adjustNumWriters(-1) == 0 {
t.controlBuf.put(&flushIO{})
}
t.writableChan <- 0
return ContextErr(s.ctx.Err())
}
if r.Len() == 0 && t.framer.adjustNumWriters(0) == 1 { if r.Len() == 0 && t.framer.adjustNumWriters(0) == 1 {
// Do a force flush iff this is last frame for the entire gRPC message // Do a force flush iff this is last frame for the entire gRPC message
// and the caller is the only writer at this moment. // and the caller is the only writer at this moment.
@ -560,41 +570,39 @@ func (t *http2Client) getStream(f http2.Frame) (*Stream, bool) {
// Window updates will deliver to the controller for sending when // Window updates will deliver to the controller for sending when
// the cumulative quota exceeds the corresponding threshold. // the cumulative quota exceeds the corresponding threshold.
func (t *http2Client) updateWindow(s *Stream, n uint32) { func (t *http2Client) updateWindow(s *Stream, n uint32) {
swu, cwu := s.fc.onRead(n) if w := t.fc.onRead(n); w > 0 {
if swu > 0 { t.controlBuf.put(&windowUpdate{0, w})
t.controlBuf.put(&windowUpdate{s.id, swu})
} }
if cwu > 0 { if w := s.fc.onRead(n); w > 0 {
t.controlBuf.put(&windowUpdate{0, cwu}) t.controlBuf.put(&windowUpdate{s.id, w})
} }
} }
func (t *http2Client) handleData(f *http2.DataFrame) { func (t *http2Client) handleData(f *http2.DataFrame) {
// Select the right stream to dispatch. // Select the right stream to dispatch.
size := len(f.Data()) size := len(f.Data())
if err := t.fc.onData(uint32(size)); err != nil {
t.notifyError(ConnectionErrorf("%v", err))
return
}
s, ok := t.getStream(f) s, ok := t.getStream(f)
if !ok { if !ok {
cwu, err := t.fc.adjustConnPendingUpdate(uint32(size)) if w := t.fc.onRead(uint32(size)); w > 0 {
if err != nil { t.controlBuf.put(&windowUpdate{0, w})
t.notifyError(err)
return
}
if cwu > 0 {
t.controlBuf.put(&windowUpdate{0, cwu})
} }
return return
} }
if size > 0 { if size > 0 {
s.mu.Lock()
if s.state == streamDone {
s.mu.Unlock()
// The stream has been closed. Release the corresponding quota.
if w := t.fc.onRead(uint32(size)); w > 0 {
t.controlBuf.put(&windowUpdate{0, w})
}
return
}
if err := s.fc.onData(uint32(size)); err != nil { if err := s.fc.onData(uint32(size)); err != nil {
if _, ok := err.(ConnectionError); ok {
t.notifyError(err)
return
}
s.mu.Lock()
if s.state == streamDone {
s.mu.Unlock()
return
}
s.state = streamDone s.state = streamDone
s.statusCode = codes.Internal s.statusCode = codes.Internal
s.statusDesc = err.Error() s.statusDesc = err.Error()
@ -603,6 +611,7 @@ func (t *http2Client) handleData(f *http2.DataFrame) {
t.controlBuf.put(&resetStream{s.id, http2.ErrCodeFlowControl}) t.controlBuf.put(&resetStream{s.id, http2.ErrCodeFlowControl})
return return
} }
s.mu.Unlock()
// TODO(bradfitz, zhaoq): A copy is required here because there is no // TODO(bradfitz, zhaoq): A copy is required here because there is no
// guarantee f.Data() is consumed before the arrival of next frame. // guarantee f.Data() is consumed before the arrival of next frame.
// Can this copy be eliminated? // Can this copy be eliminated?

View File

@ -139,15 +139,11 @@ func newHTTP2Server(conn net.Conn, maxStreams uint32, authInfo credentials.AuthI
// operateHeader takes action on the decoded headers. // operateHeader takes action on the decoded headers.
func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(*Stream)) { func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(*Stream)) {
buf := newRecvBuffer() buf := newRecvBuffer()
fc := &inFlow{
limit: initialWindowSize,
conn: t.fc,
}
s := &Stream{ s := &Stream{
id: frame.Header().StreamID, id: frame.Header().StreamID,
st: t, st: t,
buf: buf, buf: buf,
fc: fc, fc: &inFlow{limit: initialWindowSize},
} }
var state decodeState var state decodeState
@ -307,42 +303,46 @@ func (t *http2Server) getStream(f http2.Frame) (*Stream, bool) {
// Window updates will deliver to the controller for sending when // Window updates will deliver to the controller for sending when
// the cumulative quota exceeds the corresponding threshold. // the cumulative quota exceeds the corresponding threshold.
func (t *http2Server) updateWindow(s *Stream, n uint32) { func (t *http2Server) updateWindow(s *Stream, n uint32) {
swu, cwu := s.fc.onRead(n) if w := t.fc.onRead(n); w > 0 {
if swu > 0 { t.controlBuf.put(&windowUpdate{0, w})
t.controlBuf.put(&windowUpdate{s.id, swu})
} }
if cwu > 0 { if w := s.fc.onRead(n); w > 0 {
t.controlBuf.put(&windowUpdate{0, cwu}) t.controlBuf.put(&windowUpdate{s.id, w})
} }
} }
func (t *http2Server) handleData(f *http2.DataFrame) { func (t *http2Server) handleData(f *http2.DataFrame) {
// Select the right stream to dispatch. // Select the right stream to dispatch.
size := len(f.Data()) size := len(f.Data())
if err := t.fc.onData(uint32(size)); err != nil {
grpclog.Printf("transport: http2Server %v", err)
t.Close()
return
}
s, ok := t.getStream(f) s, ok := t.getStream(f)
if !ok { if !ok {
cwu, err := t.fc.adjustConnPendingUpdate(uint32(size)) if w := t.fc.onRead(uint32(size)); w > 0 {
if err != nil { t.controlBuf.put(&windowUpdate{0, w})
grpclog.Printf("transport: http2Server %v", err)
t.Close()
return
}
if cwu > 0 {
t.controlBuf.put(&windowUpdate{0, cwu})
} }
return return
} }
if size > 0 { if size > 0 {
if err := s.fc.onData(uint32(size)); err != nil { s.mu.Lock()
if _, ok := err.(ConnectionError); ok { if s.state == streamDone {
grpclog.Printf("transport: http2Server %v", err) s.mu.Unlock()
t.Close() // The stream has been closed. Release the corresponding quota.
return if w := t.fc.onRead(uint32(size)); w > 0 {
t.controlBuf.put(&windowUpdate{0, w})
} }
return
}
if err := s.fc.onData(uint32(size)); err != nil {
s.mu.Unlock()
t.closeStream(s) t.closeStream(s)
t.controlBuf.put(&resetStream{s.id, http2.ErrCodeFlowControl}) t.controlBuf.put(&resetStream{s.id, http2.ErrCodeFlowControl})
return return
} }
s.mu.Unlock()
// TODO(bradfitz, zhaoq): A copy is required here because there is no // TODO(bradfitz, zhaoq): A copy is required here because there is no
// guarantee f.Data() is consumed before the arrival of next frame. // guarantee f.Data() is consumed before the arrival of next frame.
// Can this copy be eliminated? // Can this copy be eliminated?
@ -516,6 +516,10 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error {
// TODO(zhaoq): Support multi-writers for a single stream. // TODO(zhaoq): Support multi-writers for a single stream.
var writeHeaderFrame bool var writeHeaderFrame bool
s.mu.Lock() s.mu.Lock()
if s.state == streamDone {
s.mu.Unlock()
return StreamErrorf(codes.Unknown, "the stream has been done")
}
if !s.headerOk { if !s.headerOk {
writeHeaderFrame = true writeHeaderFrame = true
s.headerOk = true s.headerOk = true
@ -583,6 +587,10 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error {
// Got some quota. Try to acquire writing privilege on the // Got some quota. Try to acquire writing privilege on the
// transport. // transport.
if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil { if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil {
if _, ok := err.(StreamError); ok {
// Return the connection quota back.
t.sendQuotaPool.add(ps)
}
if t.framer.adjustNumWriters(-1) == 0 { if t.framer.adjustNumWriters(-1) == 0 {
// This writer is the last one in this batch and has the // This writer is the last one in this batch and has the
// responsibility to flush the buffered frames. It queues // responsibility to flush the buffered frames. It queues
@ -592,6 +600,14 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error {
} }
return err return err
} }
if s.ctx.Err() != nil {
t.sendQuotaPool.add(ps)
if t.framer.adjustNumWriters(-1) == 0 {
t.controlBuf.put(&flushIO{})
}
t.writableChan <- 0
return ContextErr(s.ctx.Err())
}
var forceFlush bool var forceFlush bool
if r.Len() == 0 && t.framer.adjustNumWriters(0) == 1 && !opts.Last { if r.Len() == 0 && t.framer.adjustNumWriters(0) == 1 && !opts.Last {
forceFlush = true forceFlush = true
@ -689,20 +705,22 @@ func (t *http2Server) closeStream(s *Stream) {
t.mu.Lock() t.mu.Lock()
delete(t.activeStreams, s.id) delete(t.activeStreams, s.id)
t.mu.Unlock() t.mu.Unlock()
if q := s.fc.restoreConn(); q > 0 { // In case stream sending and receiving are invoked in separate
t.controlBuf.put(&windowUpdate{0, q}) // goroutines (e.g., bi-directional streaming), cancel needs to be
} // called to interrupt the potential blocking on other goroutines.
s.cancel()
s.mu.Lock() s.mu.Lock()
if q := s.fc.getPendingData(); q > 0 {
if w := t.fc.onRead(q); w > 0 {
t.controlBuf.put(&windowUpdate{0, w})
}
}
if s.state == streamDone { if s.state == streamDone {
s.mu.Unlock() s.mu.Unlock()
return return
} }
s.state = streamDone s.state = streamDone
s.mu.Unlock() s.mu.Unlock()
// In case stream sending and receiving are invoked in separate
// goroutines (e.g., bi-directional streaming), cancel needs to be
// called to interrupt the potential blocking on other goroutines.
s.cancel()
} }
func (t *http2Server) RemoteAddr() net.Addr { func (t *http2Server) RemoteAddr() net.Addr {

View File

@ -584,8 +584,8 @@ func TestServerWithMisbehavedClient(t *testing.T) {
t.Fatalf("%v got err %v with statusCode %d, want err <EOF> with statusCode %d", s, err, s.statusCode, code) t.Fatalf("%v got err %v with statusCode %d, want err <EOF> with statusCode %d", s, err, s.statusCode, code)
} }
if ss.fc.pendingData != 0 || ss.fc.pendingUpdate != 0 || sc.fc.pendingData != 0 || sc.fc.pendingUpdate != initialWindowSize { if ss.fc.pendingData != 0 || ss.fc.pendingUpdate != 0 || sc.fc.pendingData != 0 || sc.fc.pendingUpdate <= initialWindowSize {
t.Fatalf("Server mistakenly resets inbound flow control params: got %d, %d, %d, %d; want 0, 0, 0, %d", ss.fc.pendingData, ss.fc.pendingUpdate, sc.fc.pendingData, sc.fc.pendingUpdate, initialWindowSize) t.Fatalf("Server mistakenly resets inbound flow control params: got %d, %d, %d, %d; want 0, 0, 0, >%d", ss.fc.pendingData, ss.fc.pendingUpdate, sc.fc.pendingData, sc.fc.pendingUpdate, initialWindowSize)
} }
ct.CloseStream(s, nil) ct.CloseStream(s, nil)
// Test server behavior for violation of connection flow control window size restriction. // Test server behavior for violation of connection flow control window size restriction.
@ -631,15 +631,15 @@ func TestClientWithMisbehavedServer(t *testing.T) {
break break
} }
} }
if s.fc.pendingData != initialWindowSize || s.fc.pendingUpdate != 0 || conn.fc.pendingData != initialWindowSize || conn.fc.pendingUpdate != 0 { if s.fc.pendingData <= initialWindowSize || s.fc.pendingUpdate != 0 || conn.fc.pendingData <= initialWindowSize || conn.fc.pendingUpdate != 0 {
t.Fatalf("Client mistakenly updates inbound flow control params: got %d, %d, %d, %d; want %d, %d, %d, %d", s.fc.pendingData, s.fc.pendingUpdate, conn.fc.pendingData, conn.fc.pendingUpdate, initialWindowSize, 0, initialWindowSize, 0) t.Fatalf("Client mistakenly updates inbound flow control params: got %d, %d, %d, %d; want >%d, %d, >%d, %d", s.fc.pendingData, s.fc.pendingUpdate, conn.fc.pendingData, conn.fc.pendingUpdate, initialWindowSize, 0, initialWindowSize, 0)
} }
if err != io.EOF || s.statusCode != codes.Internal { if err != io.EOF || s.statusCode != codes.Internal {
t.Fatalf("Got err %v and the status code %d, want <EOF> and the code %d", err, s.statusCode, codes.Internal) t.Fatalf("Got err %v and the status code %d, want <EOF> and the code %d", err, s.statusCode, codes.Internal)
} }
conn.CloseStream(s, err) conn.CloseStream(s, err)
if s.fc.pendingData != 0 || s.fc.pendingUpdate != 0 || conn.fc.pendingData != 0 || conn.fc.pendingUpdate != initialWindowSize { if s.fc.pendingData != 0 || s.fc.pendingUpdate != 0 || conn.fc.pendingData != 0 || conn.fc.pendingUpdate <= initialWindowSize {
t.Fatalf("Client mistakenly resets inbound flow control params: got %d, %d, %d, %d; want 0, 0, 0, %d", s.fc.pendingData, s.fc.pendingUpdate, conn.fc.pendingData, conn.fc.pendingUpdate, initialWindowSize) t.Fatalf("Client mistakenly resets inbound flow control params: got %d, %d, %d, %d; want 0, 0, 0, >%d", s.fc.pendingData, s.fc.pendingUpdate, conn.fc.pendingData, conn.fc.pendingUpdate, initialWindowSize)
} }
// Test the logic for the violation of the connection flow control window size restriction. // Test the logic for the violation of the connection flow control window size restriction.
// //