From 779083c6337c1cc7b5dec96332630c4067ab6ec2 Mon Sep 17 00:00:00 2001 From: Menghan Li Date: Fri, 15 Jul 2016 17:22:46 -0700 Subject: [PATCH] Change TestDialWithBlockErrorOnBadCertificates error check --- test/end2end_test.go | 31 ++++++++++++++++++++++--------- 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/test/end2end_test.go b/test/end2end_test.go index ffa4fa35..a2aa4143 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -439,12 +439,15 @@ func (te *test) startServer(ts testpb.TestServiceServer) { if err != nil { te.t.Fatalf("Failed to listen: %v", err) } - if te.e.security == "tls" { + switch te.e.security { + case "tls": creds, err := credentials.NewServerTLSFromFile(tlsDir+"server1.pem", tlsDir+"server1.key") if err != nil { te.t.Fatalf("Failed to generate credentials %v", err) } sopts = append(sopts, grpc.Creds(creds)) + case "clientAlwaysFailCred": + sopts = append(sopts, grpc.Creds(clientAlwaysFailCred{})) } s := grpc.NewServer(sopts...) te.srv = s @@ -2267,8 +2270,22 @@ func testClientRequestBodyError_Cancel_StreamingInput(t *testing.T, e env) { }) } +const clientAlwaysFailCredErrorMsg = "clientAlwaysFailCred always fails" + +type clientAlwaysFailCred struct{} + +func (c clientAlwaysFailCred) ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (_ net.Conn, _ credentials.AuthInfo, err error) { + return nil, nil, errors.New(clientAlwaysFailCredErrorMsg) +} +func (c clientAlwaysFailCred) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + return rawConn, nil, nil +} +func (c clientAlwaysFailCred) Info() credentials.ProtocolInfo { + return credentials.ProtocolInfo{} +} + func TestDialWithBlockErrorOnBadCertificates(t *testing.T) { - te := newTest(t, env{name: "bad-tls", network: "tcp", security: "bad-tls"}) + te := newTest(t, env{name: "bad-cred", network: "tcp", security: "clientAlwaysFailCred"}) te.startServer() defer te.tearDown() @@ -2276,14 +2293,10 @@ func TestDialWithBlockErrorOnBadCertificates(t *testing.T) { err error opts []grpc.DialOption ) - creds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "wrong-server.com") - if err != nil { - te.t.Fatalf("Failed to load credentials: %v", err) - } - opts = append(opts, grpc.WithTransportCredentials(creds), grpc.WithBlock()) + opts = append(opts, grpc.WithTransportCredentials(clientAlwaysFailCred{}), grpc.WithBlock()) te.cc, err = grpc.Dial(te.srvAddr, opts...) - if err == nil { - te.t.Fatalf("Dial(%q) = %v, want ConnectionError: credentials handshake failed", te.srvAddr, err) + if !strings.Contains(err.Error(), clientAlwaysFailCredErrorMsg) { + te.t.Fatalf("Dial(%q) = %v, want err.Error() contains %q", te.srvAddr, err, clientAlwaysFailCredErrorMsg) } }