diff --git a/clientconn.go b/clientconn.go index 885eedd9..9b9c78d0 100644 --- a/clientconn.go +++ b/clientconn.go @@ -61,10 +61,13 @@ var ( // being set for ClientConn. Users should either set one or explicitly // call WithInsecure DialOption to disable security. errNoTransportSecurity = errors.New("grpc: no transport security set (use grpc.WithInsecure() explicitly or set credentials)") - // errCredentialsMisuse indicates that users want to transmit security information - // (e.g., oauth2 token) which requires secure connection on an insecure + // errTransportCredentialsMissing indicates that users want to transmit security + // information (e.g., oauth2 token) which requires secure connection on an insecure // connection. - errCredentialsMisuse = errors.New("grpc: the credentials require transport level security (use grpc.WithTransportAuthenticator() to set)") + errTransportCredentialsMissing = errors.New("grpc: the credentials require transport level security (use grpc.WithTransportCredentials() to set)") + // errCredentialsConflict indicates that grpc.WithTransportCredentials() + // and grpc.WithInsecure() are both called for a connection. + errCredentialsConflict = errors.New("grpc: transport credentials are set for an insecure connection (grpc.WithTransportCredentials() and grpc.WithInsecure() are both called)") // errNetworkIP indicates that the connection is down due to some network I/O error. errNetworkIO = errors.New("grpc: failed with network I/O error") // errConnDrain indicates that the connection starts to be drained and does not accept any new RPCs. @@ -170,17 +173,17 @@ func WithInsecure() DialOption { // WithTransportCredentials returns a DialOption which configures a // connection level security credentials (e.g., TLS/SSL). -func WithTransportCredentials(creds credentials.TransportAuthenticator) DialOption { +func WithTransportCredentials(creds credentials.TransportCredentials) DialOption { return func(o *dialOptions) { - o.copts.AuthOptions = append(o.copts.AuthOptions, creds) + o.copts.TransportCredentials = creds } } // WithPerRPCCredentials returns a DialOption which sets // credentials which will place auth state on each outbound RPC. -func WithPerRPCCredentials(creds credentials.Credentials) DialOption { +func WithPerRPCCredentials(creds credentials.PerRPCCredentials) DialOption { return func(o *dialOptions) { - o.copts.AuthOptions = append(o.copts.AuthOptions, creds) + o.copts.PerRPCCredentials = append(o.copts.PerRPCCredentials, creds) } } @@ -369,19 +372,16 @@ func (cc *ClientConn) newAddrConn(addr Address, skipWait bool) error { ac.events = trace.NewEventLog("grpc.ClientConn", ac.addr.Addr) } if !ac.dopts.insecure { - var ok bool - for _, cd := range ac.dopts.copts.AuthOptions { - if _, ok = cd.(credentials.TransportAuthenticator); ok { - break - } - } - if !ok { + if ac.dopts.copts.TransportCredentials == nil { return errNoTransportSecurity } } else { - for _, cd := range ac.dopts.copts.AuthOptions { + if ac.dopts.copts.TransportCredentials != nil { + return errCredentialsConflict + } + for _, cd := range ac.dopts.copts.PerRPCCredentials { if cd.RequireTransportSecurity() { - return errCredentialsMisuse + return errTransportCredentialsMissing } } } diff --git a/clientconn_test.go b/clientconn_test.go index d60a3aee..29db8bfc 100644 --- a/clientconn_test.go +++ b/clientconn_test.go @@ -38,6 +38,7 @@ import ( "time" "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/oauth" ) const tlsDir = "testdata/" @@ -67,17 +68,21 @@ func TestTLSDialTimeout(t *testing.T) { } func TestCredentialsMisuse(t *testing.T) { - creds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "x.test.youtube.com") + tlsCreds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "x.test.youtube.com") + if err != nil { + t.Fatalf("Failed to create authenticator %v", err) + } + // Two conflicting credential configurations + if _, err := Dial("Non-Existent.Server:80", WithTransportCredentials(tlsCreds), WithTimeout(time.Millisecond), WithBlock(), WithInsecure()); err != errCredentialsConflict { + t.Fatalf("Dial(_, _) = _, %v, want _, %v", err, errCredentialsConflict) + } + rpcCreds, err := oauth.NewJWTAccessFromKey(nil) if err != nil { t.Fatalf("Failed to create credentials %v", err) } - // Two conflicting credential configurations - if _, err := Dial("Non-Existent.Server:80", WithTransportCredentials(creds), WithTimeout(time.Millisecond), WithBlock(), WithInsecure()); err != errCredentialsMisuse { - t.Fatalf("Dial(_, _) = _, %v, want _, %v", err, errCredentialsMisuse) - } // security info on insecure connection - if _, err := Dial("Non-Existent.Server:80", WithPerRPCCredentials(creds), WithTimeout(time.Millisecond), WithBlock(), WithInsecure()); err != errCredentialsMisuse { - t.Fatalf("Dial(_, _) = _, %v, want _, %v", err, errCredentialsMisuse) + if _, err := Dial("Non-Existent.Server:80", WithPerRPCCredentials(rpcCreds), WithTimeout(time.Millisecond), WithBlock(), WithInsecure()); err != errTransportCredentialsMissing { + t.Fatalf("Dial(_, _) = _, %v, want _, %v", err, errTransportCredentialsMissing) } } diff --git a/credentials/credentials.go b/credentials/credentials.go index 681f64e4..4481a26e 100644 --- a/credentials/credentials.go +++ b/credentials/credentials.go @@ -54,9 +54,9 @@ var ( alpnProtoStr = []string{"h2"} ) -// Credentials defines the common interface all supported credentials must -// implement. -type Credentials interface { +// PerRPCCredentials defines the common interface for the credentials which need to +// attach security information to every RPC (e.g., oauth2). +type PerRPCCredentials interface { // GetRequestMetadata gets the current request metadata, refreshing // tokens if required. This should be called by the transport layer on // each request, and the data should be populated in headers or other @@ -87,9 +87,9 @@ type AuthInfo interface { AuthType() string } -// TransportAuthenticator defines the common interface for all the live gRPC wire +// TransportCredentials defines the common interface for all the live gRPC wire // protocols and supported transport security protocols (e.g., TLS, SSL). -type TransportAuthenticator interface { +type TransportCredentials interface { // ClientHandshake does the authentication handshake specified by the corresponding // authentication protocol on rawConn for clients. It returns the authenticated // connection and the corresponding auth information about the connection. @@ -98,9 +98,8 @@ type TransportAuthenticator interface { // the authenticated connection and the corresponding auth information about // the connection. ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error) - // Info provides the ProtocolInfo of this TransportAuthenticator. + // Info provides the ProtocolInfo of this TransportCredentials. Info() ProtocolInfo - Credentials } // TLSInfo contains the auth information for a TLS authenticated connection. @@ -109,6 +108,7 @@ type TLSInfo struct { State tls.ConnectionState } +// AuthType returns the type of TLSInfo as a string. func (t TLSInfo) AuthType() string { return "tls" } @@ -185,20 +185,20 @@ func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error) return conn, TLSInfo{conn.ConnectionState()}, nil } -// NewTLS uses c to construct a TransportAuthenticator based on TLS. -func NewTLS(c *tls.Config) TransportAuthenticator { +// NewTLS uses c to construct a TransportCredentials based on TLS. +func NewTLS(c *tls.Config) TransportCredentials { tc := &tlsCreds{*c} tc.config.NextProtos = alpnProtoStr return tc } // NewClientTLSFromCert constructs a TLS from the input certificate for client. -func NewClientTLSFromCert(cp *x509.CertPool, serverName string) TransportAuthenticator { +func NewClientTLSFromCert(cp *x509.CertPool, serverName string) TransportCredentials { return NewTLS(&tls.Config{ServerName: serverName, RootCAs: cp}) } // NewClientTLSFromFile constructs a TLS from the input certificate file for client. -func NewClientTLSFromFile(certFile, serverName string) (TransportAuthenticator, error) { +func NewClientTLSFromFile(certFile, serverName string) (TransportCredentials, error) { b, err := ioutil.ReadFile(certFile) if err != nil { return nil, err @@ -211,13 +211,13 @@ func NewClientTLSFromFile(certFile, serverName string) (TransportAuthenticator, } // NewServerTLSFromCert constructs a TLS from the input certificate for server. -func NewServerTLSFromCert(cert *tls.Certificate) TransportAuthenticator { +func NewServerTLSFromCert(cert *tls.Certificate) TransportCredentials { return NewTLS(&tls.Config{Certificates: []tls.Certificate{*cert}}) } // NewServerTLSFromFile constructs a TLS from the input certificate file and key // file for server. -func NewServerTLSFromFile(certFile, keyFile string) (TransportAuthenticator, error) { +func NewServerTLSFromFile(certFile, keyFile string) (TransportCredentials, error) { cert, err := tls.LoadX509KeyPair(certFile, keyFile) if err != nil { return nil, err diff --git a/credentials/oauth/oauth.go b/credentials/oauth/oauth.go index 04943fdf..d54a3158 100644 --- a/credentials/oauth/oauth.go +++ b/credentials/oauth/oauth.go @@ -45,7 +45,7 @@ import ( "google.golang.org/grpc/credentials" ) -// TokenSource supplies credentials from an oauth2.TokenSource. +// TokenSource supplies PerRPCCredentials from an oauth2.TokenSource. type TokenSource struct { oauth2.TokenSource } @@ -61,6 +61,7 @@ func (ts TokenSource) GetRequestMetadata(ctx context.Context, uri ...string) (ma }, nil } +// RequireTransportSecurity indicates whether the credentails requires transport security. func (ts TokenSource) RequireTransportSecurity() bool { return true } @@ -69,7 +70,8 @@ type jwtAccess struct { jsonKey []byte } -func NewJWTAccessFromFile(keyFile string) (credentials.Credentials, error) { +// NewJWTAccessFromFile creates PerRPCCredentials from the given keyFile. +func NewJWTAccessFromFile(keyFile string) (credentials.PerRPCCredentials, error) { jsonKey, err := ioutil.ReadFile(keyFile) if err != nil { return nil, fmt.Errorf("credentials: failed to read the service account key file: %v", err) @@ -77,7 +79,8 @@ func NewJWTAccessFromFile(keyFile string) (credentials.Credentials, error) { return NewJWTAccessFromKey(jsonKey) } -func NewJWTAccessFromKey(jsonKey []byte) (credentials.Credentials, error) { +// NewJWTAccessFromKey creates PerRPCCredentials from the given jsonKey. +func NewJWTAccessFromKey(jsonKey []byte) (credentials.PerRPCCredentials, error) { return jwtAccess{jsonKey}, nil } @@ -99,13 +102,13 @@ func (j jwtAccess) RequireTransportSecurity() bool { return true } -// oauthAccess supplies credentials from a given token. +// oauthAccess supplies PerRPCCredentials from a given token. type oauthAccess struct { token oauth2.Token } -// NewOauthAccess constructs the credentials using a given token. -func NewOauthAccess(token *oauth2.Token) credentials.Credentials { +// NewOauthAccess constructs the PerRPCCredentials using a given token. +func NewOauthAccess(token *oauth2.Token) credentials.PerRPCCredentials { return oauthAccess{token: *token} } @@ -119,15 +122,15 @@ func (oa oauthAccess) RequireTransportSecurity() bool { return true } -// NewComputeEngine constructs the credentials that fetches access tokens from +// NewComputeEngine constructs the PerRPCCredentials that fetches access tokens from // Google Compute Engine (GCE)'s metadata server. It is only valid to use this // if your program is running on a GCE instance. // TODO(dsymonds): Deprecate and remove this. -func NewComputeEngine() credentials.Credentials { +func NewComputeEngine() credentials.PerRPCCredentials { return TokenSource{google.ComputeTokenSource("")} } -// serviceAccount represents credentials via JWT signing key. +// serviceAccount represents PerRPCCredentials via JWT signing key. type serviceAccount struct { config *jwt.Config } @@ -146,9 +149,9 @@ func (s serviceAccount) RequireTransportSecurity() bool { return true } -// NewServiceAccountFromKey constructs the credentials using the JSON key slice +// NewServiceAccountFromKey constructs the PerRPCCredentials using the JSON key slice // from a Google Developers service account. -func NewServiceAccountFromKey(jsonKey []byte, scope ...string) (credentials.Credentials, error) { +func NewServiceAccountFromKey(jsonKey []byte, scope ...string) (credentials.PerRPCCredentials, error) { config, err := google.JWTConfigFromJSON(jsonKey, scope...) if err != nil { return nil, err @@ -156,9 +159,9 @@ func NewServiceAccountFromKey(jsonKey []byte, scope ...string) (credentials.Cred return serviceAccount{config: config}, nil } -// NewServiceAccountFromFile constructs the credentials using the JSON key file +// NewServiceAccountFromFile constructs the PerRPCCredentials using the JSON key file // of a Google Developers service account. -func NewServiceAccountFromFile(keyFile string, scope ...string) (credentials.Credentials, error) { +func NewServiceAccountFromFile(keyFile string, scope ...string) (credentials.PerRPCCredentials, error) { jsonKey, err := ioutil.ReadFile(keyFile) if err != nil { return nil, fmt.Errorf("credentials: failed to read the service account key file: %v", err) @@ -168,7 +171,7 @@ func NewServiceAccountFromFile(keyFile string, scope ...string) (credentials.Cre // NewApplicationDefault returns "Application Default Credentials". For more // detail, see https://developers.google.com/accounts/docs/application-default-credentials. -func NewApplicationDefault(ctx context.Context, scope ...string) (credentials.Credentials, error) { +func NewApplicationDefault(ctx context.Context, scope ...string) (credentials.PerRPCCredentials, error) { t, err := google.DefaultTokenSource(ctx, scope...) if err != nil { return nil, err diff --git a/examples/route_guide/client/client.go b/examples/route_guide/client/client.go index a96c0302..f84352c8 100644 --- a/examples/route_guide/client/client.go +++ b/examples/route_guide/client/client.go @@ -164,7 +164,7 @@ func main() { if *serverHostOverride != "" { sn = *serverHostOverride } - var creds credentials.TransportAuthenticator + var creds credentials.TransportCredentials if *caFile != "" { var err error creds, err = credentials.NewClientTLSFromFile(*caFile, sn) diff --git a/interop/client/client.go b/interop/client/client.go index e05854c7..98f6cfec 100644 --- a/interop/client/client.go +++ b/interop/client/client.go @@ -85,7 +85,7 @@ func main() { if *tlsServerName != "" { sn = *tlsServerName } - var creds credentials.TransportAuthenticator + var creds credentials.TransportCredentials if *testCA { var err error creds, err = credentials.NewClientTLSFromFile(testCAFile, sn) diff --git a/server.go b/server.go index 440fe249..1b5ac050 100644 --- a/server.go +++ b/server.go @@ -96,7 +96,7 @@ type Server struct { } type options struct { - creds credentials.Credentials + creds credentials.TransportCredentials codec Codec cp Compressor dc Decompressor @@ -139,7 +139,7 @@ func MaxConcurrentStreams(n uint32) ServerOption { } // Creds returns a ServerOption that sets credentials for server connections. -func Creds(c credentials.Credentials) ServerOption { +func Creds(c credentials.TransportCredentials) ServerOption { return func(o *options) { o.creds = c } @@ -250,11 +250,10 @@ var ( ) func (s *Server) useTransportAuthenticator(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { - creds, ok := s.opts.creds.(credentials.TransportAuthenticator) - if !ok { + if s.opts.creds == nil { return rawConn, nil, nil } - return creds.ServerHandshake(rawConn) + return s.opts.creds.ServerHandshake(rawConn) } // Serve accepts incoming connections on the listener lis, creating a new diff --git a/transport/http2_client.go b/transport/http2_client.go index e624f8da..227686d4 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -88,7 +88,7 @@ type http2Client struct { // The scheme used: https if TLS is on, http otherwise. scheme string - authCreds []credentials.Credentials + creds []credentials.PerRPCCredentials mu sync.Mutex // guard the following variables state transportState // the state of underlying connection @@ -117,19 +117,12 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e return nil, ConnectionErrorf("transport: %v", connErr) } var authInfo credentials.AuthInfo - for _, c := range opts.AuthOptions { - if ccreds, ok := c.(credentials.TransportAuthenticator); ok { - scheme = "https" - // TODO(zhaoq): Now the first TransportAuthenticator is used if there are - // multiple ones provided. Revisit this if it is not appropriate. Probably - // place the ClientTransport construction into a separate function to make - // things clear. - if timeout > 0 { - timeout -= time.Since(startT) - } - conn, authInfo, connErr = ccreds.ClientHandshake(addr, conn, timeout) - break + if opts.TransportCredentials != nil { + scheme = "https" + if timeout > 0 { + timeout -= time.Since(startT) } + conn, authInfo, connErr = opts.TransportCredentials.ClientHandshake(addr, conn, timeout) } if connErr != nil { return nil, ConnectionErrorf("transport: %v", connErr) @@ -163,7 +156,7 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e scheme: scheme, state: reachable, activeStreams: make(map[uint32]*Stream), - authCreds: opts.AuthOptions, + creds: opts.PerRPCCredentials, maxStreams: math.MaxInt32, streamSendQuota: defaultWindowSize, } @@ -248,7 +241,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea } ctx = peer.NewContext(ctx, pr) authData := make(map[string]string) - for _, c := range t.authCreds { + for _, c := range t.creds { // Construct URI required to get auth request metadata. var port string if pos := strings.LastIndex(t.target, ":"); pos != -1 { diff --git a/transport/transport.go b/transport/transport.go index 1e9d0c01..d4c220a0 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -336,8 +336,10 @@ type ConnectOptions struct { UserAgent string // Dialer specifies how to dial a network address. Dialer func(string, time.Duration) (net.Conn, error) - // AuthOptions stores the credentials required to setup a client connection and/or issue RPCs. - AuthOptions []credentials.Credentials + // PerRPCCredentials stores the PerRPCCredentials required to issue RPCs. + PerRPCCredentials []credentials.PerRPCCredentials + // TransportCredentials stores the Authenticator required to setup a client connection. + TransportCredentials credentials.TransportCredentials // Timeout specifies the timeout for dialing a ClientTransport. Timeout time.Duration } @@ -473,7 +475,7 @@ func (e ConnectionError) Error() string { return fmt.Sprintf("connection error: desc = %q", e.Desc) } -// Define some common ConnectionErrors. +// ErrConnClosing indicates that the transport is closing. var ErrConnClosing = ConnectionError{Desc: "transport is closing"} // StreamError is an error that only affects one stream within a connection.