From d720ab346fd09f63ecd5b34fcf3696d3d345f938 Mon Sep 17 00:00:00 2001
From: Easwar Swaminathan <easwars@google.com>
Date: Tue, 19 Nov 2019 14:43:22 -0800
Subject: [PATCH] server: Keepalive pings should be sent every [Time] period
 (#3172)

This PR contains the server side changes corresponding to the client
side changes made in https://github.com/grpc/grpc-go/pull/3102.

Apart from the fix for the issue mentioned in
https://github.com/grpc/grpc-go/issues/2638, this PR also makes some
minor code cleanup and fixes the channelz test for keepalives count.
---
 internal/transport/http2_client.go |  3 +-
 internal/transport/http2_server.go | 96 ++++++++++++++++--------------
 test/channelz_test.go              | 19 +++++-
 3 files changed, 69 insertions(+), 49 deletions(-)

diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go
index faba7bde..c26e71cc 100644
--- a/internal/transport/http2_client.go
+++ b/internal/transport/http2_client.go
@@ -47,7 +47,7 @@ import (
 
 // http2Client implements the ClientTransport interface with HTTP2.
 type http2Client struct {
-	lastRead   int64 // keep this field 64-bit aligned
+	lastRead   int64 // Keep this field 64-bit aligned. Accessed atomically.
 	ctx        context.Context
 	cancel     context.CancelFunc
 	ctxDone    <-chan struct{} // Cache the ctx.Done() chan.
@@ -1374,7 +1374,6 @@ func (t *http2Client) keepalive() {
 			// acked).
 			sleepDuration := minTime(t.kp.Time, timeoutLeft)
 			timeoutLeft -= sleepDuration
-			prevNano = lastRead
 			timer.Reset(sleepDuration)
 		case <-t.ctx.Done():
 			if !timer.Stop() {
diff --git a/internal/transport/http2_server.go b/internal/transport/http2_server.go
index 3368e6aa..1c5ce276 100644
--- a/internal/transport/http2_server.go
+++ b/internal/transport/http2_server.go
@@ -64,6 +64,7 @@ var (
 
 // http2Server implements the ServerTransport interface with HTTP2.
 type http2Server struct {
+	lastRead    int64 // Keep this field 64-bit aligned. Accessed atomically.
 	ctx         context.Context
 	done        chan struct{}
 	conn        net.Conn
@@ -83,12 +84,8 @@ type http2Server struct {
 	controlBuf *controlBuffer
 	fc         *trInFlow
 	stats      stats.Handler
-	// Flag to keep track of reading activity on transport.
-	// 1 is true and 0 is false.
-	activity uint32 // Accessed atomically.
 	// Keepalive and max-age parameters for the server.
 	kp keepalive.ServerParameters
-
 	// Keepalive enforcement policy.
 	kep keepalive.EnforcementPolicy
 	// The time instance last ping was received.
@@ -277,7 +274,7 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err
 	if err != nil {
 		return nil, connectionErrorf(false, err, "transport: http2Server.HandleStreams failed to read initial settings frame: %v", err)
 	}
-	atomic.StoreUint32(&t.activity, 1)
+	atomic.StoreInt64(&t.lastRead, time.Now().UnixNano())
 	sf, ok := frame.(*http2.SettingsFrame)
 	if !ok {
 		return nil, connectionErrorf(false, nil, "transport: http2Server.HandleStreams saw invalid preface type %T from client", frame)
@@ -450,7 +447,7 @@ func (t *http2Server) HandleStreams(handle func(*Stream), traceCtx func(context.
 	for {
 		t.controlBuf.throttle()
 		frame, err := t.framer.fr.ReadFrame()
-		atomic.StoreUint32(&t.activity, 1)
+		atomic.StoreInt64(&t.lastRead, time.Now().UnixNano())
 		if err != nil {
 			if se, ok := err.(http2.StreamError); ok {
 				warningf("transport: http2Server.HandleStreams encountered http2.StreamError: %v", se)
@@ -937,32 +934,35 @@ func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) e
 // after an additional duration of keepalive.Timeout.
 func (t *http2Server) keepalive() {
 	p := &ping{}
-	var pingSent bool
-	maxIdle := time.NewTimer(t.kp.MaxConnectionIdle)
-	maxAge := time.NewTimer(t.kp.MaxConnectionAge)
-	keepalive := time.NewTimer(t.kp.Time)
-	// NOTE: All exit paths of this function should reset their
-	// respective timers. A failure to do so will cause the
-	// following clean-up to deadlock and eventually leak.
+	// True iff a ping has been sent, and no data has been received since then.
+	outstandingPing := false
+	// Amount of time remaining before which we should receive an ACK for the
+	// last sent ping.
+	kpTimeoutLeft := time.Duration(0)
+	// Records the last value of t.lastRead before we go block on the timer.
+	// This is required to check for read activity since then.
+	prevNano := time.Now().UnixNano()
+	// Initialize the different timers to their default values.
+	idleTimer := time.NewTimer(t.kp.MaxConnectionIdle)
+	ageTimer := time.NewTimer(t.kp.MaxConnectionAge)
+	kpTimer := time.NewTimer(t.kp.Time)
 	defer func() {
-		if !maxIdle.Stop() {
-			<-maxIdle.C
-		}
-		if !maxAge.Stop() {
-			<-maxAge.C
-		}
-		if !keepalive.Stop() {
-			<-keepalive.C
-		}
+		// We need to drain the underlying channel in these timers after a call
+		// to Stop(), only if we are interested in resetting them. Clearly we
+		// are not interested in resetting them here.
+		idleTimer.Stop()
+		ageTimer.Stop()
+		kpTimer.Stop()
 	}()
+
 	for {
 		select {
-		case <-maxIdle.C:
+		case <-idleTimer.C:
 			t.mu.Lock()
 			idle := t.idle
 			if idle.IsZero() { // The connection is non-idle.
 				t.mu.Unlock()
-				maxIdle.Reset(t.kp.MaxConnectionIdle)
+				idleTimer.Reset(t.kp.MaxConnectionIdle)
 				continue
 			}
 			val := t.kp.MaxConnectionIdle - time.Since(idle)
@@ -971,43 +971,51 @@ func (t *http2Server) keepalive() {
 				// The connection has been idle for a duration of keepalive.MaxConnectionIdle or more.
 				// Gracefully close the connection.
 				t.drain(http2.ErrCodeNo, []byte{})
-				// Resetting the timer so that the clean-up doesn't deadlock.
-				maxIdle.Reset(infinity)
 				return
 			}
-			maxIdle.Reset(val)
-		case <-maxAge.C:
+			idleTimer.Reset(val)
+		case <-ageTimer.C:
 			t.drain(http2.ErrCodeNo, []byte{})
-			maxAge.Reset(t.kp.MaxConnectionAgeGrace)
+			ageTimer.Reset(t.kp.MaxConnectionAgeGrace)
 			select {
-			case <-maxAge.C:
+			case <-ageTimer.C:
 				// Close the connection after grace period.
 				infof("transport: closing server transport due to maximum connection age.")
 				t.Close()
-				// Resetting the timer so that the clean-up doesn't deadlock.
-				maxAge.Reset(infinity)
 			case <-t.done:
 			}
 			return
-		case <-keepalive.C:
-			if atomic.CompareAndSwapUint32(&t.activity, 1, 0) {
-				pingSent = false
-				keepalive.Reset(t.kp.Time)
+		case <-kpTimer.C:
+			lastRead := atomic.LoadInt64(&t.lastRead)
+			if lastRead > prevNano {
+				// There has been read activity since the last time we were
+				// here. Setup the timer to fire at kp.Time seconds from
+				// lastRead time and continue.
+				outstandingPing = false
+				kpTimer.Reset(time.Duration(lastRead) + t.kp.Time - time.Duration(time.Now().UnixNano()))
+				prevNano = lastRead
 				continue
 			}
-			if pingSent {
+			if outstandingPing && kpTimeoutLeft <= 0 {
 				infof("transport: closing server transport due to idleness.")
 				t.Close()
-				// Resetting the timer so that the clean-up doesn't deadlock.
-				keepalive.Reset(infinity)
 				return
 			}
-			pingSent = true
-			if channelz.IsOn() {
-				atomic.AddInt64(&t.czData.kpCount, 1)
+			if !outstandingPing {
+				if channelz.IsOn() {
+					atomic.AddInt64(&t.czData.kpCount, 1)
+				}
+				t.controlBuf.put(p)
+				kpTimeoutLeft = t.kp.Timeout
+				outstandingPing = true
 			}
-			t.controlBuf.put(p)
-			keepalive.Reset(t.kp.Timeout)
+			// The amount of time to sleep here is the minimum of kp.Time and
+			// timeoutLeft. This will ensure that we wait only for kp.Time
+			// before sending out the next ping (for cases where the ping is
+			// acked).
+			sleepDuration := minTime(t.kp.Time, kpTimeoutLeft)
+			kpTimeoutLeft -= sleepDuration
+			kpTimer.Reset(sleepDuration)
 		case <-t.done:
 			return
 		}
diff --git a/test/channelz_test.go b/test/channelz_test.go
index 8ce0a677..45021022 100644
--- a/test/channelz_test.go
+++ b/test/channelz_test.go
@@ -1273,11 +1273,23 @@ func (s) TestCZServerSocketMetricsKeepAlive(t *testing.T) {
 	defer czCleanupWrapper(czCleanup, t)
 	e := tcpClearRREnv
 	te := newTest(t, e)
-	te.customServerOptions = append(te.customServerOptions, grpc.KeepaliveParams(keepalive.ServerParameters{Time: time.Second, Timeout: 500 * time.Millisecond}))
+	// We setup the server keepalive parameters to send one keepalive every
+	// second, and verify that the actual number of keepalives is very close to
+	// the number of seconds elapsed in the test.  We had a bug wherein the
+	// server was sending one keepalive every [Time+Timeout] instead of every
+	// [Time] period, and since Timeout is configured to a low value here, we
+	// should be able to verify that the fix works with the above mentioned
+	// logic.
+	kpOption := grpc.KeepaliveParams(keepalive.ServerParameters{
+		Time:    time.Second,
+		Timeout: 100 * time.Millisecond,
+	})
+	te.customServerOptions = append(te.customServerOptions, kpOption)
 	te.startServer(&testServer{security: e.security})
 	defer te.tearDown()
 	cc := te.clientConn()
 	tc := testpb.NewTestServiceClient(cc)
+	start := time.Now()
 	doIdleCallToInvokeKeepAlive(tc, t)
 
 	if err := verifyResultWithDelay(func() (bool, error) {
@@ -1289,8 +1301,9 @@ func (s) TestCZServerSocketMetricsKeepAlive(t *testing.T) {
 		if len(ns) != 1 {
 			return false, fmt.Errorf("there should be one server normal socket, not %d", len(ns))
 		}
-		if ns[0].SocketData.KeepAlivesSent != 2 { // doIdleCallToInvokeKeepAlive func is set up to send 2 KeepAlives.
-			return false, fmt.Errorf("there should be 2 KeepAlives sent, not %d", ns[0].SocketData.KeepAlivesSent)
+		wantKeepalivesCount := int64(time.Since(start).Seconds()) - 1
+		if gotKeepalivesCount := ns[0].SocketData.KeepAlivesSent; gotKeepalivesCount != wantKeepalivesCount {
+			return false, fmt.Errorf("got keepalivesCount: %v, want keepalivesCount: %v", gotKeepalivesCount, wantKeepalivesCount)
 		}
 		return true, nil
 	}); err != nil {