Move Server Credentials Handshake to transport (#4692)

* Move Server Credentials Handshake to transport
This commit is contained in:
Zach Reyes
2021-08-23 19:39:14 -04:00
committed by GitHub
parent 8ab16ef276
commit c361e9ea16
3 changed files with 30 additions and 35 deletions

View File

@ -133,6 +133,20 @@ type http2Server struct {
// underlying conn gets closed before the client preface could be read, it // underlying conn gets closed before the client preface could be read, it
// returns a nil transport and a nil error. // returns a nil transport and a nil error.
func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport, err error) { func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport, err error) {
var authInfo credentials.AuthInfo
rawConn := conn
if config.Credentials != nil {
var err error
conn, authInfo, err = config.Credentials.ServerHandshake(rawConn)
if err != nil {
// ErrConnDispatched means that the connection was dispatched away from
// gRPC; those connections should be left open.
if err == credentials.ErrConnDispatched {
return nil, err
}
return nil, connectionErrorf(false, err, "ServerHandshake(%q) failed: %v", rawConn.RemoteAddr(), err)
}
}
writeBufSize := config.WriteBufferSize writeBufSize := config.WriteBufferSize
readBufSize := config.ReadBufferSize readBufSize := config.ReadBufferSize
maxHeaderListSize := defaultServerMaxHeaderListSize maxHeaderListSize := defaultServerMaxHeaderListSize
@ -215,14 +229,15 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
if kep.MinTime == 0 { if kep.MinTime == 0 {
kep.MinTime = defaultKeepalivePolicyMinTime kep.MinTime = defaultKeepalivePolicyMinTime
} }
done := make(chan struct{}) done := make(chan struct{})
t := &http2Server{ t := &http2Server{
ctx: setConnection(context.Background(), conn), ctx: setConnection(context.Background(), rawConn),
done: done, done: done,
conn: conn, conn: conn,
remoteAddr: conn.RemoteAddr(), remoteAddr: conn.RemoteAddr(),
localAddr: conn.LocalAddr(), localAddr: conn.LocalAddr(),
authInfo: config.AuthInfo, authInfo: authInfo,
framer: framer, framer: framer,
readerDone: make(chan struct{}), readerDone: make(chan struct{}),
writerDone: make(chan struct{}), writerDone: make(chan struct{}),

View File

@ -30,6 +30,7 @@ import (
"net" "net"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
@ -518,7 +519,8 @@ const (
// ServerConfig consists of all the configurations to establish a server transport. // ServerConfig consists of all the configurations to establish a server transport.
type ServerConfig struct { type ServerConfig struct {
MaxStreams uint32 MaxStreams uint32
AuthInfo credentials.AuthInfo ConnectionTimeout time.Duration
Credentials credentials.TransportCredentials
InTapHandle tap.ServerInHandle InTapHandle tap.ServerInHandle
StatsHandler stats.Handler StatsHandler stats.Handler
KeepaliveParams keepalive.ServerParameters KeepaliveParams keepalive.ServerParameters

View File

@ -710,13 +710,6 @@ func (s *Server) GetServiceInfo() map[string]ServiceInfo {
// the server being stopped. // the server being stopped.
var ErrServerStopped = errors.New("grpc: the server has been stopped") var ErrServerStopped = errors.New("grpc: the server has been stopped")
func (s *Server) useTransportAuthenticator(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
if s.opts.creds == nil {
return rawConn, nil, nil
}
return s.opts.creds.ServerHandshake(rawConn)
}
type listenSocket struct { type listenSocket struct {
net.Listener net.Listener
channelzID int64 channelzID int64
@ -839,35 +832,14 @@ func (s *Server) handleRawConn(lisAddr string, rawConn net.Conn) {
return return
} }
rawConn.SetDeadline(time.Now().Add(s.opts.connectionTimeout)) rawConn.SetDeadline(time.Now().Add(s.opts.connectionTimeout))
conn, authInfo, err := s.useTransportAuthenticator(rawConn)
if err != nil {
// ErrConnDispatched means that the connection was dispatched away from
// gRPC; those connections should be left open.
if err != credentials.ErrConnDispatched {
// In deployments where a gRPC server runs behind a cloud load
// balancer which performs regular TCP level health checks, the
// connection is closed immediately by the latter. Skipping the
// error here will help reduce log clutter.
if err != io.EOF {
s.mu.Lock()
s.errorf("ServerHandshake(%q) failed: %v", rawConn.RemoteAddr(), err)
s.mu.Unlock()
channelz.Warningf(logger, s.channelzID, "grpc: Server.Serve failed to complete security handshake from %q: %v", rawConn.RemoteAddr(), err)
}
rawConn.Close()
}
rawConn.SetDeadline(time.Time{})
return
}
// Finish handshaking (HTTP2) // Finish handshaking (HTTP2)
st := s.newHTTP2Transport(conn, authInfo) st := s.newHTTP2Transport(rawConn)
rawConn.SetDeadline(time.Time{})
if st == nil { if st == nil {
conn.Close()
return return
} }
rawConn.SetDeadline(time.Time{})
if !s.addConn(lisAddr, st) { if !s.addConn(lisAddr, st) {
return return
} }
@ -888,10 +860,11 @@ func (s *Server) drainServerTransports(addr string) {
// newHTTP2Transport sets up a http/2 transport (using the // newHTTP2Transport sets up a http/2 transport (using the
// gRPC http2 server transport in transport/http2_server.go). // gRPC http2 server transport in transport/http2_server.go).
func (s *Server) newHTTP2Transport(c net.Conn, authInfo credentials.AuthInfo) transport.ServerTransport { func (s *Server) newHTTP2Transport(c net.Conn) transport.ServerTransport {
config := &transport.ServerConfig{ config := &transport.ServerConfig{
MaxStreams: s.opts.maxConcurrentStreams, MaxStreams: s.opts.maxConcurrentStreams,
AuthInfo: authInfo, ConnectionTimeout: s.opts.connectionTimeout,
Credentials: s.opts.creds,
InTapHandle: s.opts.inTapHandle, InTapHandle: s.opts.inTapHandle,
StatsHandler: s.opts.statsHandler, StatsHandler: s.opts.statsHandler,
KeepaliveParams: s.opts.keepaliveParams, KeepaliveParams: s.opts.keepaliveParams,
@ -909,6 +882,11 @@ func (s *Server) newHTTP2Transport(c net.Conn, authInfo credentials.AuthInfo) tr
s.mu.Lock() s.mu.Lock()
s.errorf("NewServerTransport(%q) failed: %v", c.RemoteAddr(), err) s.errorf("NewServerTransport(%q) failed: %v", c.RemoteAddr(), err)
s.mu.Unlock() s.mu.Unlock()
// ErrConnDispatched means that the connection was dispatched away from
// gRPC; those connections should be left open.
if err != credentials.ErrConnDispatched {
c.Close()
}
channelz.Warning(logger, s.channelzID, "grpc: Server.Serve failed to create ServerTransport: ", err) channelz.Warning(logger, s.channelzID, "grpc: Server.Serve failed to create ServerTransport: ", err)
return nil return nil
} }