allow access of some info of client certificate

This commit is contained in:
iamqizhao
2015-08-21 15:49:53 -07:00
parent 69288679b3
commit d12ff72146
6 changed files with 44 additions and 17 deletions

View File

@ -82,15 +82,24 @@ type ProtocolInfo struct {
// protocols and supported transport security protocols (e.g., TLS, SSL). // protocols and supported transport security protocols (e.g., TLS, SSL).
type TransportAuthenticator interface { type TransportAuthenticator interface {
// ClientHandshake does the authentication handshake specified by the corresponding // ClientHandshake does the authentication handshake specified by the corresponding
// authentication protocol on rawConn for clients. // authentication protocol on rawConn for clients. It returns the authenticated
ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (net.Conn, error) // connection and the corresponding auth information about the connection.
// ServerHandshake does the authentication handshake for servers. ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (net.Conn, map[string][]string, error)
ServerHandshake(rawConn net.Conn) (net.Conn, error) // ServerHandshake does the authentication handshake for servers. It returns
// the authenticated connection and the corresponding auth information about
// the connection.
ServerHandshake(rawConn net.Conn) (net.Conn, map[string][]string, error)
// Info provides the ProtocolInfo of this TransportAuthenticator. // Info provides the ProtocolInfo of this TransportAuthenticator.
Info() ProtocolInfo Info() ProtocolInfo
Credentials Credentials
} }
const (
transportSecurityType = "transport_security_type"
x509CN = "x509_common_name"
x509SAN = "x509_suject_alternative_name"
)
// tlsCreds is the credentials required for authenticating a connection using TLS. // tlsCreds is the credentials required for authenticating a connection using TLS.
type tlsCreds struct { type tlsCreds struct {
// TLS configuration // TLS configuration
@ -116,7 +125,7 @@ func (timeoutError) Error() string { return "credentials: Dial timed out" }
func (timeoutError) Timeout() bool { return true } func (timeoutError) Timeout() bool { return true }
func (timeoutError) Temporary() bool { return true } func (timeoutError) Temporary() bool { return true }
func (c *tlsCreds) ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (_ net.Conn, err error) { func (c *tlsCreds) ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (_ net.Conn, _ map[string][]string, err error) {
// borrow some code from tls.DialWithDialer // borrow some code from tls.DialWithDialer
var errChannel chan error var errChannel chan error
if timeout != 0 { if timeout != 0 {
@ -143,18 +152,32 @@ func (c *tlsCreds) ClientHandshake(addr string, rawConn net.Conn, timeout time.D
} }
if err != nil { if err != nil {
rawConn.Close() rawConn.Close()
return nil, err return nil, nil, err
} }
return conn, nil // TODO(zhaoq): Omit the auth info for client now. It is more for
// information than anything else.
return conn, nil, nil
} }
func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, error) { func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, map[string][]string, error) {
conn := tls.Server(rawConn, &c.config) conn := tls.Server(rawConn, &c.config)
if err := conn.Handshake(); err != nil { if err := conn.Handshake(); err != nil {
rawConn.Close() rawConn.Close()
return nil, err return nil, nil, err
} }
return conn, nil state := conn.ConnectionState()
info := make(map[string][]string)
info[transportSecurityType] = []string{"tls"}
for _, certs := range state.VerifiedChains {
fmt.Println("DEBUG: reach here")
for _, cert := range certs {
info[x509CN] = append(info[x509CN], cert.Subject.CommonName)
for _, san := range cert.DNSNames {
info[x509SAN] = append(info[x509SAN], san)
}
}
}
return conn, info, nil
} }
// NewTLS uses c to construct a TransportAuthenticator based on TLS. // NewTLS uses c to construct a TransportAuthenticator based on TLS.

View File

@ -566,7 +566,7 @@ func main() {
doPerRPCCreds(tc) doPerRPCCreds(tc)
case "oauth2_auth_token": case "oauth2_auth_token":
if !*useTLS { if !*useTLS {
grpclog.Fatalf("TLS is not enabled. TLS is required to execute oauth2_token_creds test case.") grpclog.Fatalf("TLS is not enabled. TLS is required to execute oauth2_auth_token test case.")
} }
doOauth2TokenCreds(tc) doOauth2TokenCreds(tc)
case "cancel_after_begin": case "cancel_after_begin":

View File

@ -199,8 +199,9 @@ func (s *Server) Serve(lis net.Listener) error {
if err != nil { if err != nil {
return err return err
} }
var authInfo map[string][]string
if creds, ok := s.opts.creds.(credentials.TransportAuthenticator); ok { if creds, ok := s.opts.creds.(credentials.TransportAuthenticator); ok {
c, err = creds.ServerHandshake(c) c, authInfo, err = creds.ServerHandshake(c)
if err != nil { if err != nil {
grpclog.Println("grpc: Server.Serve failed to complete security handshake.") grpclog.Println("grpc: Server.Serve failed to complete security handshake.")
continue continue
@ -212,7 +213,7 @@ func (s *Server) Serve(lis net.Listener) error {
c.Close() c.Close()
return nil return nil
} }
st, err := transport.NewServerTransport("http2", c, s.opts.maxConcurrentStreams) st, err := transport.NewServerTransport("http2", c, s.opts.maxConcurrentStreams, authInfo)
if err != nil { if err != nil {
s.mu.Unlock() s.mu.Unlock()
c.Close() c.Close()

View File

@ -124,7 +124,7 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
if timeout > 0 { if timeout > 0 {
timeout -= time.Since(startT) timeout -= time.Since(startT)
} }
conn, connErr = ccreds.ClientHandshake(addr, conn, timeout) conn, _, connErr = ccreds.ClientHandshake(addr, conn, timeout)
break break
} }
} }

View File

@ -58,6 +58,7 @@ var ErrIllegalHeaderWrite = errors.New("transport: the stream is done or WriteHe
type http2Server struct { type http2Server struct {
conn net.Conn conn net.Conn
maxStreamID uint32 // max stream ID ever seen maxStreamID uint32 // max stream ID ever seen
authInfo map[string][]string // basic auth info about the connection
// writableChan synchronizes write access to the transport. // writableChan synchronizes write access to the transport.
// A writer acquires the write lock by sending a value on writableChan // A writer acquires the write lock by sending a value on writableChan
// and releases it by receiving from writableChan. // and releases it by receiving from writableChan.
@ -88,7 +89,7 @@ type http2Server struct {
// newHTTP2Server constructs a ServerTransport based on HTTP2. ConnectionError is // newHTTP2Server constructs a ServerTransport based on HTTP2. ConnectionError is
// returned if something goes wrong. // returned if something goes wrong.
func newHTTP2Server(conn net.Conn, maxStreams uint32) (_ ServerTransport, err error) { func newHTTP2Server(conn net.Conn, maxStreams uint32, authInfo map[string][]string) (_ ServerTransport, err error) {
framer := newFramer(conn) framer := newFramer(conn)
// Send initial settings as connection preface to client. // Send initial settings as connection preface to client.
var settings []http2.Setting var settings []http2.Setting
@ -114,6 +115,7 @@ func newHTTP2Server(conn net.Conn, maxStreams uint32) (_ ServerTransport, err er
var buf bytes.Buffer var buf bytes.Buffer
t := &http2Server{ t := &http2Server{
conn: conn, conn: conn,
authInfo: authInfo,
framer: framer, framer: framer,
hBuf: &buf, hBuf: &buf,
hEnc: hpack.NewEncoder(&buf), hEnc: hpack.NewEncoder(&buf),
@ -235,6 +237,7 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) {
t.handleSettings(sf) t.handleSettings(sf)
hDec := newHPACKDecoder() hDec := newHPACKDecoder()
hDec.state.mdata = t.authInfo
var curStream *Stream var curStream *Stream
var wg sync.WaitGroup var wg sync.WaitGroup
defer wg.Wait() defer wg.Wait()

View File

@ -308,8 +308,8 @@ const (
// NewServerTransport creates a ServerTransport with conn or non-nil error // NewServerTransport creates a ServerTransport with conn or non-nil error
// if it fails. // if it fails.
func NewServerTransport(protocol string, conn net.Conn, maxStreams uint32) (ServerTransport, error) { func NewServerTransport(protocol string, conn net.Conn, maxStreams uint32, authInfo map[string][]string) (ServerTransport, error) {
return newHTTP2Server(conn, maxStreams) return newHTTP2Server(conn, maxStreams, authInfo)
} }
// ConnectOptions covers all relevant options for dialing a server. // ConnectOptions covers all relevant options for dialing a server.