Change TestDialWithBlockErrorOnBadCertificates error check

This commit is contained in:
Menghan Li
2016-07-15 17:22:46 -07:00
parent 1d0bea7943
commit 779083c633

View File

@ -439,12 +439,15 @@ func (te *test) startServer(ts testpb.TestServiceServer) {
if err != nil {
te.t.Fatalf("Failed to listen: %v", err)
}
if te.e.security == "tls" {
switch te.e.security {
case "tls":
creds, err := credentials.NewServerTLSFromFile(tlsDir+"server1.pem", tlsDir+"server1.key")
if err != nil {
te.t.Fatalf("Failed to generate credentials %v", err)
}
sopts = append(sopts, grpc.Creds(creds))
case "clientAlwaysFailCred":
sopts = append(sopts, grpc.Creds(clientAlwaysFailCred{}))
}
s := grpc.NewServer(sopts...)
te.srv = s
@ -2267,8 +2270,22 @@ func testClientRequestBodyError_Cancel_StreamingInput(t *testing.T, e env) {
})
}
const clientAlwaysFailCredErrorMsg = "clientAlwaysFailCred always fails"
type clientAlwaysFailCred struct{}
func (c clientAlwaysFailCred) ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (_ net.Conn, _ credentials.AuthInfo, err error) {
return nil, nil, errors.New(clientAlwaysFailCredErrorMsg)
}
func (c clientAlwaysFailCred) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
return rawConn, nil, nil
}
func (c clientAlwaysFailCred) Info() credentials.ProtocolInfo {
return credentials.ProtocolInfo{}
}
func TestDialWithBlockErrorOnBadCertificates(t *testing.T) {
te := newTest(t, env{name: "bad-tls", network: "tcp", security: "bad-tls"})
te := newTest(t, env{name: "bad-cred", network: "tcp", security: "clientAlwaysFailCred"})
te.startServer()
defer te.tearDown()
@ -2276,14 +2293,10 @@ func TestDialWithBlockErrorOnBadCertificates(t *testing.T) {
err error
opts []grpc.DialOption
)
creds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "wrong-server.com")
if err != nil {
te.t.Fatalf("Failed to load credentials: %v", err)
}
opts = append(opts, grpc.WithTransportCredentials(creds), grpc.WithBlock())
opts = append(opts, grpc.WithTransportCredentials(clientAlwaysFailCred{}), grpc.WithBlock())
te.cc, err = grpc.Dial(te.srvAddr, opts...)
if err == nil {
te.t.Fatalf("Dial(%q) = %v, want ConnectionError: credentials handshake failed", te.srvAddr, err)
if !strings.Contains(err.Error(), clientAlwaysFailCredErrorMsg) {
te.t.Fatalf("Dial(%q) = %v, want err.Error() contains %q", te.srvAddr, err, clientAlwaysFailCredErrorMsg)
}
}