Make TransportAuthenticator not embed Credentials
This commit is contained in:
@ -170,9 +170,9 @@ func WithInsecure() DialOption {
|
||||
|
||||
// WithTransportCredentials returns a DialOption which configures a
|
||||
// connection level security credentials (e.g., TLS/SSL).
|
||||
func WithTransportCredentials(creds credentials.TransportAuthenticator) DialOption {
|
||||
func WithTransportCredentials(auth credentials.TransportAuthenticator) DialOption {
|
||||
return func(o *dialOptions) {
|
||||
o.copts.AuthOptions = append(o.copts.AuthOptions, creds)
|
||||
o.copts.Authenticators = append(o.copts.Authenticators, auth)
|
||||
}
|
||||
}
|
||||
|
||||
@ -180,7 +180,7 @@ func WithTransportCredentials(creds credentials.TransportAuthenticator) DialOpti
|
||||
// credentials which will place auth state on each outbound RPC.
|
||||
func WithPerRPCCredentials(creds credentials.Credentials) DialOption {
|
||||
return func(o *dialOptions) {
|
||||
o.copts.AuthOptions = append(o.copts.AuthOptions, creds)
|
||||
o.copts.Credentials = append(o.copts.Credentials, creds)
|
||||
}
|
||||
}
|
||||
|
||||
@ -369,17 +369,14 @@ func (cc *ClientConn) newAddrConn(addr Address, skipWait bool) error {
|
||||
ac.events = trace.NewEventLog("grpc.ClientConn", ac.addr.Addr)
|
||||
}
|
||||
if !ac.dopts.insecure {
|
||||
var ok bool
|
||||
for _, cd := range ac.dopts.copts.AuthOptions {
|
||||
if _, ok = cd.(credentials.TransportAuthenticator); ok {
|
||||
break
|
||||
}
|
||||
}
|
||||
if !ok {
|
||||
if len(ac.dopts.copts.Authenticators) == 0 {
|
||||
return errNoTransportSecurity
|
||||
}
|
||||
} else {
|
||||
for _, cd := range ac.dopts.copts.AuthOptions {
|
||||
if len(ac.dopts.copts.Authenticators) > 0 {
|
||||
return errCredentialsMisuse
|
||||
}
|
||||
for _, cd := range ac.dopts.copts.Credentials {
|
||||
if cd.RequireTransportSecurity() {
|
||||
return errCredentialsMisuse
|
||||
}
|
||||
|
@ -38,6 +38,7 @@ import (
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/credentials/oauth"
|
||||
)
|
||||
|
||||
const tlsDir = "testdata/"
|
||||
@ -67,14 +68,18 @@ func TestTLSDialTimeout(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestCredentialsMisuse(t *testing.T) {
|
||||
creds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "x.test.youtube.com")
|
||||
auth, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "x.test.youtube.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create credentials %v", err)
|
||||
t.Fatalf("Failed to create authenticator %v", err)
|
||||
}
|
||||
// Two conflicting credential configurations
|
||||
if _, err := Dial("Non-Existent.Server:80", WithTransportCredentials(creds), WithTimeout(time.Millisecond), WithBlock(), WithInsecure()); err != errCredentialsMisuse {
|
||||
if _, err := Dial("Non-Existent.Server:80", WithTransportCredentials(auth), WithTimeout(time.Millisecond), WithBlock(), WithInsecure()); err != errCredentialsMisuse {
|
||||
t.Fatalf("Dial(_, _) = _, %v, want _, %v", err, errCredentialsMisuse)
|
||||
}
|
||||
creds, err := oauth.NewJWTAccessFromKey(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create credentials %v", err)
|
||||
}
|
||||
// security info on insecure connection
|
||||
if _, err := Dial("Non-Existent.Server:80", WithPerRPCCredentials(creds), WithTimeout(time.Millisecond), WithBlock(), WithInsecure()); err != errCredentialsMisuse {
|
||||
t.Fatalf("Dial(_, _) = _, %v, want _, %v", err, errCredentialsMisuse)
|
||||
|
@ -100,7 +100,6 @@ type TransportAuthenticator interface {
|
||||
ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error)
|
||||
// Info provides the ProtocolInfo of this TransportAuthenticator.
|
||||
Info() ProtocolInfo
|
||||
Credentials
|
||||
}
|
||||
|
||||
// TLSInfo contains the auth information for a TLS authenticated connection.
|
||||
@ -109,6 +108,7 @@ type TLSInfo struct {
|
||||
State tls.ConnectionState
|
||||
}
|
||||
|
||||
// AuthType returns the type of TLSInfo as a string.
|
||||
func (t TLSInfo) AuthType() string {
|
||||
return "tls"
|
||||
}
|
||||
|
11
server.go
11
server.go
@ -95,7 +95,7 @@ type Server struct {
|
||||
}
|
||||
|
||||
type options struct {
|
||||
creds credentials.Credentials
|
||||
auth credentials.TransportAuthenticator
|
||||
codec Codec
|
||||
cp Compressor
|
||||
dc Decompressor
|
||||
@ -138,9 +138,9 @@ func MaxConcurrentStreams(n uint32) ServerOption {
|
||||
}
|
||||
|
||||
// Creds returns a ServerOption that sets credentials for server connections.
|
||||
func Creds(c credentials.Credentials) ServerOption {
|
||||
func Creds(c credentials.TransportAuthenticator) ServerOption {
|
||||
return func(o *options) {
|
||||
o.creds = c
|
||||
o.auth = c
|
||||
}
|
||||
}
|
||||
|
||||
@ -249,11 +249,10 @@ var (
|
||||
)
|
||||
|
||||
func (s *Server) useTransportAuthenticator(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
||||
creds, ok := s.opts.creds.(credentials.TransportAuthenticator)
|
||||
if !ok {
|
||||
if s.opts.auth == nil {
|
||||
return rawConn, nil, nil
|
||||
}
|
||||
return creds.ServerHandshake(rawConn)
|
||||
return s.opts.auth.ServerHandshake(rawConn)
|
||||
}
|
||||
|
||||
// Serve accepts incoming connections on the listener lis, creating a new
|
||||
|
@ -88,7 +88,7 @@ type http2Client struct {
|
||||
// The scheme used: https if TLS is on, http otherwise.
|
||||
scheme string
|
||||
|
||||
authCreds []credentials.Credentials
|
||||
creds []credentials.Credentials
|
||||
|
||||
mu sync.Mutex // guard the following variables
|
||||
state transportState // the state of underlying connection
|
||||
@ -117,19 +117,17 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
|
||||
return nil, ConnectionErrorf("transport: %v", connErr)
|
||||
}
|
||||
var authInfo credentials.AuthInfo
|
||||
for _, c := range opts.AuthOptions {
|
||||
if ccreds, ok := c.(credentials.TransportAuthenticator); ok {
|
||||
scheme = "https"
|
||||
// TODO(zhaoq): Now the first TransportAuthenticator is used if there are
|
||||
// multiple ones provided. Revisit this if it is not appropriate. Probably
|
||||
// place the ClientTransport construction into a separate function to make
|
||||
// things clear.
|
||||
if timeout > 0 {
|
||||
timeout -= time.Since(startT)
|
||||
}
|
||||
conn, authInfo, connErr = ccreds.ClientHandshake(addr, conn, timeout)
|
||||
break
|
||||
for _, auth := range opts.Authenticators {
|
||||
scheme = "https"
|
||||
// TODO(zhaoq): Now the first TransportAuthenticator is used if there are
|
||||
// multiple ones provided. Revisit this if it is not appropriate. Probably
|
||||
// place the ClientTransport construction into a separate function to make
|
||||
// things clear.
|
||||
if timeout > 0 {
|
||||
timeout -= time.Since(startT)
|
||||
}
|
||||
conn, authInfo, connErr = auth.ClientHandshake(addr, conn, timeout)
|
||||
break
|
||||
}
|
||||
if connErr != nil {
|
||||
return nil, ConnectionErrorf("transport: %v", connErr)
|
||||
@ -163,7 +161,7 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
|
||||
scheme: scheme,
|
||||
state: reachable,
|
||||
activeStreams: make(map[uint32]*Stream),
|
||||
authCreds: opts.AuthOptions,
|
||||
creds: opts.Credentials,
|
||||
maxStreams: math.MaxInt32,
|
||||
streamSendQuota: defaultWindowSize,
|
||||
}
|
||||
@ -248,7 +246,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
|
||||
}
|
||||
ctx = peer.NewContext(ctx, pr)
|
||||
authData := make(map[string]string)
|
||||
for _, c := range t.authCreds {
|
||||
for _, c := range t.creds {
|
||||
// Construct URI required to get auth request metadata.
|
||||
var port string
|
||||
if pos := strings.LastIndex(t.target, ":"); pos != -1 {
|
||||
|
@ -336,8 +336,10 @@ type ConnectOptions struct {
|
||||
UserAgent string
|
||||
// Dialer specifies how to dial a network address.
|
||||
Dialer func(string, time.Duration) (net.Conn, error)
|
||||
// AuthOptions stores the credentials required to setup a client connection and/or issue RPCs.
|
||||
AuthOptions []credentials.Credentials
|
||||
// Credentials stores the credentials required to issue RPCs.
|
||||
Credentials []credentials.Credentials
|
||||
// Authenticators stores the Authenticators required to setup a client connection.
|
||||
Authenticators []credentials.TransportAuthenticator
|
||||
// Timeout specifies the timeout for dialing a ClientTransport.
|
||||
Timeout time.Duration
|
||||
}
|
||||
@ -473,7 +475,7 @@ func (e ConnectionError) Error() string {
|
||||
return fmt.Sprintf("connection error: desc = %q", e.Desc)
|
||||
}
|
||||
|
||||
// Define some common ConnectionErrors.
|
||||
// ErrConnClosing indicates that the transport is closing.
|
||||
var ErrConnClosing = ConnectionError{Desc: "transport is closing"}
|
||||
|
||||
// StreamError is an error that only affects one stream within a connection.
|
||||
|
Reference in New Issue
Block a user