From 6404c491926de01531a12d8ff99217853bfca37d Mon Sep 17 00:00:00 2001 From: Menghan Li Date: Mon, 6 Jun 2016 16:24:46 -0700 Subject: [PATCH] Make TransportAuthenticator not embed Credentials --- clientconn.go | 19 ++++++++----------- clientconn_test.go | 11 ++++++++--- credentials/credentials.go | 2 +- server.go | 11 +++++------ transport/http2_client.go | 28 +++++++++++++--------------- transport/transport.go | 8 +++++--- 6 files changed, 40 insertions(+), 39 deletions(-) diff --git a/clientconn.go b/clientconn.go index c38110da..688180cc 100644 --- a/clientconn.go +++ b/clientconn.go @@ -170,9 +170,9 @@ func WithInsecure() DialOption { // WithTransportCredentials returns a DialOption which configures a // connection level security credentials (e.g., TLS/SSL). -func WithTransportCredentials(creds credentials.TransportAuthenticator) DialOption { +func WithTransportCredentials(auth credentials.TransportAuthenticator) DialOption { return func(o *dialOptions) { - o.copts.AuthOptions = append(o.copts.AuthOptions, creds) + o.copts.Authenticators = append(o.copts.Authenticators, auth) } } @@ -180,7 +180,7 @@ func WithTransportCredentials(creds credentials.TransportAuthenticator) DialOpti // credentials which will place auth state on each outbound RPC. func WithPerRPCCredentials(creds credentials.Credentials) DialOption { return func(o *dialOptions) { - o.copts.AuthOptions = append(o.copts.AuthOptions, creds) + o.copts.Credentials = append(o.copts.Credentials, creds) } } @@ -369,17 +369,14 @@ func (cc *ClientConn) newAddrConn(addr Address, skipWait bool) error { ac.events = trace.NewEventLog("grpc.ClientConn", ac.addr.Addr) } if !ac.dopts.insecure { - var ok bool - for _, cd := range ac.dopts.copts.AuthOptions { - if _, ok = cd.(credentials.TransportAuthenticator); ok { - break - } - } - if !ok { + if len(ac.dopts.copts.Authenticators) == 0 { return errNoTransportSecurity } } else { - for _, cd := range ac.dopts.copts.AuthOptions { + if len(ac.dopts.copts.Authenticators) > 0 { + return errCredentialsMisuse + } + for _, cd := range ac.dopts.copts.Credentials { if cd.RequireTransportSecurity() { return errCredentialsMisuse } diff --git a/clientconn_test.go b/clientconn_test.go index d60a3aee..9e113437 100644 --- a/clientconn_test.go +++ b/clientconn_test.go @@ -38,6 +38,7 @@ import ( "time" "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/oauth" ) const tlsDir = "testdata/" @@ -67,14 +68,18 @@ func TestTLSDialTimeout(t *testing.T) { } func TestCredentialsMisuse(t *testing.T) { - creds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "x.test.youtube.com") + auth, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "x.test.youtube.com") if err != nil { - t.Fatalf("Failed to create credentials %v", err) + t.Fatalf("Failed to create authenticator %v", err) } // Two conflicting credential configurations - if _, err := Dial("Non-Existent.Server:80", WithTransportCredentials(creds), WithTimeout(time.Millisecond), WithBlock(), WithInsecure()); err != errCredentialsMisuse { + if _, err := Dial("Non-Existent.Server:80", WithTransportCredentials(auth), WithTimeout(time.Millisecond), WithBlock(), WithInsecure()); err != errCredentialsMisuse { t.Fatalf("Dial(_, _) = _, %v, want _, %v", err, errCredentialsMisuse) } + creds, err := oauth.NewJWTAccessFromKey(nil) + if err != nil { + t.Fatalf("Failed to create credentials %v", err) + } // security info on insecure connection if _, err := Dial("Non-Existent.Server:80", WithPerRPCCredentials(creds), WithTimeout(time.Millisecond), WithBlock(), WithInsecure()); err != errCredentialsMisuse { t.Fatalf("Dial(_, _) = _, %v, want _, %v", err, errCredentialsMisuse) diff --git a/credentials/credentials.go b/credentials/credentials.go index 681f64e4..e930fec5 100644 --- a/credentials/credentials.go +++ b/credentials/credentials.go @@ -100,7 +100,6 @@ type TransportAuthenticator interface { ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error) // Info provides the ProtocolInfo of this TransportAuthenticator. Info() ProtocolInfo - Credentials } // TLSInfo contains the auth information for a TLS authenticated connection. @@ -109,6 +108,7 @@ type TLSInfo struct { State tls.ConnectionState } +// AuthType returns the type of TLSInfo as a string. func (t TLSInfo) AuthType() string { return "tls" } diff --git a/server.go b/server.go index bfb9c606..262f2e41 100644 --- a/server.go +++ b/server.go @@ -95,7 +95,7 @@ type Server struct { } type options struct { - creds credentials.Credentials + auth credentials.TransportAuthenticator codec Codec cp Compressor dc Decompressor @@ -138,9 +138,9 @@ func MaxConcurrentStreams(n uint32) ServerOption { } // Creds returns a ServerOption that sets credentials for server connections. -func Creds(c credentials.Credentials) ServerOption { +func Creds(c credentials.TransportAuthenticator) ServerOption { return func(o *options) { - o.creds = c + o.auth = c } } @@ -249,11 +249,10 @@ var ( ) func (s *Server) useTransportAuthenticator(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { - creds, ok := s.opts.creds.(credentials.TransportAuthenticator) - if !ok { + if s.opts.auth == nil { return rawConn, nil, nil } - return creds.ServerHandshake(rawConn) + return s.opts.auth.ServerHandshake(rawConn) } // Serve accepts incoming connections on the listener lis, creating a new diff --git a/transport/http2_client.go b/transport/http2_client.go index e624f8da..23ae7897 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -88,7 +88,7 @@ type http2Client struct { // The scheme used: https if TLS is on, http otherwise. scheme string - authCreds []credentials.Credentials + creds []credentials.Credentials mu sync.Mutex // guard the following variables state transportState // the state of underlying connection @@ -117,19 +117,17 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e return nil, ConnectionErrorf("transport: %v", connErr) } var authInfo credentials.AuthInfo - for _, c := range opts.AuthOptions { - if ccreds, ok := c.(credentials.TransportAuthenticator); ok { - scheme = "https" - // TODO(zhaoq): Now the first TransportAuthenticator is used if there are - // multiple ones provided. Revisit this if it is not appropriate. Probably - // place the ClientTransport construction into a separate function to make - // things clear. - if timeout > 0 { - timeout -= time.Since(startT) - } - conn, authInfo, connErr = ccreds.ClientHandshake(addr, conn, timeout) - break + for _, auth := range opts.Authenticators { + scheme = "https" + // TODO(zhaoq): Now the first TransportAuthenticator is used if there are + // multiple ones provided. Revisit this if it is not appropriate. Probably + // place the ClientTransport construction into a separate function to make + // things clear. + if timeout > 0 { + timeout -= time.Since(startT) } + conn, authInfo, connErr = auth.ClientHandshake(addr, conn, timeout) + break } if connErr != nil { return nil, ConnectionErrorf("transport: %v", connErr) @@ -163,7 +161,7 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e scheme: scheme, state: reachable, activeStreams: make(map[uint32]*Stream), - authCreds: opts.AuthOptions, + creds: opts.Credentials, maxStreams: math.MaxInt32, streamSendQuota: defaultWindowSize, } @@ -248,7 +246,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea } ctx = peer.NewContext(ctx, pr) authData := make(map[string]string) - for _, c := range t.authCreds { + for _, c := range t.creds { // Construct URI required to get auth request metadata. var port string if pos := strings.LastIndex(t.target, ":"); pos != -1 { diff --git a/transport/transport.go b/transport/transport.go index 1e9d0c01..8ec537d4 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -336,8 +336,10 @@ type ConnectOptions struct { UserAgent string // Dialer specifies how to dial a network address. Dialer func(string, time.Duration) (net.Conn, error) - // AuthOptions stores the credentials required to setup a client connection and/or issue RPCs. - AuthOptions []credentials.Credentials + // Credentials stores the credentials required to issue RPCs. + Credentials []credentials.Credentials + // Authenticators stores the Authenticators required to setup a client connection. + Authenticators []credentials.TransportAuthenticator // Timeout specifies the timeout for dialing a ClientTransport. Timeout time.Duration } @@ -473,7 +475,7 @@ func (e ConnectionError) Error() string { return fmt.Sprintf("connection error: desc = %q", e.Desc) } -// Define some common ConnectionErrors. +// ErrConnClosing indicates that the transport is closing. var ErrConnClosing = ConnectionError{Desc: "transport is closing"} // StreamError is an error that only affects one stream within a connection.