Keepalive client-side implementation
This commit is contained in:
		| @ -45,6 +45,7 @@ import ( | |||||||
| 	"golang.org/x/net/trace" | 	"golang.org/x/net/trace" | ||||||
| 	"google.golang.org/grpc/credentials" | 	"google.golang.org/grpc/credentials" | ||||||
| 	"google.golang.org/grpc/grpclog" | 	"google.golang.org/grpc/grpclog" | ||||||
|  | 	"google.golang.org/grpc/keepalive" | ||||||
| 	"google.golang.org/grpc/transport" | 	"google.golang.org/grpc/transport" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @ -230,6 +231,13 @@ func WithUserAgent(s string) DialOption { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // WithKeepaliveParams returns a DialOption that specifies a user agent string for all the RPCs. | ||||||
|  | func WithKeepaliveParams(k keepalive.KeepaliveParams) DialOption { | ||||||
|  | 	return func(o *dialOptions) { | ||||||
|  | 		o.copts.KParams = k | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
| // WithUnaryInterceptor returns a DialOption that specifies the interceptor for unary RPCs. | // WithUnaryInterceptor returns a DialOption that specifies the interceptor for unary RPCs. | ||||||
| func WithUnaryInterceptor(f UnaryClientInterceptor) DialOption { | func WithUnaryInterceptor(f UnaryClientInterceptor) DialOption { | ||||||
| 	return func(o *dialOptions) { | 	return func(o *dialOptions) { | ||||||
| @ -277,6 +285,9 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Set defaults. | 	// Set defaults. | ||||||
|  | 	if cc.dopts.copts.KParams == (keepalive.KeepaliveParams{}) { | ||||||
|  | 		cc.dopts.copts.KParams = keepalive.DefaultKParams | ||||||
|  | 	} | ||||||
| 	if cc.dopts.codec == nil { | 	if cc.dopts.codec == nil { | ||||||
| 		cc.dopts.codec = protoCodec{} | 		cc.dopts.codec = protoCodec{} | ||||||
| 	} | 	} | ||||||
|  | |||||||
							
								
								
									
										22
									
								
								keepalive/keepalive.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								keepalive/keepalive.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,22 @@ | |||||||
|  | package keepalive | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"time" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type KeepaliveParams struct { | ||||||
|  | 	// After a duration of this time the client pings the server to see if the transport is still alive. | ||||||
|  | 	Ktime time.Duration | ||||||
|  | 	// After having pinged fot keepalive check, the client waits for a duration of keepalive_timeout before closing the transport. | ||||||
|  | 	Ktimeout time.Duration | ||||||
|  | 	//If true, client runs keepalive checks even with no active RPCs. | ||||||
|  | 	KNoStream bool | ||||||
|  | } | ||||||
|  |  | ||||||
|  | var DefaultKParams KeepaliveParams = KeepaliveParams{ | ||||||
|  | 	Ktime:     time.Duration(290 * 365 * 24 * 60 * 60 * 1000 * 1000 * 1000), // default to infinite | ||||||
|  | 	Ktimeout:  time.Duration(20 * 1000 * 1000 * 1000),                       // default to 20 seconds | ||||||
|  | 	KNoStream: false, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | var Enabled = false | ||||||
| @ -49,6 +49,7 @@ import ( | |||||||
| 	"google.golang.org/grpc/codes" | 	"google.golang.org/grpc/codes" | ||||||
| 	"google.golang.org/grpc/credentials" | 	"google.golang.org/grpc/credentials" | ||||||
| 	"google.golang.org/grpc/grpclog" | 	"google.golang.org/grpc/grpclog" | ||||||
|  | 	"google.golang.org/grpc/keepalive" | ||||||
| 	"google.golang.org/grpc/metadata" | 	"google.golang.org/grpc/metadata" | ||||||
| 	"google.golang.org/grpc/peer" | 	"google.golang.org/grpc/peer" | ||||||
| 	"google.golang.org/grpc/stats" | 	"google.golang.org/grpc/stats" | ||||||
| @ -109,6 +110,15 @@ type http2Client struct { | |||||||
| 	goAwayID uint32 | 	goAwayID uint32 | ||||||
| 	// prevGoAway ID records the Last-Stream-ID in the previous GOAway frame. | 	// prevGoAway ID records the Last-Stream-ID in the previous GOAway frame. | ||||||
| 	prevGoAwayID uint32 | 	prevGoAwayID uint32 | ||||||
|  |  | ||||||
|  | 	// lastRecv counts whenever a frame is recieved | ||||||
|  | 	lastRecv int64 | ||||||
|  |  | ||||||
|  | 	// lastSent counts whenever a frame is sent | ||||||
|  | 	lastSent int64 | ||||||
|  |  | ||||||
|  | 	// keepalive parameters | ||||||
|  | 	kParams keepalive.KeepaliveParams | ||||||
| } | } | ||||||
|  |  | ||||||
| func dial(ctx context.Context, fn func(context.Context, string) (net.Conn, error), addr string) (net.Conn, error) { | func dial(ctx context.Context, fn func(context.Context, string) (net.Conn, error), addr string) (net.Conn, error) { | ||||||
| @ -206,6 +216,7 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) ( | |||||||
| 		creds:           opts.PerRPCCredentials, | 		creds:           opts.PerRPCCredentials, | ||||||
| 		maxStreams:      math.MaxInt32, | 		maxStreams:      math.MaxInt32, | ||||||
| 		streamSendQuota: defaultWindowSize, | 		streamSendQuota: defaultWindowSize, | ||||||
|  | 		kParams:         opts.KParams, | ||||||
| 	} | 	} | ||||||
| 	// Start the reader goroutine for incoming message. Each transport has | 	// Start the reader goroutine for incoming message. Each transport has | ||||||
| 	// a dedicated goroutine which reads HTTP2 frame from network. Then it | 	// a dedicated goroutine which reads HTTP2 frame from network. Then it | ||||||
| @ -690,6 +701,8 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error { | |||||||
| 			break | 			break | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  | 	// update last send | ||||||
|  | 	t.lastSent++ | ||||||
| 	if !opts.Last { | 	if !opts.Last { | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
| @ -830,6 +843,8 @@ func (t *http2Client) handlePing(f *http2.PingFrame) { | |||||||
| 	pingAck := &ping{ack: true} | 	pingAck := &ping{ack: true} | ||||||
| 	copy(pingAck.data[:], f.Data[:]) | 	copy(pingAck.data[:], f.Data[:]) | ||||||
| 	t.controlBuf.put(pingAck) | 	t.controlBuf.put(pingAck) | ||||||
|  | 	// Update last sent | ||||||
|  | 	t.lastSent++ | ||||||
| } | } | ||||||
|  |  | ||||||
| func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { | func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { | ||||||
| @ -976,6 +991,8 @@ func (t *http2Client) reader() { | |||||||
| 	// loop to keep reading incoming messages on this transport. | 	// loop to keep reading incoming messages on this transport. | ||||||
| 	for { | 	for { | ||||||
| 		frame, err := t.framer.readFrame() | 		frame, err := t.framer.readFrame() | ||||||
|  | 		// update lastRecv counter | ||||||
|  | 		t.lastRecv++ | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			// Abort an active stream if the http2.Framer returns a | 			// Abort an active stream if the http2.Framer returns a | ||||||
| 			// http2.StreamError. This can happen only if the server's response | 			// http2.StreamError. This can happen only if the server's response | ||||||
| @ -1052,6 +1069,16 @@ func (t *http2Client) applySettings(ss []http2.Setting) { | |||||||
| // controller running in a separate goroutine takes charge of sending control | // controller running in a separate goroutine takes charge of sending control | ||||||
| // frames (e.g., window update, reset stream, setting, etc.) to the server. | // frames (e.g., window update, reset stream, setting, etc.) to the server. | ||||||
| func (t *http2Client) controller() { | func (t *http2Client) controller() { | ||||||
|  | 	tRCounter := t.lastRecv | ||||||
|  | 	tSCounter := t.lastSent | ||||||
|  | 	timer := time.NewTimer(t.kParams.Ktime) | ||||||
|  | 	if !keepalive.Enabled { | ||||||
|  | 		// prevent the timer from firing, ever | ||||||
|  | 		if !timer.Stop() { | ||||||
|  | 			<-timer.C | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	isPingSent := false | ||||||
| 	for { | 	for { | ||||||
| 		select { | 		select { | ||||||
| 		case i := <-t.controlBuf.get(): | 		case i := <-t.controlBuf.get(): | ||||||
| @ -1082,6 +1109,23 @@ func (t *http2Client) controller() { | |||||||
| 			case <-t.shutdownChan: | 			case <-t.shutdownChan: | ||||||
| 				return | 				return | ||||||
| 			} | 			} | ||||||
|  | 		case <-timer.C: | ||||||
|  | 			if t.lastRecv > tRCounter || t.lastSent > tSCounter || (!t.kParams.KNoStream && len(t.activeStreams) < 1) { | ||||||
|  | 				timer.Reset(t.kParams.Ktime) | ||||||
|  | 				isPingSent = false | ||||||
|  | 			} else { | ||||||
|  | 				if !isPingSent { | ||||||
|  | 					// send ping | ||||||
|  | 					t.framer.writePing(true, false, [8]byte{}) | ||||||
|  | 					isPingSent = true | ||||||
|  | 					timer.Reset(t.kParams.Ktimeout) | ||||||
|  | 				} else { | ||||||
|  | 					t.Close() | ||||||
|  | 					continue | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 			tRCounter = t.lastRecv | ||||||
|  | 			tSCounter = t.lastSent | ||||||
| 		case <-t.shutdownChan: | 		case <-t.shutdownChan: | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
|  | |||||||
| @ -47,6 +47,7 @@ import ( | |||||||
| 	"golang.org/x/net/context" | 	"golang.org/x/net/context" | ||||||
| 	"google.golang.org/grpc/codes" | 	"google.golang.org/grpc/codes" | ||||||
| 	"google.golang.org/grpc/credentials" | 	"google.golang.org/grpc/credentials" | ||||||
|  | 	"google.golang.org/grpc/keepalive" | ||||||
| 	"google.golang.org/grpc/metadata" | 	"google.golang.org/grpc/metadata" | ||||||
| 	"google.golang.org/grpc/tap" | 	"google.golang.org/grpc/tap" | ||||||
| ) | ) | ||||||
| @ -380,6 +381,8 @@ type ConnectOptions struct { | |||||||
| 	PerRPCCredentials []credentials.PerRPCCredentials | 	PerRPCCredentials []credentials.PerRPCCredentials | ||||||
| 	// TransportCredentials stores the Authenticator required to setup a client connection. | 	// TransportCredentials stores the Authenticator required to setup a client connection. | ||||||
| 	TransportCredentials credentials.TransportCredentials | 	TransportCredentials credentials.TransportCredentials | ||||||
|  | 	// Keepalive parameters | ||||||
|  | 	KParams keepalive.KeepaliveParams | ||||||
| } | } | ||||||
|  |  | ||||||
| // TargetInfo contains the information of the target such as network address and metadata. | // TargetInfo contains the information of the target such as network address and metadata. | ||||||
|  | |||||||
| @ -49,6 +49,7 @@ import ( | |||||||
| 	"golang.org/x/net/http2" | 	"golang.org/x/net/http2" | ||||||
| 	"golang.org/x/net/http2/hpack" | 	"golang.org/x/net/http2/hpack" | ||||||
| 	"google.golang.org/grpc/codes" | 	"google.golang.org/grpc/codes" | ||||||
|  | 	"google.golang.org/grpc/keepalive" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| type server struct { | type server struct { | ||||||
| @ -251,6 +252,10 @@ func (s *server) stop() { | |||||||
| } | } | ||||||
|  |  | ||||||
| func setUp(t *testing.T, port int, maxStreams uint32, ht hType) (*server, ClientTransport) { | func setUp(t *testing.T, port int, maxStreams uint32, ht hType) (*server, ClientTransport) { | ||||||
|  | 	return setUpWithOptions(t, port, maxStreams, ht, ConnectOptions{}) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func setUpWithOptions(t *testing.T, port int, maxStreams uint32, ht hType, copts ConnectOptions) (*server, ClientTransport) { | ||||||
| 	server := &server{startedErr: make(chan error, 1)} | 	server := &server{startedErr: make(chan error, 1)} | ||||||
| 	go server.start(t, port, maxStreams, ht) | 	go server.start(t, port, maxStreams, ht) | ||||||
| 	server.wait(t, 2*time.Second) | 	server.wait(t, 2*time.Second) | ||||||
| @ -262,13 +267,140 @@ func setUp(t *testing.T, port int, maxStreams uint32, ht hType) (*server, Client | |||||||
| 	target := TargetInfo{ | 	target := TargetInfo{ | ||||||
| 		Addr: addr, | 		Addr: addr, | ||||||
| 	} | 	} | ||||||
| 	ct, connErr = NewClientTransport(context.Background(), target, ConnectOptions{}) | 	ct, connErr = NewClientTransport(context.Background(), target, copts) | ||||||
| 	if connErr != nil { | 	if connErr != nil { | ||||||
| 		t.Fatalf("failed to create transport: %v", connErr) | 		t.Fatalf("failed to create transport: %v", connErr) | ||||||
| 	} | 	} | ||||||
| 	return server, ct | 	return server, ct | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func setUpWithNoPingServer(t *testing.T, copts ConnectOptions, done chan net.Conn) *http2Client { | ||||||
|  | 	lis, err := net.Listen("tcp", "localhost:0") | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("Failed to listen: %v", err) | ||||||
|  | 	} | ||||||
|  | 	// launch a non responsive server | ||||||
|  | 	go func() { | ||||||
|  | 		defer lis.Close() | ||||||
|  | 		conn, err := lis.Accept() | ||||||
|  | 		if err != nil { | ||||||
|  | 			t.Errorf("Error at server-side while accepting: %v", err) | ||||||
|  | 			close(done) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 		done <- conn | ||||||
|  | 	}() | ||||||
|  | 	tr, err := newHTTP2Client(context.Background(), TargetInfo{Addr: lis.Addr().String()}, copts) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("Failed to dial: %v", err) | ||||||
|  | 	} | ||||||
|  | 	cT := tr.(*http2Client) | ||||||
|  | 	// Assert client transport is healthy | ||||||
|  | 	cT.mu.Lock() | ||||||
|  | 	defer cT.mu.Unlock() | ||||||
|  | 	if cT.state != reachable { | ||||||
|  | 		t.Fatalf("Client transport not healthy") | ||||||
|  | 	} | ||||||
|  | 	return cT | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestKeepaliveClientClosesIdleTransport(t *testing.T) { | ||||||
|  | 	keepalive.Enabled = true | ||||||
|  | 	done := make(chan net.Conn, 1) | ||||||
|  | 	cT := setUpWithNoPingServer(t, ConnectOptions{KParams: keepalive.KeepaliveParams{ | ||||||
|  | 		Ktime:     2 * 1000 * 1000 * 1000, // keepalive time = 2 sec | ||||||
|  | 		Ktimeout:  1 * 1000 * 1000 * 1000, // keepalive timeout = 1 sec | ||||||
|  | 		KNoStream: true,                   // run keepalive even with no RPCs | ||||||
|  | 	}}, done) | ||||||
|  | 	defer cT.Close() | ||||||
|  | 	conn, ok := <-done | ||||||
|  | 	if !ok { | ||||||
|  | 		t.Fatalf("Server didn't return connection object") | ||||||
|  | 	} | ||||||
|  | 	defer conn.Close() | ||||||
|  | 	// Sleep for keepalive to close the connection | ||||||
|  | 	time.Sleep(4 * time.Second) | ||||||
|  | 	// Assert that the connection was closed | ||||||
|  | 	cT.mu.Lock() | ||||||
|  | 	defer cT.mu.Unlock() | ||||||
|  | 	if cT.state == reachable { | ||||||
|  | 		t.Fatalf("Test Failed: Expected client transport to have closed.") | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestKeepaliveClientStaysHealthyOnIdleTransport(t *testing.T) { | ||||||
|  | 	keepalive.Enabled = true | ||||||
|  | 	done := make(chan net.Conn, 1) | ||||||
|  | 	cT := setUpWithNoPingServer(t, ConnectOptions{KParams: keepalive.KeepaliveParams{ | ||||||
|  | 		Ktime:     2 * 1000 * 1000 * 1000, // keepalive time = 2 sec | ||||||
|  | 		Ktimeout:  1 * 1000 * 1000 * 1000, // keepalive timeout = 1 sec | ||||||
|  | 		KNoStream: false,                  // don't run keepalive even with no RPCs | ||||||
|  | 	}}, done) | ||||||
|  | 	defer cT.Close() | ||||||
|  | 	conn, ok := <-done | ||||||
|  | 	if !ok { | ||||||
|  | 		t.Fatalf("server didn't reutrn connection object") | ||||||
|  | 	} | ||||||
|  | 	defer conn.Close() | ||||||
|  | 	// Give keepalive some time | ||||||
|  | 	time.Sleep(4 * time.Second) | ||||||
|  | 	// Assert that connections is still healthy | ||||||
|  | 	cT.mu.Lock() | ||||||
|  | 	defer cT.mu.Unlock() | ||||||
|  | 	if cT.state != reachable { | ||||||
|  | 		t.Fatalf("Test failed: Expected client transport to be healthy.") | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestKeepaliveClientClosesWithActiveStreams(t *testing.T) { | ||||||
|  | 	keepalive.Enabled = true | ||||||
|  | 	done := make(chan net.Conn, 1) | ||||||
|  | 	cT := setUpWithNoPingServer(t, ConnectOptions{KParams: keepalive.KeepaliveParams{ | ||||||
|  | 		Ktime:     2 * 1000 * 1000 * 1000, // keepalive time = 2 sec | ||||||
|  | 		Ktimeout:  1 * 1000 * 1000 * 1000, // keepalive timeout = 1 sec | ||||||
|  | 		KNoStream: false,                  // don't run keepalive even with no RPCs | ||||||
|  | 	}}, done) | ||||||
|  | 	defer cT.Close() | ||||||
|  | 	conn, ok := <-done | ||||||
|  | 	if !ok { | ||||||
|  | 		t.Fatalf("Server didn't return connection object") | ||||||
|  | 	} | ||||||
|  | 	defer conn.Close() | ||||||
|  | 	// create a stream | ||||||
|  | 	_, err := cT.NewStream(context.Background(), &CallHdr{}) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("Failed to create a new stream: %v", err) | ||||||
|  | 	} | ||||||
|  | 	// Give keepalive some time | ||||||
|  | 	time.Sleep(4 * time.Second) | ||||||
|  | 	// Asser that transport was closed | ||||||
|  | 	cT.mu.Lock() | ||||||
|  | 	defer cT.mu.Unlock() | ||||||
|  | 	if cT.state == reachable { | ||||||
|  | 		t.Fatalf("Test failed: Expected client transport to have closed.") | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestKeepaliveClientStaysHealthyWithResponsiveServer(t *testing.T) { | ||||||
|  | 	keepalive.Enabled = true | ||||||
|  | 	s, tr := setUpWithOptions(t, 0, math.MaxUint32, normal, ConnectOptions{KParams: keepalive.KeepaliveParams{ | ||||||
|  | 		Ktime:     2 * 1000 * 1000 * 1000, // keepalive time = 2 sec | ||||||
|  | 		Ktimeout:  1 * 1000 * 1000 * 1000, // keepalive timeout = 1 sec | ||||||
|  | 		KNoStream: true,                   // don't run keepalive even with no RPCs | ||||||
|  | 	}}) | ||||||
|  | 	defer s.stop() | ||||||
|  | 	defer tr.Close() | ||||||
|  | 	// Give keep alive some time | ||||||
|  | 	time.Sleep(4 * time.Second) | ||||||
|  | 	// Assert that transport is healthy | ||||||
|  | 	cT := tr.(*http2Client) | ||||||
|  | 	cT.mu.Lock() | ||||||
|  | 	defer cT.mu.Unlock() | ||||||
|  | 	if cT.state != reachable { | ||||||
|  | 		t.Fatalf("Test failed: Expected client transport to be healthy.") | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
| func TestClientSendAndReceive(t *testing.T) { | func TestClientSendAndReceive(t *testing.T) { | ||||||
| 	server, ct := setUp(t, 0, math.MaxUint32, normal) | 	server, ct := setUp(t, 0, math.MaxUint32, normal) | ||||||
| 	callHdr := &CallHdr{ | 	callHdr := &CallHdr{ | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user
	 Mahak Mukhi
					Mahak Mukhi