diff --git a/clientconn.go b/clientconn.go index 99f2c37e..b45b699f 100644 --- a/clientconn.go +++ b/clientconn.go @@ -232,7 +232,7 @@ func WithUserAgent(s string) DialOption { } // WithKeepaliveParams returns a DialOption that specifies a user agent string for all the RPCs. -func WithKeepaliveParams(k keepalive.KeepaliveParams) DialOption { +func WithKeepaliveParams(k keepalive.Params) DialOption { return func(o *dialOptions) { o.copts.KParams = k } @@ -285,9 +285,6 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * } // Set defaults. - if cc.dopts.copts.KParams == (keepalive.KeepaliveParams{}) { - cc.dopts.copts.KParams = keepalive.DefaultKParams - } if cc.dopts.codec == nil { cc.dopts.codec = protoCodec{} } diff --git a/keepalive/keepalive.go b/keepalive/keepalive.go index 284b6eac..686315ee 100644 --- a/keepalive/keepalive.go +++ b/keepalive/keepalive.go @@ -1,10 +1,12 @@ package keepalive import ( + "math" + "sync" "time" ) -type KeepaliveParams struct { +type Params struct { // After a duration of this time the client pings the server to see if the transport is still alive. Ktime time.Duration // After having pinged fot keepalive check, the client waits for a duration of keepalive_timeout before closing the transport. @@ -13,10 +15,13 @@ type KeepaliveParams struct { KNoStream bool } -var DefaultKParams KeepaliveParams = KeepaliveParams{ - Ktime: time.Duration(290 * 365 * 24 * 60 * 60 * 1000 * 1000 * 1000), // default to infinite - Ktimeout: time.Duration(20 * 1000 * 1000 * 1000), // default to 20 seconds +var DefaultKParams Params = Params{ + Ktime: time.Duration(math.MaxInt64), // default to infinite + Ktimeout: time.Duration(20 * 1000 * 1000 * 1000), // default to 20 seconds KNoStream: false, } +// Mutex to protect Enabled variable +var Mu sync.Mutex = sync.Mutex{} + var Enabled = false diff --git a/transport/http2_client.go b/transport/http2_client.go index 23bb91bd..e691e4be 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -103,7 +103,7 @@ type http2Client struct { // activity counter activity *uint64 // keepalive parameters - kParams keepalive.KeepaliveParams + keepaliveParams keepalive.Params mu sync.Mutex // guard the following variables state transportState // the state of underlying connection @@ -186,6 +186,10 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) ( if opts.UserAgent != "" { ua = opts.UserAgent + " " + ua } + kp := keepalive.DefaultKParams + if opts.KParams != (keepalive.Params{}) { + kp = opts.KParams + } var buf bytes.Buffer t := &http2Client{ target: addr.Addr, @@ -213,7 +217,7 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) ( creds: opts.PerRPCCredentials, maxStreams: math.MaxInt32, streamSendQuota: defaultWindowSize, - kParams: opts.KParams, + keepaliveParams: kp, activity: new(uint64), } // Start the reader goroutine for incoming message. Each transport has @@ -1069,15 +1073,17 @@ func (t *http2Client) applySettings(ss []http2.Setting) { func (t *http2Client) controller() { // Activity value seen by timer ta := atomic.LoadUint64(t.activity) - timer := time.NewTimer(t.kParams.Ktime) + timer := time.NewTimer(t.keepaliveParams.Ktime) + keepalive.Mu.Lock() if !keepalive.Enabled { // Prevent the timer from firing, ever. if !timer.Stop() { <-timer.C } } + keepalive.Mu.Unlock() isPingSent := false - kPing := &ping{data: [8]byte{}} + keepalivePing := &ping{data: [8]byte{}} for { select { case i := <-t.controlBuf.get(): @@ -1114,15 +1120,15 @@ func (t *http2Client) controller() { t.mu.Unlock() // Global activity value. ga := atomic.LoadUint64(t.activity) - if ga > ta || (!t.kParams.KNoStream && ns < 1) { - timer.Reset(t.kParams.Ktime) + if ga > ta || (!t.keepaliveParams.KNoStream && ns < 1) { + timer.Reset(t.keepaliveParams.Ktime) isPingSent = false } else { if !isPingSent { // send ping - t.controlBuf.put(kPing) + t.controlBuf.put(keepalivePing) isPingSent = true - timer.Reset(t.kParams.Ktimeout) + timer.Reset(t.keepaliveParams.Ktimeout) } else { t.Close() continue diff --git a/transport/transport.go b/transport/transport.go index 3648ca14..20dc914e 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -382,7 +382,7 @@ type ConnectOptions struct { // TransportCredentials stores the Authenticator required to setup a client connection. TransportCredentials credentials.TransportCredentials // Keepalive parameters - KParams keepalive.KeepaliveParams + KParams keepalive.Params } // TargetInfo contains the information of the target such as network address and metadata. diff --git a/transport/transport_test.go b/transport/transport_test.go index dd464581..6e8be33b 100644 --- a/transport/transport_test.go +++ b/transport/transport_test.go @@ -305,9 +305,11 @@ func setUpWithNoPingServer(t *testing.T, copts ConnectOptions, done chan net.Con } func TestKeepaliveClientClosesIdleTransport(t *testing.T) { + keepalive.Mu.Lock() keepalive.Enabled = true + keepalive.Mu.Unlock() done := make(chan net.Conn, 1) - cT := setUpWithNoPingServer(t, ConnectOptions{KParams: keepalive.KeepaliveParams{ + cT := setUpWithNoPingServer(t, ConnectOptions{KParams: keepalive.Params{ Ktime: 2 * 1000 * 1000 * 1000, // keepalive time = 2 sec Ktimeout: 1 * 1000 * 1000 * 1000, // keepalive timeout = 1 sec KNoStream: true, // run keepalive even with no RPCs @@ -329,9 +331,11 @@ func TestKeepaliveClientClosesIdleTransport(t *testing.T) { } func TestKeepaliveClientStaysHealthyOnIdleTransport(t *testing.T) { + keepalive.Mu.Lock() keepalive.Enabled = true + keepalive.Mu.Unlock() done := make(chan net.Conn, 1) - cT := setUpWithNoPingServer(t, ConnectOptions{KParams: keepalive.KeepaliveParams{ + cT := setUpWithNoPingServer(t, ConnectOptions{KParams: keepalive.Params{ Ktime: 2 * 1000 * 1000 * 1000, // keepalive time = 2 sec Ktimeout: 1 * 1000 * 1000 * 1000, // keepalive timeout = 1 sec KNoStream: false, // don't run keepalive even with no RPCs @@ -353,9 +357,11 @@ func TestKeepaliveClientStaysHealthyOnIdleTransport(t *testing.T) { } func TestKeepaliveClientClosesWithActiveStreams(t *testing.T) { + keepalive.Mu.Lock() keepalive.Enabled = true + keepalive.Mu.Unlock() done := make(chan net.Conn, 1) - cT := setUpWithNoPingServer(t, ConnectOptions{KParams: keepalive.KeepaliveParams{ + cT := setUpWithNoPingServer(t, ConnectOptions{KParams: keepalive.Params{ Ktime: 2 * 1000 * 1000 * 1000, // keepalive time = 2 sec Ktimeout: 1 * 1000 * 1000 * 1000, // keepalive timeout = 1 sec KNoStream: false, // don't run keepalive even with no RPCs @@ -382,8 +388,10 @@ func TestKeepaliveClientClosesWithActiveStreams(t *testing.T) { } func TestKeepaliveClientStaysHealthyWithResponsiveServer(t *testing.T) { + keepalive.Mu.Lock() keepalive.Enabled = true - s, tr := setUpWithOptions(t, 0, math.MaxUint32, normal, ConnectOptions{KParams: keepalive.KeepaliveParams{ + keepalive.Mu.Unlock() + s, tr := setUpWithOptions(t, 0, math.MaxUint32, normal, ConnectOptions{KParams: keepalive.Params{ Ktime: 2 * 1000 * 1000 * 1000, // keepalive time = 2 sec Ktimeout: 1 * 1000 * 1000 * 1000, // keepalive timeout = 1 sec KNoStream: true, // don't run keepalive even with no RPCs