From c361e9ea1646283baf7b23a5d060c45fce9a1dea Mon Sep 17 00:00:00 2001 From: Zach Reyes <39203661+zasweq@users.noreply.github.com> Date: Mon, 23 Aug 2021 19:39:14 -0400 Subject: [PATCH] Move Server Credentials Handshake to transport (#4692) * Move Server Credentials Handshake to transport --- internal/transport/http2_server.go | 19 ++++++++++++-- internal/transport/transport.go | 4 ++- server.go | 42 +++++++----------------------- 3 files changed, 30 insertions(+), 35 deletions(-) diff --git a/internal/transport/http2_server.go b/internal/transport/http2_server.go index 88c37723..cd0ebed9 100644 --- a/internal/transport/http2_server.go +++ b/internal/transport/http2_server.go @@ -133,6 +133,20 @@ type http2Server struct { // underlying conn gets closed before the client preface could be read, it // returns a nil transport and a nil error. func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport, err error) { + var authInfo credentials.AuthInfo + rawConn := conn + if config.Credentials != nil { + var err error + conn, authInfo, err = config.Credentials.ServerHandshake(rawConn) + if err != nil { + // ErrConnDispatched means that the connection was dispatched away from + // gRPC; those connections should be left open. + if err == credentials.ErrConnDispatched { + return nil, err + } + return nil, connectionErrorf(false, err, "ServerHandshake(%q) failed: %v", rawConn.RemoteAddr(), err) + } + } writeBufSize := config.WriteBufferSize readBufSize := config.ReadBufferSize maxHeaderListSize := defaultServerMaxHeaderListSize @@ -215,14 +229,15 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport, if kep.MinTime == 0 { kep.MinTime = defaultKeepalivePolicyMinTime } + done := make(chan struct{}) t := &http2Server{ - ctx: setConnection(context.Background(), conn), + ctx: setConnection(context.Background(), rawConn), done: done, conn: conn, remoteAddr: conn.RemoteAddr(), localAddr: conn.LocalAddr(), - authInfo: config.AuthInfo, + authInfo: authInfo, framer: framer, readerDone: make(chan struct{}), writerDone: make(chan struct{}), diff --git a/internal/transport/transport.go b/internal/transport/transport.go index 14198126..d3bf65b2 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -30,6 +30,7 @@ import ( "net" "sync" "sync/atomic" + "time" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" @@ -518,7 +519,8 @@ const ( // ServerConfig consists of all the configurations to establish a server transport. type ServerConfig struct { MaxStreams uint32 - AuthInfo credentials.AuthInfo + ConnectionTimeout time.Duration + Credentials credentials.TransportCredentials InTapHandle tap.ServerInHandle StatsHandler stats.Handler KeepaliveParams keepalive.ServerParameters diff --git a/server.go b/server.go index 596384fd..d6155c0e 100644 --- a/server.go +++ b/server.go @@ -710,13 +710,6 @@ func (s *Server) GetServiceInfo() map[string]ServiceInfo { // the server being stopped. var ErrServerStopped = errors.New("grpc: the server has been stopped") -func (s *Server) useTransportAuthenticator(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { - if s.opts.creds == nil { - return rawConn, nil, nil - } - return s.opts.creds.ServerHandshake(rawConn) -} - type listenSocket struct { net.Listener channelzID int64 @@ -839,35 +832,14 @@ func (s *Server) handleRawConn(lisAddr string, rawConn net.Conn) { return } rawConn.SetDeadline(time.Now().Add(s.opts.connectionTimeout)) - conn, authInfo, err := s.useTransportAuthenticator(rawConn) - if err != nil { - // ErrConnDispatched means that the connection was dispatched away from - // gRPC; those connections should be left open. - if err != credentials.ErrConnDispatched { - // In deployments where a gRPC server runs behind a cloud load - // balancer which performs regular TCP level health checks, the - // connection is closed immediately by the latter. Skipping the - // error here will help reduce log clutter. - if err != io.EOF { - s.mu.Lock() - s.errorf("ServerHandshake(%q) failed: %v", rawConn.RemoteAddr(), err) - s.mu.Unlock() - channelz.Warningf(logger, s.channelzID, "grpc: Server.Serve failed to complete security handshake from %q: %v", rawConn.RemoteAddr(), err) - } - rawConn.Close() - } - rawConn.SetDeadline(time.Time{}) - return - } // Finish handshaking (HTTP2) - st := s.newHTTP2Transport(conn, authInfo) + st := s.newHTTP2Transport(rawConn) + rawConn.SetDeadline(time.Time{}) if st == nil { - conn.Close() return } - rawConn.SetDeadline(time.Time{}) if !s.addConn(lisAddr, st) { return } @@ -888,10 +860,11 @@ func (s *Server) drainServerTransports(addr string) { // newHTTP2Transport sets up a http/2 transport (using the // gRPC http2 server transport in transport/http2_server.go). -func (s *Server) newHTTP2Transport(c net.Conn, authInfo credentials.AuthInfo) transport.ServerTransport { +func (s *Server) newHTTP2Transport(c net.Conn) transport.ServerTransport { config := &transport.ServerConfig{ MaxStreams: s.opts.maxConcurrentStreams, - AuthInfo: authInfo, + ConnectionTimeout: s.opts.connectionTimeout, + Credentials: s.opts.creds, InTapHandle: s.opts.inTapHandle, StatsHandler: s.opts.statsHandler, KeepaliveParams: s.opts.keepaliveParams, @@ -909,6 +882,11 @@ func (s *Server) newHTTP2Transport(c net.Conn, authInfo credentials.AuthInfo) tr s.mu.Lock() s.errorf("NewServerTransport(%q) failed: %v", c.RemoteAddr(), err) s.mu.Unlock() + // ErrConnDispatched means that the connection was dispatched away from + // gRPC; those connections should be left open. + if err != credentials.ErrConnDispatched { + c.Close() + } channelz.Warning(logger, s.channelzID, "grpc: Server.Serve failed to create ServerTransport: ", err) return nil }