Change TestDialWithBlockErrorOnBadCertificates error check
This commit is contained in:
@ -439,12 +439,15 @@ func (te *test) startServer(ts testpb.TestServiceServer) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
te.t.Fatalf("Failed to listen: %v", err)
|
te.t.Fatalf("Failed to listen: %v", err)
|
||||||
}
|
}
|
||||||
if te.e.security == "tls" {
|
switch te.e.security {
|
||||||
|
case "tls":
|
||||||
creds, err := credentials.NewServerTLSFromFile(tlsDir+"server1.pem", tlsDir+"server1.key")
|
creds, err := credentials.NewServerTLSFromFile(tlsDir+"server1.pem", tlsDir+"server1.key")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
te.t.Fatalf("Failed to generate credentials %v", err)
|
te.t.Fatalf("Failed to generate credentials %v", err)
|
||||||
}
|
}
|
||||||
sopts = append(sopts, grpc.Creds(creds))
|
sopts = append(sopts, grpc.Creds(creds))
|
||||||
|
case "clientAlwaysFailCred":
|
||||||
|
sopts = append(sopts, grpc.Creds(clientAlwaysFailCred{}))
|
||||||
}
|
}
|
||||||
s := grpc.NewServer(sopts...)
|
s := grpc.NewServer(sopts...)
|
||||||
te.srv = s
|
te.srv = s
|
||||||
@ -2267,8 +2270,22 @@ func testClientRequestBodyError_Cancel_StreamingInput(t *testing.T, e env) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const clientAlwaysFailCredErrorMsg = "clientAlwaysFailCred always fails"
|
||||||
|
|
||||||
|
type clientAlwaysFailCred struct{}
|
||||||
|
|
||||||
|
func (c clientAlwaysFailCred) ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (_ net.Conn, _ credentials.AuthInfo, err error) {
|
||||||
|
return nil, nil, errors.New(clientAlwaysFailCredErrorMsg)
|
||||||
|
}
|
||||||
|
func (c clientAlwaysFailCred) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
||||||
|
return rawConn, nil, nil
|
||||||
|
}
|
||||||
|
func (c clientAlwaysFailCred) Info() credentials.ProtocolInfo {
|
||||||
|
return credentials.ProtocolInfo{}
|
||||||
|
}
|
||||||
|
|
||||||
func TestDialWithBlockErrorOnBadCertificates(t *testing.T) {
|
func TestDialWithBlockErrorOnBadCertificates(t *testing.T) {
|
||||||
te := newTest(t, env{name: "bad-tls", network: "tcp", security: "bad-tls"})
|
te := newTest(t, env{name: "bad-cred", network: "tcp", security: "clientAlwaysFailCred"})
|
||||||
te.startServer()
|
te.startServer()
|
||||||
defer te.tearDown()
|
defer te.tearDown()
|
||||||
|
|
||||||
@ -2276,14 +2293,10 @@ func TestDialWithBlockErrorOnBadCertificates(t *testing.T) {
|
|||||||
err error
|
err error
|
||||||
opts []grpc.DialOption
|
opts []grpc.DialOption
|
||||||
)
|
)
|
||||||
creds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "wrong-server.com")
|
opts = append(opts, grpc.WithTransportCredentials(clientAlwaysFailCred{}), grpc.WithBlock())
|
||||||
if err != nil {
|
|
||||||
te.t.Fatalf("Failed to load credentials: %v", err)
|
|
||||||
}
|
|
||||||
opts = append(opts, grpc.WithTransportCredentials(creds), grpc.WithBlock())
|
|
||||||
te.cc, err = grpc.Dial(te.srvAddr, opts...)
|
te.cc, err = grpc.Dial(te.srvAddr, opts...)
|
||||||
if err == nil {
|
if !strings.Contains(err.Error(), clientAlwaysFailCredErrorMsg) {
|
||||||
te.t.Fatalf("Dial(%q) = %v, want ConnectionError: credentials handshake failed", te.srvAddr, err)
|
te.t.Fatalf("Dial(%q) = %v, want err.Error() contains %q", te.srvAddr, err, clientAlwaysFailCredErrorMsg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user