Make TransportAuthenticator not embed Credentials
This commit is contained in:
@ -170,9 +170,9 @@ func WithInsecure() DialOption {
|
|||||||
|
|
||||||
// WithTransportCredentials returns a DialOption which configures a
|
// WithTransportCredentials returns a DialOption which configures a
|
||||||
// connection level security credentials (e.g., TLS/SSL).
|
// connection level security credentials (e.g., TLS/SSL).
|
||||||
func WithTransportCredentials(creds credentials.TransportAuthenticator) DialOption {
|
func WithTransportCredentials(auth credentials.TransportAuthenticator) DialOption {
|
||||||
return func(o *dialOptions) {
|
return func(o *dialOptions) {
|
||||||
o.copts.AuthOptions = append(o.copts.AuthOptions, creds)
|
o.copts.Authenticators = append(o.copts.Authenticators, auth)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -180,7 +180,7 @@ func WithTransportCredentials(creds credentials.TransportAuthenticator) DialOpti
|
|||||||
// credentials which will place auth state on each outbound RPC.
|
// credentials which will place auth state on each outbound RPC.
|
||||||
func WithPerRPCCredentials(creds credentials.Credentials) DialOption {
|
func WithPerRPCCredentials(creds credentials.Credentials) DialOption {
|
||||||
return func(o *dialOptions) {
|
return func(o *dialOptions) {
|
||||||
o.copts.AuthOptions = append(o.copts.AuthOptions, creds)
|
o.copts.Credentials = append(o.copts.Credentials, creds)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -369,17 +369,14 @@ func (cc *ClientConn) newAddrConn(addr Address, skipWait bool) error {
|
|||||||
ac.events = trace.NewEventLog("grpc.ClientConn", ac.addr.Addr)
|
ac.events = trace.NewEventLog("grpc.ClientConn", ac.addr.Addr)
|
||||||
}
|
}
|
||||||
if !ac.dopts.insecure {
|
if !ac.dopts.insecure {
|
||||||
var ok bool
|
if len(ac.dopts.copts.Authenticators) == 0 {
|
||||||
for _, cd := range ac.dopts.copts.AuthOptions {
|
|
||||||
if _, ok = cd.(credentials.TransportAuthenticator); ok {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !ok {
|
|
||||||
return errNoTransportSecurity
|
return errNoTransportSecurity
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for _, cd := range ac.dopts.copts.AuthOptions {
|
if len(ac.dopts.copts.Authenticators) > 0 {
|
||||||
|
return errCredentialsMisuse
|
||||||
|
}
|
||||||
|
for _, cd := range ac.dopts.copts.Credentials {
|
||||||
if cd.RequireTransportSecurity() {
|
if cd.RequireTransportSecurity() {
|
||||||
return errCredentialsMisuse
|
return errCredentialsMisuse
|
||||||
}
|
}
|
||||||
|
@ -38,6 +38,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"google.golang.org/grpc/credentials"
|
"google.golang.org/grpc/credentials"
|
||||||
|
"google.golang.org/grpc/credentials/oauth"
|
||||||
)
|
)
|
||||||
|
|
||||||
const tlsDir = "testdata/"
|
const tlsDir = "testdata/"
|
||||||
@ -67,14 +68,18 @@ func TestTLSDialTimeout(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestCredentialsMisuse(t *testing.T) {
|
func TestCredentialsMisuse(t *testing.T) {
|
||||||
creds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "x.test.youtube.com")
|
auth, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "x.test.youtube.com")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create credentials %v", err)
|
t.Fatalf("Failed to create authenticator %v", err)
|
||||||
}
|
}
|
||||||
// Two conflicting credential configurations
|
// Two conflicting credential configurations
|
||||||
if _, err := Dial("Non-Existent.Server:80", WithTransportCredentials(creds), WithTimeout(time.Millisecond), WithBlock(), WithInsecure()); err != errCredentialsMisuse {
|
if _, err := Dial("Non-Existent.Server:80", WithTransportCredentials(auth), WithTimeout(time.Millisecond), WithBlock(), WithInsecure()); err != errCredentialsMisuse {
|
||||||
t.Fatalf("Dial(_, _) = _, %v, want _, %v", err, errCredentialsMisuse)
|
t.Fatalf("Dial(_, _) = _, %v, want _, %v", err, errCredentialsMisuse)
|
||||||
}
|
}
|
||||||
|
creds, err := oauth.NewJWTAccessFromKey(nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create credentials %v", err)
|
||||||
|
}
|
||||||
// security info on insecure connection
|
// security info on insecure connection
|
||||||
if _, err := Dial("Non-Existent.Server:80", WithPerRPCCredentials(creds), WithTimeout(time.Millisecond), WithBlock(), WithInsecure()); err != errCredentialsMisuse {
|
if _, err := Dial("Non-Existent.Server:80", WithPerRPCCredentials(creds), WithTimeout(time.Millisecond), WithBlock(), WithInsecure()); err != errCredentialsMisuse {
|
||||||
t.Fatalf("Dial(_, _) = _, %v, want _, %v", err, errCredentialsMisuse)
|
t.Fatalf("Dial(_, _) = _, %v, want _, %v", err, errCredentialsMisuse)
|
||||||
|
@ -100,7 +100,6 @@ type TransportAuthenticator interface {
|
|||||||
ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error)
|
ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error)
|
||||||
// Info provides the ProtocolInfo of this TransportAuthenticator.
|
// Info provides the ProtocolInfo of this TransportAuthenticator.
|
||||||
Info() ProtocolInfo
|
Info() ProtocolInfo
|
||||||
Credentials
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TLSInfo contains the auth information for a TLS authenticated connection.
|
// TLSInfo contains the auth information for a TLS authenticated connection.
|
||||||
@ -109,6 +108,7 @@ type TLSInfo struct {
|
|||||||
State tls.ConnectionState
|
State tls.ConnectionState
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AuthType returns the type of TLSInfo as a string.
|
||||||
func (t TLSInfo) AuthType() string {
|
func (t TLSInfo) AuthType() string {
|
||||||
return "tls"
|
return "tls"
|
||||||
}
|
}
|
||||||
|
11
server.go
11
server.go
@ -95,7 +95,7 @@ type Server struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type options struct {
|
type options struct {
|
||||||
creds credentials.Credentials
|
auth credentials.TransportAuthenticator
|
||||||
codec Codec
|
codec Codec
|
||||||
cp Compressor
|
cp Compressor
|
||||||
dc Decompressor
|
dc Decompressor
|
||||||
@ -138,9 +138,9 @@ func MaxConcurrentStreams(n uint32) ServerOption {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Creds returns a ServerOption that sets credentials for server connections.
|
// Creds returns a ServerOption that sets credentials for server connections.
|
||||||
func Creds(c credentials.Credentials) ServerOption {
|
func Creds(c credentials.TransportAuthenticator) ServerOption {
|
||||||
return func(o *options) {
|
return func(o *options) {
|
||||||
o.creds = c
|
o.auth = c
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -249,11 +249,10 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func (s *Server) useTransportAuthenticator(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
func (s *Server) useTransportAuthenticator(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
||||||
creds, ok := s.opts.creds.(credentials.TransportAuthenticator)
|
if s.opts.auth == nil {
|
||||||
if !ok {
|
|
||||||
return rawConn, nil, nil
|
return rawConn, nil, nil
|
||||||
}
|
}
|
||||||
return creds.ServerHandshake(rawConn)
|
return s.opts.auth.ServerHandshake(rawConn)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Serve accepts incoming connections on the listener lis, creating a new
|
// Serve accepts incoming connections on the listener lis, creating a new
|
||||||
|
@ -88,7 +88,7 @@ type http2Client struct {
|
|||||||
// The scheme used: https if TLS is on, http otherwise.
|
// The scheme used: https if TLS is on, http otherwise.
|
||||||
scheme string
|
scheme string
|
||||||
|
|
||||||
authCreds []credentials.Credentials
|
creds []credentials.Credentials
|
||||||
|
|
||||||
mu sync.Mutex // guard the following variables
|
mu sync.Mutex // guard the following variables
|
||||||
state transportState // the state of underlying connection
|
state transportState // the state of underlying connection
|
||||||
@ -117,19 +117,17 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
|
|||||||
return nil, ConnectionErrorf("transport: %v", connErr)
|
return nil, ConnectionErrorf("transport: %v", connErr)
|
||||||
}
|
}
|
||||||
var authInfo credentials.AuthInfo
|
var authInfo credentials.AuthInfo
|
||||||
for _, c := range opts.AuthOptions {
|
for _, auth := range opts.Authenticators {
|
||||||
if ccreds, ok := c.(credentials.TransportAuthenticator); ok {
|
scheme = "https"
|
||||||
scheme = "https"
|
// TODO(zhaoq): Now the first TransportAuthenticator is used if there are
|
||||||
// TODO(zhaoq): Now the first TransportAuthenticator is used if there are
|
// multiple ones provided. Revisit this if it is not appropriate. Probably
|
||||||
// multiple ones provided. Revisit this if it is not appropriate. Probably
|
// place the ClientTransport construction into a separate function to make
|
||||||
// place the ClientTransport construction into a separate function to make
|
// things clear.
|
||||||
// things clear.
|
if timeout > 0 {
|
||||||
if timeout > 0 {
|
timeout -= time.Since(startT)
|
||||||
timeout -= time.Since(startT)
|
|
||||||
}
|
|
||||||
conn, authInfo, connErr = ccreds.ClientHandshake(addr, conn, timeout)
|
|
||||||
break
|
|
||||||
}
|
}
|
||||||
|
conn, authInfo, connErr = auth.ClientHandshake(addr, conn, timeout)
|
||||||
|
break
|
||||||
}
|
}
|
||||||
if connErr != nil {
|
if connErr != nil {
|
||||||
return nil, ConnectionErrorf("transport: %v", connErr)
|
return nil, ConnectionErrorf("transport: %v", connErr)
|
||||||
@ -163,7 +161,7 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
|
|||||||
scheme: scheme,
|
scheme: scheme,
|
||||||
state: reachable,
|
state: reachable,
|
||||||
activeStreams: make(map[uint32]*Stream),
|
activeStreams: make(map[uint32]*Stream),
|
||||||
authCreds: opts.AuthOptions,
|
creds: opts.Credentials,
|
||||||
maxStreams: math.MaxInt32,
|
maxStreams: math.MaxInt32,
|
||||||
streamSendQuota: defaultWindowSize,
|
streamSendQuota: defaultWindowSize,
|
||||||
}
|
}
|
||||||
@ -248,7 +246,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
|
|||||||
}
|
}
|
||||||
ctx = peer.NewContext(ctx, pr)
|
ctx = peer.NewContext(ctx, pr)
|
||||||
authData := make(map[string]string)
|
authData := make(map[string]string)
|
||||||
for _, c := range t.authCreds {
|
for _, c := range t.creds {
|
||||||
// Construct URI required to get auth request metadata.
|
// Construct URI required to get auth request metadata.
|
||||||
var port string
|
var port string
|
||||||
if pos := strings.LastIndex(t.target, ":"); pos != -1 {
|
if pos := strings.LastIndex(t.target, ":"); pos != -1 {
|
||||||
|
@ -336,8 +336,10 @@ type ConnectOptions struct {
|
|||||||
UserAgent string
|
UserAgent string
|
||||||
// Dialer specifies how to dial a network address.
|
// Dialer specifies how to dial a network address.
|
||||||
Dialer func(string, time.Duration) (net.Conn, error)
|
Dialer func(string, time.Duration) (net.Conn, error)
|
||||||
// AuthOptions stores the credentials required to setup a client connection and/or issue RPCs.
|
// Credentials stores the credentials required to issue RPCs.
|
||||||
AuthOptions []credentials.Credentials
|
Credentials []credentials.Credentials
|
||||||
|
// Authenticators stores the Authenticators required to setup a client connection.
|
||||||
|
Authenticators []credentials.TransportAuthenticator
|
||||||
// Timeout specifies the timeout for dialing a ClientTransport.
|
// Timeout specifies the timeout for dialing a ClientTransport.
|
||||||
Timeout time.Duration
|
Timeout time.Duration
|
||||||
}
|
}
|
||||||
@ -473,7 +475,7 @@ func (e ConnectionError) Error() string {
|
|||||||
return fmt.Sprintf("connection error: desc = %q", e.Desc)
|
return fmt.Sprintf("connection error: desc = %q", e.Desc)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Define some common ConnectionErrors.
|
// ErrConnClosing indicates that the transport is closing.
|
||||||
var ErrConnClosing = ConnectionError{Desc: "transport is closing"}
|
var ErrConnClosing = ConnectionError{Desc: "transport is closing"}
|
||||||
|
|
||||||
// StreamError is an error that only affects one stream within a connection.
|
// StreamError is an error that only affects one stream within a connection.
|
||||||
|
Reference in New Issue
Block a user