diff --git a/clientconn.go b/clientconn.go index 7a61f9c5..ae64a660 100644 --- a/clientconn.go +++ b/clientconn.go @@ -678,7 +678,7 @@ func (ac *addrConn) connect(block bool) error { ac.mu.Unlock() if block { - if err := ac.resetTransport(false); err != nil { + if err := ac.resetTransport(); err != nil { if err != errConnClosing { ac.tearDown(err) } @@ -692,7 +692,7 @@ func (ac *addrConn) connect(block bool) error { } else { // Start a goroutine connecting to the server asynchronously. go func() { - if err := ac.resetTransport(false); err != nil { + if err := ac.resetTransport(); err != nil { grpclog.Warningf("Failed to dial %s: %v; please retry.", ac.addrs[0].Addr, err) if err != errConnClosing { // Keep this ac in cc.conns, to get the reason it's torn down. @@ -867,12 +867,10 @@ func (ac *addrConn) errorf(format string, a ...interface{}) { } } -// resetTransport recreates a transport to the address for ac. -// For the old transport: -// - if drain is true, it will be gracefully closed. -// - otherwise, it will be closed. +// resetTransport recreates a transport to the address for ac. The old +// transport will close itself on error or when the clientconn is closed. // TODO(bar) make sure all state transitions are valid. -func (ac *addrConn) resetTransport(drain bool) error { +func (ac *addrConn) resetTransport() error { ac.mu.Lock() if ac.state == connectivity.Shutdown { ac.mu.Unlock() @@ -888,13 +886,9 @@ func (ac *addrConn) resetTransport(drain bool) error { close(ac.ready) ac.ready = nil } - t := ac.transport ac.transport = nil ac.curAddr = resolver.Address{} ac.mu.Unlock() - if t != nil && !drain { - t.Close() - } ac.cc.mu.RLock() ac.dopts.copts.KeepaliveParams = ac.cc.mkp ac.cc.mu.RUnlock() @@ -931,17 +925,12 @@ func (ac *addrConn) resetTransport(drain bool) error { return errConnClosing } ac.mu.Unlock() - ctx, cancel := context.WithTimeout(ac.ctx, timeout) sinfo := transport.TargetInfo{ Addr: addr.Addr, Metadata: addr.Metadata, } - newTransport, err := transport.NewClientTransport(ctx, sinfo, copts) - // Don't call cancel in success path due to a race in Go 1.6: - // https://github.com/golang/go/issues/15078. + newTransport, err := transport.NewClientTransport(ac.cc.ctx, sinfo, copts, timeout) if err != nil { - cancel() - if e, ok := err.(transport.ConnectionError); ok && !e.Temporary() { return err } @@ -1012,58 +1001,28 @@ func (ac *addrConn) transportMonitor() { ac.mu.Lock() t := ac.transport ac.mu.Unlock() + // Block until we receive a goaway or an error occurs. + select { + case <-t.GoAway(): + case <-t.Error(): + } + // If a GoAway happened, regardless of error, adjust our keepalive + // parameters as appropriate. select { - // This is needed to detect the teardown when - // the addrConn is idle (i.e., no RPC in flight). - case <-ac.ctx.Done(): - select { - case <-t.Error(): - t.Close() - default: - } - return case <-t.GoAway(): ac.adjustParams(t.GetGoAwayReason()) - // If GoAway happens without any network I/O error, the underlying transport - // will be gracefully closed, and a new transport will be created. - // (The transport will be closed when all the pending RPCs finished or failed.) - // If GoAway and some network I/O error happen concurrently, the underlying transport - // will be closed, and a new transport will be created. - var drain bool - select { - case <-t.Error(): - default: - drain = true - } - if err := ac.resetTransport(drain); err != nil { - grpclog.Infof("get error from resetTransport %v, transportMonitor returning", err) - if err != errConnClosing { - // Keep this ac in cc.conns, to get the reason it's torn down. - ac.tearDown(err) - } - return - } - case <-t.Error(): - select { - case <-ac.ctx.Done(): - t.Close() - return - case <-t.GoAway(): - ac.adjustParams(t.GetGoAwayReason()) - default: - } - if err := ac.resetTransport(false); err != nil { - grpclog.Infof("get error from resetTransport %v, transportMonitor returning", err) - ac.mu.Lock() - ac.printf("transport exiting: %v", err) - ac.mu.Unlock() - grpclog.Warningf("grpc: addrConn.transportMonitor exits due to: %v", err) - if err != errConnClosing { - // Keep this ac in cc.conns, to get the reason it's torn down. - ac.tearDown(err) - } - return + default: + } + if err := ac.resetTransport(); err != nil { + ac.mu.Lock() + ac.printf("transport exiting: %v", err) + ac.mu.Unlock() + grpclog.Warningf("grpc: addrConn.transportMonitor exits due to: %v", err) + if err != errConnClosing { + // Keep this ac in cc.conns, to get the reason it's torn down. + ac.tearDown(err) } + return } } } @@ -1137,7 +1096,6 @@ func (ac *addrConn) getReadyTransport() (transport.ClientTransport, bool) { // tearDown doesn't remove ac from ac.cc.conns. func (ac *addrConn) tearDown(err error) { ac.cancel() - ac.mu.Lock() ac.curAddr = resolver.Address{} defer ac.mu.Unlock() @@ -1166,9 +1124,6 @@ func (ac *addrConn) tearDown(err error) { close(ac.ready) ac.ready = nil } - if ac.transport != nil && err != errConnDrain { - ac.transport.Close() - } return } diff --git a/transport/http2_client.go b/transport/http2_client.go index 3d0b5424..0b4846fd 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -43,6 +43,7 @@ import ( // http2Client implements the ClientTransport interface with HTTP2. type http2Client struct { ctx context.Context + cancel context.CancelFunc target string // server name/addr userAgent string md interface{} @@ -52,13 +53,6 @@ type http2Client struct { authInfo credentials.AuthInfo // auth info about the connection nextID uint32 // the next stream ID to be used - // shutdownChan is closed when Close is called. - // Blocking operations should select on shutdownChan to avoid - // blocking forever after Close. - // TODO(zhaoq): Maybe have a channel context? - shutdownChan chan struct{} - // errorChan is closed to notify the I/O error to the caller. - errorChan chan struct{} // goAway is closed to notify the upper layer (i.e., addrConn.transportMonitor) // that the server sent GoAway on this transport. goAway chan struct{} @@ -149,9 +143,20 @@ func isTemporary(err error) bool { // newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2 // and starts to receive messages on it. Non-nil error returns if construction // fails. -func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) (_ ClientTransport, err error) { +func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions, timeout time.Duration) (_ ClientTransport, err error) { scheme := "http" - conn, err := dial(ctx, opts.Dialer, addr.Addr) + ctx, cancel := context.WithCancel(ctx) + connectCtx, connectCancel := context.WithTimeout(ctx, timeout) + defer func() { + if err != nil { + cancel() + // Don't call connectCancel in success path due to a race in Go 1.6: + // https://github.com/golang/go/issues/15078. + connectCancel() + } + }() + + conn, err := dial(connectCtx, opts.Dialer, addr.Addr) if err != nil { if opts.FailOnNonTempDialError { return nil, connectionErrorf(isTemporary(err), err, "transport: error while dialing: %v", err) @@ -170,7 +175,7 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) ( ) if creds := opts.TransportCredentials; creds != nil { scheme = "https" - conn, authInfo, err = creds.ClientHandshake(ctx, addr.Addr, conn) + conn, authInfo, err = creds.ClientHandshake(connectCtx, addr.Addr, conn) if err != nil { // Credentials handshake errors are typically considered permanent // to avoid retrying on e.g. bad certificates. @@ -204,6 +209,7 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) ( } t := &http2Client{ ctx: ctx, + cancel: cancel, target: addr.Addr, userAgent: opts.UserAgent, md: addr.Metadata, @@ -213,8 +219,6 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) ( authInfo: authInfo, // The client initiated stream id is odd starting from 1. nextID: 1, - shutdownChan: make(chan struct{}), - errorChan: make(chan struct{}), goAway: make(chan struct{}), awakenKeepalive: make(chan struct{}, 1), hBuf: &buf, @@ -292,7 +296,10 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) ( } } t.framer.writer.Flush() - go loopyWriter(t.controlBuf, t.shutdownChan, t.itemHandler) + go func() { + loopyWriter(t.ctx, t.controlBuf, t.itemHandler) + t.Close() + }() if t.kp.Time != infinity { go t.keepalive() } @@ -404,7 +411,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea return nil, ErrConnClosing } t.mu.Unlock() - sq, err := wait(ctx, nil, nil, t.shutdownChan, t.streamsQuota.acquire()) + sq, err := wait(ctx, t.ctx, nil, nil, t.streamsQuota.acquire()) if err != nil { return nil, err } @@ -583,12 +590,9 @@ func (t *http2Client) Close() (err error) { t.mu.Unlock() return } - if t.state == reachable || t.state == draining { - close(t.errorChan) - } t.state = closing t.mu.Unlock() - close(t.shutdownChan) + t.cancel() err = t.conn.Close() t.mu.Lock() streams := t.activeStreams @@ -610,23 +614,18 @@ func (t *http2Client) Close() (err error) { } t.statsHandler.HandleConn(t.ctx, connEnd) } - return + return err } +// GracefulClose sets the state to draining, which prevents new streams from +// being created and causes the transport to be closed when the last active +// stream is closed. If there are no active streams, the transport is closed +// immediately. This does nothing if the transport is already draining or +// closing. func (t *http2Client) GracefulClose() error { t.mu.Lock() switch t.state { - case unreachable: - // The server may close the connection concurrently. t is not available for - // any streams. Close it now. - t.mu.Unlock() - t.Close() - return nil - case closing: - t.mu.Unlock() - return nil - } - if t.state == draining { + case closing, draining: t.mu.Unlock() return nil } @@ -645,7 +644,7 @@ func (t *http2Client) Write(s *Stream, hdr []byte, data []byte, opts *Options) e select { case <-s.ctx.Done(): return ContextErr(s.ctx.Err()) - case <-t.shutdownChan: + case <-t.ctx.Done(): return ErrConnClosing default: } @@ -667,12 +666,12 @@ func (t *http2Client) Write(s *Stream, hdr []byte, data []byte, opts *Options) e size := http2MaxFrameLen // Wait until the stream has some quota to send the data. quotaChan, quotaVer := s.sendQuotaPool.acquireWithVersion() - sq, err := wait(s.ctx, s.done, s.goAway, t.shutdownChan, quotaChan) + sq, err := wait(s.ctx, t.ctx, s.done, s.goAway, quotaChan) if err != nil { return err } // Wait until the transport has some quota to send the data. - tq, err := wait(s.ctx, s.done, s.goAway, t.shutdownChan, t.sendQuotaPool.acquire()) + tq, err := wait(s.ctx, t.ctx, s.done, s.goAway, t.sendQuotaPool.acquire()) if err != nil { return err } @@ -692,7 +691,7 @@ func (t *http2Client) Write(s *Stream, hdr []byte, data []byte, opts *Options) e t.sendQuotaPool.add(tq - ps) } // Acquire local send quota to be able to write to the controlBuf. - ltq, err := wait(s.ctx, s.done, s.goAway, t.shutdownChan, s.localSendQuota.acquire()) + ltq, err := wait(s.ctx, t.ctx, s.done, s.goAway, s.localSendQuota.acquire()) if err != nil { if _, ok := err.(ConnectionError); !ok { t.sendQuotaPool.add(ps) @@ -828,7 +827,7 @@ func (t *http2Client) handleData(f *http2.DataFrame) { t.controlBuf.put(bdpPing) } else { if err := t.fc.onData(uint32(size)); err != nil { - t.notifyError(connectionErrorf(true, err, "%v", err)) + t.Close() return } if w := t.fc.onRead(uint32(size)); w > 0 { @@ -945,7 +944,7 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { id := f.LastStreamID if id > 0 && id%2 != 1 { t.mu.Unlock() - t.notifyError(connectionErrorf(true, nil, "received illegal http2 GOAWAY frame: stream ID %d is even", f.LastStreamID)) + t.Close() return } // A client can receive multiple GoAways from server (look at https://github.com/grpc/grpc-go/issues/1387). @@ -959,7 +958,7 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { // If there are multiple GoAways the first one should always have an ID greater than the following ones. if id > t.prevGoAwayID { t.mu.Unlock() - t.notifyError(connectionErrorf(true, nil, "received illegal http2 GOAWAY frame: previously recv GOAWAY frame with LastStramID %d, currently recv %d", id, f.LastStreamID)) + t.Close() return } default: @@ -1105,13 +1104,13 @@ func (t *http2Client) reader() { // Check the validity of server preface. frame, err := t.framer.fr.ReadFrame() if err != nil { - t.notifyError(err) + t.Close() return } atomic.CompareAndSwapUint32(&t.activity, 0, 1) sf, ok := frame.(*http2.SettingsFrame) if !ok { - t.notifyError(err) + t.Close() return } t.handleSettings(sf) @@ -1135,7 +1134,7 @@ func (t *http2Client) reader() { continue } else { // Transport error. - t.notifyError(err) + t.Close() return } } @@ -1192,11 +1191,6 @@ func (t *http2Client) applySettings(ss []http2.Setting) { // The transport layer needs to be refactored to take care of this. func (t *http2Client) itemHandler(i item) error { var err error - defer func() { - if err != nil { - t.notifyError(err) - } - }() switch i := i.(type) { case *dataFrame: err = t.framer.fr.WriteData(i.streamID, i.endStream, i.d) @@ -1287,7 +1281,7 @@ func (t *http2Client) keepalive() { case <-t.awakenKeepalive: // If the control gets here a ping has been sent // need to reset the timer with keepalive.Timeout. - case <-t.shutdownChan: + case <-t.ctx.Done(): return } } else { @@ -1306,13 +1300,13 @@ func (t *http2Client) keepalive() { } t.Close() return - case <-t.shutdownChan: + case <-t.ctx.Done(): if !timer.Stop() { <-timer.C } return } - case <-t.shutdownChan: + case <-t.ctx.Done(): if !timer.Stop() { <-timer.C } @@ -1322,25 +1316,9 @@ func (t *http2Client) keepalive() { } func (t *http2Client) Error() <-chan struct{} { - return t.errorChan + return t.ctx.Done() } func (t *http2Client) GoAway() <-chan struct{} { return t.goAway } - -func (t *http2Client) notifyError(err error) { - t.mu.Lock() - // make sure t.errorChan is closed only once. - if t.state == draining { - t.mu.Unlock() - t.Close() - return - } - if t.state == reachable { - t.state = unreachable - close(t.errorChan) - infof("transport: http2Client.notifyError got notified that the client transport was broken %v.", err) - } - t.mu.Unlock() -} diff --git a/transport/http2_server.go b/transport/http2_server.go index 8763b4c5..307968e1 100644 --- a/transport/http2_server.go +++ b/transport/http2_server.go @@ -52,19 +52,16 @@ var ErrIllegalHeaderWrite = errors.New("transport: the stream is done or WriteHe // http2Server implements the ServerTransport interface with HTTP2. type http2Server struct { ctx context.Context + cancel context.CancelFunc conn net.Conn remoteAddr net.Addr localAddr net.Addr maxStreamID uint32 // max stream ID ever seen authInfo credentials.AuthInfo // auth info about the connection inTapHandle tap.ServerInHandle - // shutdownChan is closed when Close is called. - // Blocking operations should select on shutdownChan to avoid - // blocking forever after Close. - shutdownChan chan struct{} - framer *framer - hBuf *bytes.Buffer // the buffer for HPACK encoding - hEnc *hpack.Encoder // HPACK encoder + framer *framer + hBuf *bytes.Buffer // the buffer for HPACK encoding + hEnc *hpack.Encoder // HPACK encoder // The max number of concurrent streams. maxStreams uint32 // controlBuf delivers all the control related tasks (e.g., window @@ -186,8 +183,10 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err kep.MinTime = defaultKeepalivePolicyMinTime } var buf bytes.Buffer + ctx, cancel := context.WithCancel(context.Background()) t := &http2Server{ - ctx: context.Background(), + ctx: ctx, + cancel: cancel, conn: conn, remoteAddr: conn.RemoteAddr(), localAddr: conn.LocalAddr(), @@ -201,7 +200,6 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err fc: &inFlow{limit: uint32(icwz)}, sendQuotaPool: newQuotaPool(defaultWindowSize), state: reachable, - shutdownChan: make(chan struct{}), activeStreams: make(map[uint32]*Stream), streamSendQuota: defaultWindowSize, stats: config.StatsHandler, @@ -225,7 +223,10 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err t.stats.HandleConn(t.ctx, connBegin) } t.framer.writer.Flush() - go loopyWriter(t.controlBuf, t.shutdownChan, t.itemHandler) + go func() { + loopyWriter(t.ctx, t.controlBuf, t.itemHandler) + t.Close() + }() go t.keepalive() return t, nil } @@ -687,7 +688,7 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error { select { case <-s.ctx.Done(): return ContextErr(s.ctx.Err()) - case <-t.shutdownChan: + case <-t.ctx.Done(): return ErrConnClosing default: } @@ -744,7 +745,7 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error { // OK is adopted. func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error { select { - case <-t.shutdownChan: + case <-t.ctx.Done(): return ErrConnClosing default: } @@ -816,7 +817,7 @@ func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) ( select { case <-s.ctx.Done(): return ContextErr(s.ctx.Err()) - case <-t.shutdownChan: + case <-t.ctx.Done(): return ErrConnClosing default: } @@ -846,12 +847,12 @@ func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) ( size := http2MaxFrameLen // Wait until the stream has some quota to send the data. quotaChan, quotaVer := s.sendQuotaPool.acquireWithVersion() - sq, err := wait(s.ctx, nil, nil, t.shutdownChan, quotaChan) + sq, err := wait(s.ctx, t.ctx, nil, nil, quotaChan) if err != nil { return err } // Wait until the transport has some quota to send the data. - tq, err := wait(s.ctx, nil, nil, t.shutdownChan, t.sendQuotaPool.acquire()) + tq, err := wait(s.ctx, t.ctx, nil, nil, t.sendQuotaPool.acquire()) if err != nil { return err } @@ -871,7 +872,7 @@ func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) ( t.sendQuotaPool.add(tq - ps) } // Acquire local send quota to be able to write to the controlBuf. - ltq, err := wait(s.ctx, nil, nil, t.shutdownChan, s.localSendQuota.acquire()) + ltq, err := wait(s.ctx, t.ctx, nil, nil, s.localSendQuota.acquire()) if err != nil { if _, ok := err.(ConnectionError); !ok { t.sendQuotaPool.add(ps) @@ -960,7 +961,7 @@ func (t *http2Server) keepalive() { t.Close() // Reseting the timer so that the clean-up doesn't deadlock. maxAge.Reset(infinity) - case <-t.shutdownChan: + case <-t.ctx.Done(): } return case <-keepalive.C: @@ -978,7 +979,7 @@ func (t *http2Server) keepalive() { pingSent = true t.controlBuf.put(p) keepalive.Reset(t.kp.Timeout) - case <-t.shutdownChan: + case <-t.ctx.Done(): return } } @@ -990,19 +991,13 @@ var goAwayPing = &ping{data: [8]byte{1, 6, 1, 8, 0, 3, 3, 9}} // is duplicated between the client and the server. // The transport layer needs to be refactored to take care of this. func (t *http2Server) itemHandler(i item) error { - var err error - defer func() { - if err != nil { - t.Close() - errorf("transport: Error while writing: %v", err) - } - }() switch i := i.(type) { case *dataFrame: - err = t.framer.fr.WriteData(i.streamID, i.endStream, i.d) - if err == nil { - i.f() + if err := t.framer.fr.WriteData(i.streamID, i.endStream, i.d); err != nil { + return err } + i.f() + return nil case *headerFrame: t.hBuf.Reset() for _, f := range i.hf { @@ -1017,6 +1012,7 @@ func (t *http2Server) itemHandler(i item) error { } else { endHeaders = true } + var err error if first { first = false err = t.framer.fr.WriteHeaders(http2.HeadersFrameParam{ @@ -1037,17 +1033,17 @@ func (t *http2Server) itemHandler(i item) error { } } atomic.StoreUint32(&t.resetPingStrikes, 1) + return nil case *windowUpdate: - err = t.framer.fr.WriteWindowUpdate(i.streamID, i.increment) + return t.framer.fr.WriteWindowUpdate(i.streamID, i.increment) case *settings: if i.ack { t.applySettings(i.ss) - err = t.framer.fr.WriteSettingsAck() - } else { - err = t.framer.fr.WriteSettings(i.ss...) + return t.framer.fr.WriteSettingsAck() } + return t.framer.fr.WriteSettings(i.ss...) case *resetStream: - err = t.framer.fr.WriteRSTStream(i.streamID, i.code) + return t.framer.fr.WriteRSTStream(i.streamID, i.code) case *goAway: t.mu.Lock() if t.state == closing { @@ -1060,15 +1056,13 @@ func (t *http2Server) itemHandler(i item) error { // Stop accepting more streams now. t.state = draining t.mu.Unlock() - err = t.framer.fr.WriteGoAway(sid, i.code, i.debugData) - if err != nil { + if err := t.framer.fr.WriteGoAway(sid, i.code, i.debugData); err != nil { return err } if i.closeConn { - // Abruptly close the connection following the GoAway. - // But flush out what's inside the buffer first. + // Abruptly close the connection following the GoAway (via + // loopywriter). But flush out what's inside the buffer first. t.framer.writer.Flush() - t.Close() return fmt.Errorf("transport: Connection closing") } return nil @@ -1080,36 +1074,42 @@ func (t *http2Server) itemHandler(i item) error { // originated before the GoAway reaches the client. // After getting the ack or timer expiration send out another GoAway this // time with an ID of the max stream server intends to process. - err = t.framer.fr.WriteGoAway(math.MaxUint32, http2.ErrCodeNo, []byte{}) - err = t.framer.fr.WritePing(false, goAwayPing.data) + if err := t.framer.fr.WriteGoAway(math.MaxUint32, http2.ErrCodeNo, []byte{}); err != nil { + return err + } + if err := t.framer.fr.WritePing(false, goAwayPing.data); err != nil { + return err + } go func() { timer := time.NewTimer(time.Minute) defer timer.Stop() select { case <-t.drainChan: case <-timer.C: - case <-t.shutdownChan: + case <-t.ctx.Done(): return } t.controlBuf.put(&goAway{code: i.code, debugData: i.debugData}) }() + return nil case *flushIO: - err = t.framer.writer.Flush() + return t.framer.writer.Flush() case *ping: if !i.ack { t.bdpEst.timesnap(i.data) } - err = t.framer.fr.WritePing(i.ack, i.data) + return t.framer.fr.WritePing(i.ack, i.data) default: - errorf("transport: http2Server.controller got unexpected item type %v\n", i) + err := status.Errorf(codes.Internal, "transport: http2Server.controller got unexpected item type %t\n", i) + errorf("%v", err) + return err } - return err } // Close starts shutting down the http2Server transport. // TODO(zhaoq): Now the destruction is not blocked on any pending streams. This // could cause some resource issue. Revisit this later. -func (t *http2Server) Close() (err error) { +func (t *http2Server) Close() error { t.mu.Lock() if t.state == closing { t.mu.Unlock() @@ -1119,8 +1119,8 @@ func (t *http2Server) Close() (err error) { streams := t.activeStreams t.activeStreams = nil t.mu.Unlock() - close(t.shutdownChan) - err = t.conn.Close() + t.cancel() + err := t.conn.Close() // Cancel all active streams. for _, s := range streams { s.cancel() @@ -1129,7 +1129,7 @@ func (t *http2Server) Close() (err error) { connEnd := &stats.ConnEnd{} t.stats.HandleConn(t.ctx, connEnd) } - return + return err } // closeStream clears the footprint of a stream when the stream is not needed diff --git a/transport/transport.go b/transport/transport.go index f8d8faed..e4a35403 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -25,6 +25,7 @@ import ( "io" "net" "sync" + "time" "golang.org/x/net/context" "golang.org/x/net/http2" @@ -457,7 +458,6 @@ type transportState int const ( reachable transportState = iota - unreachable closing draining ) @@ -519,8 +519,8 @@ type TargetInfo struct { // NewClientTransport establishes the transport with the required ConnectOptions // and returns it to the caller. -func NewClientTransport(ctx context.Context, target TargetInfo, opts ConnectOptions) (ClientTransport, error) { - return newHTTP2Client(ctx, target, opts) +func NewClientTransport(ctx context.Context, target TargetInfo, opts ConnectOptions, timeout time.Duration) (ClientTransport, error) { + return newHTTP2Client(ctx, target, opts, timeout) } // Options provides additional hints and information for message @@ -702,14 +702,8 @@ func (e StreamError) Error() string { return fmt.Sprintf("stream error: code = %s desc = %q", e.Code, e.Desc) } -// wait blocks until it can receive from ctx.Done, closing, or proceed. -// If it receives from ctx.Done, it returns 0, the StreamError for ctx.Err. -// If it receives from done, it returns 0, io.EOF if ctx is not done; otherwise -// it return the StreamError for ctx.Err. -// If it receives from goAway, it returns 0, ErrStreamDrain. -// If it receives from closing, it returns 0, ErrConnClosing. -// If it receives from proceed, it returns the received integer, nil. -func wait(ctx context.Context, done, goAway, closing <-chan struct{}, proceed <-chan int) (int, error) { +// wait blocks until it can receive from one of the provided contexts or channels +func wait(ctx, tctx context.Context, done, goAway <-chan struct{}, proceed <-chan int) (int, error) { select { case <-ctx.Done(): return 0, ContextErr(ctx.Err()) @@ -717,7 +711,7 @@ func wait(ctx context.Context, done, goAway, closing <-chan struct{}, proceed <- return 0, io.EOF case <-goAway: return 0, ErrStreamDrain - case <-closing: + case <-tctx.Done(): return 0, ErrConnClosing case i := <-proceed: return i, nil @@ -739,7 +733,7 @@ const ( // loopyWriter is run in a separate go routine. It is the single code path that will // write data on wire. -func loopyWriter(cbuf *controlBuffer, done chan struct{}, handler func(item) error) { +func loopyWriter(ctx context.Context, cbuf *controlBuffer, handler func(item) error) { for { select { case i := <-cbuf.get(): @@ -747,7 +741,7 @@ func loopyWriter(cbuf *controlBuffer, done chan struct{}, handler func(item) err if err := handler(i); err != nil { return } - case <-done: + case <-ctx.Done(): return } hasData: @@ -758,7 +752,7 @@ func loopyWriter(cbuf *controlBuffer, done chan struct{}, handler func(item) err if err := handler(i); err != nil { return } - case <-done: + case <-ctx.Done(): return default: if err := handler(&flushIO{}); err != nil { diff --git a/transport/transport_test.go b/transport/transport_test.go index 2bb2b711..e1dd080a 100644 --- a/transport/transport_test.go +++ b/transport/transport_test.go @@ -357,7 +357,7 @@ func setUpWithOptions(t *testing.T, port int, serverConfig *ServerConfig, ht hTy target := TargetInfo{ Addr: addr, } - ct, connErr = NewClientTransport(context.Background(), target, copts) + ct, connErr = NewClientTransport(context.Background(), target, copts, 2*time.Second) if connErr != nil { t.Fatalf("failed to create transport: %v", connErr) } @@ -380,7 +380,7 @@ func setUpWithNoPingServer(t *testing.T, copts ConnectOptions, done chan net.Con } done <- conn }() - tr, err := NewClientTransport(context.Background(), TargetInfo{Addr: lis.Addr().String()}, copts) + tr, err := NewClientTransport(context.Background(), TargetInfo{Addr: lis.Addr().String()}, copts, 2*time.Second) if err != nil { // Server clean-up. lis.Close() @@ -1680,7 +1680,7 @@ func testAccountCheckWindowSize(t *testing.T, wc windowSizeConfig) { }) ctx, cancel := context.WithTimeout(context.Background(), time.Second) - serverSendQuota, err := wait(ctx, nil, nil, nil, st.sendQuotaPool.acquire()) + serverSendQuota, err := wait(ctx, context.Background(), nil, nil, st.sendQuotaPool.acquire()) if err != nil { t.Fatalf("Error while acquiring sendQuota on server. Err: %v", err) } @@ -1702,7 +1702,7 @@ func testAccountCheckWindowSize(t *testing.T, wc windowSizeConfig) { t.Fatalf("Client transport flow control window size is %v, want %v", limit, connectOptions.InitialConnWindowSize) } ctx, cancel = context.WithTimeout(context.Background(), time.Second) - clientSendQuota, err := wait(ctx, nil, nil, nil, ct.sendQuotaPool.acquire()) + clientSendQuota, err := wait(ctx, context.Background(), nil, nil, ct.sendQuotaPool.acquire()) if err != nil { t.Fatalf("Error while acquiring sendQuota on client. Err: %v", err) } @@ -1838,7 +1838,7 @@ func TestAccountCheckExpandingWindow(t *testing.T) { // Check flow conrtrol window on client stream is equal to out flow on server stream. ctx, cancel := context.WithTimeout(context.Background(), time.Second) - serverStreamSendQuota, err := wait(ctx, nil, nil, nil, sstream.sendQuotaPool.acquire()) + serverStreamSendQuota, err := wait(ctx, context.Background(), nil, nil, sstream.sendQuotaPool.acquire()) cancel() if err != nil { return true, fmt.Errorf("error while acquiring server stream send quota. Err: %v", err) @@ -1853,7 +1853,7 @@ func TestAccountCheckExpandingWindow(t *testing.T) { // Check flow control window on server stream is equal to out flow on client stream. ctx, cancel = context.WithTimeout(context.Background(), time.Second) - clientStreamSendQuota, err := wait(ctx, nil, nil, nil, cstream.sendQuotaPool.acquire()) + clientStreamSendQuota, err := wait(ctx, context.Background(), nil, nil, cstream.sendQuotaPool.acquire()) cancel() if err != nil { return true, fmt.Errorf("error while acquiring client stream send quota. Err: %v", err) @@ -1868,7 +1868,7 @@ func TestAccountCheckExpandingWindow(t *testing.T) { // Check flow control window on client transport is equal to out flow of server transport. ctx, cancel = context.WithTimeout(context.Background(), time.Second) - serverTrSendQuota, err := wait(ctx, nil, nil, nil, st.sendQuotaPool.acquire()) + serverTrSendQuota, err := wait(ctx, context.Background(), nil, nil, st.sendQuotaPool.acquire()) cancel() if err != nil { return true, fmt.Errorf("error while acquring server transport send quota. Err: %v", err) @@ -1883,7 +1883,7 @@ func TestAccountCheckExpandingWindow(t *testing.T) { // Check flow control window on server transport is equal to out flow of client transport. ctx, cancel = context.WithTimeout(context.Background(), time.Second) - clientTrSendQuota, err := wait(ctx, nil, nil, nil, ct.sendQuotaPool.acquire()) + clientTrSendQuota, err := wait(ctx, context.Background(), nil, nil, ct.sendQuotaPool.acquire()) cancel() if err != nil { return true, fmt.Errorf("error while acquiring client transport send quota. Err: %v", err) @@ -2055,7 +2055,7 @@ func setUpHTTPStatusTest(t *testing.T, httpStatus int, wh writeHeaders) (stream wh: wh, } server.start(t, lis) - client, err = newHTTP2Client(context.Background(), TargetInfo{Addr: lis.Addr().String()}, ConnectOptions{}) + client, err = newHTTP2Client(context.Background(), TargetInfo{Addr: lis.Addr().String()}, ConnectOptions{}, 2*time.Second) if err != nil { t.Fatalf("Error creating client. Err: %v", err) } diff --git a/vet.sh b/vet.sh index e478188f..4905ed3d 100755 --- a/vet.sh +++ b/vet.sh @@ -65,7 +65,7 @@ git ls-files "*.go" | xargs sed -i 's:"golang.org/x/net/context":"context":' set +o pipefail # TODO: Stop filtering pb.go files once golang/protobuf#214 is fixed. # TODO: Remove clientconn exception once go1.6 support is removed. -go tool vet -all . 2>&1 | grep -vE 'clientconn.go:.*cancel' | grep -vF '.pb.go:' | tee /dev/stderr | (! read) +go tool vet -all . 2>&1 | grep -vF '.pb.go:' | tee /dev/stderr | (! read) set -o pipefail git reset --hard HEAD