Make TransportAuthenticator not embed Credentials

This commit is contained in:
Menghan Li
2016-06-06 16:24:46 -07:00
parent b60d3e9ed8
commit 6404c49192
6 changed files with 40 additions and 39 deletions

View File

@ -170,9 +170,9 @@ func WithInsecure() DialOption {
// WithTransportCredentials returns a DialOption which configures a
// connection level security credentials (e.g., TLS/SSL).
func WithTransportCredentials(creds credentials.TransportAuthenticator) DialOption {
func WithTransportCredentials(auth credentials.TransportAuthenticator) DialOption {
return func(o *dialOptions) {
o.copts.AuthOptions = append(o.copts.AuthOptions, creds)
o.copts.Authenticators = append(o.copts.Authenticators, auth)
}
}
@ -180,7 +180,7 @@ func WithTransportCredentials(creds credentials.TransportAuthenticator) DialOpti
// credentials which will place auth state on each outbound RPC.
func WithPerRPCCredentials(creds credentials.Credentials) DialOption {
return func(o *dialOptions) {
o.copts.AuthOptions = append(o.copts.AuthOptions, creds)
o.copts.Credentials = append(o.copts.Credentials, creds)
}
}
@ -369,17 +369,14 @@ func (cc *ClientConn) newAddrConn(addr Address, skipWait bool) error {
ac.events = trace.NewEventLog("grpc.ClientConn", ac.addr.Addr)
}
if !ac.dopts.insecure {
var ok bool
for _, cd := range ac.dopts.copts.AuthOptions {
if _, ok = cd.(credentials.TransportAuthenticator); ok {
break
}
}
if !ok {
if len(ac.dopts.copts.Authenticators) == 0 {
return errNoTransportSecurity
}
} else {
for _, cd := range ac.dopts.copts.AuthOptions {
if len(ac.dopts.copts.Authenticators) > 0 {
return errCredentialsMisuse
}
for _, cd := range ac.dopts.copts.Credentials {
if cd.RequireTransportSecurity() {
return errCredentialsMisuse
}

View File

@ -38,6 +38,7 @@ import (
"time"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/oauth"
)
const tlsDir = "testdata/"
@ -67,14 +68,18 @@ func TestTLSDialTimeout(t *testing.T) {
}
func TestCredentialsMisuse(t *testing.T) {
creds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "x.test.youtube.com")
auth, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "x.test.youtube.com")
if err != nil {
t.Fatalf("Failed to create credentials %v", err)
t.Fatalf("Failed to create authenticator %v", err)
}
// Two conflicting credential configurations
if _, err := Dial("Non-Existent.Server:80", WithTransportCredentials(creds), WithTimeout(time.Millisecond), WithBlock(), WithInsecure()); err != errCredentialsMisuse {
if _, err := Dial("Non-Existent.Server:80", WithTransportCredentials(auth), WithTimeout(time.Millisecond), WithBlock(), WithInsecure()); err != errCredentialsMisuse {
t.Fatalf("Dial(_, _) = _, %v, want _, %v", err, errCredentialsMisuse)
}
creds, err := oauth.NewJWTAccessFromKey(nil)
if err != nil {
t.Fatalf("Failed to create credentials %v", err)
}
// security info on insecure connection
if _, err := Dial("Non-Existent.Server:80", WithPerRPCCredentials(creds), WithTimeout(time.Millisecond), WithBlock(), WithInsecure()); err != errCredentialsMisuse {
t.Fatalf("Dial(_, _) = _, %v, want _, %v", err, errCredentialsMisuse)

View File

@ -100,7 +100,6 @@ type TransportAuthenticator interface {
ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error)
// Info provides the ProtocolInfo of this TransportAuthenticator.
Info() ProtocolInfo
Credentials
}
// TLSInfo contains the auth information for a TLS authenticated connection.
@ -109,6 +108,7 @@ type TLSInfo struct {
State tls.ConnectionState
}
// AuthType returns the type of TLSInfo as a string.
func (t TLSInfo) AuthType() string {
return "tls"
}

View File

@ -95,7 +95,7 @@ type Server struct {
}
type options struct {
creds credentials.Credentials
auth credentials.TransportAuthenticator
codec Codec
cp Compressor
dc Decompressor
@ -138,9 +138,9 @@ func MaxConcurrentStreams(n uint32) ServerOption {
}
// Creds returns a ServerOption that sets credentials for server connections.
func Creds(c credentials.Credentials) ServerOption {
func Creds(c credentials.TransportAuthenticator) ServerOption {
return func(o *options) {
o.creds = c
o.auth = c
}
}
@ -249,11 +249,10 @@ var (
)
func (s *Server) useTransportAuthenticator(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
creds, ok := s.opts.creds.(credentials.TransportAuthenticator)
if !ok {
if s.opts.auth == nil {
return rawConn, nil, nil
}
return creds.ServerHandshake(rawConn)
return s.opts.auth.ServerHandshake(rawConn)
}
// Serve accepts incoming connections on the listener lis, creating a new

View File

@ -88,7 +88,7 @@ type http2Client struct {
// The scheme used: https if TLS is on, http otherwise.
scheme string
authCreds []credentials.Credentials
creds []credentials.Credentials
mu sync.Mutex // guard the following variables
state transportState // the state of underlying connection
@ -117,19 +117,17 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
return nil, ConnectionErrorf("transport: %v", connErr)
}
var authInfo credentials.AuthInfo
for _, c := range opts.AuthOptions {
if ccreds, ok := c.(credentials.TransportAuthenticator); ok {
scheme = "https"
// TODO(zhaoq): Now the first TransportAuthenticator is used if there are
// multiple ones provided. Revisit this if it is not appropriate. Probably
// place the ClientTransport construction into a separate function to make
// things clear.
if timeout > 0 {
timeout -= time.Since(startT)
}
conn, authInfo, connErr = ccreds.ClientHandshake(addr, conn, timeout)
break
for _, auth := range opts.Authenticators {
scheme = "https"
// TODO(zhaoq): Now the first TransportAuthenticator is used if there are
// multiple ones provided. Revisit this if it is not appropriate. Probably
// place the ClientTransport construction into a separate function to make
// things clear.
if timeout > 0 {
timeout -= time.Since(startT)
}
conn, authInfo, connErr = auth.ClientHandshake(addr, conn, timeout)
break
}
if connErr != nil {
return nil, ConnectionErrorf("transport: %v", connErr)
@ -163,7 +161,7 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
scheme: scheme,
state: reachable,
activeStreams: make(map[uint32]*Stream),
authCreds: opts.AuthOptions,
creds: opts.Credentials,
maxStreams: math.MaxInt32,
streamSendQuota: defaultWindowSize,
}
@ -248,7 +246,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
}
ctx = peer.NewContext(ctx, pr)
authData := make(map[string]string)
for _, c := range t.authCreds {
for _, c := range t.creds {
// Construct URI required to get auth request metadata.
var port string
if pos := strings.LastIndex(t.target, ":"); pos != -1 {

View File

@ -336,8 +336,10 @@ type ConnectOptions struct {
UserAgent string
// Dialer specifies how to dial a network address.
Dialer func(string, time.Duration) (net.Conn, error)
// AuthOptions stores the credentials required to setup a client connection and/or issue RPCs.
AuthOptions []credentials.Credentials
// Credentials stores the credentials required to issue RPCs.
Credentials []credentials.Credentials
// Authenticators stores the Authenticators required to setup a client connection.
Authenticators []credentials.TransportAuthenticator
// Timeout specifies the timeout for dialing a ClientTransport.
Timeout time.Duration
}
@ -473,7 +475,7 @@ func (e ConnectionError) Error() string {
return fmt.Sprintf("connection error: desc = %q", e.Desc)
}
// Define some common ConnectionErrors.
// ErrConnClosing indicates that the transport is closing.
var ErrConnClosing = ConnectionError{Desc: "transport is closing"}
// StreamError is an error that only affects one stream within a connection.