server: apply deadline to new connections until all handshaking is completed (#1646)
This commit is contained in:
@ -87,10 +87,14 @@ type TransportCredentials interface {
|
|||||||
// (io.EOF, context.DeadlineExceeded or err.Temporary() == true).
|
// (io.EOF, context.DeadlineExceeded or err.Temporary() == true).
|
||||||
// If the returned error is a wrapper error, implementations should make sure that
|
// If the returned error is a wrapper error, implementations should make sure that
|
||||||
// the error implements Temporary() to have the correct retry behaviors.
|
// the error implements Temporary() to have the correct retry behaviors.
|
||||||
|
//
|
||||||
|
// If the returned net.Conn is closed, it MUST close the net.Conn provided.
|
||||||
ClientHandshake(context.Context, string, net.Conn) (net.Conn, AuthInfo, error)
|
ClientHandshake(context.Context, string, net.Conn) (net.Conn, AuthInfo, error)
|
||||||
// ServerHandshake does the authentication handshake for servers. It returns
|
// ServerHandshake does the authentication handshake for servers. It returns
|
||||||
// the authenticated connection and the corresponding auth information about
|
// the authenticated connection and the corresponding auth information about
|
||||||
// the connection.
|
// the connection.
|
||||||
|
//
|
||||||
|
// If the returned net.Conn is closed, it MUST close the net.Conn provided.
|
||||||
ServerHandshake(net.Conn) (net.Conn, AuthInfo, error)
|
ServerHandshake(net.Conn) (net.Conn, AuthInfo, error)
|
||||||
// Info provides the ProtocolInfo of this TransportCredentials.
|
// Info provides the ProtocolInfo of this TransportCredentials.
|
||||||
Info() ProtocolInfo
|
Info() ProtocolInfo
|
||||||
|
39
server.go
39
server.go
@ -126,11 +126,13 @@ type options struct {
|
|||||||
initialConnWindowSize int32
|
initialConnWindowSize int32
|
||||||
writeBufferSize int
|
writeBufferSize int
|
||||||
readBufferSize int
|
readBufferSize int
|
||||||
|
connectionTimeout time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
var defaultServerOptions = options{
|
var defaultServerOptions = options{
|
||||||
maxReceiveMessageSize: defaultServerMaxReceiveMessageSize,
|
maxReceiveMessageSize: defaultServerMaxReceiveMessageSize,
|
||||||
maxSendMessageSize: defaultServerMaxSendMessageSize,
|
maxSendMessageSize: defaultServerMaxSendMessageSize,
|
||||||
|
connectionTimeout: 120 * time.Second,
|
||||||
}
|
}
|
||||||
|
|
||||||
// A ServerOption sets options such as credentials, codec and keepalive parameters, etc.
|
// A ServerOption sets options such as credentials, codec and keepalive parameters, etc.
|
||||||
@ -303,6 +305,16 @@ func UnknownServiceHandler(streamHandler StreamHandler) ServerOption {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ConnectionTimeout returns a ServerOption that sets the timeout for
|
||||||
|
// connection establishment (up to and including HTTP/2 handshaking) for all
|
||||||
|
// new connections. If this is not set, the default is 120 seconds. A zero or
|
||||||
|
// negative value will result in an immediate timeout.
|
||||||
|
func ConnectionTimeout(d time.Duration) ServerOption {
|
||||||
|
return func(o *options) {
|
||||||
|
o.connectionTimeout = d
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// NewServer creates a gRPC server which has no service registered and has not
|
// NewServer creates a gRPC server which has no service registered and has not
|
||||||
// started to accept requests yet.
|
// started to accept requests yet.
|
||||||
func NewServer(opt ...ServerOption) *Server {
|
func NewServer(opt ...ServerOption) *Server {
|
||||||
@ -519,16 +531,18 @@ func (s *Server) Serve(lis net.Listener) error {
|
|||||||
// handleRawConn is run in its own goroutine and handles a just-accepted
|
// handleRawConn is run in its own goroutine and handles a just-accepted
|
||||||
// connection that has not had any I/O performed on it yet.
|
// connection that has not had any I/O performed on it yet.
|
||||||
func (s *Server) handleRawConn(rawConn net.Conn) {
|
func (s *Server) handleRawConn(rawConn net.Conn) {
|
||||||
|
rawConn.SetDeadline(time.Now().Add(s.opts.connectionTimeout))
|
||||||
conn, authInfo, err := s.useTransportAuthenticator(rawConn)
|
conn, authInfo, err := s.useTransportAuthenticator(rawConn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
s.errorf("ServerHandshake(%q) failed: %v", rawConn.RemoteAddr(), err)
|
s.errorf("ServerHandshake(%q) failed: %v", rawConn.RemoteAddr(), err)
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
grpclog.Warningf("grpc: Server.Serve failed to complete security handshake from %q: %v", rawConn.RemoteAddr(), err)
|
grpclog.Warningf("grpc: Server.Serve failed to complete security handshake from %q: %v", rawConn.RemoteAddr(), err)
|
||||||
// If serverHandShake returns ErrConnDispatched, keep rawConn open.
|
// If serverHandshake returns ErrConnDispatched, keep rawConn open.
|
||||||
if err != credentials.ErrConnDispatched {
|
if err != credentials.ErrConnDispatched {
|
||||||
rawConn.Close()
|
rawConn.Close()
|
||||||
}
|
}
|
||||||
|
rawConn.SetDeadline(time.Time{})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -541,18 +555,21 @@ func (s *Server) handleRawConn(rawConn net.Conn) {
|
|||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
|
||||||
if s.opts.useHandlerImpl {
|
if s.opts.useHandlerImpl {
|
||||||
|
rawConn.SetDeadline(time.Time{})
|
||||||
s.serveUsingHandler(conn)
|
s.serveUsingHandler(conn)
|
||||||
} else {
|
} else {
|
||||||
s.serveHTTP2Transport(conn, authInfo)
|
st := s.newHTTP2Transport(conn, authInfo)
|
||||||
|
if st == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
rawConn.SetDeadline(time.Time{})
|
||||||
|
s.serveStreams(st)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// serveHTTP2Transport sets up a http/2 transport (using the
|
// newHTTP2Transport sets up a http/2 transport (using the
|
||||||
// gRPC http2 server transport in transport/http2_server.go) and
|
// gRPC http2 server transport in transport/http2_server.go).
|
||||||
// serves streams on it.
|
func (s *Server) newHTTP2Transport(c net.Conn, authInfo credentials.AuthInfo) transport.ServerTransport {
|
||||||
// This is run in its own goroutine (it does network I/O in
|
|
||||||
// transport.NewServerTransport).
|
|
||||||
func (s *Server) serveHTTP2Transport(c net.Conn, authInfo credentials.AuthInfo) {
|
|
||||||
config := &transport.ServerConfig{
|
config := &transport.ServerConfig{
|
||||||
MaxStreams: s.opts.maxConcurrentStreams,
|
MaxStreams: s.opts.maxConcurrentStreams,
|
||||||
AuthInfo: authInfo,
|
AuthInfo: authInfo,
|
||||||
@ -572,13 +589,13 @@ func (s *Server) serveHTTP2Transport(c net.Conn, authInfo credentials.AuthInfo)
|
|||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
c.Close()
|
c.Close()
|
||||||
grpclog.Warningln("grpc: Server.Serve failed to create ServerTransport: ", err)
|
grpclog.Warningln("grpc: Server.Serve failed to create ServerTransport: ", err)
|
||||||
return
|
return nil
|
||||||
}
|
}
|
||||||
if !s.addConn(st) {
|
if !s.addConn(st) {
|
||||||
st.Close()
|
st.Close()
|
||||||
return
|
return nil
|
||||||
}
|
}
|
||||||
s.serveStreams(st)
|
return st
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) serveStreams(st transport.ServerTransport) {
|
func (s *Server) serveStreams(st transport.ServerTransport) {
|
||||||
|
@ -155,12 +155,12 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err
|
|||||||
Val: uint32(iwz)})
|
Val: uint32(iwz)})
|
||||||
}
|
}
|
||||||
if err := framer.fr.WriteSettings(isettings...); err != nil {
|
if err := framer.fr.WriteSettings(isettings...); err != nil {
|
||||||
return nil, connectionErrorf(true, err, "transport: %v", err)
|
return nil, connectionErrorf(false, err, "transport: %v", err)
|
||||||
}
|
}
|
||||||
// Adjust the connection flow control window if needed.
|
// Adjust the connection flow control window if needed.
|
||||||
if delta := uint32(icwz - defaultWindowSize); delta > 0 {
|
if delta := uint32(icwz - defaultWindowSize); delta > 0 {
|
||||||
if err := framer.fr.WriteWindowUpdate(0, delta); err != nil {
|
if err := framer.fr.WriteWindowUpdate(0, delta); err != nil {
|
||||||
return nil, connectionErrorf(true, err, "transport: %v", err)
|
return nil, connectionErrorf(false, err, "transport: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
kp := config.KeepaliveParams
|
kp := config.KeepaliveParams
|
||||||
@ -227,6 +227,31 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err
|
|||||||
t.stats.HandleConn(t.ctx, connBegin)
|
t.stats.HandleConn(t.ctx, connBegin)
|
||||||
}
|
}
|
||||||
t.framer.writer.Flush()
|
t.framer.writer.Flush()
|
||||||
|
|
||||||
|
// Check the validity of client preface.
|
||||||
|
preface := make([]byte, len(clientPreface))
|
||||||
|
if _, err := io.ReadFull(t.conn, preface); err != nil {
|
||||||
|
return nil, connectionErrorf(false, err, "transport: http2Server.HandleStreams failed to receive the preface from client: %v", err)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(preface, clientPreface) {
|
||||||
|
return nil, connectionErrorf(false, nil, "transport: http2Server.HandleStreams received bogus greeting from client: %q", preface)
|
||||||
|
}
|
||||||
|
|
||||||
|
frame, err := t.framer.fr.ReadFrame()
|
||||||
|
if err == io.EOF || err == io.ErrUnexpectedEOF {
|
||||||
|
t.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, connectionErrorf(false, err, "transport: http2Server.HandleStreams failed to read initial settings frame: %v", err)
|
||||||
|
}
|
||||||
|
atomic.StoreUint32(&t.activity, 1)
|
||||||
|
sf, ok := frame.(*http2.SettingsFrame)
|
||||||
|
if !ok {
|
||||||
|
return nil, connectionErrorf(false, nil, "transport: http2Server.HandleStreams saw invalid preface type %T from client", frame)
|
||||||
|
}
|
||||||
|
t.handleSettings(sf)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
loopyWriter(t.ctx, t.controlBuf, t.itemHandler)
|
loopyWriter(t.ctx, t.controlBuf, t.itemHandler)
|
||||||
t.Close()
|
t.Close()
|
||||||
@ -361,41 +386,6 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
|
|||||||
// typically run in a separate goroutine.
|
// typically run in a separate goroutine.
|
||||||
// traceCtx attaches trace to ctx and returns the new context.
|
// traceCtx attaches trace to ctx and returns the new context.
|
||||||
func (t *http2Server) HandleStreams(handle func(*Stream), traceCtx func(context.Context, string) context.Context) {
|
func (t *http2Server) HandleStreams(handle func(*Stream), traceCtx func(context.Context, string) context.Context) {
|
||||||
// Check the validity of client preface.
|
|
||||||
preface := make([]byte, len(clientPreface))
|
|
||||||
if _, err := io.ReadFull(t.conn, preface); err != nil {
|
|
||||||
// Only log if it isn't a simple tcp accept check (ie: tcp balancer doing open/close socket)
|
|
||||||
if err != io.EOF {
|
|
||||||
errorf("transport: http2Server.HandleStreams failed to receive the preface from client: %v", err)
|
|
||||||
}
|
|
||||||
t.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !bytes.Equal(preface, clientPreface) {
|
|
||||||
errorf("transport: http2Server.HandleStreams received bogus greeting from client: %q", preface)
|
|
||||||
t.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
frame, err := t.framer.fr.ReadFrame()
|
|
||||||
if err == io.EOF || err == io.ErrUnexpectedEOF {
|
|
||||||
t.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
errorf("transport: http2Server.HandleStreams failed to read initial settings frame: %v", err)
|
|
||||||
t.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
atomic.StoreUint32(&t.activity, 1)
|
|
||||||
sf, ok := frame.(*http2.SettingsFrame)
|
|
||||||
if !ok {
|
|
||||||
errorf("transport: http2Server.HandleStreams saw invalid preface type %T from client", frame)
|
|
||||||
t.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
t.handleSettings(sf)
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
frame, err := t.framer.fr.ReadFrame()
|
frame, err := t.framer.fr.ReadFrame()
|
||||||
atomic.StoreUint32(&t.activity, 1)
|
atomic.StoreUint32(&t.activity, 1)
|
||||||
|
@ -533,8 +533,18 @@ func TestKeepaliveServer(t *testing.T) {
|
|||||||
t.Fatalf("Failed to dial: %v", err)
|
t.Fatalf("Failed to dial: %v", err)
|
||||||
}
|
}
|
||||||
defer client.Close()
|
defer client.Close()
|
||||||
|
|
||||||
// Set read deadline on client conn so that it doesn't block forever in errorsome cases.
|
// Set read deadline on client conn so that it doesn't block forever in errorsome cases.
|
||||||
client.SetReadDeadline(time.Now().Add(10 * time.Second))
|
client.SetDeadline(time.Now().Add(10 * time.Second))
|
||||||
|
|
||||||
|
if n, err := client.Write(clientPreface); err != nil || n != len(clientPreface) {
|
||||||
|
t.Fatalf("Error writing client preface; n=%v, err=%v", n, err)
|
||||||
|
}
|
||||||
|
framer := newFramer(client, defaultWriteBufSize, defaultReadBufSize)
|
||||||
|
if err := framer.fr.WriteSettings(http2.Setting{}); err != nil {
|
||||||
|
t.Fatal("Error writing settings frame:", err)
|
||||||
|
}
|
||||||
|
framer.writer.Flush()
|
||||||
// Wait for keepalive logic to close the connection.
|
// Wait for keepalive logic to close the connection.
|
||||||
time.Sleep(4 * time.Second)
|
time.Sleep(4 * time.Second)
|
||||||
b := make([]byte, 24)
|
b := make([]byte, 24)
|
||||||
|
Reference in New Issue
Block a user