Keepalive client-side implementation

This commit is contained in:
Mahak Mukhi
2016-11-17 17:50:52 -08:00
parent b13ef79499
commit e58450b5d3
5 changed files with 213 additions and 1 deletions

View File

@ -45,6 +45,7 @@ import (
"golang.org/x/net/trace"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/transport"
)
@ -230,6 +231,13 @@ func WithUserAgent(s string) DialOption {
}
}
// WithKeepaliveParams returns a DialOption that specifies a user agent string for all the RPCs.
func WithKeepaliveParams(k keepalive.KeepaliveParams) DialOption {
return func(o *dialOptions) {
o.copts.KParams = k
}
}
// WithUnaryInterceptor returns a DialOption that specifies the interceptor for unary RPCs.
func WithUnaryInterceptor(f UnaryClientInterceptor) DialOption {
return func(o *dialOptions) {
@ -277,6 +285,9 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
}
// Set defaults.
if cc.dopts.copts.KParams == (keepalive.KeepaliveParams{}) {
cc.dopts.copts.KParams = keepalive.DefaultKParams
}
if cc.dopts.codec == nil {
cc.dopts.codec = protoCodec{}
}

22
keepalive/keepalive.go Normal file
View File

@ -0,0 +1,22 @@
package keepalive
import (
"time"
)
type KeepaliveParams struct {
// After a duration of this time the client pings the server to see if the transport is still alive.
Ktime time.Duration
// After having pinged fot keepalive check, the client waits for a duration of keepalive_timeout before closing the transport.
Ktimeout time.Duration
//If true, client runs keepalive checks even with no active RPCs.
KNoStream bool
}
var DefaultKParams KeepaliveParams = KeepaliveParams{
Ktime: time.Duration(290 * 365 * 24 * 60 * 60 * 1000 * 1000 * 1000), // default to infinite
Ktimeout: time.Duration(20 * 1000 * 1000 * 1000), // default to 20 seconds
KNoStream: false,
}
var Enabled = false

View File

@ -49,6 +49,7 @@ import (
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/stats"
@ -109,6 +110,15 @@ type http2Client struct {
goAwayID uint32
// prevGoAway ID records the Last-Stream-ID in the previous GOAway frame.
prevGoAwayID uint32
// lastRecv counts whenever a frame is recieved
lastRecv int64
// lastSent counts whenever a frame is sent
lastSent int64
// keepalive parameters
kParams keepalive.KeepaliveParams
}
func dial(ctx context.Context, fn func(context.Context, string) (net.Conn, error), addr string) (net.Conn, error) {
@ -206,6 +216,7 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) (
creds: opts.PerRPCCredentials,
maxStreams: math.MaxInt32,
streamSendQuota: defaultWindowSize,
kParams: opts.KParams,
}
// Start the reader goroutine for incoming message. Each transport has
// a dedicated goroutine which reads HTTP2 frame from network. Then it
@ -690,6 +701,8 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error {
break
}
}
// update last send
t.lastSent++
if !opts.Last {
return nil
}
@ -830,6 +843,8 @@ func (t *http2Client) handlePing(f *http2.PingFrame) {
pingAck := &ping{ack: true}
copy(pingAck.data[:], f.Data[:])
t.controlBuf.put(pingAck)
// Update last sent
t.lastSent++
}
func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) {
@ -976,6 +991,8 @@ func (t *http2Client) reader() {
// loop to keep reading incoming messages on this transport.
for {
frame, err := t.framer.readFrame()
// update lastRecv counter
t.lastRecv++
if err != nil {
// Abort an active stream if the http2.Framer returns a
// http2.StreamError. This can happen only if the server's response
@ -1052,6 +1069,16 @@ func (t *http2Client) applySettings(ss []http2.Setting) {
// controller running in a separate goroutine takes charge of sending control
// frames (e.g., window update, reset stream, setting, etc.) to the server.
func (t *http2Client) controller() {
tRCounter := t.lastRecv
tSCounter := t.lastSent
timer := time.NewTimer(t.kParams.Ktime)
if !keepalive.Enabled {
// prevent the timer from firing, ever
if !timer.Stop() {
<-timer.C
}
}
isPingSent := false
for {
select {
case i := <-t.controlBuf.get():
@ -1082,6 +1109,23 @@ func (t *http2Client) controller() {
case <-t.shutdownChan:
return
}
case <-timer.C:
if t.lastRecv > tRCounter || t.lastSent > tSCounter || (!t.kParams.KNoStream && len(t.activeStreams) < 1) {
timer.Reset(t.kParams.Ktime)
isPingSent = false
} else {
if !isPingSent {
// send ping
t.framer.writePing(true, false, [8]byte{})
isPingSent = true
timer.Reset(t.kParams.Ktimeout)
} else {
t.Close()
continue
}
}
tRCounter = t.lastRecv
tSCounter = t.lastSent
case <-t.shutdownChan:
return
}

View File

@ -47,6 +47,7 @@ import (
"golang.org/x/net/context"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/tap"
)
@ -380,6 +381,8 @@ type ConnectOptions struct {
PerRPCCredentials []credentials.PerRPCCredentials
// TransportCredentials stores the Authenticator required to setup a client connection.
TransportCredentials credentials.TransportCredentials
// Keepalive parameters
KParams keepalive.KeepaliveParams
}
// TargetInfo contains the information of the target such as network address and metadata.

View File

@ -49,6 +49,7 @@ import (
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/keepalive"
)
type server struct {
@ -251,6 +252,10 @@ func (s *server) stop() {
}
func setUp(t *testing.T, port int, maxStreams uint32, ht hType) (*server, ClientTransport) {
return setUpWithOptions(t, port, maxStreams, ht, ConnectOptions{})
}
func setUpWithOptions(t *testing.T, port int, maxStreams uint32, ht hType, copts ConnectOptions) (*server, ClientTransport) {
server := &server{startedErr: make(chan error, 1)}
go server.start(t, port, maxStreams, ht)
server.wait(t, 2*time.Second)
@ -262,13 +267,140 @@ func setUp(t *testing.T, port int, maxStreams uint32, ht hType) (*server, Client
target := TargetInfo{
Addr: addr,
}
ct, connErr = NewClientTransport(context.Background(), target, ConnectOptions{})
ct, connErr = NewClientTransport(context.Background(), target, copts)
if connErr != nil {
t.Fatalf("failed to create transport: %v", connErr)
}
return server, ct
}
func setUpWithNoPingServer(t *testing.T, copts ConnectOptions, done chan net.Conn) *http2Client {
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Failed to listen: %v", err)
}
// launch a non responsive server
go func() {
defer lis.Close()
conn, err := lis.Accept()
if err != nil {
t.Errorf("Error at server-side while accepting: %v", err)
close(done)
return
}
done <- conn
}()
tr, err := newHTTP2Client(context.Background(), TargetInfo{Addr: lis.Addr().String()}, copts)
if err != nil {
t.Fatalf("Failed to dial: %v", err)
}
cT := tr.(*http2Client)
// Assert client transport is healthy
cT.mu.Lock()
defer cT.mu.Unlock()
if cT.state != reachable {
t.Fatalf("Client transport not healthy")
}
return cT
}
func TestKeepaliveClientClosesIdleTransport(t *testing.T) {
keepalive.Enabled = true
done := make(chan net.Conn, 1)
cT := setUpWithNoPingServer(t, ConnectOptions{KParams: keepalive.KeepaliveParams{
Ktime: 2 * 1000 * 1000 * 1000, // keepalive time = 2 sec
Ktimeout: 1 * 1000 * 1000 * 1000, // keepalive timeout = 1 sec
KNoStream: true, // run keepalive even with no RPCs
}}, done)
defer cT.Close()
conn, ok := <-done
if !ok {
t.Fatalf("Server didn't return connection object")
}
defer conn.Close()
// Sleep for keepalive to close the connection
time.Sleep(4 * time.Second)
// Assert that the connection was closed
cT.mu.Lock()
defer cT.mu.Unlock()
if cT.state == reachable {
t.Fatalf("Test Failed: Expected client transport to have closed.")
}
}
func TestKeepaliveClientStaysHealthyOnIdleTransport(t *testing.T) {
keepalive.Enabled = true
done := make(chan net.Conn, 1)
cT := setUpWithNoPingServer(t, ConnectOptions{KParams: keepalive.KeepaliveParams{
Ktime: 2 * 1000 * 1000 * 1000, // keepalive time = 2 sec
Ktimeout: 1 * 1000 * 1000 * 1000, // keepalive timeout = 1 sec
KNoStream: false, // don't run keepalive even with no RPCs
}}, done)
defer cT.Close()
conn, ok := <-done
if !ok {
t.Fatalf("server didn't reutrn connection object")
}
defer conn.Close()
// Give keepalive some time
time.Sleep(4 * time.Second)
// Assert that connections is still healthy
cT.mu.Lock()
defer cT.mu.Unlock()
if cT.state != reachable {
t.Fatalf("Test failed: Expected client transport to be healthy.")
}
}
func TestKeepaliveClientClosesWithActiveStreams(t *testing.T) {
keepalive.Enabled = true
done := make(chan net.Conn, 1)
cT := setUpWithNoPingServer(t, ConnectOptions{KParams: keepalive.KeepaliveParams{
Ktime: 2 * 1000 * 1000 * 1000, // keepalive time = 2 sec
Ktimeout: 1 * 1000 * 1000 * 1000, // keepalive timeout = 1 sec
KNoStream: false, // don't run keepalive even with no RPCs
}}, done)
defer cT.Close()
conn, ok := <-done
if !ok {
t.Fatalf("Server didn't return connection object")
}
defer conn.Close()
// create a stream
_, err := cT.NewStream(context.Background(), &CallHdr{})
if err != nil {
t.Fatalf("Failed to create a new stream: %v", err)
}
// Give keepalive some time
time.Sleep(4 * time.Second)
// Asser that transport was closed
cT.mu.Lock()
defer cT.mu.Unlock()
if cT.state == reachable {
t.Fatalf("Test failed: Expected client transport to have closed.")
}
}
func TestKeepaliveClientStaysHealthyWithResponsiveServer(t *testing.T) {
keepalive.Enabled = true
s, tr := setUpWithOptions(t, 0, math.MaxUint32, normal, ConnectOptions{KParams: keepalive.KeepaliveParams{
Ktime: 2 * 1000 * 1000 * 1000, // keepalive time = 2 sec
Ktimeout: 1 * 1000 * 1000 * 1000, // keepalive timeout = 1 sec
KNoStream: true, // don't run keepalive even with no RPCs
}})
defer s.stop()
defer tr.Close()
// Give keep alive some time
time.Sleep(4 * time.Second)
// Assert that transport is healthy
cT := tr.(*http2Client)
cT.mu.Lock()
defer cT.mu.Unlock()
if cT.state != reachable {
t.Fatalf("Test failed: Expected client transport to be healthy.")
}
}
func TestClientSendAndReceive(t *testing.T) {
server, ct := setUp(t, 0, math.MaxUint32, normal)
callHdr := &CallHdr{