Insecure ClientConn made explicit

This commit is contained in:
iamqizhao
2015-08-27 17:21:52 -07:00
parent 6ab9a9c9d7
commit 996538ab4b
2 changed files with 30 additions and 7 deletions

View File

@ -50,6 +50,10 @@ import (
var ( var (
// ErrUnspecTarget indicates that the target address is unspecified. // ErrUnspecTarget indicates that the target address is unspecified.
ErrUnspecTarget = errors.New("grpc: target is unspecified") ErrUnspecTarget = errors.New("grpc: target is unspecified")
// ErrNoTransportSecurity indicates that there is no transport security
// being set for ClientConn. Users should either set one or explicityly
// call WithInsecure DialOption to disable security.
ErrNoTransportSecurity = errors.New("grpc: no transport security set")
// ErrClientConnClosing indicates that the operation is illegal because // ErrClientConnClosing indicates that the operation is illegal because
// the session is closing. // the session is closing.
ErrClientConnClosing = errors.New("grpc: the client connection is closing") ErrClientConnClosing = errors.New("grpc: the client connection is closing")
@ -63,9 +67,10 @@ var (
// dialOptions configure a Dial call. dialOptions are set by the DialOption // dialOptions configure a Dial call. dialOptions are set by the DialOption
// values passed to Dial. // values passed to Dial.
type dialOptions struct { type dialOptions struct {
codec Codec codec Codec
block bool block bool
copts transport.ConnectOptions insecure bool
copts transport.ConnectOptions
} }
// DialOption configures how we set up the connection. // DialOption configures how we set up the connection.
@ -87,6 +92,12 @@ func WithBlock() DialOption {
} }
} }
func WithInsecure() DialOption {
return func(o *dialOptions) {
o.insecure = true
}
}
// WithTransportCredentials returns a DialOption which configures a // WithTransportCredentials returns a DialOption which configures a
// connection level security credentials (e.g., TLS/SSL). // connection level security credentials (e.g., TLS/SSL).
func WithTransportCredentials(creds credentials.TransportAuthenticator) DialOption { func WithTransportCredentials(creds credentials.TransportAuthenticator) DialOption {
@ -136,6 +147,18 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) {
for _, opt := range opts { for _, opt := range opts {
opt(&cc.dopts) opt(&cc.dopts)
} }
if !cc.dopts.insecure {
var ok bool
for _, c := range cc.dopts.copts.AuthOptions {
if _, ok := c.(credentials.TransportAuthenticator); !ok {
continue
}
ok = true
}
if !ok {
return nil, ErrNoTransportSecurity
}
}
colonPos := strings.LastIndex(target, ":") colonPos := strings.LastIndex(target, ":")
if colonPos == -1 { if colonPos == -1 {
colonPos = len(target) colonPos = len(target)

View File

@ -237,7 +237,7 @@ func (s *testServer) HalfDuplexCall(stream testpb.TestService_HalfDuplexCallServ
const tlsDir = "testdata/" const tlsDir = "testdata/"
func TestDialTimeout(t *testing.T) { func TestDialTimeout(t *testing.T) {
conn, err := grpc.Dial("Non-Existent.Server:80", grpc.WithTimeout(time.Millisecond), grpc.WithBlock()) conn, err := grpc.Dial("Non-Existent.Server:80", grpc.WithTimeout(time.Millisecond), grpc.WithBlock(), grpc.WithInsecure())
if err == nil { if err == nil {
conn.Close() conn.Close()
} }
@ -251,7 +251,7 @@ func TestTLSDialTimeout(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Failed to create credentials %v", err) t.Fatalf("Failed to create credentials %v", err)
} }
conn, err := grpc.Dial("Non-Existent.Server:80", grpc.WithTransportCredentials(creds), grpc.WithTimeout(time.Millisecond), grpc.WithBlock()) conn, err := grpc.Dial("Non-Existent.Server:80", grpc.WithTransportCredentials(creds), grpc.WithTimeout(time.Millisecond), grpc.WithBlock(), grpc.WithInsecure())
if err == nil { if err == nil {
conn.Close() conn.Close()
} }
@ -270,7 +270,7 @@ func TestReconnectTimeout(t *testing.T) {
t.Fatalf("Failed to parse listener address: %v", err) t.Fatalf("Failed to parse listener address: %v", err)
} }
addr := "localhost:" + port addr := "localhost:" + port
conn, err := grpc.Dial(addr, grpc.WithTimeout(5*time.Second), grpc.WithBlock()) conn, err := grpc.Dial(addr, grpc.WithTimeout(5*time.Second), grpc.WithBlock(), grpc.WithInsecure())
if err != nil { if err != nil {
t.Fatalf("Failed to dial to the server %q: %v", addr, err) t.Fatalf("Failed to dial to the server %q: %v", addr, err)
} }
@ -357,7 +357,7 @@ func setUp(hs *health.HealthServer, maxStream uint32, ua string, e env) (s *grpc
} }
cc, err = grpc.Dial(addr, grpc.WithTransportCredentials(creds), grpc.WithDialer(e.dialer), grpc.WithUserAgent(ua)) cc, err = grpc.Dial(addr, grpc.WithTransportCredentials(creds), grpc.WithDialer(e.dialer), grpc.WithUserAgent(ua))
} else { } else {
cc, err = grpc.Dial(addr, grpc.WithDialer(e.dialer), grpc.WithUserAgent(ua)) cc, err = grpc.Dial(addr, grpc.WithDialer(e.dialer), grpc.WithInsecure(), grpc.WithUserAgent(ua))
} }
if err != nil { if err != nil {
grpclog.Fatalf("Dial(%q) = %v", addr, err) grpclog.Fatalf("Dial(%q) = %v", addr, err)