diff --git a/clientconn.go b/clientconn.go index 52d47a57..fc13c38d 100644 --- a/clientconn.go +++ b/clientconn.go @@ -50,6 +50,10 @@ import ( var ( // ErrUnspecTarget indicates that the target address is unspecified. ErrUnspecTarget = errors.New("grpc: target is unspecified") + // ErrNoTransportSecurity indicates that there is no transport security + // being set for ClientConn. Users should either set one or explicityly + // call WithInsecure DialOption to disable security. + ErrNoTransportSecurity = errors.New("grpc: no transport security set") // ErrClientConnClosing indicates that the operation is illegal because // the session is closing. ErrClientConnClosing = errors.New("grpc: the client connection is closing") @@ -63,9 +67,10 @@ var ( // dialOptions configure a Dial call. dialOptions are set by the DialOption // values passed to Dial. type dialOptions struct { - codec Codec - block bool - copts transport.ConnectOptions + codec Codec + block bool + insecure bool + copts transport.ConnectOptions } // DialOption configures how we set up the connection. @@ -87,6 +92,12 @@ func WithBlock() DialOption { } } +func WithInsecure() DialOption { + return func(o *dialOptions) { + o.insecure = true + } +} + // WithTransportCredentials returns a DialOption which configures a // connection level security credentials (e.g., TLS/SSL). func WithTransportCredentials(creds credentials.TransportAuthenticator) DialOption { @@ -136,6 +147,18 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) { for _, opt := range opts { opt(&cc.dopts) } + if !cc.dopts.insecure { + var ok bool + for _, c := range cc.dopts.copts.AuthOptions { + if _, ok := c.(credentials.TransportAuthenticator); !ok { + continue + } + ok = true + } + if !ok { + return nil, ErrNoTransportSecurity + } + } colonPos := strings.LastIndex(target, ":") if colonPos == -1 { colonPos = len(target) diff --git a/test/end2end_test.go b/test/end2end_test.go index 35b40ec0..a14770c3 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -237,7 +237,7 @@ func (s *testServer) HalfDuplexCall(stream testpb.TestService_HalfDuplexCallServ const tlsDir = "testdata/" func TestDialTimeout(t *testing.T) { - conn, err := grpc.Dial("Non-Existent.Server:80", grpc.WithTimeout(time.Millisecond), grpc.WithBlock()) + conn, err := grpc.Dial("Non-Existent.Server:80", grpc.WithTimeout(time.Millisecond), grpc.WithBlock(), grpc.WithInsecure()) if err == nil { conn.Close() } @@ -251,7 +251,7 @@ func TestTLSDialTimeout(t *testing.T) { if err != nil { t.Fatalf("Failed to create credentials %v", err) } - conn, err := grpc.Dial("Non-Existent.Server:80", grpc.WithTransportCredentials(creds), grpc.WithTimeout(time.Millisecond), grpc.WithBlock()) + conn, err := grpc.Dial("Non-Existent.Server:80", grpc.WithTransportCredentials(creds), grpc.WithTimeout(time.Millisecond), grpc.WithBlock(), grpc.WithInsecure()) if err == nil { conn.Close() } @@ -270,7 +270,7 @@ func TestReconnectTimeout(t *testing.T) { t.Fatalf("Failed to parse listener address: %v", err) } addr := "localhost:" + port - conn, err := grpc.Dial(addr, grpc.WithTimeout(5*time.Second), grpc.WithBlock()) + conn, err := grpc.Dial(addr, grpc.WithTimeout(5*time.Second), grpc.WithBlock(), grpc.WithInsecure()) if err != nil { t.Fatalf("Failed to dial to the server %q: %v", addr, err) } @@ -357,7 +357,7 @@ func setUp(hs *health.HealthServer, maxStream uint32, ua string, e env) (s *grpc } cc, err = grpc.Dial(addr, grpc.WithTransportCredentials(creds), grpc.WithDialer(e.dialer), grpc.WithUserAgent(ua)) } else { - cc, err = grpc.Dial(addr, grpc.WithDialer(e.dialer), grpc.WithUserAgent(ua)) + cc, err = grpc.Dial(addr, grpc.WithDialer(e.dialer), grpc.WithInsecure(), grpc.WithUserAgent(ua)) } if err != nil { grpclog.Fatalf("Dial(%q) = %v", addr, err)