test: Move creds related to tests to creds_test.go (#3542)
This commit is contained in:

committed by
GitHub

parent
f9ac13d469
commit
d15f1a4aa1
@ -18,15 +18,24 @@
|
||||
|
||||
package test
|
||||
|
||||
// TODO(https://github.com/grpc/grpc-go/issues/2330): move all creds related
|
||||
// tests to this file.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/connectivity"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/resolver"
|
||||
"google.golang.org/grpc/resolver/manual"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/grpc/tap"
|
||||
testpb "google.golang.org/grpc/test/grpc_testing"
|
||||
"google.golang.org/grpc/testdata"
|
||||
)
|
||||
@ -125,3 +134,398 @@ func (s) TestCredsBundlePerRPCCredentials(t *testing.T) {
|
||||
t.Fatalf("Test failed. Reason: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
type clientTimeoutCreds struct {
|
||||
credentials.TransportCredentials
|
||||
timeoutReturned bool
|
||||
}
|
||||
|
||||
func (c *clientTimeoutCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
||||
if !c.timeoutReturned {
|
||||
c.timeoutReturned = true
|
||||
return nil, nil, context.DeadlineExceeded
|
||||
}
|
||||
return rawConn, nil, nil
|
||||
}
|
||||
|
||||
func (c *clientTimeoutCreds) Info() credentials.ProtocolInfo {
|
||||
return credentials.ProtocolInfo{}
|
||||
}
|
||||
|
||||
func (c *clientTimeoutCreds) Clone() credentials.TransportCredentials {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s) TestNonFailFastRPCSucceedOnTimeoutCreds(t *testing.T) {
|
||||
te := newTest(t, env{name: "timeout-cred", network: "tcp", security: "empty", balancer: "v1"})
|
||||
te.userAgent = testAppUA
|
||||
te.startServer(&testServer{security: te.e.security})
|
||||
defer te.tearDown()
|
||||
|
||||
cc := te.clientConn(grpc.WithTransportCredentials(&clientTimeoutCreds{}))
|
||||
tc := testpb.NewTestServiceClient(cc)
|
||||
// This unary call should succeed, because ClientHandshake will succeed for the second time.
|
||||
if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.WaitForReady(true)); err != nil {
|
||||
te.t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want <nil>", err)
|
||||
}
|
||||
}
|
||||
|
||||
type methodTestCreds struct{}
|
||||
|
||||
func (m *methodTestCreds) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
|
||||
ri, _ := credentials.RequestInfoFromContext(ctx)
|
||||
return nil, status.Errorf(codes.Unknown, ri.Method)
|
||||
}
|
||||
|
||||
func (m *methodTestCreds) RequireTransportSecurity() bool { return false }
|
||||
|
||||
func (s) TestGRPCMethodAccessibleToCredsViaContextRequestInfo(t *testing.T) {
|
||||
const wantMethod = "/grpc.testing.TestService/EmptyCall"
|
||||
te := newTest(t, env{name: "context-request-info", network: "tcp", balancer: "v1"})
|
||||
te.userAgent = testAppUA
|
||||
te.startServer(&testServer{security: te.e.security})
|
||||
defer te.tearDown()
|
||||
|
||||
cc := te.clientConn(grpc.WithPerRPCCredentials(&methodTestCreds{}))
|
||||
tc := testpb.NewTestServiceClient(cc)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); status.Convert(err).Message() != wantMethod {
|
||||
t.Fatalf("ss.client.EmptyCall(_, _) = _, %v; want _, _.Message()=%q", err, wantMethod)
|
||||
}
|
||||
}
|
||||
|
||||
const clientAlwaysFailCredErrorMsg = "clientAlwaysFailCred always fails"
|
||||
|
||||
type clientAlwaysFailCred struct {
|
||||
credentials.TransportCredentials
|
||||
}
|
||||
|
||||
func (c clientAlwaysFailCred) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
||||
return nil, nil, errors.New(clientAlwaysFailCredErrorMsg)
|
||||
}
|
||||
func (c clientAlwaysFailCred) Info() credentials.ProtocolInfo {
|
||||
return credentials.ProtocolInfo{}
|
||||
}
|
||||
func (c clientAlwaysFailCred) Clone() credentials.TransportCredentials {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s) TestFailFastRPCErrorOnBadCertificates(t *testing.T) {
|
||||
te := newTest(t, env{name: "bad-cred", network: "tcp", security: "empty", balancer: "round_robin"})
|
||||
te.startServer(&testServer{security: te.e.security})
|
||||
defer te.tearDown()
|
||||
|
||||
opts := []grpc.DialOption{grpc.WithTransportCredentials(clientAlwaysFailCred{})}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
cc, err := grpc.DialContext(ctx, te.srvAddr, opts...)
|
||||
if err != nil {
|
||||
t.Fatalf("Dial(_) = %v, want %v", err, nil)
|
||||
}
|
||||
defer cc.Close()
|
||||
|
||||
tc := testpb.NewTestServiceClient(cc)
|
||||
for i := 0; i < 1000; i++ {
|
||||
// This loop runs for at most 1 second. The first several RPCs will fail
|
||||
// with Unavailable because the connection hasn't started. When the
|
||||
// first connection failed with creds error, the next RPC should also
|
||||
// fail with the expected error.
|
||||
if _, err = tc.EmptyCall(context.Background(), &testpb.Empty{}); strings.Contains(err.Error(), clientAlwaysFailCredErrorMsg) {
|
||||
return
|
||||
}
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
te.t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want err.Error() contains %q", err, clientAlwaysFailCredErrorMsg)
|
||||
}
|
||||
|
||||
func (s) TestWaitForReadyRPCErrorOnBadCertificates(t *testing.T) {
|
||||
te := newTest(t, env{name: "bad-cred", network: "tcp", security: "empty", balancer: "round_robin"})
|
||||
te.startServer(&testServer{security: te.e.security})
|
||||
defer te.tearDown()
|
||||
|
||||
opts := []grpc.DialOption{grpc.WithTransportCredentials(clientAlwaysFailCred{})}
|
||||
dctx, dcancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer dcancel()
|
||||
cc, err := grpc.DialContext(dctx, te.srvAddr, opts...)
|
||||
if err != nil {
|
||||
t.Fatalf("Dial(_) = %v, want %v", err, nil)
|
||||
}
|
||||
defer cc.Close()
|
||||
|
||||
tc := testpb.NewTestServiceClient(cc)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel()
|
||||
if _, err = tc.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); strings.Contains(err.Error(), clientAlwaysFailCredErrorMsg) {
|
||||
return
|
||||
}
|
||||
te.t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want err.Error() contains %q", err, clientAlwaysFailCredErrorMsg)
|
||||
}
|
||||
|
||||
var (
|
||||
// test authdata
|
||||
authdata = map[string]string{
|
||||
"test-key": "test-value",
|
||||
"test-key2-bin": string([]byte{1, 2, 3}),
|
||||
}
|
||||
)
|
||||
|
||||
type testPerRPCCredentials struct{}
|
||||
|
||||
func (cr testPerRPCCredentials) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
|
||||
return authdata, nil
|
||||
}
|
||||
|
||||
func (cr testPerRPCCredentials) RequireTransportSecurity() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func authHandle(ctx context.Context, info *tap.Info) (context.Context, error) {
|
||||
md, ok := metadata.FromIncomingContext(ctx)
|
||||
if !ok {
|
||||
return ctx, fmt.Errorf("didn't find metadata in context")
|
||||
}
|
||||
for k, vwant := range authdata {
|
||||
vgot, ok := md[k]
|
||||
if !ok {
|
||||
return ctx, fmt.Errorf("didn't find authdata key %v in context", k)
|
||||
}
|
||||
if vgot[0] != vwant {
|
||||
return ctx, fmt.Errorf("for key %v, got value %v, want %v", k, vgot, vwant)
|
||||
}
|
||||
}
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
func (s) TestPerRPCCredentialsViaDialOptions(t *testing.T) {
|
||||
for _, e := range listTestEnv() {
|
||||
testPerRPCCredentialsViaDialOptions(t, e)
|
||||
}
|
||||
}
|
||||
|
||||
func testPerRPCCredentialsViaDialOptions(t *testing.T, e env) {
|
||||
te := newTest(t, e)
|
||||
te.tapHandle = authHandle
|
||||
te.perRPCCreds = testPerRPCCredentials{}
|
||||
te.startServer(&testServer{security: e.security})
|
||||
defer te.tearDown()
|
||||
|
||||
cc := te.clientConn()
|
||||
tc := testpb.NewTestServiceClient(cc)
|
||||
if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
|
||||
t.Fatalf("Test failed. Reason: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s) TestPerRPCCredentialsViaCallOptions(t *testing.T) {
|
||||
for _, e := range listTestEnv() {
|
||||
testPerRPCCredentialsViaCallOptions(t, e)
|
||||
}
|
||||
}
|
||||
|
||||
func testPerRPCCredentialsViaCallOptions(t *testing.T, e env) {
|
||||
te := newTest(t, e)
|
||||
te.tapHandle = authHandle
|
||||
te.startServer(&testServer{security: e.security})
|
||||
defer te.tearDown()
|
||||
|
||||
cc := te.clientConn()
|
||||
tc := testpb.NewTestServiceClient(cc)
|
||||
if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.PerRPCCredentials(testPerRPCCredentials{})); err != nil {
|
||||
t.Fatalf("Test failed. Reason: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s) TestPerRPCCredentialsViaDialOptionsAndCallOptions(t *testing.T) {
|
||||
for _, e := range listTestEnv() {
|
||||
testPerRPCCredentialsViaDialOptionsAndCallOptions(t, e)
|
||||
}
|
||||
}
|
||||
|
||||
func testPerRPCCredentialsViaDialOptionsAndCallOptions(t *testing.T, e env) {
|
||||
te := newTest(t, e)
|
||||
te.perRPCCreds = testPerRPCCredentials{}
|
||||
// When credentials are provided via both dial options and call options,
|
||||
// we apply both sets.
|
||||
te.tapHandle = func(ctx context.Context, _ *tap.Info) (context.Context, error) {
|
||||
md, ok := metadata.FromIncomingContext(ctx)
|
||||
if !ok {
|
||||
return ctx, fmt.Errorf("couldn't find metadata in context")
|
||||
}
|
||||
for k, vwant := range authdata {
|
||||
vgot, ok := md[k]
|
||||
if !ok {
|
||||
return ctx, fmt.Errorf("couldn't find metadata for key %v", k)
|
||||
}
|
||||
if len(vgot) != 2 {
|
||||
return ctx, fmt.Errorf("len of value for key %v was %v, want 2", k, len(vgot))
|
||||
}
|
||||
if vgot[0] != vwant || vgot[1] != vwant {
|
||||
return ctx, fmt.Errorf("value for %v was %v, want [%v, %v]", k, vgot, vwant, vwant)
|
||||
}
|
||||
}
|
||||
return ctx, nil
|
||||
}
|
||||
te.startServer(&testServer{security: e.security})
|
||||
defer te.tearDown()
|
||||
|
||||
cc := te.clientConn()
|
||||
tc := testpb.NewTestServiceClient(cc)
|
||||
if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.PerRPCCredentials(testPerRPCCredentials{})); err != nil {
|
||||
t.Fatalf("Test failed. Reason: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
const testAuthority = "test.auth.ori.ty"
|
||||
|
||||
type authorityCheckCreds struct {
|
||||
credentials.TransportCredentials
|
||||
got string
|
||||
}
|
||||
|
||||
func (c *authorityCheckCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
||||
c.got = authority
|
||||
return rawConn, nil, nil
|
||||
}
|
||||
func (c *authorityCheckCreds) Info() credentials.ProtocolInfo {
|
||||
return credentials.ProtocolInfo{}
|
||||
}
|
||||
func (c *authorityCheckCreds) Clone() credentials.TransportCredentials {
|
||||
return c
|
||||
}
|
||||
|
||||
// This test makes sure that the authority client handshake gets is the endpoint
|
||||
// in dial target, not the resolved ip address.
|
||||
func (s) TestCredsHandshakeAuthority(t *testing.T) {
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
cred := &authorityCheckCreds{}
|
||||
s := grpc.NewServer()
|
||||
go s.Serve(lis)
|
||||
defer s.Stop()
|
||||
|
||||
r, rcleanup := manual.GenerateAndRegisterManualResolver()
|
||||
defer rcleanup()
|
||||
|
||||
cc, err := grpc.Dial(r.Scheme()+":///"+testAuthority, grpc.WithTransportCredentials(cred))
|
||||
if err != nil {
|
||||
t.Fatalf("grpc.Dial(%q) = %v", lis.Addr().String(), err)
|
||||
}
|
||||
defer cc.Close()
|
||||
r.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: lis.Addr().String()}}})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
for {
|
||||
s := cc.GetState()
|
||||
if s == connectivity.Ready {
|
||||
break
|
||||
}
|
||||
if !cc.WaitForStateChange(ctx, s) {
|
||||
t.Fatalf("ClientConn is not ready after 100 ms")
|
||||
}
|
||||
}
|
||||
|
||||
if cred.got != testAuthority {
|
||||
t.Fatalf("client creds got authority: %q, want: %q", cred.got, testAuthority)
|
||||
}
|
||||
}
|
||||
|
||||
// This test makes sure that the authority client handshake gets is the endpoint
|
||||
// of the ServerName of the address when it is set.
|
||||
func (s) TestCredsHandshakeServerNameAuthority(t *testing.T) {
|
||||
const testServerName = "test.server.name"
|
||||
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
cred := &authorityCheckCreds{}
|
||||
s := grpc.NewServer()
|
||||
go s.Serve(lis)
|
||||
defer s.Stop()
|
||||
|
||||
r, rcleanup := manual.GenerateAndRegisterManualResolver()
|
||||
defer rcleanup()
|
||||
|
||||
cc, err := grpc.Dial(r.Scheme()+":///"+testAuthority, grpc.WithTransportCredentials(cred))
|
||||
if err != nil {
|
||||
t.Fatalf("grpc.Dial(%q) = %v", lis.Addr().String(), err)
|
||||
}
|
||||
defer cc.Close()
|
||||
r.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: lis.Addr().String(), ServerName: testServerName}}})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
for {
|
||||
s := cc.GetState()
|
||||
if s == connectivity.Ready {
|
||||
break
|
||||
}
|
||||
if !cc.WaitForStateChange(ctx, s) {
|
||||
t.Fatalf("ClientConn is not ready after 100 ms")
|
||||
}
|
||||
}
|
||||
|
||||
if cred.got != testServerName {
|
||||
t.Fatalf("client creds got authority: %q, want: %q", cred.got, testAuthority)
|
||||
}
|
||||
}
|
||||
|
||||
type serverDispatchCred struct {
|
||||
rawConnCh chan net.Conn
|
||||
}
|
||||
|
||||
func (c *serverDispatchCred) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
||||
return rawConn, nil, nil
|
||||
}
|
||||
func (c *serverDispatchCred) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
||||
select {
|
||||
case c.rawConnCh <- rawConn:
|
||||
default:
|
||||
}
|
||||
return nil, nil, credentials.ErrConnDispatched
|
||||
}
|
||||
func (c *serverDispatchCred) Info() credentials.ProtocolInfo {
|
||||
return credentials.ProtocolInfo{}
|
||||
}
|
||||
func (c *serverDispatchCred) Clone() credentials.TransportCredentials {
|
||||
return nil
|
||||
}
|
||||
func (c *serverDispatchCred) OverrideServerName(s string) error {
|
||||
return nil
|
||||
}
|
||||
func (c *serverDispatchCred) getRawConn() net.Conn {
|
||||
return <-c.rawConnCh
|
||||
}
|
||||
|
||||
func (s) TestServerCredsDispatch(t *testing.T) {
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
cred := &serverDispatchCred{
|
||||
rawConnCh: make(chan net.Conn, 1),
|
||||
}
|
||||
s := grpc.NewServer(grpc.Creds(cred))
|
||||
go s.Serve(lis)
|
||||
defer s.Stop()
|
||||
|
||||
cc, err := grpc.Dial(lis.Addr().String(), grpc.WithTransportCredentials(cred))
|
||||
if err != nil {
|
||||
t.Fatalf("grpc.Dial(%q) = %v", lis.Addr().String(), err)
|
||||
}
|
||||
defer cc.Close()
|
||||
|
||||
rawConn := cred.getRawConn()
|
||||
// Give grpc a chance to see the error and potentially close the connection.
|
||||
// And check that connection is not closed after that.
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
// Check rawConn is not closed.
|
||||
if n, err := rawConn.Write([]byte{0}); n <= 0 || err != nil {
|
||||
t.Errorf("Read() = %v, %v; want n>0, <nil>", n, err)
|
||||
}
|
||||
}
|
||||
|
@ -619,15 +619,12 @@ func (te *test) listenAndServe(ts testpb.TestServiceServer, listen func(network,
|
||||
if err != nil {
|
||||
te.t.Fatalf("Failed to listen: %v", err)
|
||||
}
|
||||
switch te.e.security {
|
||||
case "tls":
|
||||
if te.e.security == "tls" {
|
||||
creds, err := credentials.NewServerTLSFromFile(testdata.Path("server1.pem"), testdata.Path("server1.key"))
|
||||
if err != nil {
|
||||
te.t.Fatalf("Failed to generate credentials %v", err)
|
||||
}
|
||||
sopts = append(sopts, grpc.Creds(creds))
|
||||
case "clientTimeoutCreds":
|
||||
sopts = append(sopts, grpc.Creds(&clientTimeoutCreds{}))
|
||||
}
|
||||
sopts = append(sopts, te.customServerOptions...)
|
||||
s := grpc.NewServer(sopts...)
|
||||
@ -803,8 +800,6 @@ func (te *test) configDial(opts ...grpc.DialOption) ([]grpc.DialOption, string)
|
||||
te.t.Fatalf("Failed to load credentials: %v", err)
|
||||
}
|
||||
opts = append(opts, grpc.WithTransportCredentials(creds))
|
||||
case "clientTimeoutCreds":
|
||||
opts = append(opts, grpc.WithTransportCredentials(&clientTimeoutCreds{}))
|
||||
case "empty":
|
||||
// Don't add any transport creds option.
|
||||
default:
|
||||
@ -4720,250 +4715,6 @@ func testClientResourceExhaustedCancelFullDuplex(t *testing.T, e env) {
|
||||
}
|
||||
}
|
||||
|
||||
type clientTimeoutCreds struct {
|
||||
timeoutReturned bool
|
||||
}
|
||||
|
||||
func (c *clientTimeoutCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
||||
if !c.timeoutReturned {
|
||||
c.timeoutReturned = true
|
||||
return nil, nil, context.DeadlineExceeded
|
||||
}
|
||||
return rawConn, nil, nil
|
||||
}
|
||||
func (c *clientTimeoutCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
||||
return rawConn, nil, nil
|
||||
}
|
||||
func (c *clientTimeoutCreds) Info() credentials.ProtocolInfo {
|
||||
return credentials.ProtocolInfo{}
|
||||
}
|
||||
func (c *clientTimeoutCreds) Clone() credentials.TransportCredentials {
|
||||
return nil
|
||||
}
|
||||
func (c *clientTimeoutCreds) OverrideServerName(s string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s) TestNonFailFastRPCSucceedOnTimeoutCreds(t *testing.T) {
|
||||
te := newTest(t, env{name: "timeout-cred", network: "tcp", security: "clientTimeoutCreds", balancer: "v1"})
|
||||
te.userAgent = testAppUA
|
||||
te.startServer(&testServer{security: te.e.security})
|
||||
defer te.tearDown()
|
||||
|
||||
cc := te.clientConn()
|
||||
tc := testpb.NewTestServiceClient(cc)
|
||||
// This unary call should succeed, because ClientHandshake will succeed for the second time.
|
||||
if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.WaitForReady(true)); err != nil {
|
||||
te.t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want <nil>", err)
|
||||
}
|
||||
}
|
||||
|
||||
type serverDispatchCred struct {
|
||||
rawConnCh chan net.Conn
|
||||
}
|
||||
|
||||
func newServerDispatchCred() *serverDispatchCred {
|
||||
return &serverDispatchCred{
|
||||
rawConnCh: make(chan net.Conn, 1),
|
||||
}
|
||||
}
|
||||
func (c *serverDispatchCred) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
||||
return rawConn, nil, nil
|
||||
}
|
||||
func (c *serverDispatchCred) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
||||
select {
|
||||
case c.rawConnCh <- rawConn:
|
||||
default:
|
||||
}
|
||||
return nil, nil, credentials.ErrConnDispatched
|
||||
}
|
||||
func (c *serverDispatchCred) Info() credentials.ProtocolInfo {
|
||||
return credentials.ProtocolInfo{}
|
||||
}
|
||||
func (c *serverDispatchCred) Clone() credentials.TransportCredentials {
|
||||
return nil
|
||||
}
|
||||
func (c *serverDispatchCred) OverrideServerName(s string) error {
|
||||
return nil
|
||||
}
|
||||
func (c *serverDispatchCred) getRawConn() net.Conn {
|
||||
return <-c.rawConnCh
|
||||
}
|
||||
|
||||
func (s) TestServerCredsDispatch(t *testing.T) {
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to listen: %v", err)
|
||||
}
|
||||
cred := newServerDispatchCred()
|
||||
s := grpc.NewServer(grpc.Creds(cred))
|
||||
go s.Serve(lis)
|
||||
defer s.Stop()
|
||||
|
||||
cc, err := grpc.Dial(lis.Addr().String(), grpc.WithTransportCredentials(cred))
|
||||
if err != nil {
|
||||
t.Fatalf("grpc.Dial(%q) = %v", lis.Addr().String(), err)
|
||||
}
|
||||
defer cc.Close()
|
||||
|
||||
rawConn := cred.getRawConn()
|
||||
// Give grpc a chance to see the error and potentially close the connection.
|
||||
// And check that connection is not closed after that.
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
// Check rawConn is not closed.
|
||||
if n, err := rawConn.Write([]byte{0}); n <= 0 || err != nil {
|
||||
t.Errorf("Read() = %v, %v; want n>0, <nil>", n, err)
|
||||
}
|
||||
}
|
||||
|
||||
type authorityCheckCreds struct {
|
||||
got string
|
||||
}
|
||||
|
||||
func (c *authorityCheckCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
||||
return rawConn, nil, nil
|
||||
}
|
||||
func (c *authorityCheckCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
||||
c.got = authority
|
||||
return rawConn, nil, nil
|
||||
}
|
||||
func (c *authorityCheckCreds) Info() credentials.ProtocolInfo {
|
||||
return credentials.ProtocolInfo{}
|
||||
}
|
||||
func (c *authorityCheckCreds) Clone() credentials.TransportCredentials {
|
||||
return c
|
||||
}
|
||||
func (c *authorityCheckCreds) OverrideServerName(s string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// This test makes sure that the authority client handshake gets is the endpoint
|
||||
// in dial target, not the resolved ip address.
|
||||
func (s) TestCredsHandshakeAuthority(t *testing.T) {
|
||||
const testAuthority = "test.auth.ori.ty"
|
||||
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to listen: %v", err)
|
||||
}
|
||||
cred := &authorityCheckCreds{}
|
||||
s := grpc.NewServer()
|
||||
go s.Serve(lis)
|
||||
defer s.Stop()
|
||||
|
||||
r, rcleanup := manual.GenerateAndRegisterManualResolver()
|
||||
defer rcleanup()
|
||||
|
||||
cc, err := grpc.Dial(r.Scheme()+":///"+testAuthority, grpc.WithTransportCredentials(cred))
|
||||
if err != nil {
|
||||
t.Fatalf("grpc.Dial(%q) = %v", lis.Addr().String(), err)
|
||||
}
|
||||
defer cc.Close()
|
||||
r.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: lis.Addr().String()}}})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
for {
|
||||
s := cc.GetState()
|
||||
if s == connectivity.Ready {
|
||||
break
|
||||
}
|
||||
if !cc.WaitForStateChange(ctx, s) {
|
||||
// ctx got timeout or canceled.
|
||||
t.Fatalf("ClientConn is not ready after 100 ms")
|
||||
}
|
||||
}
|
||||
|
||||
if cred.got != testAuthority {
|
||||
t.Fatalf("client creds got authority: %q, want: %q", cred.got, testAuthority)
|
||||
}
|
||||
}
|
||||
|
||||
// This test makes sure that the authority client handshake gets is the endpoint
|
||||
// of the ServerName of the address when it is set.
|
||||
func (s) TestCredsHandshakeServerNameAuthority(t *testing.T) {
|
||||
const testAuthority = "test.auth.ori.ty"
|
||||
const testServerName = "test.server.name"
|
||||
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to listen: %v", err)
|
||||
}
|
||||
cred := &authorityCheckCreds{}
|
||||
s := grpc.NewServer()
|
||||
go s.Serve(lis)
|
||||
defer s.Stop()
|
||||
|
||||
r, rcleanup := manual.GenerateAndRegisterManualResolver()
|
||||
defer rcleanup()
|
||||
|
||||
cc, err := grpc.Dial(r.Scheme()+":///"+testAuthority, grpc.WithTransportCredentials(cred))
|
||||
if err != nil {
|
||||
t.Fatalf("grpc.Dial(%q) = %v", lis.Addr().String(), err)
|
||||
}
|
||||
defer cc.Close()
|
||||
r.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: lis.Addr().String(), ServerName: testServerName}}})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
for {
|
||||
s := cc.GetState()
|
||||
if s == connectivity.Ready {
|
||||
break
|
||||
}
|
||||
if !cc.WaitForStateChange(ctx, s) {
|
||||
// ctx got timeout or canceled.
|
||||
t.Fatalf("ClientConn is not ready after 100 ms")
|
||||
}
|
||||
}
|
||||
|
||||
if cred.got != testServerName {
|
||||
t.Fatalf("client creds got authority: %q, want: %q", cred.got, testAuthority)
|
||||
}
|
||||
}
|
||||
|
||||
type clientFailCreds struct{}
|
||||
|
||||
func (c *clientFailCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
||||
return rawConn, nil, nil
|
||||
}
|
||||
func (c *clientFailCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
||||
return nil, nil, fmt.Errorf("client handshake fails with fatal error")
|
||||
}
|
||||
func (c *clientFailCreds) Info() credentials.ProtocolInfo {
|
||||
return credentials.ProtocolInfo{}
|
||||
}
|
||||
func (c *clientFailCreds) Clone() credentials.TransportCredentials {
|
||||
return c
|
||||
}
|
||||
func (c *clientFailCreds) OverrideServerName(s string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// This test makes sure that failfast RPCs fail if client handshake fails with
|
||||
// fatal errors.
|
||||
func (s) TestFailfastRPCFailOnFatalHandshakeError(t *testing.T) {
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to listen: %v", err)
|
||||
}
|
||||
defer lis.Close()
|
||||
|
||||
cc, err := grpc.Dial("passthrough:///"+lis.Addr().String(), grpc.WithTransportCredentials(&clientFailCreds{}))
|
||||
if err != nil {
|
||||
t.Fatalf("grpc.Dial(_) = %v", err)
|
||||
}
|
||||
defer cc.Close()
|
||||
|
||||
tc := testpb.NewTestServiceClient(cc)
|
||||
// This unary call should fail, but not timeout.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(false)); status.Code(err) != codes.Unavailable {
|
||||
t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want <Unavailable>", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s) TestFlowControlLogicalRace(t *testing.T) {
|
||||
// Test for a regression of https://github.com/grpc/grpc-go/issues/632,
|
||||
// and other flow control bugs.
|
||||
@ -5633,120 +5384,6 @@ func testConfigurableWindowSize(t *testing.T, e env, wc windowSizeConfig) {
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
// test authdata
|
||||
authdata = map[string]string{
|
||||
"test-key": "test-value",
|
||||
"test-key2-bin": string([]byte{1, 2, 3}),
|
||||
}
|
||||
)
|
||||
|
||||
type testPerRPCCredentials struct{}
|
||||
|
||||
func (cr testPerRPCCredentials) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
|
||||
return authdata, nil
|
||||
}
|
||||
|
||||
func (cr testPerRPCCredentials) RequireTransportSecurity() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func authHandle(ctx context.Context, info *tap.Info) (context.Context, error) {
|
||||
md, ok := metadata.FromIncomingContext(ctx)
|
||||
if !ok {
|
||||
return ctx, fmt.Errorf("didn't find metadata in context")
|
||||
}
|
||||
for k, vwant := range authdata {
|
||||
vgot, ok := md[k]
|
||||
if !ok {
|
||||
return ctx, fmt.Errorf("didn't find authdata key %v in context", k)
|
||||
}
|
||||
if vgot[0] != vwant {
|
||||
return ctx, fmt.Errorf("for key %v, got value %v, want %v", k, vgot, vwant)
|
||||
}
|
||||
}
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
func (s) TestPerRPCCredentialsViaDialOptions(t *testing.T) {
|
||||
for _, e := range listTestEnv() {
|
||||
testPerRPCCredentialsViaDialOptions(t, e)
|
||||
}
|
||||
}
|
||||
|
||||
func testPerRPCCredentialsViaDialOptions(t *testing.T, e env) {
|
||||
te := newTest(t, e)
|
||||
te.tapHandle = authHandle
|
||||
te.perRPCCreds = testPerRPCCredentials{}
|
||||
te.startServer(&testServer{security: e.security})
|
||||
defer te.tearDown()
|
||||
|
||||
cc := te.clientConn()
|
||||
tc := testpb.NewTestServiceClient(cc)
|
||||
if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
|
||||
t.Fatalf("Test failed. Reason: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s) TestPerRPCCredentialsViaCallOptions(t *testing.T) {
|
||||
for _, e := range listTestEnv() {
|
||||
testPerRPCCredentialsViaCallOptions(t, e)
|
||||
}
|
||||
}
|
||||
|
||||
func testPerRPCCredentialsViaCallOptions(t *testing.T, e env) {
|
||||
te := newTest(t, e)
|
||||
te.tapHandle = authHandle
|
||||
te.startServer(&testServer{security: e.security})
|
||||
defer te.tearDown()
|
||||
|
||||
cc := te.clientConn()
|
||||
tc := testpb.NewTestServiceClient(cc)
|
||||
if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.PerRPCCredentials(testPerRPCCredentials{})); err != nil {
|
||||
t.Fatalf("Test failed. Reason: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s) TestPerRPCCredentialsViaDialOptionsAndCallOptions(t *testing.T) {
|
||||
for _, e := range listTestEnv() {
|
||||
testPerRPCCredentialsViaDialOptionsAndCallOptions(t, e)
|
||||
}
|
||||
}
|
||||
|
||||
func testPerRPCCredentialsViaDialOptionsAndCallOptions(t *testing.T, e env) {
|
||||
te := newTest(t, e)
|
||||
te.perRPCCreds = testPerRPCCredentials{}
|
||||
// When credentials are provided via both dial options and call options,
|
||||
// we apply both sets.
|
||||
te.tapHandle = func(ctx context.Context, _ *tap.Info) (context.Context, error) {
|
||||
md, ok := metadata.FromIncomingContext(ctx)
|
||||
if !ok {
|
||||
return ctx, fmt.Errorf("couldn't find metadata in context")
|
||||
}
|
||||
for k, vwant := range authdata {
|
||||
vgot, ok := md[k]
|
||||
if !ok {
|
||||
return ctx, fmt.Errorf("couldn't find metadata for key %v", k)
|
||||
}
|
||||
if len(vgot) != 2 {
|
||||
return ctx, fmt.Errorf("len of value for key %v was %v, want 2", k, len(vgot))
|
||||
}
|
||||
if vgot[0] != vwant || vgot[1] != vwant {
|
||||
return ctx, fmt.Errorf("value for %v was %v, want [%v, %v]", k, vgot, vwant, vwant)
|
||||
}
|
||||
}
|
||||
return ctx, nil
|
||||
}
|
||||
te.startServer(&testServer{security: e.security})
|
||||
defer te.tearDown()
|
||||
|
||||
cc := te.clientConn()
|
||||
tc := testpb.NewTestServiceClient(cc)
|
||||
if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.PerRPCCredentials(testPerRPCCredentials{})); err != nil {
|
||||
t.Fatalf("Test failed. Reason: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s) TestWaitForReadyConnection(t *testing.T) {
|
||||
for _, e := range listTestEnv() {
|
||||
testWaitForReadyConnection(t, e)
|
||||
@ -6637,79 +6274,6 @@ func testClientDoesntDeadlockWhileWritingErrornousLargeMessages(t *testing.T, e
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
const clientAlwaysFailCredErrorMsg = "clientAlwaysFailCred always fails"
|
||||
|
||||
var errClientAlwaysFailCred = errors.New(clientAlwaysFailCredErrorMsg)
|
||||
|
||||
type clientAlwaysFailCred struct{}
|
||||
|
||||
func (c clientAlwaysFailCred) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
||||
return nil, nil, errClientAlwaysFailCred
|
||||
}
|
||||
func (c clientAlwaysFailCred) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
||||
return rawConn, nil, nil
|
||||
}
|
||||
func (c clientAlwaysFailCred) Info() credentials.ProtocolInfo {
|
||||
return credentials.ProtocolInfo{}
|
||||
}
|
||||
func (c clientAlwaysFailCred) Clone() credentials.TransportCredentials {
|
||||
return nil
|
||||
}
|
||||
func (c clientAlwaysFailCred) OverrideServerName(s string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s) TestFailFastRPCErrorOnBadCertificates(t *testing.T) {
|
||||
te := newTest(t, env{name: "bad-cred", network: "tcp", security: "clientAlwaysFailCred", balancer: "round_robin"})
|
||||
te.startServer(&testServer{security: te.e.security})
|
||||
defer te.tearDown()
|
||||
|
||||
opts := []grpc.DialOption{grpc.WithTransportCredentials(clientAlwaysFailCred{})}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
cc, err := grpc.DialContext(ctx, te.srvAddr, opts...)
|
||||
if err != nil {
|
||||
t.Fatalf("Dial(_) = %v, want %v", err, nil)
|
||||
}
|
||||
defer cc.Close()
|
||||
|
||||
tc := testpb.NewTestServiceClient(cc)
|
||||
for i := 0; i < 1000; i++ {
|
||||
// This loop runs for at most 1 second. The first several RPCs will fail
|
||||
// with Unavailable because the connection hasn't started. When the
|
||||
// first connection failed with creds error, the next RPC should also
|
||||
// fail with the expected error.
|
||||
if _, err = tc.EmptyCall(context.Background(), &testpb.Empty{}); strings.Contains(err.Error(), clientAlwaysFailCredErrorMsg) {
|
||||
return
|
||||
}
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
te.t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want err.Error() contains %q", err, clientAlwaysFailCredErrorMsg)
|
||||
}
|
||||
|
||||
func (s) TestWaitForReadyRPCErrorOnBadCertificates(t *testing.T) {
|
||||
te := newTest(t, env{name: "bad-cred", network: "tcp", security: "clientAlwaysFailCred", balancer: "round_robin"})
|
||||
te.startServer(&testServer{security: te.e.security})
|
||||
defer te.tearDown()
|
||||
|
||||
opts := []grpc.DialOption{grpc.WithTransportCredentials(clientAlwaysFailCred{})}
|
||||
dctx, dcancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer dcancel()
|
||||
cc, err := grpc.DialContext(dctx, te.srvAddr, opts...)
|
||||
if err != nil {
|
||||
t.Fatalf("Dial(_) = %v, want %v", err, nil)
|
||||
}
|
||||
defer cc.Close()
|
||||
|
||||
tc := testpb.NewTestServiceClient(cc)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel()
|
||||
if _, err = tc.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); strings.Contains(err.Error(), clientAlwaysFailCredErrorMsg) {
|
||||
return
|
||||
}
|
||||
te.t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want err.Error() contains %q", err, clientAlwaysFailCredErrorMsg)
|
||||
}
|
||||
|
||||
func (s) TestRPCTimeout(t *testing.T) {
|
||||
for _, e := range listTestEnv() {
|
||||
testRPCTimeout(t, e)
|
||||
@ -6747,7 +6311,6 @@ func testRPCTimeout(t *testing.T, e env) {
|
||||
}
|
||||
|
||||
func (s) TestDisabledIOBuffers(t *testing.T) {
|
||||
|
||||
payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, int32(60000))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create payload: %v", err)
|
||||
@ -7439,33 +7002,6 @@ func parseCfg(r *manual.Resolver, s string) *serviceconfig.ParseResult {
|
||||
return g
|
||||
}
|
||||
|
||||
type methodTestCreds struct{}
|
||||
|
||||
func (m methodTestCreds) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
|
||||
ri, _ := credentials.RequestInfoFromContext(ctx)
|
||||
return nil, status.Errorf(codes.Unknown, ri.Method)
|
||||
}
|
||||
|
||||
func (m methodTestCreds) RequireTransportSecurity() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (s) TestGRPCMethodAccessibleToCredsViaContextRequestInfo(t *testing.T) {
|
||||
const wantMethod = "/grpc.testing.TestService/EmptyCall"
|
||||
ss := &stubServer{}
|
||||
if err := ss.Start(nil, grpc.WithPerRPCCredentials(methodTestCreds{})); err != nil {
|
||||
t.Fatalf("Error starting endpoint server: %v", err)
|
||||
}
|
||||
defer ss.Stop()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if _, err := ss.client.EmptyCall(ctx, &testpb.Empty{}); status.Convert(err).Message() != wantMethod {
|
||||
t.Fatalf("ss.client.EmptyCall(_, _) = _, %v; want _, _.Message()=%q", err, wantMethod)
|
||||
}
|
||||
}
|
||||
|
||||
func (s) TestClientCancellationPropagatesUnary(t *testing.T) {
|
||||
wg := &sync.WaitGroup{}
|
||||
called, done := make(chan struct{}), make(chan struct{})
|
||||
|
Reference in New Issue
Block a user