diff --git a/keepalive/keepalive.go b/keepalive/keepalive.go index b915a2ad..d492589c 100644 --- a/keepalive/keepalive.go +++ b/keepalive/keepalive.go @@ -41,6 +41,8 @@ import ( // ClientParameters is used to set keepalive parameters on the client-side. // These configure how the client will actively probe to notice when a connection broken // and to cause activity so intermediaries are aware the connection is still in use. +// Make sure these parameters are set in coordination with the keepalive policy on the server, +// as incompatible settings can result in closing of connection. type ClientParameters struct { // After a duration of this time if the client doesn't see any activity it pings the server to see if the transport is still alive. Time time.Duration // The current default value is infinity. @@ -48,22 +50,31 @@ type ClientParameters struct { // the connection is closed. Timeout time.Duration // The current default value is 20 seconds. // If true, client runs keepalive checks even with no active RPCs. - PermitWithoutStream bool + PermitWithoutStream bool // false by default. } // ServerParameters is used to set keepalive and max-age parameters on the server-side. type ServerParameters struct { // MaxConnectionIdle is a duration for the amount of time after which an idle connection would be closed by sending a GoAway. // Idleness duration is defined since the most recent time the number of outstanding RPCs became zero or the connection establishment. - MaxConnectionIdle time.Duration + MaxConnectionIdle time.Duration // The current default value is infinity. // MaxConnectionAge is a duration for the maximum amount of time a connection may exist before it will be closed by sending a GoAway. // A random jitter of +/-10% will be added to MaxConnectionAge to spread out connection storms. - MaxConnectionAge time.Duration + MaxConnectionAge time.Duration // The current default value is infinity. // MaxConnectinoAgeGrace is an additive period after MaxConnectionAge after which the connection will be forcibly closed. - MaxConnectionAgeGrace time.Duration + MaxConnectionAgeGrace time.Duration // The current default value is infinity. // After a duration of this time if the server doesn't see any activity it pings the client to see if the transport is still alive. - Time time.Duration + Time time.Duration // The current default value is 2 hours. // After having pinged for keepalive check, the server waits for a duration of Timeout and if no activity is seen even after that // the connection is closed. - Timeout time.Duration + Timeout time.Duration // The current default value is 20 seconds. +} + +// EnforcementPolicy is used to set keepalive enforcement policy on the server-side. +// Server will close connection with a client that violates this policy. +type EnforcementPolicy struct { + // MinTime is the minimum amount of time a client should wait before sending a keepalive ping. + MinTime time.Duration // The current default value is 5 minutes. + // If true, server expects keepalive pings even when there are no active streams(RPCs). + PermitWithoutStream bool // false by default. } diff --git a/server.go b/server.go index b19a3c4a..32be4437 100644 --- a/server.go +++ b/server.go @@ -119,6 +119,7 @@ type options struct { useHandlerImpl bool // use http.Handler-based server unknownStreamDesc *StreamDesc keepaliveParams keepalive.ServerParameters + keepalivePolicy keepalive.EnforcementPolicy } var defaultMaxMsgSize = 1024 * 1024 * 4 // use 4MB as the default message size limit @@ -133,6 +134,13 @@ func KeepaliveParams(kp keepalive.ServerParameters) ServerOption { } } +// KeepaliveEnforcementPolicy returns a ServerOption that sets keepalive enforcement policy for the server. +func KeepaliveEnforcementPolicy(kep keepalive.EnforcementPolicy) ServerOption { + return func(o *options) { + o.keepalivePolicy = kep + } +} + // CustomCodec returns a ServerOption that sets a codec for message marshaling and unmarshaling. func CustomCodec(codec Codec) ServerOption { return func(o *options) { @@ -479,6 +487,7 @@ func (s *Server) serveHTTP2Transport(c net.Conn, authInfo credentials.AuthInfo) InTapHandle: s.opts.inTapHandle, StatsHandler: s.opts.statsHandler, KeepaliveParams: s.opts.keepaliveParams, + KeepalivePolicy: s.opts.keepalivePolicy, } st, err := transport.NewServerTransport("http2", c, config) if err != nil { diff --git a/transport/control.go b/transport/control.go index 64d22f84..8d29aee5 100644 --- a/transport/control.go +++ b/transport/control.go @@ -57,6 +57,7 @@ const ( defaultMaxConnectionAgeGrace = infinity defaultServerKeepaliveTime = time.Duration(2 * time.Hour) defaultServerKeepaliveTimeout = time.Duration(20 * time.Second) + defaultKeepalivePolicyMinTime = time.Duration(5 * time.Minute) ) // The following defines various control items which could flow through @@ -84,6 +85,8 @@ type resetStream struct { func (*resetStream) item() {} type goAway struct { + code http2.ErrCode + debugData []byte } func (*goAway) item() {} diff --git a/transport/http2_client.go b/transport/http2_client.go index 550fe713..d6e2998b 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -893,6 +893,9 @@ func (t *http2Client) handlePing(f *http2.PingFrame) { } func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { + if f.ErrCode == http2.ErrCodeEnhanceYourCalm { + grpclog.Printf("Client received GoAway with http2.ErrCodeEnhanceYourCalm.") + } t.mu.Lock() if t.state == reachable || t.state == draining { if f.LastStreamID > 0 && f.LastStreamID%2 != 1 { diff --git a/transport/http2_server.go b/transport/http2_server.go index e810d195..f3bc569d 100644 --- a/transport/http2_server.go +++ b/transport/http2_server.go @@ -100,6 +100,17 @@ type http2Server struct { // Keepalive and max-age parameters for the server. kp keepalive.ServerParameters + // Keepalive enforcement policy. + kep keepalive.EnforcementPolicy + // The time instance last ping was received. + lastPingAt time.Time + // Number of times the client has violated keepalive ping policy so far. + pingStrikes uint8 + // Flag to signify that number of ping strikes should be reset to 0. + // This is set whenever data or header frames are sent. + // 1 means yes. + resetPingStrikes uint32 // Accessed atomically. + mu sync.Mutex // guard the following state transportState activeStreams map[uint32]*Stream @@ -161,6 +172,10 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err if kp.Timeout == 0 { kp.Timeout = defaultServerKeepaliveTimeout } + kep := config.KeepalivePolicy + if kep.MinTime == 0 { + kep.MinTime = defaultKeepalivePolicyMinTime + } var buf bytes.Buffer t := &http2Server{ ctx: context.Background(), @@ -184,6 +199,7 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err stats: config.StatsHandler, kp: kp, idle: time.Now(), + kep: kep, } if t.stats != nil { t.ctx = t.stats.TagConn(t.ctx, &stats.ConnTagInfo{ @@ -504,6 +520,11 @@ func (t *http2Server) handleSettings(f *http2.SettingsFrame) { t.controlBuf.put(&settings{ack: true, ss: ss}) } +const ( + maxPingStrikes = 2 + defaultPingTimeout = 2 * time.Hour +) + func (t *http2Server) handlePing(f *http2.PingFrame) { if f.IsAck() { // Do nothing. return @@ -511,6 +532,38 @@ func (t *http2Server) handlePing(f *http2.PingFrame) { pingAck := &ping{ack: true} copy(pingAck.data[:], f.Data[:]) t.controlBuf.put(pingAck) + + now := time.Now() + defer func() { + t.lastPingAt = now + }() + // A reset ping strikes means that we don't need to check for policy + // violation for this ping and the pingStrikes counter should be set + // to 0. + if atomic.CompareAndSwapUint32(&t.resetPingStrikes, 1, 0) { + t.pingStrikes = 0 + return + } + t.mu.Lock() + ns := len(t.activeStreams) + t.mu.Unlock() + if ns < 1 && !t.kep.PermitWithoutStream { + // Keepalive shouldn't be active thus, this new ping should + // have come after atleast defaultPingTimeout. + if t.lastPingAt.Add(defaultPingTimeout).After(now) { + t.pingStrikes++ + } + } else { + // Check if keepalive policy is respected. + if t.lastPingAt.Add(t.kep.MinTime).After(now) { + t.pingStrikes++ + } + } + + if t.pingStrikes > maxPingStrikes { + // Send goaway and close the connection. + t.controlBuf.put(&goAway{code: http2.ErrCodeEnhanceYourCalm, debugData: []byte("too_many_pings")}) + } } func (t *http2Server) handleWindowUpdate(f *http2.WindowUpdateFrame) { @@ -529,6 +582,13 @@ func (t *http2Server) writeHeaders(s *Stream, b *bytes.Buffer, endStream bool) e first := true endHeaders := false var err error + defer func() { + if err == nil { + // Reset ping strikes when seding headers since that might cause the + // peer to send ping. + atomic.StoreUint32(&t.resetPingStrikes, 1) + } + }() // Sends the headers in a single batch. for !endHeaders { size := t.hBuf.Len() @@ -672,7 +732,7 @@ func (t *http2Server) WriteStatus(s *Stream, statusCode codes.Code, statusDesc s // Write converts the data into HTTP2 data frame and sends it out. Non-nil error // is returns if it fails (e.g., framing error, transport error). -func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error { +func (t *http2Server) Write(s *Stream, data []byte, opts *Options) (err error) { // TODO(zhaoq): Support multi-writers for a single stream. var writeHeaderFrame bool s.mu.Lock() @@ -687,6 +747,13 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error { if writeHeaderFrame { t.WriteHeader(s, nil) } + defer func() { + if err == nil { + // Reset ping strikes when sending data since this might cause + // the peer to send ping. + atomic.StoreUint32(&t.resetPingStrikes, 1) + } + }() r := bytes.NewBuffer(data) for { if r.Len() == 0 { @@ -892,7 +959,10 @@ func (t *http2Server) controller() { sid := t.maxStreamID t.state = draining t.mu.Unlock() - t.framer.writeGoAway(true, sid, http2.ErrCodeNo, nil) + t.framer.writeGoAway(true, sid, i.code, i.debugData) + if i.code == http2.ErrCodeEnhanceYourCalm { + t.Close() + } case *flushIO: t.framer.flushWrite() case *ping: @@ -972,7 +1042,7 @@ func (t *http2Server) RemoteAddr() net.Addr { } func (t *http2Server) Drain() { - t.controlBuf.put(&goAway{}) + t.controlBuf.put(&goAway{code: http2.ErrCodeNo}) } var rgen = rand.New(rand.NewSource(time.Now().UnixNano())) diff --git a/transport/transport.go b/transport/transport.go index 51d71e39..51716803 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -370,6 +370,7 @@ type ServerConfig struct { InTapHandle tap.ServerInHandle StatsHandler stats.Handler KeepaliveParams keepalive.ServerParameters + KeepalivePolicy keepalive.EnforcementPolicy } // NewServerTransport creates a ServerTransport with conn or non-nil error diff --git a/transport/transport_test.go b/transport/transport_test.go index 62038784..3108b98c 100644 --- a/transport/transport_test.go +++ b/transport/transport_test.go @@ -323,6 +323,9 @@ func TestMaxConnectionIdle(t *testing.T) { timeout := time.NewTimer(time.Second * 4) select { case <-client.GoAway(): + if !timeout.Stop() { + <-timeout.C + } case <-timeout.C: t.Fatalf("Test timed out, expected a GoAway from the server.") } @@ -345,6 +348,9 @@ func TestMaxConnectionIdleNegative(t *testing.T) { timeout := time.NewTimer(time.Second * 4) select { case <-client.GoAway(): + if !timeout.Stop() { + <-timeout.C + } t.Fatalf("A non-idle client received a GoAway.") case <-timeout.C: } @@ -369,6 +375,9 @@ func TestMaxConnectionAge(t *testing.T) { timeout := time.NewTimer(4 * time.Second) select { case <-client.GoAway(): + if !timeout.Stop() { + <-timeout.C + } case <-timeout.C: t.Fatalf("Test timer out, expected a GoAway from the server.") } @@ -523,6 +532,138 @@ func TestKeepaliveClientStaysHealthyWithResponsiveServer(t *testing.T) { } } +func TestKeepaliveServerEnforcementWithAbusiveClientNoRPC(t *testing.T) { + serverConfig := &ServerConfig{ + KeepalivePolicy: keepalive.EnforcementPolicy{ + MinTime: 2 * time.Second, + }, + } + clientOptions := ConnectOptions{ + KeepaliveParams: keepalive.ClientParameters{ + Time: 50 * time.Millisecond, + Timeout: 50 * time.Millisecond, + PermitWithoutStream: true, + }, + } + server, client := setUpWithOptions(t, 0, serverConfig, normal, clientOptions) + defer server.stop() + defer client.Close() + + timeout := time.NewTimer(2 * time.Second) + select { + case <-client.GoAway(): + if !timeout.Stop() { + <-timeout.C + } + case <-timeout.C: + t.Fatalf("Test failed: Expected a GoAway from server.") + } + time.Sleep(500 * time.Millisecond) + ct := client.(*http2Client) + ct.mu.Lock() + defer ct.mu.Unlock() + if ct.state == reachable { + t.Fatalf("Test failed: Expected the connection to be closed.") + } +} + +func TestKeepaliveServerEnforcementWithAbusiveClientWithRPC(t *testing.T) { + serverConfig := &ServerConfig{ + KeepalivePolicy: keepalive.EnforcementPolicy{ + MinTime: 2 * time.Second, + }, + } + clientOptions := ConnectOptions{ + KeepaliveParams: keepalive.ClientParameters{ + Time: 50 * time.Millisecond, + Timeout: 50 * time.Millisecond, + }, + } + server, client := setUpWithOptions(t, 0, serverConfig, suspended, clientOptions) + defer server.stop() + defer client.Close() + + if _, err := client.NewStream(context.Background(), &CallHdr{Flush: true}); err != nil { + t.Fatalf("Client failed to create stream.") + } + timeout := time.NewTimer(2 * time.Second) + select { + case <-client.GoAway(): + if !timeout.Stop() { + <-timeout.C + } + case <-timeout.C: + t.Fatalf("Test failed: Expected a GoAway from server.") + } + time.Sleep(500 * time.Millisecond) + ct := client.(*http2Client) + ct.mu.Lock() + defer ct.mu.Unlock() + if ct.state == reachable { + t.Fatalf("Test failed: Expected the connection to be closed.") + } +} + +func TestKeepaliveServerEnforcementWithObeyingClientNoRPC(t *testing.T) { + serverConfig := &ServerConfig{ + KeepalivePolicy: keepalive.EnforcementPolicy{ + MinTime: 100 * time.Millisecond, + PermitWithoutStream: true, + }, + } + clientOptions := ConnectOptions{ + KeepaliveParams: keepalive.ClientParameters{ + Time: 101 * time.Millisecond, + Timeout: 50 * time.Millisecond, + PermitWithoutStream: true, + }, + } + server, client := setUpWithOptions(t, 0, serverConfig, normal, clientOptions) + defer server.stop() + defer client.Close() + + // Give keepalive enough time. + time.Sleep(2 * time.Second) + // Assert that connection is healthy. + ct := client.(*http2Client) + ct.mu.Lock() + defer ct.mu.Unlock() + if ct.state != reachable { + t.Fatalf("Test failed: Expected connection to be healthy.") + } +} + +func TestKeepaliveServerEnforcementWithObeyingClientWithRPC(t *testing.T) { + serverConfig := &ServerConfig{ + KeepalivePolicy: keepalive.EnforcementPolicy{ + MinTime: 100 * time.Millisecond, + }, + } + clientOptions := ConnectOptions{ + KeepaliveParams: keepalive.ClientParameters{ + Time: 101 * time.Millisecond, + Timeout: 50 * time.Millisecond, + }, + } + server, client := setUpWithOptions(t, 0, serverConfig, suspended, clientOptions) + defer server.stop() + defer client.Close() + + if _, err := client.NewStream(context.Background(), &CallHdr{Flush: true}); err != nil { + t.Fatalf("Client failed to create stream.") + } + + // Give keepalive enough time. + time.Sleep(2 * time.Second) + // Assert that connection is healthy. + ct := client.(*http2Client) + ct.mu.Lock() + defer ct.mu.Unlock() + if ct.state != reachable { + t.Fatalf("Test failed: Expected connection to be healthy.") + } +} + func TestClientSendAndReceive(t *testing.T) { server, ct := setUp(t, 0, math.MaxUint32, normal) callHdr := &CallHdr{