diff --git a/credentials/credentials.go b/credentials/credentials.go index c1a331e8..11560dec 100644 --- a/credentials/credentials.go +++ b/credentials/credentials.go @@ -82,15 +82,24 @@ type ProtocolInfo struct { // protocols and supported transport security protocols (e.g., TLS, SSL). type TransportAuthenticator interface { // ClientHandshake does the authentication handshake specified by the corresponding - // authentication protocol on rawConn for clients. - ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (net.Conn, error) - // ServerHandshake does the authentication handshake for servers. - ServerHandshake(rawConn net.Conn) (net.Conn, error) + // authentication protocol on rawConn for clients. It returns the authenticated + // connection and the corresponding auth information about the connection. + ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (net.Conn, map[string][]string, error) + // ServerHandshake does the authentication handshake for servers. It returns + // the authenticated connection and the corresponding auth information about + // the connection. + ServerHandshake(rawConn net.Conn) (net.Conn, map[string][]string, error) // Info provides the ProtocolInfo of this TransportAuthenticator. Info() ProtocolInfo Credentials } +const ( + transportSecurityType = "transport_security_type" + x509CN = "x509_common_name" + x509SAN = "x509_suject_alternative_name" +) + // tlsCreds is the credentials required for authenticating a connection using TLS. type tlsCreds struct { // TLS configuration @@ -116,7 +125,7 @@ func (timeoutError) Error() string { return "credentials: Dial timed out" } func (timeoutError) Timeout() bool { return true } func (timeoutError) Temporary() bool { return true } -func (c *tlsCreds) ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (_ net.Conn, err error) { +func (c *tlsCreds) ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (_ net.Conn, _ map[string][]string, err error) { // borrow some code from tls.DialWithDialer var errChannel chan error if timeout != 0 { @@ -143,18 +152,32 @@ func (c *tlsCreds) ClientHandshake(addr string, rawConn net.Conn, timeout time.D } if err != nil { rawConn.Close() - return nil, err + return nil, nil, err } - return conn, nil + // TODO(zhaoq): Omit the auth info for client now. It is more for + // information than anything else. + return conn, nil, nil } -func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, error) { +func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, map[string][]string, error) { conn := tls.Server(rawConn, &c.config) if err := conn.Handshake(); err != nil { rawConn.Close() - return nil, err + return nil, nil, err } - return conn, nil + state := conn.ConnectionState() + info := make(map[string][]string) + info[transportSecurityType] = []string{"tls"} + for _, certs := range state.VerifiedChains { + fmt.Println("DEBUG: reach here") + for _, cert := range certs { + info[x509CN] = append(info[x509CN], cert.Subject.CommonName) + for _, san := range cert.DNSNames { + info[x509SAN] = append(info[x509SAN], san) + } + } + } + return conn, info, nil } // NewTLS uses c to construct a TransportAuthenticator based on TLS. diff --git a/interop/client/client.go b/interop/client/client.go index 8b8ff9c8..9c5ace50 100644 --- a/interop/client/client.go +++ b/interop/client/client.go @@ -566,7 +566,7 @@ func main() { doPerRPCCreds(tc) case "oauth2_auth_token": if !*useTLS { - grpclog.Fatalf("TLS is not enabled. TLS is required to execute oauth2_token_creds test case.") + grpclog.Fatalf("TLS is not enabled. TLS is required to execute oauth2_auth_token test case.") } doOauth2TokenCreds(tc) case "cancel_after_begin": diff --git a/server.go b/server.go index 95d3c00d..e2c40537 100644 --- a/server.go +++ b/server.go @@ -199,8 +199,9 @@ func (s *Server) Serve(lis net.Listener) error { if err != nil { return err } + var authInfo map[string][]string if creds, ok := s.opts.creds.(credentials.TransportAuthenticator); ok { - c, err = creds.ServerHandshake(c) + c, authInfo, err = creds.ServerHandshake(c) if err != nil { grpclog.Println("grpc: Server.Serve failed to complete security handshake.") continue @@ -212,7 +213,7 @@ func (s *Server) Serve(lis net.Listener) error { c.Close() return nil } - st, err := transport.NewServerTransport("http2", c, s.opts.maxConcurrentStreams) + st, err := transport.NewServerTransport("http2", c, s.opts.maxConcurrentStreams, authInfo) if err != nil { s.mu.Unlock() c.Close() diff --git a/transport/http2_client.go b/transport/http2_client.go index e01d3f1c..8dd4b7d3 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -124,7 +124,7 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e if timeout > 0 { timeout -= time.Since(startT) } - conn, connErr = ccreds.ClientHandshake(addr, conn, timeout) + conn, _, connErr = ccreds.ClientHandshake(addr, conn, timeout) break } } diff --git a/transport/http2_server.go b/transport/http2_server.go index 79e2b200..8df2aa24 100644 --- a/transport/http2_server.go +++ b/transport/http2_server.go @@ -58,6 +58,7 @@ var ErrIllegalHeaderWrite = errors.New("transport: the stream is done or WriteHe type http2Server struct { conn net.Conn maxStreamID uint32 // max stream ID ever seen + authInfo map[string][]string // basic auth info about the connection // writableChan synchronizes write access to the transport. // A writer acquires the write lock by sending a value on writableChan // and releases it by receiving from writableChan. @@ -88,7 +89,7 @@ type http2Server struct { // newHTTP2Server constructs a ServerTransport based on HTTP2. ConnectionError is // returned if something goes wrong. -func newHTTP2Server(conn net.Conn, maxStreams uint32) (_ ServerTransport, err error) { +func newHTTP2Server(conn net.Conn, maxStreams uint32, authInfo map[string][]string) (_ ServerTransport, err error) { framer := newFramer(conn) // Send initial settings as connection preface to client. var settings []http2.Setting @@ -114,6 +115,7 @@ func newHTTP2Server(conn net.Conn, maxStreams uint32) (_ ServerTransport, err er var buf bytes.Buffer t := &http2Server{ conn: conn, + authInfo: authInfo, framer: framer, hBuf: &buf, hEnc: hpack.NewEncoder(&buf), @@ -235,6 +237,7 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) { t.handleSettings(sf) hDec := newHPACKDecoder() + hDec.state.mdata = t.authInfo var curStream *Stream var wg sync.WaitGroup defer wg.Wait() diff --git a/transport/transport.go b/transport/transport.go index 58436f01..beaf2b34 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -308,8 +308,8 @@ const ( // NewServerTransport creates a ServerTransport with conn or non-nil error // if it fails. -func NewServerTransport(protocol string, conn net.Conn, maxStreams uint32) (ServerTransport, error) { - return newHTTP2Server(conn, maxStreams) +func NewServerTransport(protocol string, conn net.Conn, maxStreams uint32, authInfo map[string][]string) (ServerTransport, error) { + return newHTTP2Server(conn, maxStreams, authInfo) } // ConnectOptions covers all relevant options for dialing a server.