authz: create file watcher interceptor for gRPC SDK API (#4760)

* authz: create file watcher interceptor for gRPC SDK API
This commit is contained in:
Ashitha Santhosh
2021-10-08 17:09:55 -07:00
committed by GitHub
parent 03ca7b7d00
commit b99d1040b7
3 changed files with 455 additions and 52 deletions

View File

@ -21,13 +21,16 @@ package authz_test
import (
"context"
"io"
"io/ioutil"
"net"
"os"
"testing"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/authz"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
pb "google.golang.org/grpc/test/grpc_testing"
@ -53,15 +56,21 @@ func (s *testServer) StreamingInputCall(stream pb.TestService_StreamingInputCall
}
}
func TestSDKEnd2End(t *testing.T) {
tests := map[string]struct {
authzPolicy string
md metadata.MD
wantStatusCode codes.Code
wantErr string
}{
"DeniesRpcRequestMatchInDenyNoMatchInAllow": {
authzPolicy: `{
type s struct {
grpctest.Tester
}
func Test(t *testing.T) {
grpctest.RunSubTests(t, s{})
}
var sdkTests = map[string]struct {
authzPolicy string
md metadata.MD
wantStatus *status.Status
}{
"DeniesRpcMatchInDenyNoMatchInAllow": {
authzPolicy: `{
"name": "authz",
"allow_rules":
[
@ -100,12 +109,11 @@ func TestSDKEnd2End(t *testing.T) {
}
]
}`,
md: metadata.Pairs("key-abc", "val-abc"),
wantStatusCode: codes.PermissionDenied,
wantErr: "unauthorized RPC request rejected",
},
"DeniesRpcRequestMatchInDenyAndAllow": {
authzPolicy: `{
md: metadata.Pairs("key-abc", "val-abc"),
wantStatus: status.New(codes.PermissionDenied, "unauthorized RPC request rejected"),
},
"DeniesRpcMatchInDenyAndAllow": {
authzPolicy: `{
"name": "authz",
"allow_rules":
[
@ -132,11 +140,10 @@ func TestSDKEnd2End(t *testing.T) {
}
]
}`,
wantStatusCode: codes.PermissionDenied,
wantErr: "unauthorized RPC request rejected",
},
"AllowsRpcRequestNoMatchInDenyMatchInAllow": {
authzPolicy: `{
wantStatus: status.New(codes.PermissionDenied, "unauthorized RPC request rejected"),
},
"AllowsRpcNoMatchInDenyMatchInAllow": {
authzPolicy: `{
"name": "authz",
"allow_rules":
[
@ -169,11 +176,11 @@ func TestSDKEnd2End(t *testing.T) {
}
]
}`,
md: metadata.Pairs("key-xyz", "val-xyz"),
wantStatusCode: codes.OK,
},
"AllowsRpcRequestNoMatchInDenyAndAllow": {
authzPolicy: `{
md: metadata.Pairs("key-xyz", "val-xyz"),
wantStatus: status.New(codes.OK, ""),
},
"DeniesRpcNoMatchInDenyAndAllow": {
authzPolicy: `{
"name": "authz",
"allow_rules":
[
@ -200,11 +207,10 @@ func TestSDKEnd2End(t *testing.T) {
}
]
}`,
wantStatusCode: codes.PermissionDenied,
wantErr: "unauthorized RPC request rejected",
},
"AllowsRpcRequestEmptyDenyMatchInAllow": {
authzPolicy: `{
wantStatus: status.New(codes.PermissionDenied, "unauthorized RPC request rejected"),
},
"AllowsRpcEmptyDenyMatchInAllow": {
authzPolicy: `{
"name": "authz",
"allow_rules":
[
@ -230,10 +236,10 @@ func TestSDKEnd2End(t *testing.T) {
}
]
}`,
wantStatusCode: codes.OK,
},
"DeniesRpcRequestEmptyDenyNoMatchInAllow": {
authzPolicy: `{
wantStatus: status.New(codes.OK, ""),
},
"DeniesRpcEmptyDenyNoMatchInAllow": {
authzPolicy: `{
"name": "authz",
"allow_rules":
[
@ -249,22 +255,25 @@ func TestSDKEnd2End(t *testing.T) {
}
]
}`,
wantStatusCode: codes.PermissionDenied,
wantErr: "unauthorized RPC request rejected",
},
}
for name, test := range tests {
wantStatus: status.New(codes.PermissionDenied, "unauthorized RPC request rejected"),
},
}
func (s) TestSDKStaticPolicyEnd2End(t *testing.T) {
for name, test := range sdkTests {
t.Run(name, func(t *testing.T) {
// Start a gRPC server with SDK unary and stream server interceptors.
i, _ := authz.NewStatic(test.authzPolicy)
s := grpc.NewServer(
grpc.ChainUnaryInterceptor(i.UnaryInterceptor),
grpc.ChainStreamInterceptor(i.StreamInterceptor))
defer s.Stop()
pb.RegisterTestServiceServer(s, &testServer{})
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("error listening: %v", err)
}
s := grpc.NewServer(
grpc.ChainUnaryInterceptor(i.UnaryInterceptor),
grpc.ChainStreamInterceptor(i.StreamInterceptor))
pb.RegisterTestServiceServer(s, &testServer{})
go s.Serve(lis)
// Establish a connection to the server.
@ -281,8 +290,8 @@ func TestSDKEnd2End(t *testing.T) {
// Verifying authorization decision for Unary RPC.
_, err = client.UnaryCall(ctx, &pb.SimpleRequest{})
if got := status.Convert(err); got.Code() != test.wantStatusCode || got.Message() != test.wantErr {
t.Fatalf("[UnaryCall] error want:{%v %v} got:{%v %v}", test.wantStatusCode, test.wantErr, got.Code(), got.Message())
if got := status.Convert(err); got.Code() != test.wantStatus.Code() || got.Message() != test.wantStatus.Message() {
t.Fatalf("[UnaryCall] error want:{%v} got:{%v}", test.wantStatus.Err(), got.Err())
}
// Verifying authorization decision for Streaming RPC.
@ -299,9 +308,241 @@ func TestSDKEnd2End(t *testing.T) {
t.Fatalf("failed stream.Send err: %v", err)
}
_, err = stream.CloseAndRecv()
if got := status.Convert(err); got.Code() != test.wantStatusCode || got.Message() != test.wantErr {
t.Fatalf("[StreamingCall] error want:{%v %v} got:{%v %v}", test.wantStatusCode, test.wantErr, got.Code(), got.Message())
if got := status.Convert(err); got.Code() != test.wantStatus.Code() || got.Message() != test.wantStatus.Message() {
t.Fatalf("[StreamingCall] error want:{%v} got:{%v}", test.wantStatus.Err(), got.Err())
}
})
}
}
func (s) TestSDKFileWatcherEnd2End(t *testing.T) {
for name, test := range sdkTests {
t.Run(name, func(t *testing.T) {
file := createTmpPolicyFile(t, name, []byte(test.authzPolicy))
i, _ := authz.NewFileWatcher(file, 1*time.Second)
defer i.Close()
// Start a gRPC server with SDK unary and stream server interceptors.
s := grpc.NewServer(
grpc.ChainUnaryInterceptor(i.UnaryInterceptor),
grpc.ChainStreamInterceptor(i.StreamInterceptor))
defer s.Stop()
pb.RegisterTestServiceServer(s, &testServer{})
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("error listening: %v", err)
}
defer lis.Close()
go s.Serve(lis)
// Establish a connection to the server.
clientConn, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure())
if err != nil {
t.Fatalf("grpc.Dial(%v) failed: %v", lis.Addr().String(), err)
}
defer clientConn.Close()
client := pb.NewTestServiceClient(clientConn)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
ctx = metadata.NewOutgoingContext(ctx, test.md)
// Verifying authorization decision for Unary RPC.
_, err = client.UnaryCall(ctx, &pb.SimpleRequest{})
if got := status.Convert(err); got.Code() != test.wantStatus.Code() || got.Message() != test.wantStatus.Message() {
t.Fatalf("[UnaryCall] error want:{%v} got:{%v}", test.wantStatus.Err(), got.Err())
}
// Verifying authorization decision for Streaming RPC.
stream, err := client.StreamingInputCall(ctx)
if err != nil {
t.Fatalf("failed StreamingInputCall err: %v", err)
}
req := &pb.StreamingInputCallRequest{
Payload: &pb.Payload{
Body: []byte("hi"),
},
}
if err := stream.Send(req); err != nil && err != io.EOF {
t.Fatalf("failed stream.Send err: %v", err)
}
_, err = stream.CloseAndRecv()
if got := status.Convert(err); got.Code() != test.wantStatus.Code() || got.Message() != test.wantStatus.Message() {
t.Fatalf("[StreamingCall] error want:{%v} got:{%v}", test.wantStatus.Err(), got.Err())
}
})
}
}
func retryUntil(ctx context.Context, tsc pb.TestServiceClient, want *status.Status) (lastErr error) {
for ctx.Err() == nil {
_, lastErr = tsc.UnaryCall(ctx, &pb.SimpleRequest{})
if s := status.Convert(lastErr); s.Code() == want.Code() && s.Message() == want.Message() {
return nil
}
time.Sleep(20 * time.Millisecond)
}
return lastErr
}
func (s) TestSDKFileWatcher_ValidPolicyRefresh(t *testing.T) {
valid1 := sdkTests["DeniesRpcMatchInDenyAndAllow"]
file := createTmpPolicyFile(t, "valid_policy_refresh", []byte(valid1.authzPolicy))
i, _ := authz.NewFileWatcher(file, 100*time.Millisecond)
defer i.Close()
// Start a gRPC server with SDK unary server interceptor.
s := grpc.NewServer(
grpc.ChainUnaryInterceptor(i.UnaryInterceptor))
defer s.Stop()
pb.RegisterTestServiceServer(s, &testServer{})
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("error listening: %v", err)
}
defer lis.Close()
go s.Serve(lis)
// Establish a connection to the server.
clientConn, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure())
if err != nil {
t.Fatalf("grpc.Dial(%v) failed: %v", lis.Addr().String(), err)
}
defer clientConn.Close()
client := pb.NewTestServiceClient(clientConn)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// Verifying authorization decision.
_, err = client.UnaryCall(ctx, &pb.SimpleRequest{})
if got := status.Convert(err); got.Code() != valid1.wantStatus.Code() || got.Message() != valid1.wantStatus.Message() {
t.Fatalf("error want:{%v} got:{%v}", valid1.wantStatus.Err(), got.Err())
}
// Rewrite the file with a different valid authorization policy.
valid2 := sdkTests["AllowsRpcEmptyDenyMatchInAllow"]
if err := ioutil.WriteFile(file, []byte(valid2.authzPolicy), os.ModePerm); err != nil {
t.Fatalf("ioutil.WriteFile(%q) failed: %v", file, err)
}
// Verifying authorization decision.
if got := retryUntil(ctx, client, valid2.wantStatus); got != nil {
t.Fatalf("error want:{%v} got:{%v}", valid2.wantStatus.Err(), got)
}
}
func (s) TestSDKFileWatcher_InvalidPolicySkipReload(t *testing.T) {
valid := sdkTests["DeniesRpcMatchInDenyAndAllow"]
file := createTmpPolicyFile(t, "invalid_policy_skip_reload", []byte(valid.authzPolicy))
i, _ := authz.NewFileWatcher(file, 20*time.Millisecond)
defer i.Close()
// Start a gRPC server with SDK unary server interceptors.
s := grpc.NewServer(
grpc.ChainUnaryInterceptor(i.UnaryInterceptor))
defer s.Stop()
pb.RegisterTestServiceServer(s, &testServer{})
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("error listening: %v", err)
}
defer lis.Close()
go s.Serve(lis)
// Establish a connection to the server.
clientConn, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure())
if err != nil {
t.Fatalf("grpc.Dial(%v) failed: %v", lis.Addr().String(), err)
}
defer clientConn.Close()
client := pb.NewTestServiceClient(clientConn)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// Verifying authorization decision.
_, err = client.UnaryCall(ctx, &pb.SimpleRequest{})
if got := status.Convert(err); got.Code() != valid.wantStatus.Code() || got.Message() != valid.wantStatus.Message() {
t.Fatalf("error want:{%v} got:{%v}", valid.wantStatus.Err(), got.Err())
}
// Skips the invalid policy update, and continues to use the valid policy.
if err := ioutil.WriteFile(file, []byte("{}"), os.ModePerm); err != nil {
t.Fatalf("ioutil.WriteFile(%q) failed: %v", file, err)
}
// Wait 40 ms for background go routine to read updated files.
time.Sleep(40 * time.Millisecond)
// Verifying authorization decision.
_, err = client.UnaryCall(ctx, &pb.SimpleRequest{})
if got := status.Convert(err); got.Code() != valid.wantStatus.Code() || got.Message() != valid.wantStatus.Message() {
t.Fatalf("error want:{%v} got:{%v}", valid.wantStatus.Err(), got.Err())
}
}
func (s) TestSDKFileWatcher_RecoversFromReloadFailure(t *testing.T) {
valid1 := sdkTests["DeniesRpcMatchInDenyAndAllow"]
file := createTmpPolicyFile(t, "recovers_from_reload_failure", []byte(valid1.authzPolicy))
i, _ := authz.NewFileWatcher(file, 100*time.Millisecond)
defer i.Close()
// Start a gRPC server with SDK unary server interceptors.
s := grpc.NewServer(
grpc.ChainUnaryInterceptor(i.UnaryInterceptor))
defer s.Stop()
pb.RegisterTestServiceServer(s, &testServer{})
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("error listening: %v", err)
}
defer lis.Close()
go s.Serve(lis)
// Establish a connection to the server.
clientConn, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure())
if err != nil {
t.Fatalf("grpc.Dial(%v) failed: %v", lis.Addr().String(), err)
}
defer clientConn.Close()
client := pb.NewTestServiceClient(clientConn)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// Verifying authorization decision.
_, err = client.UnaryCall(ctx, &pb.SimpleRequest{})
if got := status.Convert(err); got.Code() != valid1.wantStatus.Code() || got.Message() != valid1.wantStatus.Message() {
t.Fatalf("error want:{%v} got:{%v}", valid1.wantStatus.Err(), got.Err())
}
// Skips the invalid policy update, and continues to use the valid policy.
if err := ioutil.WriteFile(file, []byte("{}"), os.ModePerm); err != nil {
t.Fatalf("ioutil.WriteFile(%q) failed: %v", file, err)
}
// Wait 120 ms for background go routine to read updated files.
time.Sleep(120 * time.Millisecond)
// Verifying authorization decision.
_, err = client.UnaryCall(ctx, &pb.SimpleRequest{})
if got := status.Convert(err); got.Code() != valid1.wantStatus.Code() || got.Message() != valid1.wantStatus.Message() {
t.Fatalf("error want:{%v} got:{%v}", valid1.wantStatus.Err(), got.Err())
}
// Rewrite the file with a different valid authorization policy.
valid2 := sdkTests["AllowsRpcEmptyDenyMatchInAllow"]
if err := ioutil.WriteFile(file, []byte(valid2.authzPolicy), os.ModePerm); err != nil {
t.Fatalf("ioutil.WriteFile(%q) failed: %v", file, err)
}
// Verifying authorization decision.
if got := retryUntil(ctx, client, valid2.wantStatus); got != nil {
t.Fatalf("error want:{%v} got:{%v}", valid2.wantStatus.Err(), got)
}
}

