diff --git a/transport/control.go b/transport/control.go index b2e602e6..f6da0ce3 100644 --- a/transport/control.go +++ b/transport/control.go @@ -162,10 +162,6 @@ func (qb *quotaPool) acquire() <-chan int { type inFlow struct { // The inbound flow control limit for pending data. limit uint32 - // conn points to the shared connection-level inFlow that is shared - // by all streams on that conn. It is nil for the inFlow on the conn - // directly. - conn *inFlow mu sync.Mutex // pendingData is the overall data which have been received but not been @@ -176,97 +172,39 @@ type inFlow struct { pendingUpdate uint32 } -// onData is invoked when some data frame is received. It increments not only its -// own pendingData but also that of the associated connection-level flow. +// onData is invoked when some data frame is received. It updates pendingData. func (f *inFlow) onData(n uint32) error { - if n == 0 { - return nil - } f.mu.Lock() defer f.mu.Unlock() - if f.pendingData+f.pendingUpdate+n > f.limit { - return fmt.Errorf("received %d-bytes data exceeding the limit %d bytes", f.pendingData+f.pendingUpdate+n, f.limit) - } - if f.conn != nil { - if err := f.conn.onData(n); err != nil { - return ConnectionErrorf("%v", err) - } - } f.pendingData += n + if f.pendingData+f.pendingUpdate > f.limit { + return fmt.Errorf("received %d-bytes data exceeding the limit %d bytes", f.pendingData+f.pendingUpdate, f.limit) + } return nil } -// adjustConnPendingUpdate increments the connection level pending updates by n. -// This is called to make the proper connection level window updates when -// receiving data frame targeting the canceled RPCs. -func (f *inFlow) adjustConnPendingUpdate(n uint32) (uint32, error) { - if n == 0 || f.conn != nil { - return 0, nil - } +// onRead is invoked when the application reads the data. It returns the window size +// to be sent to the peer. +func (f *inFlow) onRead(n uint32) uint32 { f.mu.Lock() defer f.mu.Unlock() - if f.pendingData+f.pendingUpdate+n > f.limit { - return 0, ConnectionErrorf("received %d-bytes data exceeding the limit %d bytes", f.pendingData+f.pendingUpdate+n, f.limit) - } - f.pendingUpdate += n - if f.pendingUpdate >= f.limit/4 { - ret := f.pendingUpdate - f.pendingUpdate = 0 - return ret, nil - } - return 0, nil - -} - -// connOnRead updates the connection level states when the application consumes data. -func (f *inFlow) connOnRead(n uint32) uint32 { - if n == 0 || f.conn != nil { + if f.pendingData == 0 { return 0 } - f.mu.Lock() - defer f.mu.Unlock() f.pendingData -= n f.pendingUpdate += n if f.pendingUpdate >= f.limit/4 { - ret := f.pendingUpdate + wu := f.pendingUpdate f.pendingUpdate = 0 - return ret + return wu } return 0 } -// onRead is invoked when the application reads the data. It returns the window updates -// for both stream and connection level. -func (f *inFlow) onRead(n uint32) (swu, cwu uint32) { - if n == 0 { - return - } - f.mu.Lock() - defer f.mu.Unlock() - if f.pendingData == 0 { - // pendingData has been adjusted by restoreConn. - return - } - f.pendingData -= n - f.pendingUpdate += n - if f.pendingUpdate >= f.limit/4 { - swu = f.pendingUpdate - f.pendingUpdate = 0 - } - cwu = f.conn.connOnRead(n) - return -} - -// restoreConn is invoked when a stream is terminated. It removes its stake in -// the connection-level flow and resets its own state. -func (f *inFlow) restoreConn() uint32 { - if f.conn == nil { - return 0 - } +func (f *inFlow) getPendingData() uint32 { f.mu.Lock() defer f.mu.Unlock() n := f.pendingData f.pendingData = 0 - f.pendingUpdate = 0 - return f.conn.connOnRead(n) + return n } diff --git a/transport/http2_client.go b/transport/http2_client.go index 76101d7a..69288181 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -202,17 +202,13 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e } func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream { - fc := &inFlow{ - limit: initialWindowSize, - conn: t.fc, - } // TODO(zhaoq): Handle uint32 overflow of Stream.id. s := &Stream{ id: t.nextID, method: callHdr.Method, sendCompress: callHdr.SendCompress, buf: newRecvBuffer(), - fc: fc, + fc: &inFlow{limit: initialWindowSize}, sendQuotaPool: newQuotaPool(int(t.streamSendQuota)), headerChan: make(chan struct{}), } @@ -236,9 +232,9 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea var timeout time.Duration if dl, ok := ctx.Deadline(); ok { timeout = dl.Sub(time.Now()) - if timeout <= 0 { - return nil, ContextErr(context.DeadlineExceeded) - } + } + if err := ctx.Err(); err != nil { + return nil, ContextErr(err) } pr := &peer.Peer{ Addr: t.conn.RemoteAddr(), @@ -404,8 +400,10 @@ func (t *http2Client) CloseStream(s *Stream, err error) { // other goroutines. s.cancel() s.mu.Lock() - if q := s.fc.restoreConn(); q > 0 { - t.controlBuf.put(&windowUpdate{0, q}) + if q := s.fc.getPendingData(); q > 0 { + if n := t.fc.onRead(q); n > 0 { + t.controlBuf.put(&windowUpdate{0, n}) + } } if s.state == streamDone { s.mu.Unlock() @@ -505,6 +503,10 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error { t.framer.adjustNumWriters(1) // Got some quota. Try to acquire writing privilege on the transport. if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil { + if _, ok := err.(StreamError); ok { + // Return the connection quota back. + t.sendQuotaPool.add(len(p)) + } if t.framer.adjustNumWriters(-1) == 0 { // This writer is the last one in this batch and has the // responsibility to flush the buffered frames. It queues @@ -514,6 +516,14 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error { } return err } + if s.ctx.Err() != nil { + t.sendQuotaPool.add(len(p)) + if t.framer.adjustNumWriters(-1) == 0 { + t.controlBuf.put(&flushIO{}) + } + t.writableChan <- 0 + return ContextErr(s.ctx.Err()) + } if r.Len() == 0 && t.framer.adjustNumWriters(0) == 1 { // Do a force flush iff this is last frame for the entire gRPC message // and the caller is the only writer at this moment. @@ -560,41 +570,39 @@ func (t *http2Client) getStream(f http2.Frame) (*Stream, bool) { // Window updates will deliver to the controller for sending when // the cumulative quota exceeds the corresponding threshold. func (t *http2Client) updateWindow(s *Stream, n uint32) { - swu, cwu := s.fc.onRead(n) - if swu > 0 { - t.controlBuf.put(&windowUpdate{s.id, swu}) + if w := t.fc.onRead(n); w > 0 { + t.controlBuf.put(&windowUpdate{0, w}) } - if cwu > 0 { - t.controlBuf.put(&windowUpdate{0, cwu}) + if w := s.fc.onRead(n); w > 0 { + t.controlBuf.put(&windowUpdate{s.id, w}) } } func (t *http2Client) handleData(f *http2.DataFrame) { // Select the right stream to dispatch. size := len(f.Data()) + if err := t.fc.onData(uint32(size)); err != nil { + t.notifyError(ConnectionErrorf("%v", err)) + return + } s, ok := t.getStream(f) if !ok { - cwu, err := t.fc.adjustConnPendingUpdate(uint32(size)) - if err != nil { - t.notifyError(err) - return - } - if cwu > 0 { - t.controlBuf.put(&windowUpdate{0, cwu}) + if w := t.fc.onRead(uint32(size)); w > 0 { + t.controlBuf.put(&windowUpdate{0, w}) } return } if size > 0 { + s.mu.Lock() + if s.state == streamDone { + s.mu.Unlock() + // The stream has been closed. Release the corresponding quota. + if w := t.fc.onRead(uint32(size)); w > 0 { + t.controlBuf.put(&windowUpdate{0, w}) + } + return + } if err := s.fc.onData(uint32(size)); err != nil { - if _, ok := err.(ConnectionError); ok { - t.notifyError(err) - return - } - s.mu.Lock() - if s.state == streamDone { - s.mu.Unlock() - return - } s.state = streamDone s.statusCode = codes.Internal s.statusDesc = err.Error() @@ -603,6 +611,7 @@ func (t *http2Client) handleData(f *http2.DataFrame) { t.controlBuf.put(&resetStream{s.id, http2.ErrCodeFlowControl}) return } + s.mu.Unlock() // TODO(bradfitz, zhaoq): A copy is required here because there is no // guarantee f.Data() is consumed before the arrival of next frame. // Can this copy be eliminated? diff --git a/transport/http2_server.go b/transport/http2_server.go index 68f82033..6e0dce6d 100644 --- a/transport/http2_server.go +++ b/transport/http2_server.go @@ -139,15 +139,11 @@ func newHTTP2Server(conn net.Conn, maxStreams uint32, authInfo credentials.AuthI // operateHeader takes action on the decoded headers. func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(*Stream)) { buf := newRecvBuffer() - fc := &inFlow{ - limit: initialWindowSize, - conn: t.fc, - } s := &Stream{ id: frame.Header().StreamID, st: t, buf: buf, - fc: fc, + fc: &inFlow{limit: initialWindowSize}, } var state decodeState @@ -307,42 +303,46 @@ func (t *http2Server) getStream(f http2.Frame) (*Stream, bool) { // Window updates will deliver to the controller for sending when // the cumulative quota exceeds the corresponding threshold. func (t *http2Server) updateWindow(s *Stream, n uint32) { - swu, cwu := s.fc.onRead(n) - if swu > 0 { - t.controlBuf.put(&windowUpdate{s.id, swu}) + if w := t.fc.onRead(n); w > 0 { + t.controlBuf.put(&windowUpdate{0, w}) } - if cwu > 0 { - t.controlBuf.put(&windowUpdate{0, cwu}) + if w := s.fc.onRead(n); w > 0 { + t.controlBuf.put(&windowUpdate{s.id, w}) } } func (t *http2Server) handleData(f *http2.DataFrame) { // Select the right stream to dispatch. size := len(f.Data()) + if err := t.fc.onData(uint32(size)); err != nil { + grpclog.Printf("transport: http2Server %v", err) + t.Close() + return + } s, ok := t.getStream(f) if !ok { - cwu, err := t.fc.adjustConnPendingUpdate(uint32(size)) - if err != nil { - grpclog.Printf("transport: http2Server %v", err) - t.Close() - return - } - if cwu > 0 { - t.controlBuf.put(&windowUpdate{0, cwu}) + if w := t.fc.onRead(uint32(size)); w > 0 { + t.controlBuf.put(&windowUpdate{0, w}) } return } if size > 0 { - if err := s.fc.onData(uint32(size)); err != nil { - if _, ok := err.(ConnectionError); ok { - grpclog.Printf("transport: http2Server %v", err) - t.Close() - return + s.mu.Lock() + if s.state == streamDone { + s.mu.Unlock() + // The stream has been closed. Release the corresponding quota. + if w := t.fc.onRead(uint32(size)); w > 0 { + t.controlBuf.put(&windowUpdate{0, w}) } + return + } + if err := s.fc.onData(uint32(size)); err != nil { + s.mu.Unlock() t.closeStream(s) t.controlBuf.put(&resetStream{s.id, http2.ErrCodeFlowControl}) return } + s.mu.Unlock() // TODO(bradfitz, zhaoq): A copy is required here because there is no // guarantee f.Data() is consumed before the arrival of next frame. // Can this copy be eliminated? @@ -516,6 +516,10 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error { // TODO(zhaoq): Support multi-writers for a single stream. var writeHeaderFrame bool s.mu.Lock() + if s.state == streamDone { + s.mu.Unlock() + return StreamErrorf(codes.Unknown, "the stream has been done") + } if !s.headerOk { writeHeaderFrame = true s.headerOk = true @@ -583,6 +587,10 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error { // Got some quota. Try to acquire writing privilege on the // transport. if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil { + if _, ok := err.(StreamError); ok { + // Return the connection quota back. + t.sendQuotaPool.add(ps) + } if t.framer.adjustNumWriters(-1) == 0 { // This writer is the last one in this batch and has the // responsibility to flush the buffered frames. It queues @@ -592,6 +600,14 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error { } return err } + if s.ctx.Err() != nil { + t.sendQuotaPool.add(ps) + if t.framer.adjustNumWriters(-1) == 0 { + t.controlBuf.put(&flushIO{}) + } + t.writableChan <- 0 + return ContextErr(s.ctx.Err()) + } var forceFlush bool if r.Len() == 0 && t.framer.adjustNumWriters(0) == 1 && !opts.Last { forceFlush = true @@ -689,20 +705,22 @@ func (t *http2Server) closeStream(s *Stream) { t.mu.Lock() delete(t.activeStreams, s.id) t.mu.Unlock() - if q := s.fc.restoreConn(); q > 0 { - t.controlBuf.put(&windowUpdate{0, q}) - } + // In case stream sending and receiving are invoked in separate + // goroutines (e.g., bi-directional streaming), cancel needs to be + // called to interrupt the potential blocking on other goroutines. + s.cancel() s.mu.Lock() + if q := s.fc.getPendingData(); q > 0 { + if w := t.fc.onRead(q); w > 0 { + t.controlBuf.put(&windowUpdate{0, w}) + } + } if s.state == streamDone { s.mu.Unlock() return } s.state = streamDone s.mu.Unlock() - // In case stream sending and receiving are invoked in separate - // goroutines (e.g., bi-directional streaming), cancel needs to be - // called to interrupt the potential blocking on other goroutines. - s.cancel() } func (t *http2Server) RemoteAddr() net.Addr { diff --git a/transport/transport_test.go b/transport/transport_test.go index c9a95328..d63dba31 100644 --- a/transport/transport_test.go +++ b/transport/transport_test.go @@ -584,8 +584,8 @@ func TestServerWithMisbehavedClient(t *testing.T) { t.Fatalf("%v got err %v with statusCode %d, want err with statusCode %d", s, err, s.statusCode, code) } - if ss.fc.pendingData != 0 || ss.fc.pendingUpdate != 0 || sc.fc.pendingData != 0 || sc.fc.pendingUpdate != initialWindowSize { - t.Fatalf("Server mistakenly resets inbound flow control params: got %d, %d, %d, %d; want 0, 0, 0, %d", ss.fc.pendingData, ss.fc.pendingUpdate, sc.fc.pendingData, sc.fc.pendingUpdate, initialWindowSize) + if ss.fc.pendingData != 0 || ss.fc.pendingUpdate != 0 || sc.fc.pendingData != 0 || sc.fc.pendingUpdate <= initialWindowSize { + t.Fatalf("Server mistakenly resets inbound flow control params: got %d, %d, %d, %d; want 0, 0, 0, >%d", ss.fc.pendingData, ss.fc.pendingUpdate, sc.fc.pendingData, sc.fc.pendingUpdate, initialWindowSize) } ct.CloseStream(s, nil) // Test server behavior for violation of connection flow control window size restriction. @@ -631,15 +631,15 @@ func TestClientWithMisbehavedServer(t *testing.T) { break } } - if s.fc.pendingData != initialWindowSize || s.fc.pendingUpdate != 0 || conn.fc.pendingData != initialWindowSize || conn.fc.pendingUpdate != 0 { - t.Fatalf("Client mistakenly updates inbound flow control params: got %d, %d, %d, %d; want %d, %d, %d, %d", s.fc.pendingData, s.fc.pendingUpdate, conn.fc.pendingData, conn.fc.pendingUpdate, initialWindowSize, 0, initialWindowSize, 0) + if s.fc.pendingData <= initialWindowSize || s.fc.pendingUpdate != 0 || conn.fc.pendingData <= initialWindowSize || conn.fc.pendingUpdate != 0 { + t.Fatalf("Client mistakenly updates inbound flow control params: got %d, %d, %d, %d; want >%d, %d, >%d, %d", s.fc.pendingData, s.fc.pendingUpdate, conn.fc.pendingData, conn.fc.pendingUpdate, initialWindowSize, 0, initialWindowSize, 0) } if err != io.EOF || s.statusCode != codes.Internal { t.Fatalf("Got err %v and the status code %d, want and the code %d", err, s.statusCode, codes.Internal) } conn.CloseStream(s, err) - if s.fc.pendingData != 0 || s.fc.pendingUpdate != 0 || conn.fc.pendingData != 0 || conn.fc.pendingUpdate != initialWindowSize { - t.Fatalf("Client mistakenly resets inbound flow control params: got %d, %d, %d, %d; want 0, 0, 0, %d", s.fc.pendingData, s.fc.pendingUpdate, conn.fc.pendingData, conn.fc.pendingUpdate, initialWindowSize) + if s.fc.pendingData != 0 || s.fc.pendingUpdate != 0 || conn.fc.pendingData != 0 || conn.fc.pendingUpdate <= initialWindowSize { + t.Fatalf("Client mistakenly resets inbound flow control params: got %d, %d, %d, %d; want 0, 0, 0, >%d", s.fc.pendingData, s.fc.pendingUpdate, conn.fc.pendingData, conn.fc.pendingUpdate, initialWindowSize) } // Test the logic for the violation of the connection flow control window size restriction. //