diff --git a/clientconn.go b/clientconn.go index f6dab4b7..99f2c37e 100644 --- a/clientconn.go +++ b/clientconn.go @@ -45,6 +45,7 @@ import ( "golang.org/x/net/trace" "google.golang.org/grpc/credentials" "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/keepalive" "google.golang.org/grpc/transport" ) @@ -230,6 +231,13 @@ func WithUserAgent(s string) DialOption { } } +// WithKeepaliveParams returns a DialOption that specifies a user agent string for all the RPCs. +func WithKeepaliveParams(k keepalive.KeepaliveParams) DialOption { + return func(o *dialOptions) { + o.copts.KParams = k + } +} + // WithUnaryInterceptor returns a DialOption that specifies the interceptor for unary RPCs. func WithUnaryInterceptor(f UnaryClientInterceptor) DialOption { return func(o *dialOptions) { @@ -277,6 +285,9 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * } // Set defaults. + if cc.dopts.copts.KParams == (keepalive.KeepaliveParams{}) { + cc.dopts.copts.KParams = keepalive.DefaultKParams + } if cc.dopts.codec == nil { cc.dopts.codec = protoCodec{} } diff --git a/keepalive/keepalive.go b/keepalive/keepalive.go new file mode 100644 index 00000000..284b6eac --- /dev/null +++ b/keepalive/keepalive.go @@ -0,0 +1,22 @@ +package keepalive + +import ( + "time" +) + +type KeepaliveParams struct { + // After a duration of this time the client pings the server to see if the transport is still alive. + Ktime time.Duration + // After having pinged fot keepalive check, the client waits for a duration of keepalive_timeout before closing the transport. + Ktimeout time.Duration + //If true, client runs keepalive checks even with no active RPCs. + KNoStream bool +} + +var DefaultKParams KeepaliveParams = KeepaliveParams{ + Ktime: time.Duration(290 * 365 * 24 * 60 * 60 * 1000 * 1000 * 1000), // default to infinite + Ktimeout: time.Duration(20 * 1000 * 1000 * 1000), // default to 20 seconds + KNoStream: false, +} + +var Enabled = false diff --git a/transport/http2_client.go b/transport/http2_client.go index cbd9f326..0e81050d 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -49,6 +49,7 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/keepalive" "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" "google.golang.org/grpc/stats" @@ -109,6 +110,15 @@ type http2Client struct { goAwayID uint32 // prevGoAway ID records the Last-Stream-ID in the previous GOAway frame. prevGoAwayID uint32 + + // lastRecv counts whenever a frame is recieved + lastRecv int64 + + // lastSent counts whenever a frame is sent + lastSent int64 + + // keepalive parameters + kParams keepalive.KeepaliveParams } func dial(ctx context.Context, fn func(context.Context, string) (net.Conn, error), addr string) (net.Conn, error) { @@ -206,6 +216,7 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) ( creds: opts.PerRPCCredentials, maxStreams: math.MaxInt32, streamSendQuota: defaultWindowSize, + kParams: opts.KParams, } // Start the reader goroutine for incoming message. Each transport has // a dedicated goroutine which reads HTTP2 frame from network. Then it @@ -690,6 +701,8 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error { break } } + // update last send + t.lastSent++ if !opts.Last { return nil } @@ -830,6 +843,8 @@ func (t *http2Client) handlePing(f *http2.PingFrame) { pingAck := &ping{ack: true} copy(pingAck.data[:], f.Data[:]) t.controlBuf.put(pingAck) + // Update last sent + t.lastSent++ } func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { @@ -976,6 +991,8 @@ func (t *http2Client) reader() { // loop to keep reading incoming messages on this transport. for { frame, err := t.framer.readFrame() + // update lastRecv counter + t.lastRecv++ if err != nil { // Abort an active stream if the http2.Framer returns a // http2.StreamError. This can happen only if the server's response @@ -1052,6 +1069,16 @@ func (t *http2Client) applySettings(ss []http2.Setting) { // controller running in a separate goroutine takes charge of sending control // frames (e.g., window update, reset stream, setting, etc.) to the server. func (t *http2Client) controller() { + tRCounter := t.lastRecv + tSCounter := t.lastSent + timer := time.NewTimer(t.kParams.Ktime) + if !keepalive.Enabled { + // prevent the timer from firing, ever + if !timer.Stop() { + <-timer.C + } + } + isPingSent := false for { select { case i := <-t.controlBuf.get(): @@ -1082,6 +1109,23 @@ func (t *http2Client) controller() { case <-t.shutdownChan: return } + case <-timer.C: + if t.lastRecv > tRCounter || t.lastSent > tSCounter || (!t.kParams.KNoStream && len(t.activeStreams) < 1) { + timer.Reset(t.kParams.Ktime) + isPingSent = false + } else { + if !isPingSent { + // send ping + t.framer.writePing(true, false, [8]byte{}) + isPingSent = true + timer.Reset(t.kParams.Ktimeout) + } else { + t.Close() + continue + } + } + tRCounter = t.lastRecv + tSCounter = t.lastSent case <-t.shutdownChan: return } diff --git a/transport/transport.go b/transport/transport.go index 4726bb2c..3648ca14 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -47,6 +47,7 @@ import ( "golang.org/x/net/context" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" + "google.golang.org/grpc/keepalive" "google.golang.org/grpc/metadata" "google.golang.org/grpc/tap" ) @@ -380,6 +381,8 @@ type ConnectOptions struct { PerRPCCredentials []credentials.PerRPCCredentials // TransportCredentials stores the Authenticator required to setup a client connection. TransportCredentials credentials.TransportCredentials + // Keepalive parameters + KParams keepalive.KeepaliveParams } // TargetInfo contains the information of the target such as network address and metadata. diff --git a/transport/transport_test.go b/transport/transport_test.go index 1ca6eb1a..dd464581 100644 --- a/transport/transport_test.go +++ b/transport/transport_test.go @@ -49,6 +49,7 @@ import ( "golang.org/x/net/http2" "golang.org/x/net/http2/hpack" "google.golang.org/grpc/codes" + "google.golang.org/grpc/keepalive" ) type server struct { @@ -251,6 +252,10 @@ func (s *server) stop() { } func setUp(t *testing.T, port int, maxStreams uint32, ht hType) (*server, ClientTransport) { + return setUpWithOptions(t, port, maxStreams, ht, ConnectOptions{}) +} + +func setUpWithOptions(t *testing.T, port int, maxStreams uint32, ht hType, copts ConnectOptions) (*server, ClientTransport) { server := &server{startedErr: make(chan error, 1)} go server.start(t, port, maxStreams, ht) server.wait(t, 2*time.Second) @@ -262,13 +267,140 @@ func setUp(t *testing.T, port int, maxStreams uint32, ht hType) (*server, Client target := TargetInfo{ Addr: addr, } - ct, connErr = NewClientTransport(context.Background(), target, ConnectOptions{}) + ct, connErr = NewClientTransport(context.Background(), target, copts) if connErr != nil { t.Fatalf("failed to create transport: %v", connErr) } return server, ct } +func setUpWithNoPingServer(t *testing.T, copts ConnectOptions, done chan net.Conn) *http2Client { + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("Failed to listen: %v", err) + } + // launch a non responsive server + go func() { + defer lis.Close() + conn, err := lis.Accept() + if err != nil { + t.Errorf("Error at server-side while accepting: %v", err) + close(done) + return + } + done <- conn + }() + tr, err := newHTTP2Client(context.Background(), TargetInfo{Addr: lis.Addr().String()}, copts) + if err != nil { + t.Fatalf("Failed to dial: %v", err) + } + cT := tr.(*http2Client) + // Assert client transport is healthy + cT.mu.Lock() + defer cT.mu.Unlock() + if cT.state != reachable { + t.Fatalf("Client transport not healthy") + } + return cT +} + +func TestKeepaliveClientClosesIdleTransport(t *testing.T) { + keepalive.Enabled = true + done := make(chan net.Conn, 1) + cT := setUpWithNoPingServer(t, ConnectOptions{KParams: keepalive.KeepaliveParams{ + Ktime: 2 * 1000 * 1000 * 1000, // keepalive time = 2 sec + Ktimeout: 1 * 1000 * 1000 * 1000, // keepalive timeout = 1 sec + KNoStream: true, // run keepalive even with no RPCs + }}, done) + defer cT.Close() + conn, ok := <-done + if !ok { + t.Fatalf("Server didn't return connection object") + } + defer conn.Close() + // Sleep for keepalive to close the connection + time.Sleep(4 * time.Second) + // Assert that the connection was closed + cT.mu.Lock() + defer cT.mu.Unlock() + if cT.state == reachable { + t.Fatalf("Test Failed: Expected client transport to have closed.") + } +} + +func TestKeepaliveClientStaysHealthyOnIdleTransport(t *testing.T) { + keepalive.Enabled = true + done := make(chan net.Conn, 1) + cT := setUpWithNoPingServer(t, ConnectOptions{KParams: keepalive.KeepaliveParams{ + Ktime: 2 * 1000 * 1000 * 1000, // keepalive time = 2 sec + Ktimeout: 1 * 1000 * 1000 * 1000, // keepalive timeout = 1 sec + KNoStream: false, // don't run keepalive even with no RPCs + }}, done) + defer cT.Close() + conn, ok := <-done + if !ok { + t.Fatalf("server didn't reutrn connection object") + } + defer conn.Close() + // Give keepalive some time + time.Sleep(4 * time.Second) + // Assert that connections is still healthy + cT.mu.Lock() + defer cT.mu.Unlock() + if cT.state != reachable { + t.Fatalf("Test failed: Expected client transport to be healthy.") + } +} + +func TestKeepaliveClientClosesWithActiveStreams(t *testing.T) { + keepalive.Enabled = true + done := make(chan net.Conn, 1) + cT := setUpWithNoPingServer(t, ConnectOptions{KParams: keepalive.KeepaliveParams{ + Ktime: 2 * 1000 * 1000 * 1000, // keepalive time = 2 sec + Ktimeout: 1 * 1000 * 1000 * 1000, // keepalive timeout = 1 sec + KNoStream: false, // don't run keepalive even with no RPCs + }}, done) + defer cT.Close() + conn, ok := <-done + if !ok { + t.Fatalf("Server didn't return connection object") + } + defer conn.Close() + // create a stream + _, err := cT.NewStream(context.Background(), &CallHdr{}) + if err != nil { + t.Fatalf("Failed to create a new stream: %v", err) + } + // Give keepalive some time + time.Sleep(4 * time.Second) + // Asser that transport was closed + cT.mu.Lock() + defer cT.mu.Unlock() + if cT.state == reachable { + t.Fatalf("Test failed: Expected client transport to have closed.") + } +} + +func TestKeepaliveClientStaysHealthyWithResponsiveServer(t *testing.T) { + keepalive.Enabled = true + s, tr := setUpWithOptions(t, 0, math.MaxUint32, normal, ConnectOptions{KParams: keepalive.KeepaliveParams{ + Ktime: 2 * 1000 * 1000 * 1000, // keepalive time = 2 sec + Ktimeout: 1 * 1000 * 1000 * 1000, // keepalive timeout = 1 sec + KNoStream: true, // don't run keepalive even with no RPCs + }}) + defer s.stop() + defer tr.Close() + // Give keep alive some time + time.Sleep(4 * time.Second) + // Assert that transport is healthy + cT := tr.(*http2Client) + cT.mu.Lock() + defer cT.mu.Unlock() + if cT.state != reachable { + t.Fatalf("Test failed: Expected client transport to be healthy.") + } +} + func TestClientSendAndReceive(t *testing.T) { server, ct := setUp(t, 0, math.MaxUint32, normal) callHdr := &CallHdr{