From a2a5ae5d4dce1d62d4f6f0fcd42e91de766eda46 Mon Sep 17 00:00:00 2001 From: Cesar Ghali Date: Fri, 20 Dec 2019 10:01:56 -0800 Subject: [PATCH] credentials/alts: Add Client Authorization Utility API (#3271) Add client authorization util API --- credentials/alts/utils.go | 19 +++++++++++ credentials/alts/utils_test.go | 59 ++++++++++++++++++++++++++++++++-- 2 files changed, 76 insertions(+), 2 deletions(-) diff --git a/credentials/alts/utils.go b/credentials/alts/utils.go index f13aeef1..e46280ad 100644 --- a/credentials/alts/utils.go +++ b/credentials/alts/utils.go @@ -31,7 +31,9 @@ import ( "runtime" "strings" + "google.golang.org/grpc/codes" "google.golang.org/grpc/peer" + "google.golang.org/grpc/status" ) const ( @@ -142,3 +144,20 @@ func AuthInfoFromPeer(p *peer.Peer) (AuthInfo, error) { } return altsAuthInfo, nil } + +// ClientAuthorizationCheck checks whether the client is authorized to access +// the requested resources based on the given expected client service accounts. +// This API should be used by gRPC server RPC handlers. This API should not be +// used by clients. +func ClientAuthorizationCheck(ctx context.Context, expectedServiceAccounts []string) error { + authInfo, err := AuthInfoFromContext(ctx) + if err != nil { + return status.Newf(codes.PermissionDenied, "The context is not an ALTS-compatible context: %v", err).Err() + } + for _, sa := range expectedServiceAccounts { + if authInfo.PeerServiceAccount() == sa { + return nil + } + } + return status.Newf(codes.PermissionDenied, "Client %v is not authorized", authInfo.PeerServiceAccount()).Err() +} diff --git a/credentials/alts/utils_test.go b/credentials/alts/utils_test.go index 8935c5fb..6b8adfb0 100644 --- a/credentials/alts/utils_test.go +++ b/credentials/alts/utils_test.go @@ -25,8 +25,16 @@ import ( "strings" "testing" + "google.golang.org/grpc/codes" altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp" "google.golang.org/grpc/peer" + "google.golang.org/grpc/status" +) + +const ( + testServiceAccount1 = "service_account1" + testServiceAccount2 = "service_account2" + testServiceAccount3 = "service_account3" ) func setupManufacturerReader(testOS string, reader func() (io.Reader, error)) func() { @@ -147,7 +155,54 @@ func TestAuthInfoFromPeer(t *testing.T) { } } -type fakeALTSAuthInfo struct{} +func TestClientAuthorizationCheck(t *testing.T) { + ctx := context.Background() + altsAuthInfo := &fakeALTSAuthInfo{testServiceAccount1} + p := &peer.Peer{ + AuthInfo: altsAuthInfo, + } + for _, tc := range []struct { + desc string + ctx context.Context + expectedServiceAccounts []string + success bool + code codes.Code + }{ + { + "working case", + peer.NewContext(ctx, p), + []string{testServiceAccount1, testServiceAccount2}, + true, + codes.OK, // err is nil, code is OK. + }, + { + "context does not have AuthInfo", + ctx, + []string{testServiceAccount1, testServiceAccount2}, + false, + codes.PermissionDenied, + }, + { + "unauthorized client", + peer.NewContext(ctx, p), + []string{testServiceAccount2, testServiceAccount3}, + false, + codes.PermissionDenied, + }, + } { + err := ClientAuthorizationCheck(tc.ctx, tc.expectedServiceAccounts) + if got, want := (err == nil), tc.success; got != want { + t.Errorf("%v: ClientAuthorizationCheck(_, %v)=(err=nil)=%v, want %v", tc.desc, tc.expectedServiceAccounts, got, want) + } + if got, want := status.Code(err), tc.code; got != want { + t.Errorf("%v: ClientAuthorizationCheck(_, %v).Code=%v, want %v", tc.desc, tc.expectedServiceAccounts, got, want) + } + } +} + +type fakeALTSAuthInfo struct { + peerServiceAccount string +} func (*fakeALTSAuthInfo) AuthType() string { return "" } func (*fakeALTSAuthInfo) ApplicationProtocol() string { return "" } @@ -155,6 +210,6 @@ func (*fakeALTSAuthInfo) RecordProtocol() string { return "" } func (*fakeALTSAuthInfo) SecurityLevel() altspb.SecurityLevel { return altspb.SecurityLevel_SECURITY_NONE } -func (*fakeALTSAuthInfo) PeerServiceAccount() string { return "" } +func (f *fakeALTSAuthInfo) PeerServiceAccount() string { return f.peerServiceAccount } func (*fakeALTSAuthInfo) LocalServiceAccount() string { return "" } func (*fakeALTSAuthInfo) PeerRPCVersions() *altspb.RpcProtocolVersions { return nil }