diff --git a/call.go b/call.go index d6326ea0..27cf6411 100644 --- a/call.go +++ b/call.go @@ -179,7 +179,10 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli put() put = nil } - if _, ok := err.(transport.ConnectionError); ok { + // Retry a non-failfast RPC when + // i) there is a connection error; or + // ii) the server started to drain before this RPC was initiated. + if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain { if c.failFast { return toRPCErr(err) } diff --git a/clientconn.go b/clientconn.go index 33edfa31..3206d674 100644 --- a/clientconn.go +++ b/clientconn.go @@ -697,14 +697,21 @@ func (ac *addrConn) tearDown(err error) { } ac.cc.mu.Unlock() }() - if ac.state == Shutdown { - return - } - ac.state = Shutdown if ac.down != nil { ac.down(downErrorf(false, false, "%v", err)) ac.down = nil } + if err == errConnDrain && ac.transport != nil { + // GracefulClose(...) may be executed multiple times when + // i) receiving multiple GoAway frames from the server; or + // ii) there are concurrent name resolver/Balancer triggered + // address removal and GoAway. + ac.transport.GracefulClose() + } + if ac.state == Shutdown { + return + } + ac.state = Shutdown ac.stateCV.Broadcast() if ac.events != nil { ac.events.Finish() @@ -714,12 +721,8 @@ func (ac *addrConn) tearDown(err error) { close(ac.ready) ac.ready = nil } - if ac.transport != nil { - if err == errConnDrain { - ac.transport.GracefulClose() - } else { - ac.transport.Close() - } + if ac.transport != nil && err != errConnDrain { + ac.transport.Close() } if ac.shutdownChan != nil { close(ac.shutdownChan) diff --git a/server.go b/server.go index 2ad331e1..1a250c79 100644 --- a/server.go +++ b/server.go @@ -795,12 +795,16 @@ func (s *Server) Stop() { // connections and RPCs and blocks until all the pending RPCs are finished. func (s *Server) GracefulStop() { s.mu.Lock() + if s.drain == true || s.conns == nil { + s.mu.Lock() + return + } s.drain = true for lis := range s.lis { lis.Close() } for c := range s.conns { - c.(transport.ServerTransport).GoAway() + c.(transport.ServerTransport).Drain() } for len(s.conns) != 0 { s.cv.Wait() diff --git a/test/end2end_test.go b/test/end2end_test.go index a3f022a1..2cc4ccfc 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -2114,6 +2114,7 @@ func interestingGoroutines() (gs []string) { if stack == "" || strings.Contains(stack, "testing.Main(") || + strings.Contains(stack, "testing.tRunner(") || strings.Contains(stack, "runtime.goexit") || strings.Contains(stack, "created by runtime.gc") || strings.Contains(stack, "created by google3/base/go/log.init") || diff --git a/transport/handler_server.go b/transport/handler_server.go index 35ccf627..3d7b15e5 100644 --- a/transport/handler_server.go +++ b/transport/handler_server.go @@ -370,8 +370,8 @@ func (ht *serverHandlerTransport) runStream() { } } -func (ht *serverHandlerTransport) GoAway() { - panic("not implemented") +func (ht *serverHandlerTransport) Drain() { + panic("Drain() is not implemented") } // mapRecvMsgError returns the non-nil err into the appropriate diff --git a/transport/http2_client.go b/transport/http2_client.go index 20eb4bab..ffde50ae 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -102,6 +102,8 @@ type http2Client struct { streamSendQuota uint32 // goAwayID records the Last-Stream-ID in the GoAway frame from the server. goAwayID uint32 + // prevGoAway ID records the Last-Stream-ID in the previous GOAway frame. + prevGoAwayID uint32 } // newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2 @@ -483,21 +485,25 @@ func (t *http2Client) GracefulClose() error { t.mu.Unlock() return nil } - if t.state == draining { - t.mu.Unlock() - return nil - } - t.state = draining // Notify the streams which were initiated after the server sent GOAWAY. select { case <-t.goAway: - for i := t.goAwayID + 2; i < t.nextID; i += 2 { + n := t.prevGoAwayID + if n == 0 && t.nextID > 1 { + n = t.nextID - 2 + } + for i := t.goAwayID + 2; i <= n; i += 2 { if s, ok := t.activeStreams[i]; ok { close(s.goAway) } } default: } + if t.state == draining { + t.mu.Unlock() + return nil + } + t.state = draining active := len(t.activeStreams) t.mu.Unlock() if active == 0 { @@ -742,9 +748,21 @@ func (t *http2Client) handlePing(f *http2.PingFrame) { func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { t.mu.Lock() - if t.state == reachable { + if t.state == reachable || t.state == draining { + if t.goAwayID > 0 && t.goAwayID < f.LastStreamID { + id := t.goAwayID + t.mu.Unlock() + t.notifyError(ConnectionErrorf("received illegal http2 GOAWAY frame: previously recv GOAWAY frame with LastStramID %d, currently recv %d", id, f.LastStreamID)) + return + } + t.prevGoAwayID = t.goAwayID t.goAwayID = f.LastStreamID - close(t.goAway) + select { + case <-t.goAway: + // t.goAway has been closed (i.e.,multiple GoAways). + default: + close(t.goAway) + } } t.mu.Unlock() } diff --git a/transport/http2_server.go b/transport/http2_server.go index f0d1f7f6..42f5d9c4 100644 --- a/transport/http2_server.go +++ b/transport/http2_server.go @@ -755,6 +755,6 @@ func (t *http2Server) RemoteAddr() net.Addr { return t.conn.RemoteAddr() } -func (t *http2Server) GoAway() { +func (t *http2Server) Drain() { t.controlBuf.put(&goAway{}) } diff --git a/transport/transport.go b/transport/transport.go index f94731ed..9dade654 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -473,8 +473,8 @@ type ServerTransport interface { // RemoteAddr returns the remote network address. RemoteAddr() net.Addr - // GoAway notifies the client this ServerTransport stops accepting new RPCs. - GoAway() + // Drain notifies the client this ServerTransport stops accepting new RPCs. + Drain() } // StreamErrorf creates an StreamError with the specified error code and description.