authz: create file watcher interceptor for gRPC SDK API (#4760)
* authz: create file watcher interceptor for gRPC SDK API
This commit is contained in:
@ -21,13 +21,16 @@ package authz_test
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/authz"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/internal/grpctest"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/status"
|
||||
pb "google.golang.org/grpc/test/grpc_testing"
|
||||
@ -53,15 +56,21 @@ func (s *testServer) StreamingInputCall(stream pb.TestService_StreamingInputCall
|
||||
}
|
||||
}
|
||||
|
||||
func TestSDKEnd2End(t *testing.T) {
|
||||
tests := map[string]struct {
|
||||
authzPolicy string
|
||||
md metadata.MD
|
||||
wantStatusCode codes.Code
|
||||
wantErr string
|
||||
}{
|
||||
"DeniesRpcRequestMatchInDenyNoMatchInAllow": {
|
||||
authzPolicy: `{
|
||||
type s struct {
|
||||
grpctest.Tester
|
||||
}
|
||||
|
||||
func Test(t *testing.T) {
|
||||
grpctest.RunSubTests(t, s{})
|
||||
}
|
||||
|
||||
var sdkTests = map[string]struct {
|
||||
authzPolicy string
|
||||
md metadata.MD
|
||||
wantStatus *status.Status
|
||||
}{
|
||||
"DeniesRpcMatchInDenyNoMatchInAllow": {
|
||||
authzPolicy: `{
|
||||
"name": "authz",
|
||||
"allow_rules":
|
||||
[
|
||||
@ -100,12 +109,11 @@ func TestSDKEnd2End(t *testing.T) {
|
||||
}
|
||||
]
|
||||
}`,
|
||||
md: metadata.Pairs("key-abc", "val-abc"),
|
||||
wantStatusCode: codes.PermissionDenied,
|
||||
wantErr: "unauthorized RPC request rejected",
|
||||
},
|
||||
"DeniesRpcRequestMatchInDenyAndAllow": {
|
||||
authzPolicy: `{
|
||||
md: metadata.Pairs("key-abc", "val-abc"),
|
||||
wantStatus: status.New(codes.PermissionDenied, "unauthorized RPC request rejected"),
|
||||
},
|
||||
"DeniesRpcMatchInDenyAndAllow": {
|
||||
authzPolicy: `{
|
||||
"name": "authz",
|
||||
"allow_rules":
|
||||
[
|
||||
@ -132,11 +140,10 @@ func TestSDKEnd2End(t *testing.T) {
|
||||
}
|
||||
]
|
||||
}`,
|
||||
wantStatusCode: codes.PermissionDenied,
|
||||
wantErr: "unauthorized RPC request rejected",
|
||||
},
|
||||
"AllowsRpcRequestNoMatchInDenyMatchInAllow": {
|
||||
authzPolicy: `{
|
||||
wantStatus: status.New(codes.PermissionDenied, "unauthorized RPC request rejected"),
|
||||
},
|
||||
"AllowsRpcNoMatchInDenyMatchInAllow": {
|
||||
authzPolicy: `{
|
||||
"name": "authz",
|
||||
"allow_rules":
|
||||
[
|
||||
@ -169,11 +176,11 @@ func TestSDKEnd2End(t *testing.T) {
|
||||
}
|
||||
]
|
||||
}`,
|
||||
md: metadata.Pairs("key-xyz", "val-xyz"),
|
||||
wantStatusCode: codes.OK,
|
||||
},
|
||||
"AllowsRpcRequestNoMatchInDenyAndAllow": {
|
||||
authzPolicy: `{
|
||||
md: metadata.Pairs("key-xyz", "val-xyz"),
|
||||
wantStatus: status.New(codes.OK, ""),
|
||||
},
|
||||
"DeniesRpcNoMatchInDenyAndAllow": {
|
||||
authzPolicy: `{
|
||||
"name": "authz",
|
||||
"allow_rules":
|
||||
[
|
||||
@ -200,11 +207,10 @@ func TestSDKEnd2End(t *testing.T) {
|
||||
}
|
||||
]
|
||||
}`,
|
||||
wantStatusCode: codes.PermissionDenied,
|
||||
wantErr: "unauthorized RPC request rejected",
|
||||
},
|
||||
"AllowsRpcRequestEmptyDenyMatchInAllow": {
|
||||
authzPolicy: `{
|
||||
wantStatus: status.New(codes.PermissionDenied, "unauthorized RPC request rejected"),
|
||||
},
|
||||
"AllowsRpcEmptyDenyMatchInAllow": {
|
||||
authzPolicy: `{
|
||||
"name": "authz",
|
||||
"allow_rules":
|
||||
[
|
||||
@ -230,10 +236,10 @@ func TestSDKEnd2End(t *testing.T) {
|
||||
}
|
||||
]
|
||||
}`,
|
||||
wantStatusCode: codes.OK,
|
||||
},
|
||||
"DeniesRpcRequestEmptyDenyNoMatchInAllow": {
|
||||
authzPolicy: `{
|
||||
wantStatus: status.New(codes.OK, ""),
|
||||
},
|
||||
"DeniesRpcEmptyDenyNoMatchInAllow": {
|
||||
authzPolicy: `{
|
||||
"name": "authz",
|
||||
"allow_rules":
|
||||
[
|
||||
@ -249,22 +255,25 @@ func TestSDKEnd2End(t *testing.T) {
|
||||
}
|
||||
]
|
||||
}`,
|
||||
wantStatusCode: codes.PermissionDenied,
|
||||
wantErr: "unauthorized RPC request rejected",
|
||||
},
|
||||
}
|
||||
for name, test := range tests {
|
||||
wantStatus: status.New(codes.PermissionDenied, "unauthorized RPC request rejected"),
|
||||
},
|
||||
}
|
||||
|
||||
func (s) TestSDKStaticPolicyEnd2End(t *testing.T) {
|
||||
for name, test := range sdkTests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
// Start a gRPC server with SDK unary and stream server interceptors.
|
||||
i, _ := authz.NewStatic(test.authzPolicy)
|
||||
s := grpc.NewServer(
|
||||
grpc.ChainUnaryInterceptor(i.UnaryInterceptor),
|
||||
grpc.ChainStreamInterceptor(i.StreamInterceptor))
|
||||
defer s.Stop()
|
||||
pb.RegisterTestServiceServer(s, &testServer{})
|
||||
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
t.Fatalf("error listening: %v", err)
|
||||
}
|
||||
s := grpc.NewServer(
|
||||
grpc.ChainUnaryInterceptor(i.UnaryInterceptor),
|
||||
grpc.ChainStreamInterceptor(i.StreamInterceptor))
|
||||
pb.RegisterTestServiceServer(s, &testServer{})
|
||||
go s.Serve(lis)
|
||||
|
||||
// Establish a connection to the server.
|
||||
@ -281,8 +290,8 @@ func TestSDKEnd2End(t *testing.T) {
|
||||
|
||||
// Verifying authorization decision for Unary RPC.
|
||||
_, err = client.UnaryCall(ctx, &pb.SimpleRequest{})
|
||||
if got := status.Convert(err); got.Code() != test.wantStatusCode || got.Message() != test.wantErr {
|
||||
t.Fatalf("[UnaryCall] error want:{%v %v} got:{%v %v}", test.wantStatusCode, test.wantErr, got.Code(), got.Message())
|
||||
if got := status.Convert(err); got.Code() != test.wantStatus.Code() || got.Message() != test.wantStatus.Message() {
|
||||
t.Fatalf("[UnaryCall] error want:{%v} got:{%v}", test.wantStatus.Err(), got.Err())
|
||||
}
|
||||
|
||||
// Verifying authorization decision for Streaming RPC.
|
||||
@ -299,9 +308,241 @@ func TestSDKEnd2End(t *testing.T) {
|
||||
t.Fatalf("failed stream.Send err: %v", err)
|
||||
}
|
||||
_, err = stream.CloseAndRecv()
|
||||
if got := status.Convert(err); got.Code() != test.wantStatusCode || got.Message() != test.wantErr {
|
||||
t.Fatalf("[StreamingCall] error want:{%v %v} got:{%v %v}", test.wantStatusCode, test.wantErr, got.Code(), got.Message())
|
||||
if got := status.Convert(err); got.Code() != test.wantStatus.Code() || got.Message() != test.wantStatus.Message() {
|
||||
t.Fatalf("[StreamingCall] error want:{%v} got:{%v}", test.wantStatus.Err(), got.Err())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s) TestSDKFileWatcherEnd2End(t *testing.T) {
|
||||
for name, test := range sdkTests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
file := createTmpPolicyFile(t, name, []byte(test.authzPolicy))
|
||||
i, _ := authz.NewFileWatcher(file, 1*time.Second)
|
||||
defer i.Close()
|
||||
|
||||
// Start a gRPC server with SDK unary and stream server interceptors.
|
||||
s := grpc.NewServer(
|
||||
grpc.ChainUnaryInterceptor(i.UnaryInterceptor),
|
||||
grpc.ChainStreamInterceptor(i.StreamInterceptor))
|
||||
defer s.Stop()
|
||||
pb.RegisterTestServiceServer(s, &testServer{})
|
||||
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
t.Fatalf("error listening: %v", err)
|
||||
}
|
||||
defer lis.Close()
|
||||
go s.Serve(lis)
|
||||
|
||||
// Establish a connection to the server.
|
||||
clientConn, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure())
|
||||
if err != nil {
|
||||
t.Fatalf("grpc.Dial(%v) failed: %v", lis.Addr().String(), err)
|
||||
}
|
||||
defer clientConn.Close()
|
||||
client := pb.NewTestServiceClient(clientConn)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
ctx = metadata.NewOutgoingContext(ctx, test.md)
|
||||
|
||||
// Verifying authorization decision for Unary RPC.
|
||||
_, err = client.UnaryCall(ctx, &pb.SimpleRequest{})
|
||||
if got := status.Convert(err); got.Code() != test.wantStatus.Code() || got.Message() != test.wantStatus.Message() {
|
||||
t.Fatalf("[UnaryCall] error want:{%v} got:{%v}", test.wantStatus.Err(), got.Err())
|
||||
}
|
||||
|
||||
// Verifying authorization decision for Streaming RPC.
|
||||
stream, err := client.StreamingInputCall(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed StreamingInputCall err: %v", err)
|
||||
}
|
||||
req := &pb.StreamingInputCallRequest{
|
||||
Payload: &pb.Payload{
|
||||
Body: []byte("hi"),
|
||||
},
|
||||
}
|
||||
if err := stream.Send(req); err != nil && err != io.EOF {
|
||||
t.Fatalf("failed stream.Send err: %v", err)
|
||||
}
|
||||
_, err = stream.CloseAndRecv()
|
||||
if got := status.Convert(err); got.Code() != test.wantStatus.Code() || got.Message() != test.wantStatus.Message() {
|
||||
t.Fatalf("[StreamingCall] error want:{%v} got:{%v}", test.wantStatus.Err(), got.Err())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func retryUntil(ctx context.Context, tsc pb.TestServiceClient, want *status.Status) (lastErr error) {
|
||||
for ctx.Err() == nil {
|
||||
_, lastErr = tsc.UnaryCall(ctx, &pb.SimpleRequest{})
|
||||
if s := status.Convert(lastErr); s.Code() == want.Code() && s.Message() == want.Message() {
|
||||
return nil
|
||||
}
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
}
|
||||
return lastErr
|
||||
}
|
||||
|
||||
func (s) TestSDKFileWatcher_ValidPolicyRefresh(t *testing.T) {
|
||||
valid1 := sdkTests["DeniesRpcMatchInDenyAndAllow"]
|
||||
file := createTmpPolicyFile(t, "valid_policy_refresh", []byte(valid1.authzPolicy))
|
||||
i, _ := authz.NewFileWatcher(file, 100*time.Millisecond)
|
||||
defer i.Close()
|
||||
|
||||
// Start a gRPC server with SDK unary server interceptor.
|
||||
s := grpc.NewServer(
|
||||
grpc.ChainUnaryInterceptor(i.UnaryInterceptor))
|
||||
defer s.Stop()
|
||||
pb.RegisterTestServiceServer(s, &testServer{})
|
||||
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
t.Fatalf("error listening: %v", err)
|
||||
}
|
||||
defer lis.Close()
|
||||
go s.Serve(lis)
|
||||
|
||||
// Establish a connection to the server.
|
||||
clientConn, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure())
|
||||
if err != nil {
|
||||
t.Fatalf("grpc.Dial(%v) failed: %v", lis.Addr().String(), err)
|
||||
}
|
||||
defer clientConn.Close()
|
||||
client := pb.NewTestServiceClient(clientConn)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Verifying authorization decision.
|
||||
_, err = client.UnaryCall(ctx, &pb.SimpleRequest{})
|
||||
if got := status.Convert(err); got.Code() != valid1.wantStatus.Code() || got.Message() != valid1.wantStatus.Message() {
|
||||
t.Fatalf("error want:{%v} got:{%v}", valid1.wantStatus.Err(), got.Err())
|
||||
}
|
||||
|
||||
// Rewrite the file with a different valid authorization policy.
|
||||
valid2 := sdkTests["AllowsRpcEmptyDenyMatchInAllow"]
|
||||
if err := ioutil.WriteFile(file, []byte(valid2.authzPolicy), os.ModePerm); err != nil {
|
||||
t.Fatalf("ioutil.WriteFile(%q) failed: %v", file, err)
|
||||
}
|
||||
|
||||
// Verifying authorization decision.
|
||||
if got := retryUntil(ctx, client, valid2.wantStatus); got != nil {
|
||||
t.Fatalf("error want:{%v} got:{%v}", valid2.wantStatus.Err(), got)
|
||||
}
|
||||
}
|
||||
|
||||
func (s) TestSDKFileWatcher_InvalidPolicySkipReload(t *testing.T) {
|
||||
valid := sdkTests["DeniesRpcMatchInDenyAndAllow"]
|
||||
file := createTmpPolicyFile(t, "invalid_policy_skip_reload", []byte(valid.authzPolicy))
|
||||
i, _ := authz.NewFileWatcher(file, 20*time.Millisecond)
|
||||
defer i.Close()
|
||||
|
||||
// Start a gRPC server with SDK unary server interceptors.
|
||||
s := grpc.NewServer(
|
||||
grpc.ChainUnaryInterceptor(i.UnaryInterceptor))
|
||||
defer s.Stop()
|
||||
pb.RegisterTestServiceServer(s, &testServer{})
|
||||
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
t.Fatalf("error listening: %v", err)
|
||||
}
|
||||
defer lis.Close()
|
||||
go s.Serve(lis)
|
||||
|
||||
// Establish a connection to the server.
|
||||
clientConn, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure())
|
||||
if err != nil {
|
||||
t.Fatalf("grpc.Dial(%v) failed: %v", lis.Addr().String(), err)
|
||||
}
|
||||
defer clientConn.Close()
|
||||
client := pb.NewTestServiceClient(clientConn)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Verifying authorization decision.
|
||||
_, err = client.UnaryCall(ctx, &pb.SimpleRequest{})
|
||||
if got := status.Convert(err); got.Code() != valid.wantStatus.Code() || got.Message() != valid.wantStatus.Message() {
|
||||
t.Fatalf("error want:{%v} got:{%v}", valid.wantStatus.Err(), got.Err())
|
||||
}
|
||||
|
||||
// Skips the invalid policy update, and continues to use the valid policy.
|
||||
if err := ioutil.WriteFile(file, []byte("{}"), os.ModePerm); err != nil {
|
||||
t.Fatalf("ioutil.WriteFile(%q) failed: %v", file, err)
|
||||
}
|
||||
|
||||
// Wait 40 ms for background go routine to read updated files.
|
||||
time.Sleep(40 * time.Millisecond)
|
||||
|
||||
// Verifying authorization decision.
|
||||
_, err = client.UnaryCall(ctx, &pb.SimpleRequest{})
|
||||
if got := status.Convert(err); got.Code() != valid.wantStatus.Code() || got.Message() != valid.wantStatus.Message() {
|
||||
t.Fatalf("error want:{%v} got:{%v}", valid.wantStatus.Err(), got.Err())
|
||||
}
|
||||
}
|
||||
|
||||
func (s) TestSDKFileWatcher_RecoversFromReloadFailure(t *testing.T) {
|
||||
valid1 := sdkTests["DeniesRpcMatchInDenyAndAllow"]
|
||||
file := createTmpPolicyFile(t, "recovers_from_reload_failure", []byte(valid1.authzPolicy))
|
||||
i, _ := authz.NewFileWatcher(file, 100*time.Millisecond)
|
||||
defer i.Close()
|
||||
|
||||
// Start a gRPC server with SDK unary server interceptors.
|
||||
s := grpc.NewServer(
|
||||
grpc.ChainUnaryInterceptor(i.UnaryInterceptor))
|
||||
defer s.Stop()
|
||||
pb.RegisterTestServiceServer(s, &testServer{})
|
||||
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
t.Fatalf("error listening: %v", err)
|
||||
}
|
||||
defer lis.Close()
|
||||
go s.Serve(lis)
|
||||
|
||||
// Establish a connection to the server.
|
||||
clientConn, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure())
|
||||
if err != nil {
|
||||
t.Fatalf("grpc.Dial(%v) failed: %v", lis.Addr().String(), err)
|
||||
}
|
||||
defer clientConn.Close()
|
||||
client := pb.NewTestServiceClient(clientConn)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Verifying authorization decision.
|
||||
_, err = client.UnaryCall(ctx, &pb.SimpleRequest{})
|
||||
if got := status.Convert(err); got.Code() != valid1.wantStatus.Code() || got.Message() != valid1.wantStatus.Message() {
|
||||
t.Fatalf("error want:{%v} got:{%v}", valid1.wantStatus.Err(), got.Err())
|
||||
}
|
||||
|
||||
// Skips the invalid policy update, and continues to use the valid policy.
|
||||
if err := ioutil.WriteFile(file, []byte("{}"), os.ModePerm); err != nil {
|
||||
t.Fatalf("ioutil.WriteFile(%q) failed: %v", file, err)
|
||||
}
|
||||
|
||||
// Wait 120 ms for background go routine to read updated files.
|
||||
time.Sleep(120 * time.Millisecond)
|
||||
|
||||
// Verifying authorization decision.
|
||||
_, err = client.UnaryCall(ctx, &pb.SimpleRequest{})
|
||||
if got := status.Convert(err); got.Code() != valid1.wantStatus.Code() || got.Message() != valid1.wantStatus.Message() {
|
||||
t.Fatalf("error want:{%v} got:{%v}", valid1.wantStatus.Err(), got.Err())
|
||||
}
|
||||
|
||||
// Rewrite the file with a different valid authorization policy.
|
||||
valid2 := sdkTests["AllowsRpcEmptyDenyMatchInAllow"]
|
||||
if err := ioutil.WriteFile(file, []byte(valid2.authzPolicy), os.ModePerm); err != nil {
|
||||
t.Fatalf("ioutil.WriteFile(%q) failed: %v", file, err)
|
||||
}
|
||||
|
||||
// Verifying authorization decision.
|
||||
if got := retryUntil(ctx, client, valid2.wantStatus); got != nil {
|
||||
t.Fatalf("error want:{%v} got:{%v}", valid2.wantStatus.Err(), got)
|
||||
}
|
||||
}
|
||||
|
@ -17,14 +17,23 @@
|
||||
package authz
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/grpclog"
|
||||
"google.golang.org/grpc/internal/xds/rbac"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
var logger = grpclog.Component("authz")
|
||||
|
||||
// StaticInterceptor contains engines used to make authorization decisions. It
|
||||
// either contains two engines deny engine followed by an allow engine or only
|
||||
// one allow engine.
|
||||
@ -73,3 +82,91 @@ func (i *StaticInterceptor) StreamInterceptor(srv interface{}, ss grpc.ServerStr
|
||||
}
|
||||
return handler(srv, ss)
|
||||
}
|
||||
|
||||
// FileWatcherInterceptor contains details used to make authorization decisions
|
||||
// by watching a file path that contains authorization policy in JSON format.
|
||||
type FileWatcherInterceptor struct {
|
||||
internalInterceptor unsafe.Pointer // *StaticInterceptor
|
||||
policyFile string
|
||||
policyContents []byte
|
||||
refreshDuration time.Duration
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// NewFileWatcher returns a new FileWatcherInterceptor from a policy file
|
||||
// that contains JSON string of authorization policy and a refresh duration to
|
||||
// specify the amount of time between policy refreshes.
|
||||
func NewFileWatcher(file string, duration time.Duration) (*FileWatcherInterceptor, error) {
|
||||
if file == "" {
|
||||
return nil, fmt.Errorf("authorization policy file path is empty")
|
||||
}
|
||||
if duration <= time.Duration(0) {
|
||||
return nil, fmt.Errorf("requires refresh interval(%v) greater than 0s", duration)
|
||||
}
|
||||
i := &FileWatcherInterceptor{policyFile: file, refreshDuration: duration}
|
||||
if err := i.updateInternalInterceptor(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
i.cancel = cancel
|
||||
// Create a background go routine for policy refresh.
|
||||
go i.run(ctx)
|
||||
return i, nil
|
||||
}
|
||||
|
||||
func (i *FileWatcherInterceptor) run(ctx context.Context) {
|
||||
ticker := time.NewTicker(i.refreshDuration)
|
||||
for {
|
||||
if err := i.updateInternalInterceptor(); err != nil {
|
||||
logger.Warningf("authorization policy reload status err: %v", err)
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
ticker.Stop()
|
||||
return
|
||||
case <-ticker.C:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// updateInternalInterceptor checks if the policy file that is watching has changed,
|
||||
// and if so, updates the internalInterceptor with the policy. Unlike the
|
||||
// constructor, if there is an error in reading the file or parsing the policy, the
|
||||
// previous internalInterceptors will not be replaced.
|
||||
func (i *FileWatcherInterceptor) updateInternalInterceptor() error {
|
||||
policyContents, err := ioutil.ReadFile(i.policyFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("policyFile(%s) read failed: %v", i.policyFile, err)
|
||||
}
|
||||
if bytes.Equal(i.policyContents, policyContents) {
|
||||
return nil
|
||||
}
|
||||
i.policyContents = policyContents
|
||||
policyContentsString := string(policyContents)
|
||||
interceptor, err := NewStatic(policyContentsString)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
atomic.StorePointer(&i.internalInterceptor, unsafe.Pointer(interceptor))
|
||||
logger.Infof("authorization policy reload status: successfully loaded new policy %v", policyContentsString)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close cleans up resources allocated by the interceptor.
|
||||
func (i *FileWatcherInterceptor) Close() {
|
||||
i.cancel()
|
||||
}
|
||||
|
||||
// UnaryInterceptor intercepts incoming Unary RPC requests.
|
||||
// Only authorized requests are allowed to pass. Otherwise, an unauthorized
|
||||
// error is returned to the client.
|
||||
func (i *FileWatcherInterceptor) UnaryInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
|
||||
return ((*StaticInterceptor)(atomic.LoadPointer(&i.internalInterceptor))).UnaryInterceptor(ctx, req, info, handler)
|
||||
}
|
||||
|
||||
// StreamInterceptor intercepts incoming Stream RPC requests.
|
||||
// Only authorized requests are allowed to pass. Otherwise, an unauthorized
|
||||
// error is returned to the client.
|
||||
func (i *FileWatcherInterceptor) StreamInterceptor(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
||||
return ((*StaticInterceptor)(atomic.LoadPointer(&i.internalInterceptor))).StreamInterceptor(srv, ss, info, handler)
|
||||
}
|
||||
|
@ -19,19 +19,43 @@
|
||||
package authz_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc/authz"
|
||||
)
|
||||
|
||||
func TestNewStatic(t *testing.T) {
|
||||
func createTmpPolicyFile(t *testing.T, dirSuffix string, policy []byte) string {
|
||||
t.Helper()
|
||||
|
||||
// Create a temp directory. Passing an empty string for the first argument
|
||||
// uses the system temp directory.
|
||||
dir, err := ioutil.TempDir("", dirSuffix)
|
||||
if err != nil {
|
||||
t.Fatalf("ioutil.TempDir() failed: %v", err)
|
||||
}
|
||||
t.Logf("Using tmpdir: %s", dir)
|
||||
// Write policy into file.
|
||||
filename := path.Join(dir, "policy.json")
|
||||
if err := ioutil.WriteFile(filename, policy, os.ModePerm); err != nil {
|
||||
t.Fatalf("ioutil.WriteFile(%q) failed: %v", filename, err)
|
||||
}
|
||||
t.Logf("Wrote policy %s to file at %s", string(policy), filename)
|
||||
return filename
|
||||
}
|
||||
|
||||
func (s) TestNewStatic(t *testing.T) {
|
||||
tests := map[string]struct {
|
||||
authzPolicy string
|
||||
wantErr bool
|
||||
wantErr error
|
||||
}{
|
||||
"InvalidPolicyFailsToCreateInterceptor": {
|
||||
authzPolicy: `{}`,
|
||||
wantErr: true,
|
||||
wantErr: fmt.Errorf(`"name" is not present`),
|
||||
},
|
||||
"ValidPolicyCreatesInterceptor": {
|
||||
authzPolicy: `{
|
||||
@ -43,14 +67,55 @@ func TestNewStatic(t *testing.T) {
|
||||
}
|
||||
]
|
||||
}`,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for name, test := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
if _, err := authz.NewStatic(test.authzPolicy); (err != nil) != test.wantErr {
|
||||
if _, err := authz.NewStatic(test.authzPolicy); fmt.Sprint(err) != fmt.Sprint(test.wantErr) {
|
||||
t.Fatalf("NewStatic(%v) returned err: %v, want err: %v", test.authzPolicy, err, test.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s) TestNewFileWatcher(t *testing.T) {
|
||||
tests := map[string]struct {
|
||||
authzPolicy string
|
||||
refreshDuration time.Duration
|
||||
wantErr error
|
||||
}{
|
||||
"InvalidRefreshDurationFailsToCreateInterceptor": {
|
||||
refreshDuration: time.Duration(0),
|
||||
wantErr: fmt.Errorf("requires refresh interval(0s) greater than 0s"),
|
||||
},
|
||||
"InvalidPolicyFailsToCreateInterceptor": {
|
||||
authzPolicy: `{}`,
|
||||
refreshDuration: time.Duration(1),
|
||||
wantErr: fmt.Errorf(`"name" is not present`),
|
||||
},
|
||||
"ValidPolicyCreatesInterceptor": {
|
||||
authzPolicy: `{
|
||||
"name": "authz",
|
||||
"allow_rules":
|
||||
[
|
||||
{
|
||||
"name": "allow_all"
|
||||
}
|
||||
]
|
||||
}`,
|
||||
refreshDuration: time.Duration(1),
|
||||
},
|
||||
}
|
||||
for name, test := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
file := createTmpPolicyFile(t, name, []byte(test.authzPolicy))
|
||||
i, err := authz.NewFileWatcher(file, test.refreshDuration)
|
||||
if fmt.Sprint(err) != fmt.Sprint(test.wantErr) {
|
||||
t.Fatalf("NewFileWatcher(%v) returned err: %v, want err: %v", test.authzPolicy, err, test.wantErr)
|
||||
}
|
||||
if i != nil {
|
||||
i.Close()
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user