View File

@ -17,14 +17,23 @@
package authz
import (
"bytes"
"context"
"fmt"
"io/ioutil"
"sync/atomic"
"time"
"unsafe"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal/xds/rbac"
"google.golang.org/grpc/status"
)
var logger = grpclog.Component("authz")
// StaticInterceptor contains engines used to make authorization decisions. It
// either contains two engines deny engine followed by an allow engine or only
// one allow engine.
@ -73,3 +82,91 @@ func (i *StaticInterceptor) StreamInterceptor(srv interface{}, ss grpc.ServerStr
}
return handler(srv, ss)
}
// FileWatcherInterceptor contains details used to make authorization decisions
// by watching a file path that contains authorization policy in JSON format.
type FileWatcherInterceptor struct {
internalInterceptor unsafe.Pointer // *StaticInterceptor
policyFile string
policyContents []byte
refreshDuration time.Duration
cancel context.CancelFunc
}
// NewFileWatcher returns a new FileWatcherInterceptor from a policy file
// that contains JSON string of authorization policy and a refresh duration to
// specify the amount of time between policy refreshes.
func NewFileWatcher(file string, duration time.Duration) (*FileWatcherInterceptor, error) {
if file == "" {
return nil, fmt.Errorf("authorization policy file path is empty")
}
if duration <= time.Duration(0) {
return nil, fmt.Errorf("requires refresh interval(%v) greater than 0s", duration)
}
i := &FileWatcherInterceptor{policyFile: file, refreshDuration: duration}
if err := i.updateInternalInterceptor(); err != nil {
return nil, err
}
ctx, cancel := context.WithCancel(context.Background())
i.cancel = cancel
// Create a background go routine for policy refresh.
go i.run(ctx)
return i, nil
}
func (i *FileWatcherInterceptor) run(ctx context.Context) {
ticker := time.NewTicker(i.refreshDuration)
for {
if err := i.updateInternalInterceptor(); err != nil {
logger.Warningf("authorization policy reload status err: %v", err)
}
select {
case <-ctx.Done():
ticker.Stop()
return
case <-ticker.C:
}
}
}
// updateInternalInterceptor checks if the policy file that is watching has changed,
// and if so, updates the internalInterceptor with the policy. Unlike the
// constructor, if there is an error in reading the file or parsing the policy, the
// previous internalInterceptors will not be replaced.
func (i *FileWatcherInterceptor) updateInternalInterceptor() error {
policyContents, err := ioutil.ReadFile(i.policyFile)
if err != nil {
return fmt.Errorf("policyFile(%s) read failed: %v", i.policyFile, err)
}
if bytes.Equal(i.policyContents, policyContents) {
return nil
}
i.policyContents = policyContents
policyContentsString := string(policyContents)
interceptor, err := NewStatic(policyContentsString)
if err != nil {
return err
}
atomic.StorePointer(&i.internalInterceptor, unsafe.Pointer(interceptor))
logger.Infof("authorization policy reload status: successfully loaded new policy %v", policyContentsString)
return nil
}
// Close cleans up resources allocated by the interceptor.
func (i *FileWatcherInterceptor) Close() {
i.cancel()
}
// UnaryInterceptor intercepts incoming Unary RPC requests.
// Only authorized requests are allowed to pass. Otherwise, an unauthorized
// error is returned to the client.
func (i *FileWatcherInterceptor) UnaryInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
return ((*StaticInterceptor)(atomic.LoadPointer(&i.internalInterceptor))).UnaryInterceptor(ctx, req, info, handler)
}
// StreamInterceptor intercepts incoming Stream RPC requests.
// Only authorized requests are allowed to pass. Otherwise, an unauthorized
// error is returned to the client.
func (i *FileWatcherInterceptor) StreamInterceptor(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
return ((*StaticInterceptor)(atomic.LoadPointer(&i.internalInterceptor))).StreamInterceptor(srv, ss, info, handler)
}

