diff --git a/credentials/credentials.go b/credentials/credentials.go index 7f59ca18..5f103640 100644 --- a/credentials/credentials.go +++ b/credentials/credentials.go @@ -78,27 +78,47 @@ type ProtocolInfo struct { SecurityVersion string } +// AuthInfo defines the common interface for the auth information the users are interested in. +type AuthInfo interface { + Type() string +} + +type authInfoKey struct{} + +// NewContext creates a new context with authInfo attached. +func NewContext(ctx context.Context, authInfo AuthInfo) context.Context { + return context.WithValue(ctx, authInfoKey{}, authInfo) +} + +// FromContext returns the authInfo in ctx if it exists. +func FromContext(ctx context.Context) (authInfo AuthInfo, ok bool) { + authInfo, ok = ctx.Value(authInfoKey{}).(AuthInfo) + return +} + // TransportAuthenticator defines the common interface for all the live gRPC wire // protocols and supported transport security protocols (e.g., TLS, SSL). type TransportAuthenticator interface { // ClientHandshake does the authentication handshake specified by the corresponding // authentication protocol on rawConn for clients. It returns the authenticated // connection and the corresponding auth information about the connection. - ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (net.Conn, map[string][]string, error) + ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (net.Conn, AuthInfo, error) // ServerHandshake does the authentication handshake for servers. It returns // the authenticated connection and the corresponding auth information about // the connection. - ServerHandshake(rawConn net.Conn) (net.Conn, map[string][]string, error) + ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error) // Info provides the ProtocolInfo of this TransportAuthenticator. Info() ProtocolInfo Credentials } -const ( - transportSecurityType = "transport_security_type" - x509CN = "x509_common_name" - x509SAN = "x509_suject_alternative_name" -) +type TLSInfo struct { + state tls.ConnectionState +} + +func (t TLSInfo) Type() string { + return "tls" +} // tlsCreds is the credentials required for authenticating a connection using TLS. type tlsCreds struct { @@ -106,7 +126,7 @@ type tlsCreds struct { config tls.Config } -func (c *tlsCreds) Info() ProtocolInfo { +func (c tlsCreds) Info() ProtocolInfo { return ProtocolInfo{ SecurityProtocol: "tls", SecurityVersion: "1.2", @@ -125,7 +145,7 @@ func (timeoutError) Error() string { return "credentials: Dial timed out" } func (timeoutError) Timeout() bool { return true } func (timeoutError) Temporary() bool { return true } -func (c *tlsCreds) ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (_ net.Conn, _ map[string][]string, err error) { +func (c *tlsCreds) ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (_ net.Conn, _ AuthInfo, err error) { // borrow some code from tls.DialWithDialer var errChannel chan error if timeout != 0 { @@ -159,23 +179,13 @@ func (c *tlsCreds) ClientHandshake(addr string, rawConn net.Conn, timeout time.D return conn, nil, nil } -func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, map[string][]string, error) { +func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error) { conn := tls.Server(rawConn, &c.config) if err := conn.Handshake(); err != nil { rawConn.Close() return nil, nil, err } - info := make(map[string][]string) - info[transportSecurityType] = []string{"tls"} - for _, certs := range conn.ConnectionState().VerifiedChains { - for _, cert := range certs { - info[x509CN] = append(info[x509CN], cert.Subject.CommonName) - for _, san := range cert.DNSNames { - info[x509SAN] = append(info[x509SAN], san) - } - } - } - return conn, info, nil + return conn, &TLSInfo{ conn.ConnectionState() }, nil } // NewTLS uses c to construct a TransportAuthenticator based on TLS. diff --git a/server.go b/server.go index e2c40537..feb29887 100644 --- a/server.go +++ b/server.go @@ -199,7 +199,7 @@ func (s *Server) Serve(lis net.Listener) error { if err != nil { return err } - var authInfo map[string][]string + var authInfo credentials.AuthInfo if creds, ok := s.opts.creds.(credentials.TransportAuthenticator); ok { c, authInfo, err = creds.ServerHandshake(c) if err != nil { diff --git a/transport/http2_client.go b/transport/http2_client.go index 3e60058c..a2966640 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -56,7 +56,7 @@ type http2Client struct { target string // server name/addr userAgent string conn net.Conn // underlying communication channel - authInfo map[string][]string // auth info about the connection + authInfo credentials.AuthInfo // auth info about the connection nextID uint32 // the next stream ID to be used // writableChan synchronizes write access to the transport. @@ -115,7 +115,7 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e if connErr != nil { return nil, ConnectionErrorf("transport: %v", connErr) } - var authInfo map[string][]string + var authInfo credentials.AuthInfo for _, c := range opts.AuthOptions { if ccreds, ok := c.(credentials.TransportAuthenticator); ok { scheme = "https" @@ -237,6 +237,10 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea return nil, ContextErr(context.DeadlineExceeded) } } + // Attach Auth info if there is any. + if t.authInfo != nil { + ctx = credentials.NewContext(ctx, t.authInfo) + } authData := make(map[string]string) for _, c := range t.authCreds { data, err := c.GetRequestMetadata(ctx) @@ -704,7 +708,7 @@ func (t *http2Client) reader() { } t.handleSettings(sf) - hDec := newHPACKDecoder(t.authInfo) + hDec := newHPACKDecoder() var curStream *Stream // loop to keep reading incoming messages on this transport. for { diff --git a/transport/http2_server.go b/transport/http2_server.go index 5d1bff3b..8856d7f4 100644 --- a/transport/http2_server.go +++ b/transport/http2_server.go @@ -46,6 +46,7 @@ import ( "github.com/bradfitz/http2/hpack" "golang.org/x/net/context" "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" "google.golang.org/grpc/grpclog" "google.golang.org/grpc/metadata" ) @@ -57,8 +58,8 @@ var ErrIllegalHeaderWrite = errors.New("transport: the stream is done or WriteHe // http2Server implements the ServerTransport interface with HTTP2. type http2Server struct { conn net.Conn - maxStreamID uint32 // max stream ID ever seen - authInfo map[string][]string // auth info about the connection + maxStreamID uint32 // max stream ID ever seen + authInfo credentials.AuthInfo // auth info about the connection // writableChan synchronizes write access to the transport. // A writer acquires the write lock by sending a value on writableChan // and releases it by receiving from writableChan. @@ -89,7 +90,7 @@ type http2Server struct { // newHTTP2Server constructs a ServerTransport based on HTTP2. ConnectionError is // returned if something goes wrong. -func newHTTP2Server(conn net.Conn, maxStreams uint32, authInfo map[string][]string) (_ ServerTransport, err error) { +func newHTTP2Server(conn net.Conn, maxStreams uint32, authInfo credentials.AuthInfo) (_ ServerTransport, err error) { framer := newFramer(conn) // Send initial settings as connection preface to client. var settings []http2.Setting @@ -183,6 +184,10 @@ func (t *http2Server) operateHeaders(hDec *hpackDecoder, s *Stream, frame header } else { s.ctx, s.cancel = context.WithCancel(context.TODO()) } + // Attach Auth info if there is any. + if t.authInfo != nil { + s.ctx = credentials.NewContext(s.ctx, t.authInfo) + } // Cache the current stream to the context so that the server application // can find out. Required when the server wants to send some metadata // back to the client (unary call only). @@ -236,7 +241,7 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) { } t.handleSettings(sf) - hDec := newHPACKDecoder(t.authInfo) + hDec := newHPACKDecoder() var curStream *Stream var wg sync.WaitGroup defer wg.Wait() diff --git a/transport/http_util.go b/transport/http_util.go index 5fe65e37..c442da9d 100644 --- a/transport/http_util.go +++ b/transport/http_util.go @@ -63,31 +63,29 @@ const ( var ( clientPreface = []byte(http2.ClientPreface) + http2RSTErrConvTab = map[http2.ErrCode]codes.Code{ + http2.ErrCodeNo: codes.Internal, + http2.ErrCodeProtocol: codes.Internal, + http2.ErrCodeInternal: codes.Internal, + http2.ErrCodeFlowControl: codes.ResourceExhausted, + http2.ErrCodeSettingsTimeout: codes.Internal, + http2.ErrCodeFrameSize: codes.Internal, + http2.ErrCodeRefusedStream: codes.Unavailable, + http2.ErrCodeCancel: codes.Canceled, + http2.ErrCodeCompression: codes.Internal, + http2.ErrCodeConnect: codes.Internal, + http2.ErrCodeEnhanceYourCalm: codes.ResourceExhausted, + http2.ErrCodeInadequateSecurity: codes.PermissionDenied, + } + statusCodeConvTab = map[codes.Code]http2.ErrCode{ + codes.Internal: http2.ErrCodeInternal, + codes.Canceled: http2.ErrCodeCancel, + codes.Unavailable: http2.ErrCodeRefusedStream, + codes.ResourceExhausted: http2.ErrCodeEnhanceYourCalm, + codes.PermissionDenied: http2.ErrCodeInadequateSecurity, + } ) -var http2RSTErrConvTab = map[http2.ErrCode]codes.Code{ - http2.ErrCodeNo: codes.Internal, - http2.ErrCodeProtocol: codes.Internal, - http2.ErrCodeInternal: codes.Internal, - http2.ErrCodeFlowControl: codes.ResourceExhausted, - http2.ErrCodeSettingsTimeout: codes.Internal, - http2.ErrCodeFrameSize: codes.Internal, - http2.ErrCodeRefusedStream: codes.Unavailable, - http2.ErrCodeCancel: codes.Canceled, - http2.ErrCodeCompression: codes.Internal, - http2.ErrCodeConnect: codes.Internal, - http2.ErrCodeEnhanceYourCalm: codes.ResourceExhausted, - http2.ErrCodeInadequateSecurity: codes.PermissionDenied, -} - -var statusCodeConvTab = map[codes.Code]http2.ErrCode{ - codes.Internal: http2.ErrCodeInternal, // pick an arbitrary one which is matched. - codes.Canceled: http2.ErrCodeCancel, - codes.Unavailable: http2.ErrCodeRefusedStream, - codes.ResourceExhausted: http2.ErrCodeEnhanceYourCalm, - codes.PermissionDenied: http2.ErrCodeInadequateSecurity, -} - // Records the states during HPACK decoding. Must be reset once the // decoding of the entire headers are finished. type decodeState struct { @@ -139,12 +137,8 @@ func isReservedHeader(hdr string) bool { } } -func newHPACKDecoder(mdata map[string][]string) *hpackDecoder { +func newHPACKDecoder() *hpackDecoder { d := &hpackDecoder{} - for k, v := range mdata { - d.mdata = make(map[string][]string) - d.mdata[k] = v - } d.h = hpack.NewDecoder(http2InitHeaderTableSize, func(f hpack.HeaderField) { switch f.Name { case "grpc-status": diff --git a/transport/transport.go b/transport/transport.go index beaf2b34..2dd38a83 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -308,7 +308,7 @@ const ( // NewServerTransport creates a ServerTransport with conn or non-nil error // if it fails. -func NewServerTransport(protocol string, conn net.Conn, maxStreams uint32, authInfo map[string][]string) (ServerTransport, error) { +func NewServerTransport(protocol string, conn net.Conn, maxStreams uint32, authInfo credentials.AuthInfo) (ServerTransport, error) { return newHTTP2Server(conn, maxStreams, authInfo) }