diff --git a/clientconn.go b/clientconn.go index cff53f73..b8e31988 100644 --- a/clientconn.go +++ b/clientconn.go @@ -45,6 +45,7 @@ import ( "golang.org/x/net/trace" "google.golang.org/grpc/credentials" "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/keepalive" "google.golang.org/grpc/stats" "google.golang.org/grpc/transport" ) @@ -249,6 +250,13 @@ func WithUserAgent(s string) DialOption { } } +// WithKeepaliveParams returns a DialOption that specifies keepalive paramaters for the client transport. +func WithKeepaliveParams(kp keepalive.ClientParameters) DialOption { + return func(o *dialOptions) { + o.copts.KeepaliveParams = kp + } +} + // WithUnaryInterceptor returns a DialOption that specifies the interceptor for unary RPCs. func WithUnaryInterceptor(f UnaryClientInterceptor) DialOption { return func(o *dialOptions) { diff --git a/keepalive/keepalive.go b/keepalive/keepalive.go new file mode 100644 index 00000000..e3a4068c --- /dev/null +++ b/keepalive/keepalive.go @@ -0,0 +1,18 @@ +package keepalive + +import ( + "time" +) + +// ClientParameters is used to set keepalive parameters on the client-side. +// These configure how the client will actively probe to notice when a connection broken +// and to cause activity so intermediaries are aware the connection is still in use. +type ClientParameters struct { + // After a duration of this time if the client doesn't see any activity it pings the server to see if the transport is still alive. + Time time.Duration // The current default value is infinity. + // After having pinged for keepalive check, the client waits for a duration of Timeout and if no activity is seen even after that + // the connection is closed. + Timeout time.Duration // The current default value is 20 seconds. + // If true, client runs keepalive checks even with no active RPCs. + PermitWithoutStream bool +} diff --git a/transport/control.go b/transport/control.go index 33de7b60..0edbe53a 100644 --- a/transport/control.go +++ b/transport/control.go @@ -35,7 +35,9 @@ package transport import ( "fmt" + "math" "sync" + "time" "golang.org/x/net/http2" ) @@ -46,6 +48,9 @@ const ( // The initial window size for flow control. initialWindowSize = defaultWindowSize // for an RPC initialConnWindowSize = defaultWindowSize * 16 // for a connection + infinity = time.Duration(math.MaxInt64) + defaultKeepaliveTime = infinity + defaultKeepaliveTimeout = time.Duration(20 * time.Second) defaultMaxStreamsClient = 100 ) diff --git a/transport/http2_client.go b/transport/http2_client.go index c69f66ff..c02ee160 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -41,6 +41,7 @@ import ( "net" "strings" "sync" + "sync/atomic" "time" "golang.org/x/net/context" @@ -49,6 +50,7 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/keepalive" "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" "google.golang.org/grpc/stats" @@ -80,6 +82,8 @@ type http2Client struct { // goAway is closed to notify the upper layer (i.e., addrConn.transportMonitor) // that the server sent GoAway on this transport. goAway chan struct{} + // awakenKeepalive is used to wake up keepalive when after it has gone dormant. + awakenKeepalive chan struct{} framer *framer hBuf *bytes.Buffer // the buffer for HPACK encoding @@ -99,6 +103,11 @@ type http2Client struct { creds []credentials.PerRPCCredentials + // Boolean to keep track of reading activity on transport. + // 1 is true and 0 is false. + activity uint32 // Accessed atomically. + kp keepalive.ClientParameters + statsHandler stats.Handler mu sync.Mutex // guard the following variables @@ -182,6 +191,14 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) ( if opts.UserAgent != "" { ua = opts.UserAgent + " " + ua } + kp := opts.KeepaliveParams + // Validate keepalive parameters. + if kp.Time == 0 { + kp.Time = defaultKeepaliveTime + } + if kp.Timeout == 0 { + kp.Timeout = defaultKeepaliveTimeout + } var buf bytes.Buffer t := &http2Client{ ctx: ctx, @@ -198,6 +215,7 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) ( shutdownChan: make(chan struct{}), errorChan: make(chan struct{}), goAway: make(chan struct{}), + awakenKeepalive: make(chan struct{}, 1), framer: newFramer(conn), hBuf: &buf, hEnc: hpack.NewEncoder(&buf), @@ -211,8 +229,12 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) ( maxStreams: defaultMaxStreamsClient, streamsQuota: newQuotaPool(defaultMaxStreamsClient), streamSendQuota: defaultWindowSize, + kp: kp, statsHandler: opts.StatsHandler, } + // Make sure awakenKeepalive can't be written upon. + // keepalive routine will make it writable, if need be. + t.awakenKeepalive <- struct{}{} if t.statsHandler != nil { t.ctx = t.statsHandler.TagConn(t.ctx, &stats.ConnTagInfo{ RemoteAddr: t.remoteAddr, @@ -257,6 +279,9 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) ( } } go t.controller() + if t.kp.Time != infinity { + go t.keepalive() + } t.writableChan <- 0 return t, nil } @@ -369,6 +394,15 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea s := t.newStream(ctx, callHdr) s.clientStatsCtx = userCtx t.activeStreams[s.id] = s + // If the number of active streams change from 0 to 1, then check if keepalive + // has gone dormant. If so, wake it up. + if len(t.activeStreams) == 1 { + select { + case t.awakenKeepalive <- struct{}{}: + t.framer.writePing(false, false, [8]byte{}) + default: + } + } t.mu.Unlock() @@ -992,6 +1026,7 @@ func (t *http2Client) reader() { t.notifyError(err) return } + atomic.CompareAndSwapUint32(&t.activity, 0, 1) sf, ok := frame.(*http2.SettingsFrame) if !ok { t.notifyError(err) @@ -1002,6 +1037,7 @@ func (t *http2Client) reader() { // loop to keep reading incoming messages on this transport. for { frame, err := t.framer.readFrame() + atomic.CompareAndSwapUint32(&t.activity, 0, 1) if err != nil { // Abort an active stream if the http2.Framer returns a // http2.StreamError. This can happen only if the server's response @@ -1114,6 +1150,61 @@ func (t *http2Client) controller() { } } +// keepalive running in a separate goroutune makes sure the connection is alive by sending pings. +func (t *http2Client) keepalive() { + p := &ping{data: [8]byte{}} + timer := time.NewTimer(t.kp.Time) + for { + select { + case <-timer.C: + if atomic.CompareAndSwapUint32(&t.activity, 1, 0) { + timer.Reset(t.kp.Time) + continue + } + // Check if keepalive should go dormant. + t.mu.Lock() + if len(t.activeStreams) < 1 && !t.kp.PermitWithoutStream { + // Make awakenKeepalive writable. + <-t.awakenKeepalive + t.mu.Unlock() + select { + case <-t.awakenKeepalive: + // If the control gets here a ping has been sent + // need to reset the timer with keepalive.Timeout. + case <-t.shutdownChan: + return + } + } else { + t.mu.Unlock() + // Send ping. + t.controlBuf.put(p) + } + + // By the time control gets here a ping has been sent one way or the other. + timer.Reset(t.kp.Timeout) + select { + case <-timer.C: + if atomic.CompareAndSwapUint32(&t.activity, 1, 0) { + timer.Reset(t.kp.Time) + continue + } + t.Close() + return + case <-t.shutdownChan: + if !timer.Stop() { + <-timer.C + } + return + } + case <-t.shutdownChan: + if !timer.Stop() { + <-timer.C + } + return + } + } +} + func (t *http2Client) Error() <-chan struct{} { return t.errorChan } diff --git a/transport/transport.go b/transport/transport.go index aed75d59..fed69089 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -47,6 +47,7 @@ import ( "golang.org/x/net/context" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" + "google.golang.org/grpc/keepalive" "google.golang.org/grpc/metadata" "google.golang.org/grpc/stats" "google.golang.org/grpc/tap" @@ -388,6 +389,8 @@ type ConnectOptions struct { PerRPCCredentials []credentials.PerRPCCredentials // TransportCredentials stores the Authenticator required to setup a client connection. TransportCredentials credentials.TransportCredentials + // KeepaliveParams stores the keepalive parameters. + KeepaliveParams keepalive.ClientParameters // StatsHandler stores the handler for stats. StatsHandler stats.Handler } diff --git a/transport/transport_test.go b/transport/transport_test.go index e91fc6ed..2e9674aa 100644 --- a/transport/transport_test.go +++ b/transport/transport_test.go @@ -49,6 +49,7 @@ import ( "golang.org/x/net/http2" "golang.org/x/net/http2/hpack" "google.golang.org/grpc/codes" + "google.golang.org/grpc/keepalive" ) type server struct { @@ -251,6 +252,10 @@ func (s *server) stop() { } func setUp(t *testing.T, port int, maxStreams uint32, ht hType) (*server, ClientTransport) { + return setUpWithOptions(t, port, maxStreams, ht, ConnectOptions{}) +} + +func setUpWithOptions(t *testing.T, port int, maxStreams uint32, ht hType, copts ConnectOptions) (*server, ClientTransport) { server := &server{startedErr: make(chan error, 1)} go server.start(t, port, maxStreams, ht) server.wait(t, 2*time.Second) @@ -262,13 +267,134 @@ func setUp(t *testing.T, port int, maxStreams uint32, ht hType) (*server, Client target := TargetInfo{ Addr: addr, } - ct, connErr = NewClientTransport(context.Background(), target, ConnectOptions{}) + ct, connErr = NewClientTransport(context.Background(), target, copts) if connErr != nil { t.Fatalf("failed to create transport: %v", connErr) } return server, ct } +func setUpWithNoPingServer(t *testing.T, copts ConnectOptions, done chan net.Conn) ClientTransport { + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("Failed to listen: %v", err) + } + // Launch a non responsive server. + go func() { + defer lis.Close() + conn, err := lis.Accept() + if err != nil { + t.Errorf("Error at server-side while accepting: %v", err) + close(done) + return + } + done <- conn + }() + tr, err := NewClientTransport(context.Background(), TargetInfo{Addr: lis.Addr().String()}, copts) + if err != nil { + // Server clean-up. + if conn, ok := <-done; ok { + conn.Close() + } + t.Fatalf("Failed to dial: %v", err) + } + return tr +} + +func TestKeepaliveClientClosesIdleTransport(t *testing.T) { + done := make(chan net.Conn, 1) + tr := setUpWithNoPingServer(t, ConnectOptions{KeepaliveParams: keepalive.ClientParameters{ + Time: 2 * time.Second, // Keepalive time = 2 sec. + Timeout: 1 * time.Second, // Keepalive timeout = 1 sec. + PermitWithoutStream: true, // Run keepalive even with no RPCs. + }}, done) + defer tr.Close() + conn, ok := <-done + if !ok { + t.Fatalf("Server didn't return connection object") + } + defer conn.Close() + // Sleep for keepalive to close the connection. + time.Sleep(4 * time.Second) + // Assert that the connection was closed. + ct := tr.(*http2Client) + ct.mu.Lock() + defer ct.mu.Unlock() + if ct.state == reachable { + t.Fatalf("Test Failed: Expected client transport to have closed.") + } +} + +func TestKeepaliveClientStaysHealthyOnIdleTransport(t *testing.T) { + done := make(chan net.Conn, 1) + tr := setUpWithNoPingServer(t, ConnectOptions{KeepaliveParams: keepalive.ClientParameters{ + Time: 2 * time.Second, // Keepalive time = 2 sec. + Timeout: 1 * time.Second, // Keepalive timeout = 1 sec. + }}, done) + defer tr.Close() + conn, ok := <-done + if !ok { + t.Fatalf("server didn't reutrn connection object") + } + defer conn.Close() + // Give keepalive some time. + time.Sleep(4 * time.Second) + // Assert that connections is still healthy. + ct := tr.(*http2Client) + ct.mu.Lock() + defer ct.mu.Unlock() + if ct.state != reachable { + t.Fatalf("Test failed: Expected client transport to be healthy.") + } +} + +func TestKeepaliveClientClosesWithActiveStreams(t *testing.T) { + done := make(chan net.Conn, 1) + tr := setUpWithNoPingServer(t, ConnectOptions{KeepaliveParams: keepalive.ClientParameters{ + Time: 2 * time.Second, // Keepalive time = 2 sec. + Timeout: 1 * time.Second, // Keepalive timeout = 1 sec. + }}, done) + defer tr.Close() + conn, ok := <-done + if !ok { + t.Fatalf("Server didn't return connection object") + } + defer conn.Close() + // Create a stream. + _, err := tr.NewStream(context.Background(), &CallHdr{}) + if err != nil { + t.Fatalf("Failed to create a new stream: %v", err) + } + // Give keepalive some time. + time.Sleep(4 * time.Second) + // Assert that transport was closed. + ct := tr.(*http2Client) + ct.mu.Lock() + defer ct.mu.Unlock() + if ct.state == reachable { + t.Fatalf("Test failed: Expected client transport to have closed.") + } +} + +func TestKeepaliveClientStaysHealthyWithResponsiveServer(t *testing.T) { + s, tr := setUpWithOptions(t, 0, math.MaxUint32, normal, ConnectOptions{KeepaliveParams: keepalive.ClientParameters{ + Time: 2 * time.Second, // Keepalive time = 2 sec. + Timeout: 1 * time.Second, // Keepalive timeout = 1 sec. + PermitWithoutStream: true, // Run keepalive even with no RPCs. + }}) + defer s.stop() + defer tr.Close() + // Give keep alive some time. + time.Sleep(4 * time.Second) + // Assert that transport is healthy. + ct := tr.(*http2Client) + ct.mu.Lock() + defer ct.mu.Unlock() + if ct.state != reachable { + t.Fatalf("Test failed: Expected client transport to be healthy.") + } +} + func TestClientSendAndReceive(t *testing.T) { server, ct := setUp(t, 0, math.MaxUint32, normal) callHdr := &CallHdr{