transport: refactor of error/cancellation paths (#1533)

- The transport is now responsible for closing its own connection when an error
  occurs or when the context given to it in NewClientTransport() is canceled.

- Remove client/server shutdown channels -- add cancel function to allow
  self-cancellation.

- Plumb the clientConn's context into the client transport to allow for the
  transport to be canceled even after it has been removed from the ac (due to
  graceful close) when the ClientConn is closed.
This commit is contained in:
dfawley
2017-10-02 11:56:31 -07:00
committed by GitHub
parent 4bbdf230d7
commit 5a82377e69
6 changed files with 135 additions and 208 deletions

View File

@ -678,7 +678,7 @@ func (ac *addrConn) connect(block bool) error {
ac.mu.Unlock() ac.mu.Unlock()
if block { if block {
if err := ac.resetTransport(false); err != nil { if err := ac.resetTransport(); err != nil {
if err != errConnClosing { if err != errConnClosing {
ac.tearDown(err) ac.tearDown(err)
} }
@ -692,7 +692,7 @@ func (ac *addrConn) connect(block bool) error {
} else { } else {
// Start a goroutine connecting to the server asynchronously. // Start a goroutine connecting to the server asynchronously.
go func() { go func() {
if err := ac.resetTransport(false); err != nil { if err := ac.resetTransport(); err != nil {
grpclog.Warningf("Failed to dial %s: %v; please retry.", ac.addrs[0].Addr, err) grpclog.Warningf("Failed to dial %s: %v; please retry.", ac.addrs[0].Addr, err)
if err != errConnClosing { if err != errConnClosing {
// Keep this ac in cc.conns, to get the reason it's torn down. // Keep this ac in cc.conns, to get the reason it's torn down.
@ -867,12 +867,10 @@ func (ac *addrConn) errorf(format string, a ...interface{}) {
} }
} }
// resetTransport recreates a transport to the address for ac. // resetTransport recreates a transport to the address for ac. The old
// For the old transport: // transport will close itself on error or when the clientconn is closed.
// - if drain is true, it will be gracefully closed.
// - otherwise, it will be closed.
// TODO(bar) make sure all state transitions are valid. // TODO(bar) make sure all state transitions are valid.
func (ac *addrConn) resetTransport(drain bool) error { func (ac *addrConn) resetTransport() error {
ac.mu.Lock() ac.mu.Lock()
if ac.state == connectivity.Shutdown { if ac.state == connectivity.Shutdown {
ac.mu.Unlock() ac.mu.Unlock()
@ -888,13 +886,9 @@ func (ac *addrConn) resetTransport(drain bool) error {
close(ac.ready) close(ac.ready)
ac.ready = nil ac.ready = nil
} }
t := ac.transport
ac.transport = nil ac.transport = nil
ac.curAddr = resolver.Address{} ac.curAddr = resolver.Address{}
ac.mu.Unlock() ac.mu.Unlock()
if t != nil && !drain {
t.Close()
}
ac.cc.mu.RLock() ac.cc.mu.RLock()
ac.dopts.copts.KeepaliveParams = ac.cc.mkp ac.dopts.copts.KeepaliveParams = ac.cc.mkp
ac.cc.mu.RUnlock() ac.cc.mu.RUnlock()
@ -931,17 +925,12 @@ func (ac *addrConn) resetTransport(drain bool) error {
return errConnClosing return errConnClosing
} }
ac.mu.Unlock() ac.mu.Unlock()
ctx, cancel := context.WithTimeout(ac.ctx, timeout)
sinfo := transport.TargetInfo{ sinfo := transport.TargetInfo{
Addr: addr.Addr, Addr: addr.Addr,
Metadata: addr.Metadata, Metadata: addr.Metadata,
} }
newTransport, err := transport.NewClientTransport(ctx, sinfo, copts) newTransport, err := transport.NewClientTransport(ac.cc.ctx, sinfo, copts, timeout)
// Don't call cancel in success path due to a race in Go 1.6:
// https://github.com/golang/go/issues/15078.
if err != nil { if err != nil {
cancel()
if e, ok := err.(transport.ConnectionError); ok && !e.Temporary() { if e, ok := err.(transport.ConnectionError); ok && !e.Temporary() {
return err return err
} }
@ -1012,48 +1001,19 @@ func (ac *addrConn) transportMonitor() {
ac.mu.Lock() ac.mu.Lock()
t := ac.transport t := ac.transport
ac.mu.Unlock() ac.mu.Unlock()
// Block until we receive a goaway or an error occurs.
select { select {
// This is needed to detect the teardown when
// the addrConn is idle (i.e., no RPC in flight).
case <-ac.ctx.Done():
select {
case <-t.Error():
t.Close()
default:
}
return
case <-t.GoAway(): case <-t.GoAway():
ac.adjustParams(t.GetGoAwayReason())
// If GoAway happens without any network I/O error, the underlying transport
// will be gracefully closed, and a new transport will be created.
// (The transport will be closed when all the pending RPCs finished or failed.)
// If GoAway and some network I/O error happen concurrently, the underlying transport
// will be closed, and a new transport will be created.
var drain bool
select {
case <-t.Error(): case <-t.Error():
default:
drain = true
} }
if err := ac.resetTransport(drain); err != nil { // If a GoAway happened, regardless of error, adjust our keepalive
grpclog.Infof("get error from resetTransport %v, transportMonitor returning", err) // parameters as appropriate.
if err != errConnClosing {
// Keep this ac in cc.conns, to get the reason it's torn down.
ac.tearDown(err)
}
return
}
case <-t.Error():
select { select {
case <-ac.ctx.Done():
t.Close()
return
case <-t.GoAway(): case <-t.GoAway():
ac.adjustParams(t.GetGoAwayReason()) ac.adjustParams(t.GetGoAwayReason())
default: default:
} }
if err := ac.resetTransport(false); err != nil { if err := ac.resetTransport(); err != nil {
grpclog.Infof("get error from resetTransport %v, transportMonitor returning", err)
ac.mu.Lock() ac.mu.Lock()
ac.printf("transport exiting: %v", err) ac.printf("transport exiting: %v", err)
ac.mu.Unlock() ac.mu.Unlock()
@ -1066,7 +1026,6 @@ func (ac *addrConn) transportMonitor() {
} }
} }
} }
}
// wait blocks until i) the new transport is up or ii) ctx is done or iii) ac is closed or // wait blocks until i) the new transport is up or ii) ctx is done or iii) ac is closed or
// iv) transport is in connectivity.TransientFailure and there is a balancer/failfast is true. // iv) transport is in connectivity.TransientFailure and there is a balancer/failfast is true.
@ -1137,7 +1096,6 @@ func (ac *addrConn) getReadyTransport() (transport.ClientTransport, bool) {
// tearDown doesn't remove ac from ac.cc.conns. // tearDown doesn't remove ac from ac.cc.conns.
func (ac *addrConn) tearDown(err error) { func (ac *addrConn) tearDown(err error) {
ac.cancel() ac.cancel()
ac.mu.Lock() ac.mu.Lock()
ac.curAddr = resolver.Address{} ac.curAddr = resolver.Address{}
defer ac.mu.Unlock() defer ac.mu.Unlock()
@ -1166,9 +1124,6 @@ func (ac *addrConn) tearDown(err error) {
close(ac.ready) close(ac.ready)
ac.ready = nil ac.ready = nil
} }
if ac.transport != nil && err != errConnDrain {
ac.transport.Close()
}
return return
} }

View File

@ -43,6 +43,7 @@ import (
// http2Client implements the ClientTransport interface with HTTP2. // http2Client implements the ClientTransport interface with HTTP2.
type http2Client struct { type http2Client struct {
ctx context.Context ctx context.Context
cancel context.CancelFunc
target string // server name/addr target string // server name/addr
userAgent string userAgent string
md interface{} md interface{}
@ -52,13 +53,6 @@ type http2Client struct {
authInfo credentials.AuthInfo // auth info about the connection authInfo credentials.AuthInfo // auth info about the connection
nextID uint32 // the next stream ID to be used nextID uint32 // the next stream ID to be used
// shutdownChan is closed when Close is called.
// Blocking operations should select on shutdownChan to avoid
// blocking forever after Close.
// TODO(zhaoq): Maybe have a channel context?
shutdownChan chan struct{}
// errorChan is closed to notify the I/O error to the caller.
errorChan chan struct{}
// goAway is closed to notify the upper layer (i.e., addrConn.transportMonitor) // goAway is closed to notify the upper layer (i.e., addrConn.transportMonitor)
// that the server sent GoAway on this transport. // that the server sent GoAway on this transport.
goAway chan struct{} goAway chan struct{}
@ -149,9 +143,20 @@ func isTemporary(err error) bool {
// newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2 // newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2
// and starts to receive messages on it. Non-nil error returns if construction // and starts to receive messages on it. Non-nil error returns if construction
// fails. // fails.
func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) (_ ClientTransport, err error) { func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions, timeout time.Duration) (_ ClientTransport, err error) {
scheme := "http" scheme := "http"
conn, err := dial(ctx, opts.Dialer, addr.Addr) ctx, cancel := context.WithCancel(ctx)
connectCtx, connectCancel := context.WithTimeout(ctx, timeout)
defer func() {
if err != nil {
cancel()
// Don't call connectCancel in success path due to a race in Go 1.6:
// https://github.com/golang/go/issues/15078.
connectCancel()
}
}()
conn, err := dial(connectCtx, opts.Dialer, addr.Addr)
if err != nil { if err != nil {
if opts.FailOnNonTempDialError { if opts.FailOnNonTempDialError {
return nil, connectionErrorf(isTemporary(err), err, "transport: error while dialing: %v", err) return nil, connectionErrorf(isTemporary(err), err, "transport: error while dialing: %v", err)
@ -170,7 +175,7 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) (
) )
if creds := opts.TransportCredentials; creds != nil { if creds := opts.TransportCredentials; creds != nil {
scheme = "https" scheme = "https"
conn, authInfo, err = creds.ClientHandshake(ctx, addr.Addr, conn) conn, authInfo, err = creds.ClientHandshake(connectCtx, addr.Addr, conn)
if err != nil { if err != nil {
// Credentials handshake errors are typically considered permanent // Credentials handshake errors are typically considered permanent
// to avoid retrying on e.g. bad certificates. // to avoid retrying on e.g. bad certificates.
@ -204,6 +209,7 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) (
} }
t := &http2Client{ t := &http2Client{
ctx: ctx, ctx: ctx,
cancel: cancel,
target: addr.Addr, target: addr.Addr,
userAgent: opts.UserAgent, userAgent: opts.UserAgent,
md: addr.Metadata, md: addr.Metadata,
@ -213,8 +219,6 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) (
authInfo: authInfo, authInfo: authInfo,
// The client initiated stream id is odd starting from 1. // The client initiated stream id is odd starting from 1.
nextID: 1, nextID: 1,
shutdownChan: make(chan struct{}),
errorChan: make(chan struct{}),
goAway: make(chan struct{}), goAway: make(chan struct{}),
awakenKeepalive: make(chan struct{}, 1), awakenKeepalive: make(chan struct{}, 1),
hBuf: &buf, hBuf: &buf,
@ -292,7 +296,10 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) (
} }
} }
t.framer.writer.Flush() t.framer.writer.Flush()
go loopyWriter(t.controlBuf, t.shutdownChan, t.itemHandler) go func() {
loopyWriter(t.ctx, t.controlBuf, t.itemHandler)
t.Close()
}()
if t.kp.Time != infinity { if t.kp.Time != infinity {
go t.keepalive() go t.keepalive()
} }
@ -404,7 +411,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
return nil, ErrConnClosing return nil, ErrConnClosing
} }
t.mu.Unlock() t.mu.Unlock()
sq, err := wait(ctx, nil, nil, t.shutdownChan, t.streamsQuota.acquire()) sq, err := wait(ctx, t.ctx, nil, nil, t.streamsQuota.acquire())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -583,12 +590,9 @@ func (t *http2Client) Close() (err error) {
t.mu.Unlock() t.mu.Unlock()
return return
} }
if t.state == reachable || t.state == draining {
close(t.errorChan)
}
t.state = closing t.state = closing
t.mu.Unlock() t.mu.Unlock()
close(t.shutdownChan) t.cancel()
err = t.conn.Close() err = t.conn.Close()
t.mu.Lock() t.mu.Lock()
streams := t.activeStreams streams := t.activeStreams
@ -610,23 +614,18 @@ func (t *http2Client) Close() (err error) {
} }
t.statsHandler.HandleConn(t.ctx, connEnd) t.statsHandler.HandleConn(t.ctx, connEnd)
} }
return return err
} }
// GracefulClose sets the state to draining, which prevents new streams from
// being created and causes the transport to be closed when the last active
// stream is closed. If there are no active streams, the transport is closed
// immediately. This does nothing if the transport is already draining or
// closing.
func (t *http2Client) GracefulClose() error { func (t *http2Client) GracefulClose() error {
t.mu.Lock() t.mu.Lock()
switch t.state { switch t.state {
case unreachable: case closing, draining:
// The server may close the connection concurrently. t is not available for
// any streams. Close it now.
t.mu.Unlock()
t.Close()
return nil
case closing:
t.mu.Unlock()
return nil
}
if t.state == draining {
t.mu.Unlock() t.mu.Unlock()
return nil return nil
} }
@ -645,7 +644,7 @@ func (t *http2Client) Write(s *Stream, hdr []byte, data []byte, opts *Options) e
select { select {
case <-s.ctx.Done(): case <-s.ctx.Done():
return ContextErr(s.ctx.Err()) return ContextErr(s.ctx.Err())
case <-t.shutdownChan: case <-t.ctx.Done():
return ErrConnClosing return ErrConnClosing
default: default:
} }
@ -667,12 +666,12 @@ func (t *http2Client) Write(s *Stream, hdr []byte, data []byte, opts *Options) e
size := http2MaxFrameLen size := http2MaxFrameLen
// Wait until the stream has some quota to send the data. // Wait until the stream has some quota to send the data.
quotaChan, quotaVer := s.sendQuotaPool.acquireWithVersion() quotaChan, quotaVer := s.sendQuotaPool.acquireWithVersion()
sq, err := wait(s.ctx, s.done, s.goAway, t.shutdownChan, quotaChan) sq, err := wait(s.ctx, t.ctx, s.done, s.goAway, quotaChan)
if err != nil { if err != nil {
return err return err
} }
// Wait until the transport has some quota to send the data. // Wait until the transport has some quota to send the data.
tq, err := wait(s.ctx, s.done, s.goAway, t.shutdownChan, t.sendQuotaPool.acquire()) tq, err := wait(s.ctx, t.ctx, s.done, s.goAway, t.sendQuotaPool.acquire())
if err != nil { if err != nil {
return err return err
} }
@ -692,7 +691,7 @@ func (t *http2Client) Write(s *Stream, hdr []byte, data []byte, opts *Options) e
t.sendQuotaPool.add(tq - ps) t.sendQuotaPool.add(tq - ps)
} }
// Acquire local send quota to be able to write to the controlBuf. // Acquire local send quota to be able to write to the controlBuf.
ltq, err := wait(s.ctx, s.done, s.goAway, t.shutdownChan, s.localSendQuota.acquire()) ltq, err := wait(s.ctx, t.ctx, s.done, s.goAway, s.localSendQuota.acquire())
if err != nil { if err != nil {
if _, ok := err.(ConnectionError); !ok { if _, ok := err.(ConnectionError); !ok {
t.sendQuotaPool.add(ps) t.sendQuotaPool.add(ps)
@ -828,7 +827,7 @@ func (t *http2Client) handleData(f *http2.DataFrame) {
t.controlBuf.put(bdpPing) t.controlBuf.put(bdpPing)
} else { } else {
if err := t.fc.onData(uint32(size)); err != nil { if err := t.fc.onData(uint32(size)); err != nil {
t.notifyError(connectionErrorf(true, err, "%v", err)) t.Close()
return return
} }
if w := t.fc.onRead(uint32(size)); w > 0 { if w := t.fc.onRead(uint32(size)); w > 0 {
@ -945,7 +944,7 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) {
id := f.LastStreamID id := f.LastStreamID
if id > 0 && id%2 != 1 { if id > 0 && id%2 != 1 {
t.mu.Unlock() t.mu.Unlock()
t.notifyError(connectionErrorf(true, nil, "received illegal http2 GOAWAY frame: stream ID %d is even", f.LastStreamID)) t.Close()
return return
} }
// A client can receive multiple GoAways from server (look at https://github.com/grpc/grpc-go/issues/1387). // A client can receive multiple GoAways from server (look at https://github.com/grpc/grpc-go/issues/1387).
@ -959,7 +958,7 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) {
// If there are multiple GoAways the first one should always have an ID greater than the following ones. // If there are multiple GoAways the first one should always have an ID greater than the following ones.
if id > t.prevGoAwayID { if id > t.prevGoAwayID {
t.mu.Unlock() t.mu.Unlock()
t.notifyError(connectionErrorf(true, nil, "received illegal http2 GOAWAY frame: previously recv GOAWAY frame with LastStramID %d, currently recv %d", id, f.LastStreamID)) t.Close()
return return
} }
default: default:
@ -1105,13 +1104,13 @@ func (t *http2Client) reader() {
// Check the validity of server preface. // Check the validity of server preface.
frame, err := t.framer.fr.ReadFrame() frame, err := t.framer.fr.ReadFrame()
if err != nil { if err != nil {
t.notifyError(err) t.Close()
return return
} }
atomic.CompareAndSwapUint32(&t.activity, 0, 1) atomic.CompareAndSwapUint32(&t.activity, 0, 1)
sf, ok := frame.(*http2.SettingsFrame) sf, ok := frame.(*http2.SettingsFrame)
if !ok { if !ok {
t.notifyError(err) t.Close()
return return
} }
t.handleSettings(sf) t.handleSettings(sf)
@ -1135,7 +1134,7 @@ func (t *http2Client) reader() {
continue continue
} else { } else {
// Transport error. // Transport error.
t.notifyError(err) t.Close()
return return
} }
} }
@ -1192,11 +1191,6 @@ func (t *http2Client) applySettings(ss []http2.Setting) {
// The transport layer needs to be refactored to take care of this. // The transport layer needs to be refactored to take care of this.
func (t *http2Client) itemHandler(i item) error { func (t *http2Client) itemHandler(i item) error {
var err error var err error
defer func() {
if err != nil {
t.notifyError(err)
}
}()
switch i := i.(type) { switch i := i.(type) {
case *dataFrame: case *dataFrame:
err = t.framer.fr.WriteData(i.streamID, i.endStream, i.d) err = t.framer.fr.WriteData(i.streamID, i.endStream, i.d)
@ -1287,7 +1281,7 @@ func (t *http2Client) keepalive() {
case <-t.awakenKeepalive: case <-t.awakenKeepalive:
// If the control gets here a ping has been sent // If the control gets here a ping has been sent
// need to reset the timer with keepalive.Timeout. // need to reset the timer with keepalive.Timeout.
case <-t.shutdownChan: case <-t.ctx.Done():
return return
} }
} else { } else {
@ -1306,13 +1300,13 @@ func (t *http2Client) keepalive() {
} }
t.Close() t.Close()
return return
case <-t.shutdownChan: case <-t.ctx.Done():
if !timer.Stop() { if !timer.Stop() {
<-timer.C <-timer.C
} }
return return
} }
case <-t.shutdownChan: case <-t.ctx.Done():
if !timer.Stop() { if !timer.Stop() {
<-timer.C <-timer.C
} }
@ -1322,25 +1316,9 @@ func (t *http2Client) keepalive() {
} }
func (t *http2Client) Error() <-chan struct{} { func (t *http2Client) Error() <-chan struct{} {
return t.errorChan return t.ctx.Done()
} }
func (t *http2Client) GoAway() <-chan struct{} { func (t *http2Client) GoAway() <-chan struct{} {
return t.goAway return t.goAway
} }
func (t *http2Client) notifyError(err error) {
t.mu.Lock()
// make sure t.errorChan is closed only once.
if t.state == draining {
t.mu.Unlock()
t.Close()
return
}
if t.state == reachable {
t.state = unreachable
close(t.errorChan)
infof("transport: http2Client.notifyError got notified that the client transport was broken %v.", err)
}
t.mu.Unlock()
}

View File

@ -52,16 +52,13 @@ var ErrIllegalHeaderWrite = errors.New("transport: the stream is done or WriteHe
// http2Server implements the ServerTransport interface with HTTP2. // http2Server implements the ServerTransport interface with HTTP2.
type http2Server struct { type http2Server struct {
ctx context.Context ctx context.Context
cancel context.CancelFunc
conn net.Conn conn net.Conn
remoteAddr net.Addr remoteAddr net.Addr
localAddr net.Addr localAddr net.Addr
maxStreamID uint32 // max stream ID ever seen maxStreamID uint32 // max stream ID ever seen
authInfo credentials.AuthInfo // auth info about the connection authInfo credentials.AuthInfo // auth info about the connection
inTapHandle tap.ServerInHandle inTapHandle tap.ServerInHandle
// shutdownChan is closed when Close is called.
// Blocking operations should select on shutdownChan to avoid
// blocking forever after Close.
shutdownChan chan struct{}
framer *framer framer *framer
hBuf *bytes.Buffer // the buffer for HPACK encoding hBuf *bytes.Buffer // the buffer for HPACK encoding
hEnc *hpack.Encoder // HPACK encoder hEnc *hpack.Encoder // HPACK encoder
@ -186,8 +183,10 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err
kep.MinTime = defaultKeepalivePolicyMinTime kep.MinTime = defaultKeepalivePolicyMinTime
} }
var buf bytes.Buffer var buf bytes.Buffer
ctx, cancel := context.WithCancel(context.Background())
t := &http2Server{ t := &http2Server{
ctx: context.Background(), ctx: ctx,
cancel: cancel,
conn: conn, conn: conn,
remoteAddr: conn.RemoteAddr(), remoteAddr: conn.RemoteAddr(),
localAddr: conn.LocalAddr(), localAddr: conn.LocalAddr(),
@ -201,7 +200,6 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err
fc: &inFlow{limit: uint32(icwz)}, fc: &inFlow{limit: uint32(icwz)},
sendQuotaPool: newQuotaPool(defaultWindowSize), sendQuotaPool: newQuotaPool(defaultWindowSize),
state: reachable, state: reachable,
shutdownChan: make(chan struct{}),
activeStreams: make(map[uint32]*Stream), activeStreams: make(map[uint32]*Stream),
streamSendQuota: defaultWindowSize, streamSendQuota: defaultWindowSize,
stats: config.StatsHandler, stats: config.StatsHandler,
@ -225,7 +223,10 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err
t.stats.HandleConn(t.ctx, connBegin) t.stats.HandleConn(t.ctx, connBegin)
} }
t.framer.writer.Flush() t.framer.writer.Flush()
go loopyWriter(t.controlBuf, t.shutdownChan, t.itemHandler) go func() {
loopyWriter(t.ctx, t.controlBuf, t.itemHandler)
t.Close()
}()
go t.keepalive() go t.keepalive()
return t, nil return t, nil
} }
@ -687,7 +688,7 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error {
select { select {
case <-s.ctx.Done(): case <-s.ctx.Done():
return ContextErr(s.ctx.Err()) return ContextErr(s.ctx.Err())
case <-t.shutdownChan: case <-t.ctx.Done():
return ErrConnClosing return ErrConnClosing
default: default:
} }
@ -744,7 +745,7 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error {
// OK is adopted. // OK is adopted.
func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error { func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error {
select { select {
case <-t.shutdownChan: case <-t.ctx.Done():
return ErrConnClosing return ErrConnClosing
default: default:
} }
@ -816,7 +817,7 @@ func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) (
select { select {
case <-s.ctx.Done(): case <-s.ctx.Done():
return ContextErr(s.ctx.Err()) return ContextErr(s.ctx.Err())
case <-t.shutdownChan: case <-t.ctx.Done():
return ErrConnClosing return ErrConnClosing
default: default:
} }
@ -846,12 +847,12 @@ func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) (
size := http2MaxFrameLen size := http2MaxFrameLen
// Wait until the stream has some quota to send the data. // Wait until the stream has some quota to send the data.
quotaChan, quotaVer := s.sendQuotaPool.acquireWithVersion() quotaChan, quotaVer := s.sendQuotaPool.acquireWithVersion()
sq, err := wait(s.ctx, nil, nil, t.shutdownChan, quotaChan) sq, err := wait(s.ctx, t.ctx, nil, nil, quotaChan)
if err != nil { if err != nil {
return err return err
} }
// Wait until the transport has some quota to send the data. // Wait until the transport has some quota to send the data.
tq, err := wait(s.ctx, nil, nil, t.shutdownChan, t.sendQuotaPool.acquire()) tq, err := wait(s.ctx, t.ctx, nil, nil, t.sendQuotaPool.acquire())
if err != nil { if err != nil {
return err return err
} }
@ -871,7 +872,7 @@ func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) (
t.sendQuotaPool.add(tq - ps) t.sendQuotaPool.add(tq - ps)
} }
// Acquire local send quota to be able to write to the controlBuf. // Acquire local send quota to be able to write to the controlBuf.
ltq, err := wait(s.ctx, nil, nil, t.shutdownChan, s.localSendQuota.acquire()) ltq, err := wait(s.ctx, t.ctx, nil, nil, s.localSendQuota.acquire())
if err != nil { if err != nil {
if _, ok := err.(ConnectionError); !ok { if _, ok := err.(ConnectionError); !ok {
t.sendQuotaPool.add(ps) t.sendQuotaPool.add(ps)
@ -960,7 +961,7 @@ func (t *http2Server) keepalive() {
t.Close() t.Close()
// Reseting the timer so that the clean-up doesn't deadlock. // Reseting the timer so that the clean-up doesn't deadlock.
maxAge.Reset(infinity) maxAge.Reset(infinity)
case <-t.shutdownChan: case <-t.ctx.Done():
} }
return return
case <-keepalive.C: case <-keepalive.C:
@ -978,7 +979,7 @@ func (t *http2Server) keepalive() {
pingSent = true pingSent = true
t.controlBuf.put(p) t.controlBuf.put(p)
keepalive.Reset(t.kp.Timeout) keepalive.Reset(t.kp.Timeout)
case <-t.shutdownChan: case <-t.ctx.Done():
return return
} }
} }
@ -990,19 +991,13 @@ var goAwayPing = &ping{data: [8]byte{1, 6, 1, 8, 0, 3, 3, 9}}
// is duplicated between the client and the server. // is duplicated between the client and the server.
// The transport layer needs to be refactored to take care of this. // The transport layer needs to be refactored to take care of this.
func (t *http2Server) itemHandler(i item) error { func (t *http2Server) itemHandler(i item) error {
var err error
defer func() {
if err != nil {
t.Close()
errorf("transport: Error while writing: %v", err)
}
}()
switch i := i.(type) { switch i := i.(type) {
case *dataFrame: case *dataFrame:
err = t.framer.fr.WriteData(i.streamID, i.endStream, i.d) if err := t.framer.fr.WriteData(i.streamID, i.endStream, i.d); err != nil {
if err == nil { return err
i.f()
} }
i.f()
return nil
case *headerFrame: case *headerFrame:
t.hBuf.Reset() t.hBuf.Reset()
for _, f := range i.hf { for _, f := range i.hf {
@ -1017,6 +1012,7 @@ func (t *http2Server) itemHandler(i item) error {
} else { } else {
endHeaders = true endHeaders = true
} }
var err error
if first { if first {
first = false first = false
err = t.framer.fr.WriteHeaders(http2.HeadersFrameParam{ err = t.framer.fr.WriteHeaders(http2.HeadersFrameParam{
@ -1037,17 +1033,17 @@ func (t *http2Server) itemHandler(i item) error {
} }
} }
atomic.StoreUint32(&t.resetPingStrikes, 1) atomic.StoreUint32(&t.resetPingStrikes, 1)
return nil
case *windowUpdate: case *windowUpdate:
err = t.framer.fr.WriteWindowUpdate(i.streamID, i.increment) return t.framer.fr.WriteWindowUpdate(i.streamID, i.increment)
case *settings: case *settings:
if i.ack { if i.ack {
t.applySettings(i.ss) t.applySettings(i.ss)
err = t.framer.fr.WriteSettingsAck() return t.framer.fr.WriteSettingsAck()
} else {
err = t.framer.fr.WriteSettings(i.ss...)
} }
return t.framer.fr.WriteSettings(i.ss...)
case *resetStream: case *resetStream:
err = t.framer.fr.WriteRSTStream(i.streamID, i.code) return t.framer.fr.WriteRSTStream(i.streamID, i.code)
case *goAway: case *goAway:
t.mu.Lock() t.mu.Lock()
if t.state == closing { if t.state == closing {
@ -1060,15 +1056,13 @@ func (t *http2Server) itemHandler(i item) error {
// Stop accepting more streams now. // Stop accepting more streams now.
t.state = draining t.state = draining
t.mu.Unlock() t.mu.Unlock()
err = t.framer.fr.WriteGoAway(sid, i.code, i.debugData) if err := t.framer.fr.WriteGoAway(sid, i.code, i.debugData); err != nil {
if err != nil {
return err return err
} }
if i.closeConn { if i.closeConn {
// Abruptly close the connection following the GoAway. // Abruptly close the connection following the GoAway (via
// But flush out what's inside the buffer first. // loopywriter). But flush out what's inside the buffer first.
t.framer.writer.Flush() t.framer.writer.Flush()
t.Close()
return fmt.Errorf("transport: Connection closing") return fmt.Errorf("transport: Connection closing")
} }
return nil return nil
@ -1080,36 +1074,42 @@ func (t *http2Server) itemHandler(i item) error {
// originated before the GoAway reaches the client. // originated before the GoAway reaches the client.
// After getting the ack or timer expiration send out another GoAway this // After getting the ack or timer expiration send out another GoAway this
// time with an ID of the max stream server intends to process. // time with an ID of the max stream server intends to process.
err = t.framer.fr.WriteGoAway(math.MaxUint32, http2.ErrCodeNo, []byte{}) if err := t.framer.fr.WriteGoAway(math.MaxUint32, http2.ErrCodeNo, []byte{}); err != nil {
err = t.framer.fr.WritePing(false, goAwayPing.data) return err
}
if err := t.framer.fr.WritePing(false, goAwayPing.data); err != nil {
return err
}
go func() { go func() {
timer := time.NewTimer(time.Minute) timer := time.NewTimer(time.Minute)
defer timer.Stop() defer timer.Stop()
select { select {
case <-t.drainChan: case <-t.drainChan:
case <-timer.C: case <-timer.C:
case <-t.shutdownChan: case <-t.ctx.Done():
return return
} }
t.controlBuf.put(&goAway{code: i.code, debugData: i.debugData}) t.controlBuf.put(&goAway{code: i.code, debugData: i.debugData})
}() }()
return nil
case *flushIO: case *flushIO:
err = t.framer.writer.Flush() return t.framer.writer.Flush()
case *ping: case *ping:
if !i.ack { if !i.ack {
t.bdpEst.timesnap(i.data) t.bdpEst.timesnap(i.data)
} }
err = t.framer.fr.WritePing(i.ack, i.data) return t.framer.fr.WritePing(i.ack, i.data)
default: default:
errorf("transport: http2Server.controller got unexpected item type %v\n", i) err := status.Errorf(codes.Internal, "transport: http2Server.controller got unexpected item type %t\n", i)
} errorf("%v", err)
return err return err
} }
}
// Close starts shutting down the http2Server transport. // Close starts shutting down the http2Server transport.
// TODO(zhaoq): Now the destruction is not blocked on any pending streams. This // TODO(zhaoq): Now the destruction is not blocked on any pending streams. This
// could cause some resource issue. Revisit this later. // could cause some resource issue. Revisit this later.
func (t *http2Server) Close() (err error) { func (t *http2Server) Close() error {
t.mu.Lock() t.mu.Lock()
if t.state == closing { if t.state == closing {
t.mu.Unlock() t.mu.Unlock()
@ -1119,8 +1119,8 @@ func (t *http2Server) Close() (err error) {
streams := t.activeStreams streams := t.activeStreams
t.activeStreams = nil t.activeStreams = nil
t.mu.Unlock() t.mu.Unlock()
close(t.shutdownChan) t.cancel()
err = t.conn.Close() err := t.conn.Close()
// Cancel all active streams. // Cancel all active streams.
for _, s := range streams { for _, s := range streams {
s.cancel() s.cancel()
@ -1129,7 +1129,7 @@ func (t *http2Server) Close() (err error) {
connEnd := &stats.ConnEnd{} connEnd := &stats.ConnEnd{}
t.stats.HandleConn(t.ctx, connEnd) t.stats.HandleConn(t.ctx, connEnd)
} }
return return err
} }
// closeStream clears the footprint of a stream when the stream is not needed // closeStream clears the footprint of a stream when the stream is not needed

View File

@ -25,6 +25,7 @@ import (
"io" "io"
"net" "net"
"sync" "sync"
"time"
"golang.org/x/net/context" "golang.org/x/net/context"
"golang.org/x/net/http2" "golang.org/x/net/http2"
@ -457,7 +458,6 @@ type transportState int
const ( const (
reachable transportState = iota reachable transportState = iota
unreachable
closing closing
draining draining
) )
@ -519,8 +519,8 @@ type TargetInfo struct {
// NewClientTransport establishes the transport with the required ConnectOptions // NewClientTransport establishes the transport with the required ConnectOptions
// and returns it to the caller. // and returns it to the caller.
func NewClientTransport(ctx context.Context, target TargetInfo, opts ConnectOptions) (ClientTransport, error) { func NewClientTransport(ctx context.Context, target TargetInfo, opts ConnectOptions, timeout time.Duration) (ClientTransport, error) {
return newHTTP2Client(ctx, target, opts) return newHTTP2Client(ctx, target, opts, timeout)
} }
// Options provides additional hints and information for message // Options provides additional hints and information for message
@ -702,14 +702,8 @@ func (e StreamError) Error() string {
return fmt.Sprintf("stream error: code = %s desc = %q", e.Code, e.Desc) return fmt.Sprintf("stream error: code = %s desc = %q", e.Code, e.Desc)
} }
// wait blocks until it can receive from ctx.Done, closing, or proceed. // wait blocks until it can receive from one of the provided contexts or channels
// If it receives from ctx.Done, it returns 0, the StreamError for ctx.Err. func wait(ctx, tctx context.Context, done, goAway <-chan struct{}, proceed <-chan int) (int, error) {
// If it receives from done, it returns 0, io.EOF if ctx is not done; otherwise
// it return the StreamError for ctx.Err.
// If it receives from goAway, it returns 0, ErrStreamDrain.
// If it receives from closing, it returns 0, ErrConnClosing.
// If it receives from proceed, it returns the received integer, nil.
func wait(ctx context.Context, done, goAway, closing <-chan struct{}, proceed <-chan int) (int, error) {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return 0, ContextErr(ctx.Err()) return 0, ContextErr(ctx.Err())
@ -717,7 +711,7 @@ func wait(ctx context.Context, done, goAway, closing <-chan struct{}, proceed <-
return 0, io.EOF return 0, io.EOF
case <-goAway: case <-goAway:
return 0, ErrStreamDrain return 0, ErrStreamDrain
case <-closing: case <-tctx.Done():
return 0, ErrConnClosing return 0, ErrConnClosing
case i := <-proceed: case i := <-proceed:
return i, nil return i, nil
@ -739,7 +733,7 @@ const (
// loopyWriter is run in a separate go routine. It is the single code path that will // loopyWriter is run in a separate go routine. It is the single code path that will
// write data on wire. // write data on wire.
func loopyWriter(cbuf *controlBuffer, done chan struct{}, handler func(item) error) { func loopyWriter(ctx context.Context, cbuf *controlBuffer, handler func(item) error) {
for { for {
select { select {
case i := <-cbuf.get(): case i := <-cbuf.get():
@ -747,7 +741,7 @@ func loopyWriter(cbuf *controlBuffer, done chan struct{}, handler func(item) err
if err := handler(i); err != nil { if err := handler(i); err != nil {
return return
} }
case <-done: case <-ctx.Done():
return return
} }
hasData: hasData:
@ -758,7 +752,7 @@ func loopyWriter(cbuf *controlBuffer, done chan struct{}, handler func(item) err
if err := handler(i); err != nil { if err := handler(i); err != nil {
return return
} }
case <-done: case <-ctx.Done():
return return
default: default:
if err := handler(&flushIO{}); err != nil { if err := handler(&flushIO{}); err != nil {

View File

@ -357,7 +357,7 @@ func setUpWithOptions(t *testing.T, port int, serverConfig *ServerConfig, ht hTy
target := TargetInfo{ target := TargetInfo{
Addr: addr, Addr: addr,
} }
ct, connErr = NewClientTransport(context.Background(), target, copts) ct, connErr = NewClientTransport(context.Background(), target, copts, 2*time.Second)
if connErr != nil { if connErr != nil {
t.Fatalf("failed to create transport: %v", connErr) t.Fatalf("failed to create transport: %v", connErr)
} }
@ -380,7 +380,7 @@ func setUpWithNoPingServer(t *testing.T, copts ConnectOptions, done chan net.Con
} }
done <- conn done <- conn
}() }()
tr, err := NewClientTransport(context.Background(), TargetInfo{Addr: lis.Addr().String()}, copts) tr, err := NewClientTransport(context.Background(), TargetInfo{Addr: lis.Addr().String()}, copts, 2*time.Second)
if err != nil { if err != nil {
// Server clean-up. // Server clean-up.
lis.Close() lis.Close()
@ -1680,7 +1680,7 @@ func testAccountCheckWindowSize(t *testing.T, wc windowSizeConfig) {
}) })
ctx, cancel := context.WithTimeout(context.Background(), time.Second) ctx, cancel := context.WithTimeout(context.Background(), time.Second)
serverSendQuota, err := wait(ctx, nil, nil, nil, st.sendQuotaPool.acquire()) serverSendQuota, err := wait(ctx, context.Background(), nil, nil, st.sendQuotaPool.acquire())
if err != nil { if err != nil {
t.Fatalf("Error while acquiring sendQuota on server. Err: %v", err) t.Fatalf("Error while acquiring sendQuota on server. Err: %v", err)
} }
@ -1702,7 +1702,7 @@ func testAccountCheckWindowSize(t *testing.T, wc windowSizeConfig) {
t.Fatalf("Client transport flow control window size is %v, want %v", limit, connectOptions.InitialConnWindowSize) t.Fatalf("Client transport flow control window size is %v, want %v", limit, connectOptions.InitialConnWindowSize)
} }
ctx, cancel = context.WithTimeout(context.Background(), time.Second) ctx, cancel = context.WithTimeout(context.Background(), time.Second)
clientSendQuota, err := wait(ctx, nil, nil, nil, ct.sendQuotaPool.acquire()) clientSendQuota, err := wait(ctx, context.Background(), nil, nil, ct.sendQuotaPool.acquire())
if err != nil { if err != nil {
t.Fatalf("Error while acquiring sendQuota on client. Err: %v", err) t.Fatalf("Error while acquiring sendQuota on client. Err: %v", err)
} }
@ -1838,7 +1838,7 @@ func TestAccountCheckExpandingWindow(t *testing.T) {
// Check flow conrtrol window on client stream is equal to out flow on server stream. // Check flow conrtrol window on client stream is equal to out flow on server stream.
ctx, cancel := context.WithTimeout(context.Background(), time.Second) ctx, cancel := context.WithTimeout(context.Background(), time.Second)
serverStreamSendQuota, err := wait(ctx, nil, nil, nil, sstream.sendQuotaPool.acquire()) serverStreamSendQuota, err := wait(ctx, context.Background(), nil, nil, sstream.sendQuotaPool.acquire())
cancel() cancel()
if err != nil { if err != nil {
return true, fmt.Errorf("error while acquiring server stream send quota. Err: %v", err) return true, fmt.Errorf("error while acquiring server stream send quota. Err: %v", err)
@ -1853,7 +1853,7 @@ func TestAccountCheckExpandingWindow(t *testing.T) {
// Check flow control window on server stream is equal to out flow on client stream. // Check flow control window on server stream is equal to out flow on client stream.
ctx, cancel = context.WithTimeout(context.Background(), time.Second) ctx, cancel = context.WithTimeout(context.Background(), time.Second)
clientStreamSendQuota, err := wait(ctx, nil, nil, nil, cstream.sendQuotaPool.acquire()) clientStreamSendQuota, err := wait(ctx, context.Background(), nil, nil, cstream.sendQuotaPool.acquire())
cancel() cancel()
if err != nil { if err != nil {
return true, fmt.Errorf("error while acquiring client stream send quota. Err: %v", err) return true, fmt.Errorf("error while acquiring client stream send quota. Err: %v", err)
@ -1868,7 +1868,7 @@ func TestAccountCheckExpandingWindow(t *testing.T) {
// Check flow control window on client transport is equal to out flow of server transport. // Check flow control window on client transport is equal to out flow of server transport.
ctx, cancel = context.WithTimeout(context.Background(), time.Second) ctx, cancel = context.WithTimeout(context.Background(), time.Second)
serverTrSendQuota, err := wait(ctx, nil, nil, nil, st.sendQuotaPool.acquire()) serverTrSendQuota, err := wait(ctx, context.Background(), nil, nil, st.sendQuotaPool.acquire())
cancel() cancel()
if err != nil { if err != nil {
return true, fmt.Errorf("error while acquring server transport send quota. Err: %v", err) return true, fmt.Errorf("error while acquring server transport send quota. Err: %v", err)
@ -1883,7 +1883,7 @@ func TestAccountCheckExpandingWindow(t *testing.T) {
// Check flow control window on server transport is equal to out flow of client transport. // Check flow control window on server transport is equal to out flow of client transport.
ctx, cancel = context.WithTimeout(context.Background(), time.Second) ctx, cancel = context.WithTimeout(context.Background(), time.Second)
clientTrSendQuota, err := wait(ctx, nil, nil, nil, ct.sendQuotaPool.acquire()) clientTrSendQuota, err := wait(ctx, context.Background(), nil, nil, ct.sendQuotaPool.acquire())
cancel() cancel()
if err != nil { if err != nil {
return true, fmt.Errorf("error while acquiring client transport send quota. Err: %v", err) return true, fmt.Errorf("error while acquiring client transport send quota. Err: %v", err)
@ -2055,7 +2055,7 @@ func setUpHTTPStatusTest(t *testing.T, httpStatus int, wh writeHeaders) (stream
wh: wh, wh: wh,
} }
server.start(t, lis) server.start(t, lis)
client, err = newHTTP2Client(context.Background(), TargetInfo{Addr: lis.Addr().String()}, ConnectOptions{}) client, err = newHTTP2Client(context.Background(), TargetInfo{Addr: lis.Addr().String()}, ConnectOptions{}, 2*time.Second)
if err != nil { if err != nil {
t.Fatalf("Error creating client. Err: %v", err) t.Fatalf("Error creating client. Err: %v", err)
} }

2
vet.sh
View File

@ -65,7 +65,7 @@ git ls-files "*.go" | xargs sed -i 's:"golang.org/x/net/context":"context":'
set +o pipefail set +o pipefail
# TODO: Stop filtering pb.go files once golang/protobuf#214 is fixed. # TODO: Stop filtering pb.go files once golang/protobuf#214 is fixed.
# TODO: Remove clientconn exception once go1.6 support is removed. # TODO: Remove clientconn exception once go1.6 support is removed.
go tool vet -all . 2>&1 | grep -vE 'clientconn.go:.*cancel' | grep -vF '.pb.go:' | tee /dev/stderr | (! read) go tool vet -all . 2>&1 | grep -vF '.pb.go:' | tee /dev/stderr | (! read)
set -o pipefail set -o pipefail
git reset --hard HEAD git reset --hard HEAD