Make TransportAuthenticator not embed Credentials

This commit is contained in:
Menghan Li
2016-06-06 16:24:46 -07:00
parent b60d3e9ed8
commit 6404c49192
6 changed files with 40 additions and 39 deletions

View File

@ -170,9 +170,9 @@ func WithInsecure() DialOption {
// WithTransportCredentials returns a DialOption which configures a // WithTransportCredentials returns a DialOption which configures a
// connection level security credentials (e.g., TLS/SSL). // connection level security credentials (e.g., TLS/SSL).
func WithTransportCredentials(creds credentials.TransportAuthenticator) DialOption { func WithTransportCredentials(auth credentials.TransportAuthenticator) DialOption {
return func(o *dialOptions) { return func(o *dialOptions) {
o.copts.AuthOptions = append(o.copts.AuthOptions, creds) o.copts.Authenticators = append(o.copts.Authenticators, auth)
} }
} }
@ -180,7 +180,7 @@ func WithTransportCredentials(creds credentials.TransportAuthenticator) DialOpti
// credentials which will place auth state on each outbound RPC. // credentials which will place auth state on each outbound RPC.
func WithPerRPCCredentials(creds credentials.Credentials) DialOption { func WithPerRPCCredentials(creds credentials.Credentials) DialOption {
return func(o *dialOptions) { return func(o *dialOptions) {
o.copts.AuthOptions = append(o.copts.AuthOptions, creds) o.copts.Credentials = append(o.copts.Credentials, creds)
} }
} }
@ -369,17 +369,14 @@ func (cc *ClientConn) newAddrConn(addr Address, skipWait bool) error {
ac.events = trace.NewEventLog("grpc.ClientConn", ac.addr.Addr) ac.events = trace.NewEventLog("grpc.ClientConn", ac.addr.Addr)
} }
if !ac.dopts.insecure { if !ac.dopts.insecure {
var ok bool if len(ac.dopts.copts.Authenticators) == 0 {
for _, cd := range ac.dopts.copts.AuthOptions {
if _, ok = cd.(credentials.TransportAuthenticator); ok {
break
}
}
if !ok {
return errNoTransportSecurity return errNoTransportSecurity
} }
} else { } else {
for _, cd := range ac.dopts.copts.AuthOptions { if len(ac.dopts.copts.Authenticators) > 0 {
return errCredentialsMisuse
}
for _, cd := range ac.dopts.copts.Credentials {
if cd.RequireTransportSecurity() { if cd.RequireTransportSecurity() {
return errCredentialsMisuse return errCredentialsMisuse
} }

View File

@ -38,6 +38,7 @@ import (
"time" "time"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/oauth"
) )
const tlsDir = "testdata/" const tlsDir = "testdata/"
@ -67,14 +68,18 @@ func TestTLSDialTimeout(t *testing.T) {
} }
func TestCredentialsMisuse(t *testing.T) { func TestCredentialsMisuse(t *testing.T) {
creds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "x.test.youtube.com") auth, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "x.test.youtube.com")
if err != nil { if err != nil {
t.Fatalf("Failed to create credentials %v", err) t.Fatalf("Failed to create authenticator %v", err)
} }
// Two conflicting credential configurations // Two conflicting credential configurations
if _, err := Dial("Non-Existent.Server:80", WithTransportCredentials(creds), WithTimeout(time.Millisecond), WithBlock(), WithInsecure()); err != errCredentialsMisuse { if _, err := Dial("Non-Existent.Server:80", WithTransportCredentials(auth), WithTimeout(time.Millisecond), WithBlock(), WithInsecure()); err != errCredentialsMisuse {
t.Fatalf("Dial(_, _) = _, %v, want _, %v", err, errCredentialsMisuse) t.Fatalf("Dial(_, _) = _, %v, want _, %v", err, errCredentialsMisuse)
} }
creds, err := oauth.NewJWTAccessFromKey(nil)
if err != nil {
t.Fatalf("Failed to create credentials %v", err)
}
// security info on insecure connection // security info on insecure connection
if _, err := Dial("Non-Existent.Server:80", WithPerRPCCredentials(creds), WithTimeout(time.Millisecond), WithBlock(), WithInsecure()); err != errCredentialsMisuse { if _, err := Dial("Non-Existent.Server:80", WithPerRPCCredentials(creds), WithTimeout(time.Millisecond), WithBlock(), WithInsecure()); err != errCredentialsMisuse {
t.Fatalf("Dial(_, _) = _, %v, want _, %v", err, errCredentialsMisuse) t.Fatalf("Dial(_, _) = _, %v, want _, %v", err, errCredentialsMisuse)

View File

@ -100,7 +100,6 @@ type TransportAuthenticator interface {
ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error)
// Info provides the ProtocolInfo of this TransportAuthenticator. // Info provides the ProtocolInfo of this TransportAuthenticator.
Info() ProtocolInfo Info() ProtocolInfo
Credentials
} }
// TLSInfo contains the auth information for a TLS authenticated connection. // TLSInfo contains the auth information for a TLS authenticated connection.
@ -109,6 +108,7 @@ type TLSInfo struct {
State tls.ConnectionState State tls.ConnectionState
} }
// AuthType returns the type of TLSInfo as a string.
func (t TLSInfo) AuthType() string { func (t TLSInfo) AuthType() string {
return "tls" return "tls"
} }

View File

@ -95,7 +95,7 @@ type Server struct {
} }
type options struct { type options struct {
creds credentials.Credentials auth credentials.TransportAuthenticator
codec Codec codec Codec
cp Compressor cp Compressor
dc Decompressor dc Decompressor
@ -138,9 +138,9 @@ func MaxConcurrentStreams(n uint32) ServerOption {
} }
// Creds returns a ServerOption that sets credentials for server connections. // Creds returns a ServerOption that sets credentials for server connections.
func Creds(c credentials.Credentials) ServerOption { func Creds(c credentials.TransportAuthenticator) ServerOption {
return func(o *options) { return func(o *options) {
o.creds = c o.auth = c
} }
} }
@ -249,11 +249,10 @@ var (
) )
func (s *Server) useTransportAuthenticator(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { func (s *Server) useTransportAuthenticator(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
creds, ok := s.opts.creds.(credentials.TransportAuthenticator) if s.opts.auth == nil {
if !ok {
return rawConn, nil, nil return rawConn, nil, nil
} }
return creds.ServerHandshake(rawConn) return s.opts.auth.ServerHandshake(rawConn)
} }
// Serve accepts incoming connections on the listener lis, creating a new // Serve accepts incoming connections on the listener lis, creating a new

View File

@ -88,7 +88,7 @@ type http2Client struct {
// The scheme used: https if TLS is on, http otherwise. // The scheme used: https if TLS is on, http otherwise.
scheme string scheme string
authCreds []credentials.Credentials creds []credentials.Credentials
mu sync.Mutex // guard the following variables mu sync.Mutex // guard the following variables
state transportState // the state of underlying connection state transportState // the state of underlying connection
@ -117,8 +117,7 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
return nil, ConnectionErrorf("transport: %v", connErr) return nil, ConnectionErrorf("transport: %v", connErr)
} }
var authInfo credentials.AuthInfo var authInfo credentials.AuthInfo
for _, c := range opts.AuthOptions { for _, auth := range opts.Authenticators {
if ccreds, ok := c.(credentials.TransportAuthenticator); ok {
scheme = "https" scheme = "https"
// TODO(zhaoq): Now the first TransportAuthenticator is used if there are // TODO(zhaoq): Now the first TransportAuthenticator is used if there are
// multiple ones provided. Revisit this if it is not appropriate. Probably // multiple ones provided. Revisit this if it is not appropriate. Probably
@ -127,10 +126,9 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
if timeout > 0 { if timeout > 0 {
timeout -= time.Since(startT) timeout -= time.Since(startT)
} }
conn, authInfo, connErr = ccreds.ClientHandshake(addr, conn, timeout) conn, authInfo, connErr = auth.ClientHandshake(addr, conn, timeout)
break break
} }
}
if connErr != nil { if connErr != nil {
return nil, ConnectionErrorf("transport: %v", connErr) return nil, ConnectionErrorf("transport: %v", connErr)
} }
@ -163,7 +161,7 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
scheme: scheme, scheme: scheme,
state: reachable, state: reachable,
activeStreams: make(map[uint32]*Stream), activeStreams: make(map[uint32]*Stream),
authCreds: opts.AuthOptions, creds: opts.Credentials,
maxStreams: math.MaxInt32, maxStreams: math.MaxInt32,
streamSendQuota: defaultWindowSize, streamSendQuota: defaultWindowSize,
} }
@ -248,7 +246,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
} }
ctx = peer.NewContext(ctx, pr) ctx = peer.NewContext(ctx, pr)
authData := make(map[string]string) authData := make(map[string]string)
for _, c := range t.authCreds { for _, c := range t.creds {
// Construct URI required to get auth request metadata. // Construct URI required to get auth request metadata.
var port string var port string
if pos := strings.LastIndex(t.target, ":"); pos != -1 { if pos := strings.LastIndex(t.target, ":"); pos != -1 {

View File

@ -336,8 +336,10 @@ type ConnectOptions struct {
UserAgent string UserAgent string
// Dialer specifies how to dial a network address. // Dialer specifies how to dial a network address.
Dialer func(string, time.Duration) (net.Conn, error) Dialer func(string, time.Duration) (net.Conn, error)
// AuthOptions stores the credentials required to setup a client connection and/or issue RPCs. // Credentials stores the credentials required to issue RPCs.
AuthOptions []credentials.Credentials Credentials []credentials.Credentials
// Authenticators stores the Authenticators required to setup a client connection.
Authenticators []credentials.TransportAuthenticator
// Timeout specifies the timeout for dialing a ClientTransport. // Timeout specifies the timeout for dialing a ClientTransport.
Timeout time.Duration Timeout time.Duration
} }
@ -473,7 +475,7 @@ func (e ConnectionError) Error() string {
return fmt.Sprintf("connection error: desc = %q", e.Desc) return fmt.Sprintf("connection error: desc = %q", e.Desc)
} }
// Define some common ConnectionErrors. // ErrConnClosing indicates that the transport is closing.
var ErrConnClosing = ConnectionError{Desc: "transport is closing"} var ErrConnClosing = ConnectionError{Desc: "transport is closing"}
// StreamError is an error that only affects one stream within a connection. // StreamError is an error that only affects one stream within a connection.