separate auth info from normal metadata
This commit is contained in:
@ -78,27 +78,47 @@ type ProtocolInfo struct {
|
|||||||
SecurityVersion string
|
SecurityVersion string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AuthInfo defines the common interface for the auth information the users are interested in.
|
||||||
|
type AuthInfo interface {
|
||||||
|
Type() string
|
||||||
|
}
|
||||||
|
|
||||||
|
type authInfoKey struct{}
|
||||||
|
|
||||||
|
// NewContext creates a new context with authInfo attached.
|
||||||
|
func NewContext(ctx context.Context, authInfo AuthInfo) context.Context {
|
||||||
|
return context.WithValue(ctx, authInfoKey{}, authInfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
// FromContext returns the authInfo in ctx if it exists.
|
||||||
|
func FromContext(ctx context.Context) (authInfo AuthInfo, ok bool) {
|
||||||
|
authInfo, ok = ctx.Value(authInfoKey{}).(AuthInfo)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// TransportAuthenticator defines the common interface for all the live gRPC wire
|
// TransportAuthenticator defines the common interface for all the live gRPC wire
|
||||||
// protocols and supported transport security protocols (e.g., TLS, SSL).
|
// protocols and supported transport security protocols (e.g., TLS, SSL).
|
||||||
type TransportAuthenticator interface {
|
type TransportAuthenticator interface {
|
||||||
// ClientHandshake does the authentication handshake specified by the corresponding
|
// ClientHandshake does the authentication handshake specified by the corresponding
|
||||||
// authentication protocol on rawConn for clients. It returns the authenticated
|
// authentication protocol on rawConn for clients. It returns the authenticated
|
||||||
// connection and the corresponding auth information about the connection.
|
// connection and the corresponding auth information about the connection.
|
||||||
ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (net.Conn, map[string][]string, error)
|
ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (net.Conn, AuthInfo, error)
|
||||||
// ServerHandshake does the authentication handshake for servers. It returns
|
// ServerHandshake does the authentication handshake for servers. It returns
|
||||||
// the authenticated connection and the corresponding auth information about
|
// the authenticated connection and the corresponding auth information about
|
||||||
// the connection.
|
// the connection.
|
||||||
ServerHandshake(rawConn net.Conn) (net.Conn, map[string][]string, error)
|
ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error)
|
||||||
// Info provides the ProtocolInfo of this TransportAuthenticator.
|
// Info provides the ProtocolInfo of this TransportAuthenticator.
|
||||||
Info() ProtocolInfo
|
Info() ProtocolInfo
|
||||||
Credentials
|
Credentials
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
type TLSInfo struct {
|
||||||
transportSecurityType = "transport_security_type"
|
state tls.ConnectionState
|
||||||
x509CN = "x509_common_name"
|
}
|
||||||
x509SAN = "x509_suject_alternative_name"
|
|
||||||
)
|
func (t TLSInfo) Type() string {
|
||||||
|
return "tls"
|
||||||
|
}
|
||||||
|
|
||||||
// tlsCreds is the credentials required for authenticating a connection using TLS.
|
// tlsCreds is the credentials required for authenticating a connection using TLS.
|
||||||
type tlsCreds struct {
|
type tlsCreds struct {
|
||||||
@ -106,7 +126,7 @@ type tlsCreds struct {
|
|||||||
config tls.Config
|
config tls.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *tlsCreds) Info() ProtocolInfo {
|
func (c tlsCreds) Info() ProtocolInfo {
|
||||||
return ProtocolInfo{
|
return ProtocolInfo{
|
||||||
SecurityProtocol: "tls",
|
SecurityProtocol: "tls",
|
||||||
SecurityVersion: "1.2",
|
SecurityVersion: "1.2",
|
||||||
@ -125,7 +145,7 @@ func (timeoutError) Error() string { return "credentials: Dial timed out" }
|
|||||||
func (timeoutError) Timeout() bool { return true }
|
func (timeoutError) Timeout() bool { return true }
|
||||||
func (timeoutError) Temporary() bool { return true }
|
func (timeoutError) Temporary() bool { return true }
|
||||||
|
|
||||||
func (c *tlsCreds) ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (_ net.Conn, _ map[string][]string, err error) {
|
func (c *tlsCreds) ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (_ net.Conn, _ AuthInfo, err error) {
|
||||||
// borrow some code from tls.DialWithDialer
|
// borrow some code from tls.DialWithDialer
|
||||||
var errChannel chan error
|
var errChannel chan error
|
||||||
if timeout != 0 {
|
if timeout != 0 {
|
||||||
@ -159,23 +179,13 @@ func (c *tlsCreds) ClientHandshake(addr string, rawConn net.Conn, timeout time.D
|
|||||||
return conn, nil, nil
|
return conn, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, map[string][]string, error) {
|
func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error) {
|
||||||
conn := tls.Server(rawConn, &c.config)
|
conn := tls.Server(rawConn, &c.config)
|
||||||
if err := conn.Handshake(); err != nil {
|
if err := conn.Handshake(); err != nil {
|
||||||
rawConn.Close()
|
rawConn.Close()
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
info := make(map[string][]string)
|
return conn, &TLSInfo{ conn.ConnectionState() }, nil
|
||||||
info[transportSecurityType] = []string{"tls"}
|
|
||||||
for _, certs := range conn.ConnectionState().VerifiedChains {
|
|
||||||
for _, cert := range certs {
|
|
||||||
info[x509CN] = append(info[x509CN], cert.Subject.CommonName)
|
|
||||||
for _, san := range cert.DNSNames {
|
|
||||||
info[x509SAN] = append(info[x509SAN], san)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return conn, info, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewTLS uses c to construct a TransportAuthenticator based on TLS.
|
// NewTLS uses c to construct a TransportAuthenticator based on TLS.
|
||||||
|
@ -199,7 +199,7 @@ func (s *Server) Serve(lis net.Listener) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
var authInfo map[string][]string
|
var authInfo credentials.AuthInfo
|
||||||
if creds, ok := s.opts.creds.(credentials.TransportAuthenticator); ok {
|
if creds, ok := s.opts.creds.(credentials.TransportAuthenticator); ok {
|
||||||
c, authInfo, err = creds.ServerHandshake(c)
|
c, authInfo, err = creds.ServerHandshake(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -56,7 +56,7 @@ type http2Client struct {
|
|||||||
target string // server name/addr
|
target string // server name/addr
|
||||||
userAgent string
|
userAgent string
|
||||||
conn net.Conn // underlying communication channel
|
conn net.Conn // underlying communication channel
|
||||||
authInfo map[string][]string // auth info about the connection
|
authInfo credentials.AuthInfo // auth info about the connection
|
||||||
nextID uint32 // the next stream ID to be used
|
nextID uint32 // the next stream ID to be used
|
||||||
|
|
||||||
// writableChan synchronizes write access to the transport.
|
// writableChan synchronizes write access to the transport.
|
||||||
@ -115,7 +115,7 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
|
|||||||
if connErr != nil {
|
if connErr != nil {
|
||||||
return nil, ConnectionErrorf("transport: %v", connErr)
|
return nil, ConnectionErrorf("transport: %v", connErr)
|
||||||
}
|
}
|
||||||
var authInfo map[string][]string
|
var authInfo credentials.AuthInfo
|
||||||
for _, c := range opts.AuthOptions {
|
for _, c := range opts.AuthOptions {
|
||||||
if ccreds, ok := c.(credentials.TransportAuthenticator); ok {
|
if ccreds, ok := c.(credentials.TransportAuthenticator); ok {
|
||||||
scheme = "https"
|
scheme = "https"
|
||||||
@ -237,6 +237,10 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
|
|||||||
return nil, ContextErr(context.DeadlineExceeded)
|
return nil, ContextErr(context.DeadlineExceeded)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Attach Auth info if there is any.
|
||||||
|
if t.authInfo != nil {
|
||||||
|
ctx = credentials.NewContext(ctx, t.authInfo)
|
||||||
|
}
|
||||||
authData := make(map[string]string)
|
authData := make(map[string]string)
|
||||||
for _, c := range t.authCreds {
|
for _, c := range t.authCreds {
|
||||||
data, err := c.GetRequestMetadata(ctx)
|
data, err := c.GetRequestMetadata(ctx)
|
||||||
@ -704,7 +708,7 @@ func (t *http2Client) reader() {
|
|||||||
}
|
}
|
||||||
t.handleSettings(sf)
|
t.handleSettings(sf)
|
||||||
|
|
||||||
hDec := newHPACKDecoder(t.authInfo)
|
hDec := newHPACKDecoder()
|
||||||
var curStream *Stream
|
var curStream *Stream
|
||||||
// loop to keep reading incoming messages on this transport.
|
// loop to keep reading incoming messages on this transport.
|
||||||
for {
|
for {
|
||||||
|
@ -46,6 +46,7 @@ import (
|
|||||||
"github.com/bradfitz/http2/hpack"
|
"github.com/bradfitz/http2/hpack"
|
||||||
"golang.org/x/net/context"
|
"golang.org/x/net/context"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/credentials"
|
||||||
"google.golang.org/grpc/grpclog"
|
"google.golang.org/grpc/grpclog"
|
||||||
"google.golang.org/grpc/metadata"
|
"google.golang.org/grpc/metadata"
|
||||||
)
|
)
|
||||||
@ -57,8 +58,8 @@ var ErrIllegalHeaderWrite = errors.New("transport: the stream is done or WriteHe
|
|||||||
// http2Server implements the ServerTransport interface with HTTP2.
|
// http2Server implements the ServerTransport interface with HTTP2.
|
||||||
type http2Server struct {
|
type http2Server struct {
|
||||||
conn net.Conn
|
conn net.Conn
|
||||||
maxStreamID uint32 // max stream ID ever seen
|
maxStreamID uint32 // max stream ID ever seen
|
||||||
authInfo map[string][]string // auth info about the connection
|
authInfo credentials.AuthInfo // auth info about the connection
|
||||||
// writableChan synchronizes write access to the transport.
|
// writableChan synchronizes write access to the transport.
|
||||||
// A writer acquires the write lock by sending a value on writableChan
|
// A writer acquires the write lock by sending a value on writableChan
|
||||||
// and releases it by receiving from writableChan.
|
// and releases it by receiving from writableChan.
|
||||||
@ -89,7 +90,7 @@ type http2Server struct {
|
|||||||
|
|
||||||
// newHTTP2Server constructs a ServerTransport based on HTTP2. ConnectionError is
|
// newHTTP2Server constructs a ServerTransport based on HTTP2. ConnectionError is
|
||||||
// returned if something goes wrong.
|
// returned if something goes wrong.
|
||||||
func newHTTP2Server(conn net.Conn, maxStreams uint32, authInfo map[string][]string) (_ ServerTransport, err error) {
|
func newHTTP2Server(conn net.Conn, maxStreams uint32, authInfo credentials.AuthInfo) (_ ServerTransport, err error) {
|
||||||
framer := newFramer(conn)
|
framer := newFramer(conn)
|
||||||
// Send initial settings as connection preface to client.
|
// Send initial settings as connection preface to client.
|
||||||
var settings []http2.Setting
|
var settings []http2.Setting
|
||||||
@ -183,6 +184,10 @@ func (t *http2Server) operateHeaders(hDec *hpackDecoder, s *Stream, frame header
|
|||||||
} else {
|
} else {
|
||||||
s.ctx, s.cancel = context.WithCancel(context.TODO())
|
s.ctx, s.cancel = context.WithCancel(context.TODO())
|
||||||
}
|
}
|
||||||
|
// Attach Auth info if there is any.
|
||||||
|
if t.authInfo != nil {
|
||||||
|
s.ctx = credentials.NewContext(s.ctx, t.authInfo)
|
||||||
|
}
|
||||||
// Cache the current stream to the context so that the server application
|
// Cache the current stream to the context so that the server application
|
||||||
// can find out. Required when the server wants to send some metadata
|
// can find out. Required when the server wants to send some metadata
|
||||||
// back to the client (unary call only).
|
// back to the client (unary call only).
|
||||||
@ -236,7 +241,7 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) {
|
|||||||
}
|
}
|
||||||
t.handleSettings(sf)
|
t.handleSettings(sf)
|
||||||
|
|
||||||
hDec := newHPACKDecoder(t.authInfo)
|
hDec := newHPACKDecoder()
|
||||||
var curStream *Stream
|
var curStream *Stream
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
defer wg.Wait()
|
defer wg.Wait()
|
||||||
|
@ -63,31 +63,29 @@ const (
|
|||||||
|
|
||||||
var (
|
var (
|
||||||
clientPreface = []byte(http2.ClientPreface)
|
clientPreface = []byte(http2.ClientPreface)
|
||||||
|
http2RSTErrConvTab = map[http2.ErrCode]codes.Code{
|
||||||
|
http2.ErrCodeNo: codes.Internal,
|
||||||
|
http2.ErrCodeProtocol: codes.Internal,
|
||||||
|
http2.ErrCodeInternal: codes.Internal,
|
||||||
|
http2.ErrCodeFlowControl: codes.ResourceExhausted,
|
||||||
|
http2.ErrCodeSettingsTimeout: codes.Internal,
|
||||||
|
http2.ErrCodeFrameSize: codes.Internal,
|
||||||
|
http2.ErrCodeRefusedStream: codes.Unavailable,
|
||||||
|
http2.ErrCodeCancel: codes.Canceled,
|
||||||
|
http2.ErrCodeCompression: codes.Internal,
|
||||||
|
http2.ErrCodeConnect: codes.Internal,
|
||||||
|
http2.ErrCodeEnhanceYourCalm: codes.ResourceExhausted,
|
||||||
|
http2.ErrCodeInadequateSecurity: codes.PermissionDenied,
|
||||||
|
}
|
||||||
|
statusCodeConvTab = map[codes.Code]http2.ErrCode{
|
||||||
|
codes.Internal: http2.ErrCodeInternal,
|
||||||
|
codes.Canceled: http2.ErrCodeCancel,
|
||||||
|
codes.Unavailable: http2.ErrCodeRefusedStream,
|
||||||
|
codes.ResourceExhausted: http2.ErrCodeEnhanceYourCalm,
|
||||||
|
codes.PermissionDenied: http2.ErrCodeInadequateSecurity,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
var http2RSTErrConvTab = map[http2.ErrCode]codes.Code{
|
|
||||||
http2.ErrCodeNo: codes.Internal,
|
|
||||||
http2.ErrCodeProtocol: codes.Internal,
|
|
||||||
http2.ErrCodeInternal: codes.Internal,
|
|
||||||
http2.ErrCodeFlowControl: codes.ResourceExhausted,
|
|
||||||
http2.ErrCodeSettingsTimeout: codes.Internal,
|
|
||||||
http2.ErrCodeFrameSize: codes.Internal,
|
|
||||||
http2.ErrCodeRefusedStream: codes.Unavailable,
|
|
||||||
http2.ErrCodeCancel: codes.Canceled,
|
|
||||||
http2.ErrCodeCompression: codes.Internal,
|
|
||||||
http2.ErrCodeConnect: codes.Internal,
|
|
||||||
http2.ErrCodeEnhanceYourCalm: codes.ResourceExhausted,
|
|
||||||
http2.ErrCodeInadequateSecurity: codes.PermissionDenied,
|
|
||||||
}
|
|
||||||
|
|
||||||
var statusCodeConvTab = map[codes.Code]http2.ErrCode{
|
|
||||||
codes.Internal: http2.ErrCodeInternal, // pick an arbitrary one which is matched.
|
|
||||||
codes.Canceled: http2.ErrCodeCancel,
|
|
||||||
codes.Unavailable: http2.ErrCodeRefusedStream,
|
|
||||||
codes.ResourceExhausted: http2.ErrCodeEnhanceYourCalm,
|
|
||||||
codes.PermissionDenied: http2.ErrCodeInadequateSecurity,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Records the states during HPACK decoding. Must be reset once the
|
// Records the states during HPACK decoding. Must be reset once the
|
||||||
// decoding of the entire headers are finished.
|
// decoding of the entire headers are finished.
|
||||||
type decodeState struct {
|
type decodeState struct {
|
||||||
@ -139,12 +137,8 @@ func isReservedHeader(hdr string) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHPACKDecoder(mdata map[string][]string) *hpackDecoder {
|
func newHPACKDecoder() *hpackDecoder {
|
||||||
d := &hpackDecoder{}
|
d := &hpackDecoder{}
|
||||||
for k, v := range mdata {
|
|
||||||
d.mdata = make(map[string][]string)
|
|
||||||
d.mdata[k] = v
|
|
||||||
}
|
|
||||||
d.h = hpack.NewDecoder(http2InitHeaderTableSize, func(f hpack.HeaderField) {
|
d.h = hpack.NewDecoder(http2InitHeaderTableSize, func(f hpack.HeaderField) {
|
||||||
switch f.Name {
|
switch f.Name {
|
||||||
case "grpc-status":
|
case "grpc-status":
|
||||||
|
@ -308,7 +308,7 @@ const (
|
|||||||
|
|
||||||
// NewServerTransport creates a ServerTransport with conn or non-nil error
|
// NewServerTransport creates a ServerTransport with conn or non-nil error
|
||||||
// if it fails.
|
// if it fails.
|
||||||
func NewServerTransport(protocol string, conn net.Conn, maxStreams uint32, authInfo map[string][]string) (ServerTransport, error) {
|
func NewServerTransport(protocol string, conn net.Conn, maxStreams uint32, authInfo credentials.AuthInfo) (ServerTransport, error) {
|
||||||
return newHTTP2Server(conn, maxStreams, authInfo)
|
return newHTTP2Server(conn, maxStreams, authInfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user