Change TestDialWithBlockErrorOnBadCertificates error check

This commit is contained in:
Menghan Li
2016-07-15 17:22:46 -07:00
parent 1d0bea7943
commit 779083c633

View File

@ -439,12 +439,15 @@ func (te *test) startServer(ts testpb.TestServiceServer) {
if err != nil { if err != nil {
te.t.Fatalf("Failed to listen: %v", err) te.t.Fatalf("Failed to listen: %v", err)
} }
if te.e.security == "tls" { switch te.e.security {
case "tls":
creds, err := credentials.NewServerTLSFromFile(tlsDir+"server1.pem", tlsDir+"server1.key") creds, err := credentials.NewServerTLSFromFile(tlsDir+"server1.pem", tlsDir+"server1.key")
if err != nil { if err != nil {
te.t.Fatalf("Failed to generate credentials %v", err) te.t.Fatalf("Failed to generate credentials %v", err)
} }
sopts = append(sopts, grpc.Creds(creds)) sopts = append(sopts, grpc.Creds(creds))
case "clientAlwaysFailCred":
sopts = append(sopts, grpc.Creds(clientAlwaysFailCred{}))
} }
s := grpc.NewServer(sopts...) s := grpc.NewServer(sopts...)
te.srv = s te.srv = s
@ -2267,8 +2270,22 @@ func testClientRequestBodyError_Cancel_StreamingInput(t *testing.T, e env) {
}) })
} }
const clientAlwaysFailCredErrorMsg = "clientAlwaysFailCred always fails"
type clientAlwaysFailCred struct{}
func (c clientAlwaysFailCred) ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (_ net.Conn, _ credentials.AuthInfo, err error) {
return nil, nil, errors.New(clientAlwaysFailCredErrorMsg)
}
func (c clientAlwaysFailCred) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
return rawConn, nil, nil
}
func (c clientAlwaysFailCred) Info() credentials.ProtocolInfo {
return credentials.ProtocolInfo{}
}
func TestDialWithBlockErrorOnBadCertificates(t *testing.T) { func TestDialWithBlockErrorOnBadCertificates(t *testing.T) {
te := newTest(t, env{name: "bad-tls", network: "tcp", security: "bad-tls"}) te := newTest(t, env{name: "bad-cred", network: "tcp", security: "clientAlwaysFailCred"})
te.startServer() te.startServer()
defer te.tearDown() defer te.tearDown()
@ -2276,14 +2293,10 @@ func TestDialWithBlockErrorOnBadCertificates(t *testing.T) {
err error err error
opts []grpc.DialOption opts []grpc.DialOption
) )
creds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "wrong-server.com") opts = append(opts, grpc.WithTransportCredentials(clientAlwaysFailCred{}), grpc.WithBlock())
if err != nil {
te.t.Fatalf("Failed to load credentials: %v", err)
}
opts = append(opts, grpc.WithTransportCredentials(creds), grpc.WithBlock())
te.cc, err = grpc.Dial(te.srvAddr, opts...) te.cc, err = grpc.Dial(te.srvAddr, opts...)
if err == nil { if !strings.Contains(err.Error(), clientAlwaysFailCredErrorMsg) {
te.t.Fatalf("Dial(%q) = %v, want ConnectionError: credentials handshake failed", te.srvAddr, err) te.t.Fatalf("Dial(%q) = %v, want err.Error() contains %q", te.srvAddr, err, clientAlwaysFailCredErrorMsg)
} }
} }