Merge pull request #713 from menghanl/split_authenticator_and_credentials
[API revision] Separate TransportAuthenticator and PerRPCCredentials
This commit is contained in:
@ -61,10 +61,13 @@ var (
|
|||||||
// being set for ClientConn. Users should either set one or explicitly
|
// being set for ClientConn. Users should either set one or explicitly
|
||||||
// call WithInsecure DialOption to disable security.
|
// call WithInsecure DialOption to disable security.
|
||||||
errNoTransportSecurity = errors.New("grpc: no transport security set (use grpc.WithInsecure() explicitly or set credentials)")
|
errNoTransportSecurity = errors.New("grpc: no transport security set (use grpc.WithInsecure() explicitly or set credentials)")
|
||||||
// errCredentialsMisuse indicates that users want to transmit security information
|
// errTransportCredentialsMissing indicates that users want to transmit security
|
||||||
// (e.g., oauth2 token) which requires secure connection on an insecure
|
// information (e.g., oauth2 token) which requires secure connection on an insecure
|
||||||
// connection.
|
// connection.
|
||||||
errCredentialsMisuse = errors.New("grpc: the credentials require transport level security (use grpc.WithTransportAuthenticator() to set)")
|
errTransportCredentialsMissing = errors.New("grpc: the credentials require transport level security (use grpc.WithTransportCredentials() to set)")
|
||||||
|
// errCredentialsConflict indicates that grpc.WithTransportCredentials()
|
||||||
|
// and grpc.WithInsecure() are both called for a connection.
|
||||||
|
errCredentialsConflict = errors.New("grpc: transport credentials are set for an insecure connection (grpc.WithTransportCredentials() and grpc.WithInsecure() are both called)")
|
||||||
// errNetworkIP indicates that the connection is down due to some network I/O error.
|
// errNetworkIP indicates that the connection is down due to some network I/O error.
|
||||||
errNetworkIO = errors.New("grpc: failed with network I/O error")
|
errNetworkIO = errors.New("grpc: failed with network I/O error")
|
||||||
// errConnDrain indicates that the connection starts to be drained and does not accept any new RPCs.
|
// errConnDrain indicates that the connection starts to be drained and does not accept any new RPCs.
|
||||||
@ -170,17 +173,17 @@ func WithInsecure() DialOption {
|
|||||||
|
|
||||||
// WithTransportCredentials returns a DialOption which configures a
|
// WithTransportCredentials returns a DialOption which configures a
|
||||||
// connection level security credentials (e.g., TLS/SSL).
|
// connection level security credentials (e.g., TLS/SSL).
|
||||||
func WithTransportCredentials(creds credentials.TransportAuthenticator) DialOption {
|
func WithTransportCredentials(creds credentials.TransportCredentials) DialOption {
|
||||||
return func(o *dialOptions) {
|
return func(o *dialOptions) {
|
||||||
o.copts.AuthOptions = append(o.copts.AuthOptions, creds)
|
o.copts.TransportCredentials = creds
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithPerRPCCredentials returns a DialOption which sets
|
// WithPerRPCCredentials returns a DialOption which sets
|
||||||
// credentials which will place auth state on each outbound RPC.
|
// credentials which will place auth state on each outbound RPC.
|
||||||
func WithPerRPCCredentials(creds credentials.Credentials) DialOption {
|
func WithPerRPCCredentials(creds credentials.PerRPCCredentials) DialOption {
|
||||||
return func(o *dialOptions) {
|
return func(o *dialOptions) {
|
||||||
o.copts.AuthOptions = append(o.copts.AuthOptions, creds)
|
o.copts.PerRPCCredentials = append(o.copts.PerRPCCredentials, creds)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -369,19 +372,16 @@ func (cc *ClientConn) newAddrConn(addr Address, skipWait bool) error {
|
|||||||
ac.events = trace.NewEventLog("grpc.ClientConn", ac.addr.Addr)
|
ac.events = trace.NewEventLog("grpc.ClientConn", ac.addr.Addr)
|
||||||
}
|
}
|
||||||
if !ac.dopts.insecure {
|
if !ac.dopts.insecure {
|
||||||
var ok bool
|
if ac.dopts.copts.TransportCredentials == nil {
|
||||||
for _, cd := range ac.dopts.copts.AuthOptions {
|
|
||||||
if _, ok = cd.(credentials.TransportAuthenticator); ok {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !ok {
|
|
||||||
return errNoTransportSecurity
|
return errNoTransportSecurity
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for _, cd := range ac.dopts.copts.AuthOptions {
|
if ac.dopts.copts.TransportCredentials != nil {
|
||||||
|
return errCredentialsConflict
|
||||||
|
}
|
||||||
|
for _, cd := range ac.dopts.copts.PerRPCCredentials {
|
||||||
if cd.RequireTransportSecurity() {
|
if cd.RequireTransportSecurity() {
|
||||||
return errCredentialsMisuse
|
return errTransportCredentialsMissing
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -38,6 +38,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"google.golang.org/grpc/credentials"
|
"google.golang.org/grpc/credentials"
|
||||||
|
"google.golang.org/grpc/credentials/oauth"
|
||||||
)
|
)
|
||||||
|
|
||||||
const tlsDir = "testdata/"
|
const tlsDir = "testdata/"
|
||||||
@ -67,17 +68,21 @@ func TestTLSDialTimeout(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestCredentialsMisuse(t *testing.T) {
|
func TestCredentialsMisuse(t *testing.T) {
|
||||||
creds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "x.test.youtube.com")
|
tlsCreds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "x.test.youtube.com")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create authenticator %v", err)
|
||||||
|
}
|
||||||
|
// Two conflicting credential configurations
|
||||||
|
if _, err := Dial("Non-Existent.Server:80", WithTransportCredentials(tlsCreds), WithTimeout(time.Millisecond), WithBlock(), WithInsecure()); err != errCredentialsConflict {
|
||||||
|
t.Fatalf("Dial(_, _) = _, %v, want _, %v", err, errCredentialsConflict)
|
||||||
|
}
|
||||||
|
rpcCreds, err := oauth.NewJWTAccessFromKey(nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create credentials %v", err)
|
t.Fatalf("Failed to create credentials %v", err)
|
||||||
}
|
}
|
||||||
// Two conflicting credential configurations
|
|
||||||
if _, err := Dial("Non-Existent.Server:80", WithTransportCredentials(creds), WithTimeout(time.Millisecond), WithBlock(), WithInsecure()); err != errCredentialsMisuse {
|
|
||||||
t.Fatalf("Dial(_, _) = _, %v, want _, %v", err, errCredentialsMisuse)
|
|
||||||
}
|
|
||||||
// security info on insecure connection
|
// security info on insecure connection
|
||||||
if _, err := Dial("Non-Existent.Server:80", WithPerRPCCredentials(creds), WithTimeout(time.Millisecond), WithBlock(), WithInsecure()); err != errCredentialsMisuse {
|
if _, err := Dial("Non-Existent.Server:80", WithPerRPCCredentials(rpcCreds), WithTimeout(time.Millisecond), WithBlock(), WithInsecure()); err != errTransportCredentialsMissing {
|
||||||
t.Fatalf("Dial(_, _) = _, %v, want _, %v", err, errCredentialsMisuse)
|
t.Fatalf("Dial(_, _) = _, %v, want _, %v", err, errTransportCredentialsMissing)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -54,9 +54,9 @@ var (
|
|||||||
alpnProtoStr = []string{"h2"}
|
alpnProtoStr = []string{"h2"}
|
||||||
)
|
)
|
||||||
|
|
||||||
// Credentials defines the common interface all supported credentials must
|
// PerRPCCredentials defines the common interface for the credentials which need to
|
||||||
// implement.
|
// attach security information to every RPC (e.g., oauth2).
|
||||||
type Credentials interface {
|
type PerRPCCredentials interface {
|
||||||
// GetRequestMetadata gets the current request metadata, refreshing
|
// GetRequestMetadata gets the current request metadata, refreshing
|
||||||
// tokens if required. This should be called by the transport layer on
|
// tokens if required. This should be called by the transport layer on
|
||||||
// each request, and the data should be populated in headers or other
|
// each request, and the data should be populated in headers or other
|
||||||
@ -87,9 +87,9 @@ type AuthInfo interface {
|
|||||||
AuthType() string
|
AuthType() string
|
||||||
}
|
}
|
||||||
|
|
||||||
// TransportAuthenticator defines the common interface for all the live gRPC wire
|
// TransportCredentials defines the common interface for all the live gRPC wire
|
||||||
// protocols and supported transport security protocols (e.g., TLS, SSL).
|
// protocols and supported transport security protocols (e.g., TLS, SSL).
|
||||||
type TransportAuthenticator interface {
|
type TransportCredentials interface {
|
||||||
// ClientHandshake does the authentication handshake specified by the corresponding
|
// ClientHandshake does the authentication handshake specified by the corresponding
|
||||||
// authentication protocol on rawConn for clients. It returns the authenticated
|
// authentication protocol on rawConn for clients. It returns the authenticated
|
||||||
// connection and the corresponding auth information about the connection.
|
// connection and the corresponding auth information about the connection.
|
||||||
@ -98,9 +98,8 @@ type TransportAuthenticator interface {
|
|||||||
// the authenticated connection and the corresponding auth information about
|
// the authenticated connection and the corresponding auth information about
|
||||||
// the connection.
|
// the connection.
|
||||||
ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error)
|
ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error)
|
||||||
// Info provides the ProtocolInfo of this TransportAuthenticator.
|
// Info provides the ProtocolInfo of this TransportCredentials.
|
||||||
Info() ProtocolInfo
|
Info() ProtocolInfo
|
||||||
Credentials
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TLSInfo contains the auth information for a TLS authenticated connection.
|
// TLSInfo contains the auth information for a TLS authenticated connection.
|
||||||
@ -109,6 +108,7 @@ type TLSInfo struct {
|
|||||||
State tls.ConnectionState
|
State tls.ConnectionState
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AuthType returns the type of TLSInfo as a string.
|
||||||
func (t TLSInfo) AuthType() string {
|
func (t TLSInfo) AuthType() string {
|
||||||
return "tls"
|
return "tls"
|
||||||
}
|
}
|
||||||
@ -185,20 +185,20 @@ func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error)
|
|||||||
return conn, TLSInfo{conn.ConnectionState()}, nil
|
return conn, TLSInfo{conn.ConnectionState()}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewTLS uses c to construct a TransportAuthenticator based on TLS.
|
// NewTLS uses c to construct a TransportCredentials based on TLS.
|
||||||
func NewTLS(c *tls.Config) TransportAuthenticator {
|
func NewTLS(c *tls.Config) TransportCredentials {
|
||||||
tc := &tlsCreds{*c}
|
tc := &tlsCreds{*c}
|
||||||
tc.config.NextProtos = alpnProtoStr
|
tc.config.NextProtos = alpnProtoStr
|
||||||
return tc
|
return tc
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClientTLSFromCert constructs a TLS from the input certificate for client.
|
// NewClientTLSFromCert constructs a TLS from the input certificate for client.
|
||||||
func NewClientTLSFromCert(cp *x509.CertPool, serverName string) TransportAuthenticator {
|
func NewClientTLSFromCert(cp *x509.CertPool, serverName string) TransportCredentials {
|
||||||
return NewTLS(&tls.Config{ServerName: serverName, RootCAs: cp})
|
return NewTLS(&tls.Config{ServerName: serverName, RootCAs: cp})
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClientTLSFromFile constructs a TLS from the input certificate file for client.
|
// NewClientTLSFromFile constructs a TLS from the input certificate file for client.
|
||||||
func NewClientTLSFromFile(certFile, serverName string) (TransportAuthenticator, error) {
|
func NewClientTLSFromFile(certFile, serverName string) (TransportCredentials, error) {
|
||||||
b, err := ioutil.ReadFile(certFile)
|
b, err := ioutil.ReadFile(certFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -211,13 +211,13 @@ func NewClientTLSFromFile(certFile, serverName string) (TransportAuthenticator,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewServerTLSFromCert constructs a TLS from the input certificate for server.
|
// NewServerTLSFromCert constructs a TLS from the input certificate for server.
|
||||||
func NewServerTLSFromCert(cert *tls.Certificate) TransportAuthenticator {
|
func NewServerTLSFromCert(cert *tls.Certificate) TransportCredentials {
|
||||||
return NewTLS(&tls.Config{Certificates: []tls.Certificate{*cert}})
|
return NewTLS(&tls.Config{Certificates: []tls.Certificate{*cert}})
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewServerTLSFromFile constructs a TLS from the input certificate file and key
|
// NewServerTLSFromFile constructs a TLS from the input certificate file and key
|
||||||
// file for server.
|
// file for server.
|
||||||
func NewServerTLSFromFile(certFile, keyFile string) (TransportAuthenticator, error) {
|
func NewServerTLSFromFile(certFile, keyFile string) (TransportCredentials, error) {
|
||||||
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
|
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -45,7 +45,7 @@ import (
|
|||||||
"google.golang.org/grpc/credentials"
|
"google.golang.org/grpc/credentials"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TokenSource supplies credentials from an oauth2.TokenSource.
|
// TokenSource supplies PerRPCCredentials from an oauth2.TokenSource.
|
||||||
type TokenSource struct {
|
type TokenSource struct {
|
||||||
oauth2.TokenSource
|
oauth2.TokenSource
|
||||||
}
|
}
|
||||||
@ -61,6 +61,7 @@ func (ts TokenSource) GetRequestMetadata(ctx context.Context, uri ...string) (ma
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RequireTransportSecurity indicates whether the credentails requires transport security.
|
||||||
func (ts TokenSource) RequireTransportSecurity() bool {
|
func (ts TokenSource) RequireTransportSecurity() bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@ -69,7 +70,8 @@ type jwtAccess struct {
|
|||||||
jsonKey []byte
|
jsonKey []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewJWTAccessFromFile(keyFile string) (credentials.Credentials, error) {
|
// NewJWTAccessFromFile creates PerRPCCredentials from the given keyFile.
|
||||||
|
func NewJWTAccessFromFile(keyFile string) (credentials.PerRPCCredentials, error) {
|
||||||
jsonKey, err := ioutil.ReadFile(keyFile)
|
jsonKey, err := ioutil.ReadFile(keyFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("credentials: failed to read the service account key file: %v", err)
|
return nil, fmt.Errorf("credentials: failed to read the service account key file: %v", err)
|
||||||
@ -77,7 +79,8 @@ func NewJWTAccessFromFile(keyFile string) (credentials.Credentials, error) {
|
|||||||
return NewJWTAccessFromKey(jsonKey)
|
return NewJWTAccessFromKey(jsonKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewJWTAccessFromKey(jsonKey []byte) (credentials.Credentials, error) {
|
// NewJWTAccessFromKey creates PerRPCCredentials from the given jsonKey.
|
||||||
|
func NewJWTAccessFromKey(jsonKey []byte) (credentials.PerRPCCredentials, error) {
|
||||||
return jwtAccess{jsonKey}, nil
|
return jwtAccess{jsonKey}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -99,13 +102,13 @@ func (j jwtAccess) RequireTransportSecurity() bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// oauthAccess supplies credentials from a given token.
|
// oauthAccess supplies PerRPCCredentials from a given token.
|
||||||
type oauthAccess struct {
|
type oauthAccess struct {
|
||||||
token oauth2.Token
|
token oauth2.Token
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewOauthAccess constructs the credentials using a given token.
|
// NewOauthAccess constructs the PerRPCCredentials using a given token.
|
||||||
func NewOauthAccess(token *oauth2.Token) credentials.Credentials {
|
func NewOauthAccess(token *oauth2.Token) credentials.PerRPCCredentials {
|
||||||
return oauthAccess{token: *token}
|
return oauthAccess{token: *token}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -119,15 +122,15 @@ func (oa oauthAccess) RequireTransportSecurity() bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewComputeEngine constructs the credentials that fetches access tokens from
|
// NewComputeEngine constructs the PerRPCCredentials that fetches access tokens from
|
||||||
// Google Compute Engine (GCE)'s metadata server. It is only valid to use this
|
// Google Compute Engine (GCE)'s metadata server. It is only valid to use this
|
||||||
// if your program is running on a GCE instance.
|
// if your program is running on a GCE instance.
|
||||||
// TODO(dsymonds): Deprecate and remove this.
|
// TODO(dsymonds): Deprecate and remove this.
|
||||||
func NewComputeEngine() credentials.Credentials {
|
func NewComputeEngine() credentials.PerRPCCredentials {
|
||||||
return TokenSource{google.ComputeTokenSource("")}
|
return TokenSource{google.ComputeTokenSource("")}
|
||||||
}
|
}
|
||||||
|
|
||||||
// serviceAccount represents credentials via JWT signing key.
|
// serviceAccount represents PerRPCCredentials via JWT signing key.
|
||||||
type serviceAccount struct {
|
type serviceAccount struct {
|
||||||
config *jwt.Config
|
config *jwt.Config
|
||||||
}
|
}
|
||||||
@ -146,9 +149,9 @@ func (s serviceAccount) RequireTransportSecurity() bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewServiceAccountFromKey constructs the credentials using the JSON key slice
|
// NewServiceAccountFromKey constructs the PerRPCCredentials using the JSON key slice
|
||||||
// from a Google Developers service account.
|
// from a Google Developers service account.
|
||||||
func NewServiceAccountFromKey(jsonKey []byte, scope ...string) (credentials.Credentials, error) {
|
func NewServiceAccountFromKey(jsonKey []byte, scope ...string) (credentials.PerRPCCredentials, error) {
|
||||||
config, err := google.JWTConfigFromJSON(jsonKey, scope...)
|
config, err := google.JWTConfigFromJSON(jsonKey, scope...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -156,9 +159,9 @@ func NewServiceAccountFromKey(jsonKey []byte, scope ...string) (credentials.Cred
|
|||||||
return serviceAccount{config: config}, nil
|
return serviceAccount{config: config}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewServiceAccountFromFile constructs the credentials using the JSON key file
|
// NewServiceAccountFromFile constructs the PerRPCCredentials using the JSON key file
|
||||||
// of a Google Developers service account.
|
// of a Google Developers service account.
|
||||||
func NewServiceAccountFromFile(keyFile string, scope ...string) (credentials.Credentials, error) {
|
func NewServiceAccountFromFile(keyFile string, scope ...string) (credentials.PerRPCCredentials, error) {
|
||||||
jsonKey, err := ioutil.ReadFile(keyFile)
|
jsonKey, err := ioutil.ReadFile(keyFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("credentials: failed to read the service account key file: %v", err)
|
return nil, fmt.Errorf("credentials: failed to read the service account key file: %v", err)
|
||||||
@ -168,7 +171,7 @@ func NewServiceAccountFromFile(keyFile string, scope ...string) (credentials.Cre
|
|||||||
|
|
||||||
// NewApplicationDefault returns "Application Default Credentials". For more
|
// NewApplicationDefault returns "Application Default Credentials". For more
|
||||||
// detail, see https://developers.google.com/accounts/docs/application-default-credentials.
|
// detail, see https://developers.google.com/accounts/docs/application-default-credentials.
|
||||||
func NewApplicationDefault(ctx context.Context, scope ...string) (credentials.Credentials, error) {
|
func NewApplicationDefault(ctx context.Context, scope ...string) (credentials.PerRPCCredentials, error) {
|
||||||
t, err := google.DefaultTokenSource(ctx, scope...)
|
t, err := google.DefaultTokenSource(ctx, scope...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -164,7 +164,7 @@ func main() {
|
|||||||
if *serverHostOverride != "" {
|
if *serverHostOverride != "" {
|
||||||
sn = *serverHostOverride
|
sn = *serverHostOverride
|
||||||
}
|
}
|
||||||
var creds credentials.TransportAuthenticator
|
var creds credentials.TransportCredentials
|
||||||
if *caFile != "" {
|
if *caFile != "" {
|
||||||
var err error
|
var err error
|
||||||
creds, err = credentials.NewClientTLSFromFile(*caFile, sn)
|
creds, err = credentials.NewClientTLSFromFile(*caFile, sn)
|
||||||
|
@ -85,7 +85,7 @@ func main() {
|
|||||||
if *tlsServerName != "" {
|
if *tlsServerName != "" {
|
||||||
sn = *tlsServerName
|
sn = *tlsServerName
|
||||||
}
|
}
|
||||||
var creds credentials.TransportAuthenticator
|
var creds credentials.TransportCredentials
|
||||||
if *testCA {
|
if *testCA {
|
||||||
var err error
|
var err error
|
||||||
creds, err = credentials.NewClientTLSFromFile(testCAFile, sn)
|
creds, err = credentials.NewClientTLSFromFile(testCAFile, sn)
|
||||||
|
@ -96,7 +96,7 @@ type Server struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type options struct {
|
type options struct {
|
||||||
creds credentials.Credentials
|
creds credentials.TransportCredentials
|
||||||
codec Codec
|
codec Codec
|
||||||
cp Compressor
|
cp Compressor
|
||||||
dc Decompressor
|
dc Decompressor
|
||||||
@ -139,7 +139,7 @@ func MaxConcurrentStreams(n uint32) ServerOption {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Creds returns a ServerOption that sets credentials for server connections.
|
// Creds returns a ServerOption that sets credentials for server connections.
|
||||||
func Creds(c credentials.Credentials) ServerOption {
|
func Creds(c credentials.TransportCredentials) ServerOption {
|
||||||
return func(o *options) {
|
return func(o *options) {
|
||||||
o.creds = c
|
o.creds = c
|
||||||
}
|
}
|
||||||
@ -250,11 +250,10 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func (s *Server) useTransportAuthenticator(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
func (s *Server) useTransportAuthenticator(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
||||||
creds, ok := s.opts.creds.(credentials.TransportAuthenticator)
|
if s.opts.creds == nil {
|
||||||
if !ok {
|
|
||||||
return rawConn, nil, nil
|
return rawConn, nil, nil
|
||||||
}
|
}
|
||||||
return creds.ServerHandshake(rawConn)
|
return s.opts.creds.ServerHandshake(rawConn)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Serve accepts incoming connections on the listener lis, creating a new
|
// Serve accepts incoming connections on the listener lis, creating a new
|
||||||
|
@ -88,7 +88,7 @@ type http2Client struct {
|
|||||||
// The scheme used: https if TLS is on, http otherwise.
|
// The scheme used: https if TLS is on, http otherwise.
|
||||||
scheme string
|
scheme string
|
||||||
|
|
||||||
authCreds []credentials.Credentials
|
creds []credentials.PerRPCCredentials
|
||||||
|
|
||||||
mu sync.Mutex // guard the following variables
|
mu sync.Mutex // guard the following variables
|
||||||
state transportState // the state of underlying connection
|
state transportState // the state of underlying connection
|
||||||
@ -117,19 +117,12 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
|
|||||||
return nil, ConnectionErrorf("transport: %v", connErr)
|
return nil, ConnectionErrorf("transport: %v", connErr)
|
||||||
}
|
}
|
||||||
var authInfo credentials.AuthInfo
|
var authInfo credentials.AuthInfo
|
||||||
for _, c := range opts.AuthOptions {
|
if opts.TransportCredentials != nil {
|
||||||
if ccreds, ok := c.(credentials.TransportAuthenticator); ok {
|
scheme = "https"
|
||||||
scheme = "https"
|
if timeout > 0 {
|
||||||
// TODO(zhaoq): Now the first TransportAuthenticator is used if there are
|
timeout -= time.Since(startT)
|
||||||
// multiple ones provided. Revisit this if it is not appropriate. Probably
|
|
||||||
// place the ClientTransport construction into a separate function to make
|
|
||||||
// things clear.
|
|
||||||
if timeout > 0 {
|
|
||||||
timeout -= time.Since(startT)
|
|
||||||
}
|
|
||||||
conn, authInfo, connErr = ccreds.ClientHandshake(addr, conn, timeout)
|
|
||||||
break
|
|
||||||
}
|
}
|
||||||
|
conn, authInfo, connErr = opts.TransportCredentials.ClientHandshake(addr, conn, timeout)
|
||||||
}
|
}
|
||||||
if connErr != nil {
|
if connErr != nil {
|
||||||
return nil, ConnectionErrorf("transport: %v", connErr)
|
return nil, ConnectionErrorf("transport: %v", connErr)
|
||||||
@ -163,7 +156,7 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
|
|||||||
scheme: scheme,
|
scheme: scheme,
|
||||||
state: reachable,
|
state: reachable,
|
||||||
activeStreams: make(map[uint32]*Stream),
|
activeStreams: make(map[uint32]*Stream),
|
||||||
authCreds: opts.AuthOptions,
|
creds: opts.PerRPCCredentials,
|
||||||
maxStreams: math.MaxInt32,
|
maxStreams: math.MaxInt32,
|
||||||
streamSendQuota: defaultWindowSize,
|
streamSendQuota: defaultWindowSize,
|
||||||
}
|
}
|
||||||
@ -248,7 +241,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
|
|||||||
}
|
}
|
||||||
ctx = peer.NewContext(ctx, pr)
|
ctx = peer.NewContext(ctx, pr)
|
||||||
authData := make(map[string]string)
|
authData := make(map[string]string)
|
||||||
for _, c := range t.authCreds {
|
for _, c := range t.creds {
|
||||||
// Construct URI required to get auth request metadata.
|
// Construct URI required to get auth request metadata.
|
||||||
var port string
|
var port string
|
||||||
if pos := strings.LastIndex(t.target, ":"); pos != -1 {
|
if pos := strings.LastIndex(t.target, ":"); pos != -1 {
|
||||||
|
@ -336,8 +336,10 @@ type ConnectOptions struct {
|
|||||||
UserAgent string
|
UserAgent string
|
||||||
// Dialer specifies how to dial a network address.
|
// Dialer specifies how to dial a network address.
|
||||||
Dialer func(string, time.Duration) (net.Conn, error)
|
Dialer func(string, time.Duration) (net.Conn, error)
|
||||||
// AuthOptions stores the credentials required to setup a client connection and/or issue RPCs.
|
// PerRPCCredentials stores the PerRPCCredentials required to issue RPCs.
|
||||||
AuthOptions []credentials.Credentials
|
PerRPCCredentials []credentials.PerRPCCredentials
|
||||||
|
// TransportCredentials stores the Authenticator required to setup a client connection.
|
||||||
|
TransportCredentials credentials.TransportCredentials
|
||||||
// Timeout specifies the timeout for dialing a ClientTransport.
|
// Timeout specifies the timeout for dialing a ClientTransport.
|
||||||
Timeout time.Duration
|
Timeout time.Duration
|
||||||
}
|
}
|
||||||
@ -473,7 +475,7 @@ func (e ConnectionError) Error() string {
|
|||||||
return fmt.Sprintf("connection error: desc = %q", e.Desc)
|
return fmt.Sprintf("connection error: desc = %q", e.Desc)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Define some common ConnectionErrors.
|
// ErrConnClosing indicates that the transport is closing.
|
||||||
var ErrConnClosing = ConnectionError{Desc: "transport is closing"}
|
var ErrConnClosing = ConnectionError{Desc: "transport is closing"}
|
||||||
|
|
||||||
// StreamError is an error that only affects one stream within a connection.
|
// StreamError is an error that only affects one stream within a connection.
|
||||||
|
Reference in New Issue
Block a user