rls: use a regex for the expected error string (#5827)

This commit is contained in:
Easwar Swaminathan
2022-12-01 11:59:34 -08:00
committed by GitHub
parent 617d6c8a6c
commit 736197138d

View File

@ -25,7 +25,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"strings" "regexp"
"testing" "testing"
"time" "time"
@ -350,7 +350,7 @@ func (s) TestControlChannelCredsSuccess(t *testing.T) {
} }
} }
func testControlChannelCredsFailure(t *testing.T, sopts []grpc.ServerOption, bopts balancer.BuildOptions, wantCode codes.Code, wantErr string) { func testControlChannelCredsFailure(t *testing.T, sopts []grpc.ServerOption, bopts balancer.BuildOptions, wantCode codes.Code, wantErrRegex *regexp.Regexp) {
// StartFakeRouteLookupServer a fake server. // StartFakeRouteLookupServer a fake server.
// //
// Start an RLS server and set the throttler to never throttle requests. The // Start an RLS server and set the throttler to never throttle requests. The
@ -369,8 +369,8 @@ func testControlChannelCredsFailure(t *testing.T, sopts []grpc.ServerOption, bop
// Perform the lookup and expect the callback to be invoked with an error. // Perform the lookup and expect the callback to be invoked with an error.
errCh := make(chan error) errCh := make(chan error)
ctrlCh.lookup(nil, rlspb.RouteLookupRequest_REASON_MISS, staleHeaderData, func(_ []string, _ string, err error) { ctrlCh.lookup(nil, rlspb.RouteLookupRequest_REASON_MISS, staleHeaderData, func(_ []string, _ string, err error) {
if st, ok := status.FromError(err); !ok || st.Code() != wantCode || !strings.Contains(st.String(), wantErr) { if st, ok := status.FromError(err); !ok || st.Code() != wantCode || !wantErrRegex.MatchString(st.String()) {
errCh <- fmt.Errorf("rlsClient.lookup() returned error: %v, wantCode: %v, wantErr: %s", err, wantCode, wantErr) errCh <- fmt.Errorf("rlsClient.lookup() returned error: %v, wantCode: %v, wantErr: %s", err, wantCode, wantErrRegex.String())
return return
} }
errCh <- nil errCh <- nil
@ -393,11 +393,11 @@ func (s) TestControlChannelCredsFailure(t *testing.T) {
clientCreds := makeTLSCreds(t, "x509/client1_cert.pem", "x509/client1_key.pem", "x509/server_ca_cert.pem") clientCreds := makeTLSCreds(t, "x509/client1_cert.pem", "x509/client1_key.pem", "x509/server_ca_cert.pem")
tests := []struct { tests := []struct {
name string name string
sopts []grpc.ServerOption sopts []grpc.ServerOption
bopts balancer.BuildOptions bopts balancer.BuildOptions
wantCode codes.Code wantCode codes.Code
wantErr string wantErrRegex *regexp.Regexp
}{ }{
{ {
name: "transport creds authority mismatch", name: "transport creds authority mismatch",
@ -406,8 +406,8 @@ func (s) TestControlChannelCredsFailure(t *testing.T) {
DialCreds: clientCreds, DialCreds: clientCreds,
Authority: "authority-mismatch", Authority: "authority-mismatch",
}, },
wantCode: codes.Unavailable, wantCode: codes.Unavailable,
wantErr: "transport: authentication handshake failed: x509: certificate is valid for *.test.example.com, not authority-mismatch", wantErrRegex: regexp.MustCompile(`transport: authentication handshake failed: .* \*.test.example.com.*authority-mismatch`),
}, },
{ {
name: "transport creds handshake failure", name: "transport creds handshake failure",
@ -416,8 +416,8 @@ func (s) TestControlChannelCredsFailure(t *testing.T) {
DialCreds: clientCreds, DialCreds: clientCreds,
Authority: "x.test.example.com", Authority: "x.test.example.com",
}, },
wantCode: codes.Unavailable, wantCode: codes.Unavailable,
wantErr: "transport: authentication handshake failed: tls: first record does not look like a TLS handshake", wantErrRegex: regexp.MustCompile("transport: authentication handshake failed: .*"),
}, },
{ {
name: "call creds mismatch", name: "call creds mismatch",
@ -432,13 +432,13 @@ func (s) TestControlChannelCredsFailure(t *testing.T) {
}, },
Authority: "x.test.example.com", Authority: "x.test.example.com",
}, },
wantCode: codes.PermissionDenied, wantCode: codes.PermissionDenied,
wantErr: "didn't find call creds", wantErrRegex: regexp.MustCompile("didn't find call creds"),
}, },
} }
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
testControlChannelCredsFailure(t, test.sopts, test.bopts, test.wantCode, test.wantErr) testControlChannelCredsFailure(t, test.sopts, test.bopts, test.wantCode, test.wantErrRegex)
}) })
} }
} }