From 5a82377e69857eae7e2de0b451668f3cf18d5a31 Mon Sep 17 00:00:00 2001 From: dfawley Date: Mon, 2 Oct 2017 11:56:31 -0700 Subject: [PATCH] transport: refactor of error/cancellation paths (#1533) - The transport is now responsible for closing its own connection when an error occurs or when the context given to it in NewClientTransport() is canceled. - Remove client/server shutdown channels -- add cancel function to allow self-cancellation. - Plumb the clientConn's context into the client transport to allow for the transport to be canceled even after it has been removed from the ac (due to graceful close) when the ClientConn is closed. --- clientconn.go | 93 ++++++++----------------------- transport/http2_client.go | 108 ++++++++++++++---------------------- transport/http2_server.go | 98 ++++++++++++++++---------------- transport/transport.go | 24 +++----- transport/transport_test.go | 18 +++--- vet.sh | 2 +- 6 files changed, 135 insertions(+), 208 deletions(-) diff --git a/clientconn.go b/clientconn.go index 7a61f9c5..ae64a660 100644 --- a/clientconn.go +++ b/clientconn.go @@ -678,7 +678,7 @@ func (ac *addrConn) connect(block bool) error { ac.mu.Unlock() if block { - if err := ac.resetTransport(false); err != nil { + if err := ac.resetTransport(); err != nil { if err != errConnClosing { ac.tearDown(err) } @@ -692,7 +692,7 @@ func (ac *addrConn) connect(block bool) error { } else { // Start a goroutine connecting to the server asynchronously. go func() { - if err := ac.resetTransport(false); err != nil { + if err := ac.resetTransport(); err != nil { grpclog.Warningf("Failed to dial %s: %v; please retry.", ac.addrs[0].Addr, err) if err != errConnClosing { // Keep this ac in cc.conns, to get the reason it's torn down. @@ -867,12 +867,10 @@ func (ac *addrConn) errorf(format string, a ...interface{}) { } } -// resetTransport recreates a transport to the address for ac. -// For the old transport: -// - if drain is true, it will be gracefully closed. -// - otherwise, it will be closed. +// resetTransport recreates a transport to the address for ac. The old +// transport will close itself on error or when the clientconn is closed. // TODO(bar) make sure all state transitions are valid. -func (ac *addrConn) resetTransport(drain bool) error { +func (ac *addrConn) resetTransport() error { ac.mu.Lock() if ac.state == connectivity.Shutdown { ac.mu.Unlock() @@ -888,13 +886,9 @@ func (ac *addrConn) resetTransport(drain bool) error { close(ac.ready) ac.ready = nil } - t := ac.transport ac.transport = nil ac.curAddr = resolver.Address{} ac.mu.Unlock() - if t != nil && !drain { - t.Close() - } ac.cc.mu.RLock() ac.dopts.copts.KeepaliveParams = ac.cc.mkp ac.cc.mu.RUnlock() @@ -931,17 +925,12 @@ func (ac *addrConn) resetTransport(drain bool) error { return errConnClosing } ac.mu.Unlock() - ctx, cancel := context.WithTimeout(ac.ctx, timeout) sinfo := transport.TargetInfo{ Addr: addr.Addr, Metadata: addr.Metadata, } - newTransport, err := transport.NewClientTransport(ctx, sinfo, copts) - // Don't call cancel in success path due to a race in Go 1.6: - // https://github.com/golang/go/issues/15078. + newTransport, err := transport.NewClientTransport(ac.cc.ctx, sinfo, copts, timeout) if err != nil { - cancel() - if e, ok := err.(transport.ConnectionError); ok && !e.Temporary() { return err } @@ -1012,58 +1001,28 @@ func (ac *addrConn) transportMonitor() { ac.mu.Lock() t := ac.transport ac.mu.Unlock() + // Block until we receive a goaway or an error occurs. + select { + case <-t.GoAway(): + case <-t.Error(): + } + // If a GoAway happened, regardless of error, adjust our keepalive + // parameters as appropriate. select { - // This is needed to detect the teardown when - // the addrConn is idle (i.e., no RPC in flight). - case <-ac.ctx.Done(): - select { - case <-t.Error(): - t.Close() - default: - } - return case <-t.GoAway(): ac.adjustParams(t.GetGoAwayReason()) - // If GoAway happens without any network I/O error, the underlying transport - // will be gracefully closed, and a new transport will be created. - // (The transport will be closed when all the pending RPCs finished or failed.) - // If GoAway and some network I/O error happen concurrently, the underlying transport - // will be closed, and a new transport will be created. - var drain bool - select { - case <-t.Error(): - default: - drain = true - } - if err := ac.resetTransport(drain); err != nil { - grpclog.Infof("get error from resetTransport %v, transportMonitor returning", err) - if err != errConnClosing { - // Keep this ac in cc.conns, to get the reason it's torn down. - ac.tearDown(err) - } - return - } - case <-t.Error(): - select { - case <-ac.ctx.Done(): - t.Close() - return - case <-t.GoAway(): - ac.adjustParams(t.GetGoAwayReason()) - default: - } - if err := ac.resetTransport(false); err != nil { - grpclog.Infof("get error from resetTransport %v, transportMonitor returning", err) - ac.mu.Lock() - ac.printf("transport exiting: %v", err) - ac.mu.Unlock() - grpclog.Warningf("grpc: addrConn.transportMonitor exits due to: %v", err) - if err != errConnClosing { - // Keep this ac in cc.conns, to get the reason it's torn down. - ac.tearDown(err) - } - return + default: + } + if err := ac.resetTransport(); err != nil { + ac.mu.Lock() + ac.printf("transport exiting: %v", err) + ac.mu.Unlock() + grpclog.Warningf("grpc: addrConn.transportMonitor exits due to: %v", err) + if err != errConnClosing { + // Keep this ac in cc.conns, to get the reason it's torn down. + ac.tearDown(err) } + return } } } @@ -1137,7 +1096,6 @@ func (ac *addrConn) getReadyTransport() (transport.ClientTransport, bool) { // tearDown doesn't remove ac from ac.cc.conns. func (ac *addrConn) tearDown(err error) { ac.cancel() - ac.mu.Lock() ac.curAddr = resolver.Address{} defer ac.mu.Unlock() @@ -1166,9 +1124,6 @@ func (ac *addrConn) tearDown(err error) { close(ac.ready) ac.ready = nil } - if ac.transport != nil && err != errConnDrain { - ac.transport.Close() - } return } diff --git a/transport/http2_client.go b/transport/http2_client.go index 3d0b5424..0b4846fd 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -43,6 +43,7 @@ import ( // http2Client implements the ClientTransport interface with HTTP2. type http2Client struct { ctx context.Context + cancel context.CancelFunc target string // server name/addr userAgent string md interface{} @@ -52,13 +53,6 @@ type http2Client struct { authInfo credentials.AuthInfo // auth info about the connection nextID uint32 // the next stream ID to be used - // shutdownChan is closed when Close is called. - // Blocking operations should select on shutdownChan to avoid - // blocking forever after Close. - // TODO(zhaoq): Maybe have a channel context? - shutdownChan chan struct{} - // errorChan is closed to notify the I/O error to the caller. - errorChan chan struct{} // goAway is closed to notify the upper layer (i.e., addrConn.transportMonitor) // that the server sent GoAway on this transport. goAway chan struct{} @@ -149,9 +143,20 @@ func isTemporary(err error) bool { // newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2 // and starts to receive messages on it. Non-nil error returns if construction // fails. -func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) (_ ClientTransport, err error) { +func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions, timeout time.Duration) (_ ClientTransport, err error) { scheme := "http" - conn, err := dial(ctx, opts.Dialer, addr.Addr) + ctx, cancel := context.WithCancel(ctx) + connectCtx, connectCancel := context.WithTimeout(ctx, timeout) + defer func() { + if err != nil { + cancel() + // Don't call connectCancel in success path due to a race in Go 1.6: + // https://github.com/golang/go/issues/15078. + connectCancel() + } + }() + + conn, err := dial(connectCtx, opts.Dialer, addr.Addr) if err != nil { if opts.FailOnNonTempDialError { return nil, connectionErrorf(isTemporary(err), err, "transport: error while dialing: %v", err) @@ -170,7 +175,7 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) ( ) if creds := opts.TransportCredentials; creds != nil { scheme = "https" - conn, authInfo, err = creds.ClientHandshake(ctx, addr.Addr, conn) + conn, authInfo, err = creds.ClientHandshake(connectCtx, addr.Addr, conn) if err != nil { // Credentials handshake errors are typically considered permanent // to avoid retrying on e.g. bad certificates. @@ -204,6 +209,7 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) ( } t := &http2Client{ ctx: ctx, + cancel: cancel, target: addr.Addr, userAgent: opts.UserAgent, md: addr.Metadata, @@ -213,8 +219,6 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) ( authInfo: authInfo, // The client initiated stream id is odd starting from 1. nextID: 1, - shutdownChan: make(chan struct{}), - errorChan: make(chan struct{}), goAway: make(chan struct{}), awakenKeepalive: make(chan struct{}, 1), hBuf: &buf, @@ -292,7 +296,10 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) ( } } t.framer.writer.Flush() - go loopyWriter(t.controlBuf, t.shutdownChan, t.itemHandler) + go func() { + loopyWriter(t.ctx, t.controlBuf, t.itemHandler) + t.Close() + }() if t.kp.Time != infinity { go t.keepalive() } @@ -404,7 +411,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea return nil, ErrConnClosing } t.mu.Unlock() - sq, err := wait(ctx, nil, nil, t.shutdownChan, t.streamsQuota.acquire()) + sq, err := wait(ctx, t.ctx, nil, nil, t.streamsQuota.acquire()) if err != nil { return nil, err } @@ -583,12 +590,9 @@ func (t *http2Client) Close() (err error) { t.mu.Unlock() return } - if t.state == reachable || t.state == draining { - close(t.errorChan) - } t.state = closing t.mu.Unlock() - close(t.shutdownChan) + t.cancel() err = t.conn.Close() t.mu.Lock() streams := t.activeStreams @@ -610,23 +614,18 @@ func (t *http2Client) Close() (err error) { } t.statsHandler.HandleConn(t.ctx, connEnd) } - return + return err } +// GracefulClose sets the state to draining, which prevents new streams from +// being created and causes the transport to be closed when the last active +// stream is closed. If there are no active streams, the transport is closed +// immediately. This does nothing if the transport is already draining or +// closing. func (t *http2Client) GracefulClose() error { t.mu.Lock() switch t.state { - case unreachable: - // The server may close the connection concurrently. t is not available for - // any streams. Close it now. - t.mu.Unlock() - t.Close() - return nil - case closing: - t.mu.Unlock() - return nil - } - if t.state == draining { + case closing, draining: t.mu.Unlock() return nil } @@ -645,7 +644,7 @@ func (t *http2Client) Write(s *Stream, hdr []byte, data []byte, opts *Options) e select { case <-s.ctx.Done(): return ContextErr(s.ctx.Err()) - case <-t.shutdownChan: + case <-t.ctx.Done(): return ErrConnClosing default: } @@ -667,12 +666,12 @@ func (t *http2Client) Write(s *Stream, hdr []byte, data []byte, opts *Options) e size := http2MaxFrameLen // Wait until the stream has some quota to send the data. quotaChan, quotaVer := s.sendQuotaPool.acquireWithVersion() - sq, err := wait(s.ctx, s.done, s.goAway, t.shutdownChan, quotaChan) + sq, err := wait(s.ctx, t.ctx, s.done, s.goAway, quotaChan) if err != nil { return err } // Wait until the transport has some quota to send the data. - tq, err := wait(s.ctx, s.done, s.goAway, t.shutdownChan, t.sendQuotaPool.acquire()) + tq, err := wait(s.ctx, t.ctx, s.done, s.goAway, t.sendQuotaPool.acquire()) if err != nil { return err } @@ -692,7 +691,7 @@ func (t *http2Client) Write(s *Stream, hdr []byte, data []byte, opts *Options) e t.sendQuotaPool.add(tq - ps) } // Acquire local send quota to be able to write to the controlBuf. - ltq, err := wait(s.ctx, s.done, s.goAway, t.shutdownChan, s.localSendQuota.acquire()) + ltq, err := wait(s.ctx, t.ctx, s.done, s.goAway, s.localSendQuota.acquire()) if err != nil { if _, ok := err.(ConnectionError); !ok { t.sendQuotaPool.add(ps) @@ -828,7 +827,7 @@ func (t *http2Client) handleData(f *http2.DataFrame) { t.controlBuf.put(bdpPing) } else { if err := t.fc.onData(uint32(size)); err != nil { - t.notifyError(connectionErrorf(true, err, "%v", err)) + t.Close() return } if w := t.fc.onRead(uint32(size)); w > 0 { @@ -945,7 +944,7 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { id := f.LastStreamID if id > 0 && id%2 != 1 { t.mu.Unlock() - t.notifyError(connectionErrorf(true, nil, "received illegal http2 GOAWAY frame: stream ID %d is even", f.LastStreamID)) + t.Close() return } // A client can receive multiple GoAways from server (look at https://github.com/grpc/grpc-go/issues/1387). @@ -959,7 +958,7 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { // If there are multiple GoAways the first one should always have an ID greater than the following ones. if id > t.prevGoAwayID { t.mu.Unlock() - t.notifyError(connectionErrorf(true, nil, "received illegal http2 GOAWAY frame: previously recv GOAWAY frame with LastStramID %d, currently recv %d", id, f.LastStreamID)) + t.Close() return } default: @@ -1105,13 +1104,13 @@ func (t *http2Client) reader() { // Check the validity of server preface. frame, err := t.framer.fr.ReadFrame() if err != nil { - t.notifyError(err) + t.Close() return } atomic.CompareAndSwapUint32(&t.activity, 0, 1) sf, ok := frame.(*http2.SettingsFrame) if !ok { - t.notifyError(err) + t.Close() return } t.handleSettings(sf) @@ -1135,7 +1134,7 @@ func (t *http2Client) reader() { continue } else { // Transport error. - t.notifyError(err) + t.Close() return } } @@ -1192,11 +1191,6 @@ func (t *http2Client) applySettings(ss []http2.Setting) { // The transport layer needs to be refactored to take care of this. func (t *http2Client) itemHandler(i item) error { var err error - defer func() { - if err != nil { - t.notifyError(err) - } - }() switch i := i.(type) { case *dataFrame: err = t.framer.fr.WriteData(i.streamID, i.endStream, i.d) @@ -1287,7 +1281,7 @@ func (t *http2Client) keepalive() { case <-t.awakenKeepalive: // If the control gets here a ping has been sent // need to reset the timer with keepalive.Timeout. - case <-t.shutdownChan: + case <-t.ctx.Done(): return } } else { @@ -1306,13 +1300,13 @@ func (t *http2Client) keepalive() { } t.Close() return - case <-t.shutdownChan: + case <-t.ctx.Done(): if !timer.Stop() { <-timer.C } return } - case <-t.shutdownChan: + case <-t.ctx.Done(): if !timer.Stop() { <-timer.C } @@ -1322,25 +1316,9 @@ func (t *http2Client) keepalive() { } func (t *http2Client) Error() <-chan struct{} { - return t.errorChan + return t.ctx.Done() } func (t *http2Client) GoAway() <-chan struct{} { return t.goAway } - -func (t *http2Client) notifyError(err error) { - t.mu.Lock() - // make sure t.errorChan is closed only once. - if t.state == draining { - t.mu.Unlock() - t.Close() - return - } - if t.state == reachable { - t.state = unreachable - close(t.errorChan) - infof("transport: http2Client.notifyError got notified that the client transport was broken %v.", err) - } - t.mu.Unlock() -} diff --git a/transport/http2_server.go b/transport/http2_server.go index 8763b4c5..307968e1 100644 --- a/transport/http2_server.go +++ b/transport/http2_server.go @@ -52,19 +52,16 @@ var ErrIllegalHeaderWrite = errors.New("transport: the stream is done or WriteHe // http2Server implements the ServerTransport interface with HTTP2. type http2Server struct { ctx context.Context + cancel context.CancelFunc conn net.Conn remoteAddr net.Addr localAddr net.Addr maxStreamID uint32 // max stream ID ever seen authInfo credentials.AuthInfo // auth info about the connection inTapHandle tap.ServerInHandle - // shutdownChan is closed when Close is called. - // Blocking operations should select on shutdownChan to avoid - // blocking forever after Close. - shutdownChan chan struct{} - framer *framer - hBuf *bytes.Buffer // the buffer for HPACK encoding - hEnc *hpack.Encoder // HPACK encoder + framer *framer + hBuf *bytes.Buffer // the buffer for HPACK encoding + hEnc *hpack.Encoder // HPACK encoder // The max number of concurrent streams. maxStreams uint32 // controlBuf delivers all the control related tasks (e.g., window @@ -186,8 +183,10 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err kep.MinTime = defaultKeepalivePolicyMinTime } var buf bytes.Buffer + ctx, cancel := context.WithCancel(context.Background()) t := &http2Server{ - ctx: context.Background(), + ctx: ctx, + cancel: cancel, conn: conn, remoteAddr: conn.RemoteAddr(), localAddr: conn.LocalAddr(), @@ -201,7 +200,6 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err fc: &inFlow{limit: uint32(icwz)}, sendQuotaPool: newQuotaPool(defaultWindowSize), state: reachable, - shutdownChan: make(chan struct{}), activeStreams: make(map[uint32]*Stream), streamSendQuota: defaultWindowSize, stats: config.StatsHandler, @@ -225,7 +223,10 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err t.stats.HandleConn(t.ctx, connBegin) } t.framer.writer.Flush() - go loopyWriter(t.controlBuf, t.shutdownChan, t.itemHandler) + go func() { + loopyWriter(t.ctx, t.controlBuf, t.itemHandler) + t.Close() + }() go t.keepalive() return t, nil } @@ -687,7 +688,7 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error { select { case <-s.ctx.Done(): return ContextErr(s.ctx.Err()) - case <-t.shutdownChan: + case <-t.ctx.Done(): return ErrConnClosing default: } @@ -744,7 +745,7 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error { // OK is adopted. func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error { select { - case <-t.shutdownChan: + case <-t.ctx.Done(): return ErrConnClosing default: } @@ -816,7 +817,7 @@ func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) ( select { case <-s.ctx.Done(): return ContextErr(s.ctx.Err()) - case <-t.shutdownChan: + case <-t.ctx.Done(): return ErrConnClosing default: } @@ -846,12 +847,12 @@ func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) ( size := http2MaxFrameLen // Wait until the stream has some quota to send the data. quotaChan, quotaVer := s.sendQuotaPool.acquireWithVersion() - sq, err := wait(s.ctx, nil, nil, t.shutdownChan, quotaChan) + sq, err := wait(s.ctx, t.ctx, nil, nil, quotaChan) if err != nil { return err } // Wait until the transport has some quota to send the data. - tq, err := wait(s.ctx, nil, nil, t.shutdownChan, t.sendQuotaPool.acquire()) + tq, err := wait(s.ctx, t.ctx, nil, nil, t.sendQuotaPool.acquire()) if err != nil { return err } @@ -871,7 +872,7 @@ func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) ( t.sendQuotaPool.add(tq - ps) } // Acquire local send quota to be able to write to the controlBuf. - ltq, err := wait(s.ctx, nil, nil, t.shutdownChan, s.localSendQuota.acquire()) + ltq, err := wait(s.ctx, t.ctx, nil, nil, s.localSendQuota.acquire()) if err != nil { if _, ok := err.(ConnectionError); !ok { t.sendQuotaPool.add(ps) @@ -960,7 +961,7 @@ func (t *http2Server) keepalive() { t.Close() // Reseting the timer so that the clean-up doesn't deadlock. maxAge.Reset(infinity) - case <-t.shutdownChan: + case <-t.ctx.Done(): } return case <-keepalive.C: @@ -978,7 +979,7 @@ func (t *http2Server) keepalive() { pingSent = true t.controlBuf.put(p) keepalive.Reset(t.kp.Timeout) - case <-t.shutdownChan: + case <-t.ctx.Done(): return } } @@ -990,19 +991,13 @@ var goAwayPing = &ping{data: [8]byte{1, 6, 1, 8, 0, 3, 3, 9}} // is duplicated between the client and the server. // The transport layer needs to be refactored to take care of this. func (t *http2Server) itemHandler(i item) error { - var err error - defer func() { - if err != nil { - t.Close() - errorf("transport: Error while writing: %v", err) - } - }() switch i := i.(type) { case *dataFrame: - err = t.framer.fr.WriteData(i.streamID, i.endStream, i.d) - if err == nil { - i.f() + if err := t.framer.fr.WriteData(i.streamID, i.endStream, i.d); err != nil { + return err } + i.f() + return nil case *headerFrame: t.hBuf.Reset() for _, f := range i.hf { @@ -1017,6 +1012,7 @@ func (t *http2Server) itemHandler(i item) error { } else { endHeaders = true } + var err error if first { first = false err = t.framer.fr.WriteHeaders(http2.HeadersFrameParam{ @@ -1037,17 +1033,17 @@ func (t *http2Server) itemHandler(i item) error { } } atomic.StoreUint32(&t.resetPingStrikes, 1) + return nil case *windowUpdate: - err = t.framer.fr.WriteWindowUpdate(i.streamID, i.increment) + return t.framer.fr.WriteWindowUpdate(i.streamID, i.increment) case *settings: if i.ack { t.applySettings(i.ss) - err = t.framer.fr.WriteSettingsAck() - } else { - err = t.framer.fr.WriteSettings(i.ss...) + return t.framer.fr.WriteSettingsAck() } + return t.framer.fr.WriteSettings(i.ss...) case *resetStream: - err = t.framer.fr.WriteRSTStream(i.streamID, i.code) + return t.framer.fr.WriteRSTStream(i.streamID, i.code) case *goAway: t.mu.Lock() if t.state == closing { @@ -1060,15 +1056,13 @@ func (t *http2Server) itemHandler(i item) error { // Stop accepting more streams now. t.state = draining t.mu.Unlock() - err = t.framer.fr.WriteGoAway(sid, i.code, i.debugData) - if err != nil { + if err := t.framer.fr.WriteGoAway(sid, i.code, i.debugData); err != nil { return err } if i.closeConn { - // Abruptly close the connection following the GoAway. - // But flush out what's inside the buffer first. + // Abruptly close the connection following the GoAway (via + // loopywriter). But flush out what's inside the buffer first. t.framer.writer.Flush() - t.Close() return fmt.Errorf("transport: Connection closing") } return nil @@ -1080,36 +1074,42 @@ func (t *http2Server) itemHandler(i item) error { // originated before the GoAway reaches the client. // After getting the ack or timer expiration send out another GoAway this // time with an ID of the max stream server intends to process. - err = t.framer.fr.WriteGoAway(math.MaxUint32, http2.ErrCodeNo, []byte{}) - err = t.framer.fr.WritePing(false, goAwayPing.data) + if err := t.framer.fr.WriteGoAway(math.MaxUint32, http2.ErrCodeNo, []byte{}); err != nil { + return err + } + if err := t.framer.fr.WritePing(false, goAwayPing.data); err != nil { + return err + } go func() { timer := time.NewTimer(time.Minute) defer timer.Stop() select { case <-t.drainChan: case <-timer.C: - case <-t.shutdownChan: + case <-t.ctx.Done(): return } t.controlBuf.put(&goAway{code: i.code, debugData: i.debugData}) }() + return nil case *flushIO: - err = t.framer.writer.Flush() + return t.framer.writer.Flush() case *ping: if !i.ack { t.bdpEst.timesnap(i.data) } - err = t.framer.fr.WritePing(i.ack, i.data) + return t.framer.fr.WritePing(i.ack, i.data) default: - errorf("transport: http2Server.controller got unexpected item type %v\n", i) + err := status.Errorf(codes.Internal, "transport: http2Server.controller got unexpected item type %t\n", i) + errorf("%v", err) + return err } - return err } // Close starts shutting down the http2Server transport. // TODO(zhaoq): Now the destruction is not blocked on any pending streams. This // could cause some resource issue. Revisit this later. -func (t *http2Server) Close() (err error) { +func (t *http2Server) Close() error { t.mu.Lock() if t.state == closing { t.mu.Unlock() @@ -1119,8 +1119,8 @@ func (t *http2Server) Close() (err error) { streams := t.activeStreams t.activeStreams = nil t.mu.Unlock() - close(t.shutdownChan) - err = t.conn.Close() + t.cancel() + err := t.conn.Close() // Cancel all active streams. for _, s := range streams { s.cancel() @@ -1129,7 +1129,7 @@ func (t *http2Server) Close() (err error) { connEnd := &stats.ConnEnd{} t.stats.HandleConn(t.ctx, connEnd) } - return + return err } // closeStream clears the footprint of a stream when the stream is not needed diff --git a/transport/transport.go b/transport/transport.go index f8d8faed..e4a35403 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -25,6 +25,7 @@ import ( "io" "net" "sync" + "time" "golang.org/x/net/context" "golang.org/x/net/http2" @@ -457,7 +458,6 @@ type transportState int const ( reachable transportState = iota - unreachable closing draining ) @@ -519,8 +519,8 @@ type TargetInfo struct { // NewClientTransport establishes the transport with the required ConnectOptions // and returns it to the caller. -func NewClientTransport(ctx context.Context, target TargetInfo, opts ConnectOptions) (ClientTransport, error) { - return newHTTP2Client(ctx, target, opts) +func NewClientTransport(ctx context.Context, target TargetInfo, opts ConnectOptions, timeout time.Duration) (ClientTransport, error) { + return newHTTP2Client(ctx, target, opts, timeout) } // Options provides additional hints and information for message @@ -702,14 +702,8 @@ func (e StreamError) Error() string { return fmt.Sprintf("stream error: code = %s desc = %q", e.Code, e.Desc) } -// wait blocks until it can receive from ctx.Done, closing, or proceed. -// If it receives from ctx.Done, it returns 0, the StreamError for ctx.Err. -// If it receives from done, it returns 0, io.EOF if ctx is not done; otherwise -// it return the StreamError for ctx.Err. -// If it receives from goAway, it returns 0, ErrStreamDrain. -// If it receives from closing, it returns 0, ErrConnClosing. -// If it receives from proceed, it returns the received integer, nil. -func wait(ctx context.Context, done, goAway, closing <-chan struct{}, proceed <-chan int) (int, error) { +// wait blocks until it can receive from one of the provided contexts or channels +func wait(ctx, tctx context.Context, done, goAway <-chan struct{}, proceed <-chan int) (int, error) { select { case <-ctx.Done(): return 0, ContextErr(ctx.Err()) @@ -717,7 +711,7 @@ func wait(ctx context.Context, done, goAway, closing <-chan struct{}, proceed <- return 0, io.EOF case <-goAway: return 0, ErrStreamDrain - case <-closing: + case <-tctx.Done(): return 0, ErrConnClosing case i := <-proceed: return i, nil @@ -739,7 +733,7 @@ const ( // loopyWriter is run in a separate go routine. It is the single code path that will // write data on wire. -func loopyWriter(cbuf *controlBuffer, done chan struct{}, handler func(item) error) { +func loopyWriter(ctx context.Context, cbuf *controlBuffer, handler func(item) error) { for { select { case i := <-cbuf.get(): @@ -747,7 +741,7 @@ func loopyWriter(cbuf *controlBuffer, done chan struct{}, handler func(item) err if err := handler(i); err != nil { return } - case <-done: + case <-ctx.Done(): return } hasData: @@ -758,7 +752,7 @@ func loopyWriter(cbuf *controlBuffer, done chan struct{}, handler func(item) err if err := handler(i); err != nil { return } - case <-done: + case <-ctx.Done(): return default: if err := handler(&flushIO{}); err != nil { diff --git a/transport/transport_test.go b/transport/transport_test.go index 2bb2b711..e1dd080a 100644 --- a/transport/transport_test.go +++ b/transport/transport_test.go @@ -357,7 +357,7 @@ func setUpWithOptions(t *testing.T, port int, serverConfig *ServerConfig, ht hTy target := TargetInfo{ Addr: addr, } - ct, connErr = NewClientTransport(context.Background(), target, copts) + ct, connErr = NewClientTransport(context.Background(), target, copts, 2*time.Second) if connErr != nil { t.Fatalf("failed to create transport: %v", connErr) } @@ -380,7 +380,7 @@ func setUpWithNoPingServer(t *testing.T, copts ConnectOptions, done chan net.Con } done <- conn }() - tr, err := NewClientTransport(context.Background(), TargetInfo{Addr: lis.Addr().String()}, copts) + tr, err := NewClientTransport(context.Background(), TargetInfo{Addr: lis.Addr().String()}, copts, 2*time.Second) if err != nil { // Server clean-up. lis.Close() @@ -1680,7 +1680,7 @@ func testAccountCheckWindowSize(t *testing.T, wc windowSizeConfig) { }) ctx, cancel := context.WithTimeout(context.Background(), time.Second) - serverSendQuota, err := wait(ctx, nil, nil, nil, st.sendQuotaPool.acquire()) + serverSendQuota, err := wait(ctx, context.Background(), nil, nil, st.sendQuotaPool.acquire()) if err != nil { t.Fatalf("Error while acquiring sendQuota on server. Err: %v", err) } @@ -1702,7 +1702,7 @@ func testAccountCheckWindowSize(t *testing.T, wc windowSizeConfig) { t.Fatalf("Client transport flow control window size is %v, want %v", limit, connectOptions.InitialConnWindowSize) } ctx, cancel = context.WithTimeout(context.Background(), time.Second) - clientSendQuota, err := wait(ctx, nil, nil, nil, ct.sendQuotaPool.acquire()) + clientSendQuota, err := wait(ctx, context.Background(), nil, nil, ct.sendQuotaPool.acquire()) if err != nil { t.Fatalf("Error while acquiring sendQuota on client. Err: %v", err) } @@ -1838,7 +1838,7 @@ func TestAccountCheckExpandingWindow(t *testing.T) { // Check flow conrtrol window on client stream is equal to out flow on server stream. ctx, cancel := context.WithTimeout(context.Background(), time.Second) - serverStreamSendQuota, err := wait(ctx, nil, nil, nil, sstream.sendQuotaPool.acquire()) + serverStreamSendQuota, err := wait(ctx, context.Background(), nil, nil, sstream.sendQuotaPool.acquire()) cancel() if err != nil { return true, fmt.Errorf("error while acquiring server stream send quota. Err: %v", err) @@ -1853,7 +1853,7 @@ func TestAccountCheckExpandingWindow(t *testing.T) { // Check flow control window on server stream is equal to out flow on client stream. ctx, cancel = context.WithTimeout(context.Background(), time.Second) - clientStreamSendQuota, err := wait(ctx, nil, nil, nil, cstream.sendQuotaPool.acquire()) + clientStreamSendQuota, err := wait(ctx, context.Background(), nil, nil, cstream.sendQuotaPool.acquire()) cancel() if err != nil { return true, fmt.Errorf("error while acquiring client stream send quota. Err: %v", err) @@ -1868,7 +1868,7 @@ func TestAccountCheckExpandingWindow(t *testing.T) { // Check flow control window on client transport is equal to out flow of server transport. ctx, cancel = context.WithTimeout(context.Background(), time.Second) - serverTrSendQuota, err := wait(ctx, nil, nil, nil, st.sendQuotaPool.acquire()) + serverTrSendQuota, err := wait(ctx, context.Background(), nil, nil, st.sendQuotaPool.acquire()) cancel() if err != nil { return true, fmt.Errorf("error while acquring server transport send quota. Err: %v", err) @@ -1883,7 +1883,7 @@ func TestAccountCheckExpandingWindow(t *testing.T) { // Check flow control window on server transport is equal to out flow of client transport. ctx, cancel = context.WithTimeout(context.Background(), time.Second) - clientTrSendQuota, err := wait(ctx, nil, nil, nil, ct.sendQuotaPool.acquire()) + clientTrSendQuota, err := wait(ctx, context.Background(), nil, nil, ct.sendQuotaPool.acquire()) cancel() if err != nil { return true, fmt.Errorf("error while acquiring client transport send quota. Err: %v", err) @@ -2055,7 +2055,7 @@ func setUpHTTPStatusTest(t *testing.T, httpStatus int, wh writeHeaders) (stream wh: wh, } server.start(t, lis) - client, err = newHTTP2Client(context.Background(), TargetInfo{Addr: lis.Addr().String()}, ConnectOptions{}) + client, err = newHTTP2Client(context.Background(), TargetInfo{Addr: lis.Addr().String()}, ConnectOptions{}, 2*time.Second) if err != nil { t.Fatalf("Error creating client. Err: %v", err) } diff --git a/vet.sh b/vet.sh index e478188f..4905ed3d 100755 --- a/vet.sh +++ b/vet.sh @@ -65,7 +65,7 @@ git ls-files "*.go" | xargs sed -i 's:"golang.org/x/net/context":"context":' set +o pipefail # TODO: Stop filtering pb.go files once golang/protobuf#214 is fixed. # TODO: Remove clientconn exception once go1.6 support is removed. -go tool vet -all . 2>&1 | grep -vE 'clientconn.go:.*cancel' | grep -vF '.pb.go:' | tee /dev/stderr | (! read) +go tool vet -all . 2>&1 | grep -vF '.pb.go:' | tee /dev/stderr | (! read) set -o pipefail git reset --hard HEAD