View File

@ -19,19 +19,43 @@
package authz_test
import (
"fmt"
"io/ioutil"
"os"
"path"
"testing"
"time"
"google.golang.org/grpc/authz"
)
func TestNewStatic(t *testing.T) {
func createTmpPolicyFile(t *testing.T, dirSuffix string, policy []byte) string {
t.Helper()
// Create a temp directory. Passing an empty string for the first argument
// uses the system temp directory.
dir, err := ioutil.TempDir("", dirSuffix)
if err != nil {
t.Fatalf("ioutil.TempDir() failed: %v", err)
}
t.Logf("Using tmpdir: %s", dir)
// Write policy into file.
filename := path.Join(dir, "policy.json")
if err := ioutil.WriteFile(filename, policy, os.ModePerm); err != nil {
t.Fatalf("ioutil.WriteFile(%q) failed: %v", filename, err)
}
t.Logf("Wrote policy %s to file at %s", string(policy), filename)
return filename
}
func (s) TestNewStatic(t *testing.T) {
tests := map[string]struct {
authzPolicy string
wantErr bool
wantErr error
}{
"InvalidPolicyFailsToCreateInterceptor": {
authzPolicy: `{}`,
wantErr: true,
wantErr: fmt.Errorf(`"name" is not present`),
},
"ValidPolicyCreatesInterceptor": {
authzPolicy: `{
@ -43,14 +67,55 @@ func TestNewStatic(t *testing.T) {
}
]
}`,
wantErr: false,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
if _, err := authz.NewStatic(test.authzPolicy); (err != nil) != test.wantErr {
if _, err := authz.NewStatic(test.authzPolicy); fmt.Sprint(err) != fmt.Sprint(test.wantErr) {
t.Fatalf("NewStatic(%v) returned err: %v, want err: %v", test.authzPolicy, err, test.wantErr)
}
})
}
}
func (s) TestNewFileWatcher(t *testing.T) {
tests := map[string]struct {
authzPolicy string
refreshDuration time.Duration
wantErr error
}{
"InvalidRefreshDurationFailsToCreateInterceptor": {
refreshDuration: time.Duration(0),
wantErr: fmt.Errorf("requires refresh interval(0s) greater than 0s"),
},
"InvalidPolicyFailsToCreateInterceptor": {
authzPolicy: `{}`,
refreshDuration: time.Duration(1),
wantErr: fmt.Errorf(`"name" is not present`),
},
"ValidPolicyCreatesInterceptor": {
authzPolicy: `{
"name": "authz",
"allow_rules":
[
{
"name": "allow_all"
}
]
}`,
refreshDuration: time.Duration(1),
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
file := createTmpPolicyFile(t, name, []byte(test.authzPolicy))
i, err := authz.NewFileWatcher(file, test.refreshDuration)
if fmt.Sprint(err) != fmt.Sprint(test.wantErr) {
t.Fatalf("NewFileWatcher(%v) returned err: %v, want err: %v", test.authzPolicy, err, test.wantErr)
}
if i != nil {
i.Close()
}
})
}
}