Alerting: Provisioning API respects global rule quota (#52180)

* Inject interface for quota service and create mock

* Check quota and return 403 if limit exceeded

* Implement tests for quota being exceeded
This commit is contained in:
Alexander Weaver
2022-07-13 17:36:17 -05:00
committed by GitHub
parent eb5a96eae9
commit 2d7389c34d
10 changed files with 209 additions and 22 deletions

View File

@ -52,7 +52,7 @@ type MuteTimingService interface {
type AlertRuleService interface { type AlertRuleService interface {
GetAlertRule(ctx context.Context, orgID int64, ruleUID string) (alerting_models.AlertRule, alerting_models.Provenance, error) GetAlertRule(ctx context.Context, orgID int64, ruleUID string) (alerting_models.AlertRule, alerting_models.Provenance, error)
CreateAlertRule(ctx context.Context, rule alerting_models.AlertRule, provenance alerting_models.Provenance) (alerting_models.AlertRule, error) CreateAlertRule(ctx context.Context, rule alerting_models.AlertRule, provenance alerting_models.Provenance, userID int64) (alerting_models.AlertRule, error)
UpdateAlertRule(ctx context.Context, rule alerting_models.AlertRule, provenance alerting_models.Provenance) (alerting_models.AlertRule, error) UpdateAlertRule(ctx context.Context, rule alerting_models.AlertRule, provenance alerting_models.Provenance) (alerting_models.AlertRule, error)
DeleteAlertRule(ctx context.Context, orgID int64, ruleUID string, provenance alerting_models.Provenance) error DeleteAlertRule(ctx context.Context, orgID int64, ruleUID string, provenance alerting_models.Provenance) error
GetRuleGroup(ctx context.Context, orgID int64, folder, group string) (definitions.AlertRuleGroup, error) GetRuleGroup(ctx context.Context, orgID int64, folder, group string) (definitions.AlertRuleGroup, error)
@ -254,7 +254,7 @@ func (srv *ProvisioningSrv) RouteRouteGetAlertRule(c *models.ReqContext, UID str
} }
func (srv *ProvisioningSrv) RoutePostAlertRule(c *models.ReqContext, ar definitions.ProvisionedAlertRule) response.Response { func (srv *ProvisioningSrv) RoutePostAlertRule(c *models.ReqContext, ar definitions.ProvisionedAlertRule) response.Response {
createdAlertRule, err := srv.alertRules.CreateAlertRule(c.Req.Context(), ar.UpstreamModel(), alerting_models.ProvenanceAPI) createdAlertRule, err := srv.alertRules.CreateAlertRule(c.Req.Context(), ar.UpstreamModel(), alerting_models.ProvenanceAPI, c.UserId)
if errors.Is(err, alerting_models.ErrAlertRuleFailedValidation) { if errors.Is(err, alerting_models.ErrAlertRuleFailedValidation) {
return ErrResp(http.StatusBadRequest, err, "") return ErrResp(http.StatusBadRequest, err, "")
} }
@ -262,6 +262,9 @@ func (srv *ProvisioningSrv) RoutePostAlertRule(c *models.ReqContext, ar definiti
if errors.Is(err, store.ErrOptimisticLock) { if errors.Is(err, store.ErrOptimisticLock) {
return ErrResp(http.StatusConflict, err, "") return ErrResp(http.StatusConflict, err, "")
} }
if errors.Is(err, alerting_models.ErrQuotaReached) {
return ErrResp(http.StatusForbidden, err, "")
}
return ErrResp(http.StatusInternalServerError, err, "") return ErrResp(http.StatusInternalServerError, err, "")
} }
ar.ID = createdAlertRule.ID ar.ID = createdAlertRule.ID

View File

@ -15,7 +15,8 @@ import (
"github.com/grafana/grafana/pkg/services/ngalert/models" "github.com/grafana/grafana/pkg/services/ngalert/models"
"github.com/grafana/grafana/pkg/services/ngalert/provisioning" "github.com/grafana/grafana/pkg/services/ngalert/provisioning"
"github.com/grafana/grafana/pkg/services/ngalert/store" "github.com/grafana/grafana/pkg/services/ngalert/store"
secrets "github.com/grafana/grafana/pkg/services/secrets/fakes" "github.com/grafana/grafana/pkg/services/secrets"
secrets_fakes "github.com/grafana/grafana/pkg/services/secrets/fakes"
"github.com/grafana/grafana/pkg/services/sqlstore" "github.com/grafana/grafana/pkg/services/sqlstore"
"github.com/grafana/grafana/pkg/web" "github.com/grafana/grafana/pkg/web"
prometheus "github.com/prometheus/alertmanager/config" prometheus "github.com/prometheus/alertmanager/config"
@ -259,6 +260,20 @@ func TestProvisioningApi(t *testing.T) {
require.Equal(t, 404, response.Status()) require.Equal(t, 404, response.Status())
}) })
t.Run("have reached the rule quota, POST returns 403", func(t *testing.T) {
env := createTestEnv(t)
quotas := provisioning.MockQuotaChecker{}
quotas.EXPECT().LimitExceeded()
env.quotas = &quotas
sut := createProvisioningSrvSutFromEnv(t, &env)
rule := createTestAlertRule("rule", 1)
rc := createTestRequestCtx()
response := sut.RoutePostAlertRule(&rc, rule)
require.Equal(t, 403, response.Status())
})
}) })
t.Run("alert rule groups", func(t *testing.T) { t.Run("alert rule groups", func(t *testing.T) {
@ -284,9 +299,21 @@ func TestProvisioningApi(t *testing.T) {
}) })
} }
func createProvisioningSrvSut(t *testing.T) ProvisioningSrv { // testEnvironment binds together common dependencies for testing alerting APIs.
type testEnvironment struct {
secrets secrets.Service
log log.Logger
store store.DBstore
configs provisioning.AMConfigStore
xact provisioning.TransactionManager
quotas provisioning.QuotaChecker
prov provisioning.ProvisioningStore
}
func createTestEnv(t *testing.T) testEnvironment {
t.Helper() t.Helper()
secrets := secrets.NewFakeSecretsService()
secrets := secrets_fakes.NewFakeSecretsService()
log := log.NewNopLogger() log := log.NewNopLogger()
configs := &provisioning.MockAMConfigStore{} configs := &provisioning.MockAMConfigStore{}
configs.EXPECT(). configs.EXPECT().
@ -298,18 +325,41 @@ func createProvisioningSrvSut(t *testing.T) ProvisioningSrv {
SQLStore: sqlStore, SQLStore: sqlStore,
BaseInterval: time.Second * 10, BaseInterval: time.Second * 10,
} }
quotas := &provisioning.MockQuotaChecker{}
quotas.EXPECT().LimitOK()
xact := &provisioning.NopTransactionManager{} xact := &provisioning.NopTransactionManager{}
prov := &provisioning.MockProvisioningStore{} prov := &provisioning.MockProvisioningStore{}
prov.EXPECT().SaveSucceeds() prov.EXPECT().SaveSucceeds()
prov.EXPECT().GetReturns(models.ProvenanceNone) prov.EXPECT().GetReturns(models.ProvenanceNone)
return testEnvironment{
secrets: secrets,
log: log,
configs: configs,
store: store,
xact: xact,
prov: prov,
quotas: quotas,
}
}
func createProvisioningSrvSut(t *testing.T) ProvisioningSrv {
t.Helper()
env := createTestEnv(t)
return createProvisioningSrvSutFromEnv(t, &env)
}
func createProvisioningSrvSutFromEnv(t *testing.T, env *testEnvironment) ProvisioningSrv {
t.Helper()
return ProvisioningSrv{ return ProvisioningSrv{
log: log, log: env.log,
policies: newFakeNotificationPolicyService(), policies: newFakeNotificationPolicyService(),
contactPointService: provisioning.NewContactPointService(configs, secrets, prov, xact, log), contactPointService: provisioning.NewContactPointService(env.configs, env.secrets, env.prov, env.xact, env.log),
templates: provisioning.NewTemplateService(configs, prov, xact, log), templates: provisioning.NewTemplateService(env.configs, env.prov, env.xact, env.log),
muteTimings: provisioning.NewMuteTimingService(configs, prov, xact, log), muteTimings: provisioning.NewMuteTimingService(env.configs, env.prov, env.xact, env.log),
alertRules: provisioning.NewAlertRuleService(store, prov, xact, 60, 10, log), alertRules: provisioning.NewAlertRuleService(env.store, env.prov, env.quotas, env.xact, 60, 10, env.log),
} }
} }

View File

@ -42,7 +42,6 @@ type RulerSrv struct {
} }
var ( var (
errQuotaReached = errors.New("quota has been exceeded")
errProvisionedResource = errors.New("request affects resources created via provisioning API") errProvisionedResource = errors.New("request affects resources created via provisioning API")
) )
@ -401,7 +400,7 @@ func (srv RulerSrv) updateAlertRulesInGroup(c *models.ReqContext, groupKey ngmod
return fmt.Errorf("failed to get alert rules quota: %w", err) return fmt.Errorf("failed to get alert rules quota: %w", err)
} }
if limitReached { if limitReached {
return errQuotaReached return ngmodels.ErrQuotaReached
} }
} }
return nil return nil
@ -412,7 +411,7 @@ func (srv RulerSrv) updateAlertRulesInGroup(c *models.ReqContext, groupKey ngmod
return ErrResp(http.StatusNotFound, err, "failed to update rule group") return ErrResp(http.StatusNotFound, err, "failed to update rule group")
} else if errors.Is(err, ngmodels.ErrAlertRuleFailedValidation) || errors.Is(err, errProvisionedResource) { } else if errors.Is(err, ngmodels.ErrAlertRuleFailedValidation) || errors.Is(err, errProvisionedResource) {
return ErrResp(http.StatusBadRequest, err, "failed to update rule group") return ErrResp(http.StatusBadRequest, err, "failed to update rule group")
} else if errors.Is(err, errQuotaReached) { } else if errors.Is(err, ngmodels.ErrQuotaReached) {
return ErrResp(http.StatusForbidden, err, "") return ErrResp(http.StatusForbidden, err, "")
} else if errors.Is(err, ErrAuthorization) { } else if errors.Is(err, ErrAuthorization) {
return ErrResp(http.StatusUnauthorized, err, "") return ErrResp(http.StatusUnauthorized, err, "")

View File

@ -23,6 +23,7 @@ var (
ErrRuleGroupNamespaceNotFound = errors.New("rule group not found under this namespace") ErrRuleGroupNamespaceNotFound = errors.New("rule group not found under this namespace")
ErrAlertRuleFailedValidation = errors.New("invalid alert rule") ErrAlertRuleFailedValidation = errors.New("invalid alert rule")
ErrAlertRuleUniqueConstraintViolation = errors.New("a conflicting alert rule is found: rule title under the same organisation and folder should be unique") ErrAlertRuleUniqueConstraintViolation = errors.New("a conflicting alert rule is found: rule title under the same organisation and folder should be unique")
ErrQuotaReached = errors.New("quota has been exceeded")
) )
// swagger:enum NoDataState // swagger:enum NoDataState

View File

@ -170,7 +170,7 @@ func (ng *AlertNG) init() error {
contactPointService := provisioning.NewContactPointService(store, ng.SecretsService, store, store, ng.Log) contactPointService := provisioning.NewContactPointService(store, ng.SecretsService, store, store, ng.Log)
templateService := provisioning.NewTemplateService(store, store, store, ng.Log) templateService := provisioning.NewTemplateService(store, store, store, ng.Log)
muteTimingService := provisioning.NewMuteTimingService(store, store, store, ng.Log) muteTimingService := provisioning.NewMuteTimingService(store, store, store, ng.Log)
alertRuleService := provisioning.NewAlertRuleService(store, store, store, alertRuleService := provisioning.NewAlertRuleService(store, store, ng.QuotaService, store,
int64(ng.Cfg.UnifiedAlerting.DefaultRuleEvaluationInterval.Seconds()), int64(ng.Cfg.UnifiedAlerting.DefaultRuleEvaluationInterval.Seconds()),
int64(ng.Cfg.UnifiedAlerting.BaseInterval.Seconds()), ng.Log) int64(ng.Cfg.UnifiedAlerting.BaseInterval.Seconds()), ng.Log)

View File

@ -10,6 +10,7 @@ import (
"github.com/grafana/grafana/pkg/services/ngalert/api/tooling/definitions" "github.com/grafana/grafana/pkg/services/ngalert/api/tooling/definitions"
"github.com/grafana/grafana/pkg/services/ngalert/models" "github.com/grafana/grafana/pkg/services/ngalert/models"
"github.com/grafana/grafana/pkg/services/ngalert/store" "github.com/grafana/grafana/pkg/services/ngalert/store"
"github.com/grafana/grafana/pkg/services/quota"
"github.com/grafana/grafana/pkg/util" "github.com/grafana/grafana/pkg/util"
) )
@ -18,12 +19,14 @@ type AlertRuleService struct {
baseIntervalSeconds int64 baseIntervalSeconds int64
ruleStore RuleStore ruleStore RuleStore
provenanceStore ProvisioningStore provenanceStore ProvisioningStore
quotas QuotaChecker
xact TransactionManager xact TransactionManager
log log.Logger log log.Logger
} }
func NewAlertRuleService(ruleStore RuleStore, func NewAlertRuleService(ruleStore RuleStore,
provenanceStore ProvisioningStore, provenanceStore ProvisioningStore,
quotas QuotaChecker,
xact TransactionManager, xact TransactionManager,
defaultIntervalSeconds int64, defaultIntervalSeconds int64,
baseIntervalSeconds int64, baseIntervalSeconds int64,
@ -33,6 +36,7 @@ func NewAlertRuleService(ruleStore RuleStore,
baseIntervalSeconds: baseIntervalSeconds, baseIntervalSeconds: baseIntervalSeconds,
ruleStore: ruleStore, ruleStore: ruleStore,
provenanceStore: provenanceStore, provenanceStore: provenanceStore,
quotas: quotas,
xact: xact, xact: xact,
log: log, log: log,
} }
@ -57,7 +61,7 @@ func (service *AlertRuleService) GetAlertRule(ctx context.Context, orgID int64,
// CreateAlertRule creates a new alert rule. This function will ignore any // CreateAlertRule creates a new alert rule. This function will ignore any
// interval that is set in the rule struct and use the already existing group // interval that is set in the rule struct and use the already existing group
// interval or the default one. // interval or the default one.
func (service *AlertRuleService) CreateAlertRule(ctx context.Context, rule models.AlertRule, provenance models.Provenance) (models.AlertRule, error) { func (service *AlertRuleService) CreateAlertRule(ctx context.Context, rule models.AlertRule, provenance models.Provenance, userID int64) (models.AlertRule, error) {
if rule.UID == "" { if rule.UID == "" {
rule.UID = util.GenerateShortUID() rule.UID = util.GenerateShortUID()
} }
@ -82,6 +86,18 @@ func (service *AlertRuleService) CreateAlertRule(ctx context.Context, rule model
} else { } else {
return errors.New("couldn't find newly created id") return errors.New("couldn't find newly created id")
} }
limitReached, err := service.quotas.CheckQuotaReached(ctx, "alert_rule", &quota.ScopeParameters{
OrgId: rule.OrgID,
UserId: userID,
})
if err != nil {
return fmt.Errorf("failed to check alert rule quota: %w", err)
}
if limitReached {
return models.ErrQuotaReached
}
return service.provenanceStore.SetProvenance(ctx, &rule, rule.OrgID, provenance) return service.provenanceStore.SetProvenance(ctx, &rule, rule.OrgID, provenance)
}) })
if err != nil { if err != nil {

View File

@ -15,26 +15,29 @@ import (
func TestAlertRuleService(t *testing.T) { func TestAlertRuleService(t *testing.T) {
ruleService := createAlertRuleService(t) ruleService := createAlertRuleService(t)
t.Run("alert rule creation should return the created id", func(t *testing.T) { t.Run("alert rule creation should return the created id", func(t *testing.T) {
var orgID int64 = 1 var orgID int64 = 1
rule, err := ruleService.CreateAlertRule(context.Background(), dummyRule("test#1", orgID), models.ProvenanceNone) rule, err := ruleService.CreateAlertRule(context.Background(), dummyRule("test#1", orgID), models.ProvenanceNone, 0)
require.NoError(t, err) require.NoError(t, err)
require.NotEqual(t, 0, rule.ID, "expected to get the created id and not the zero value") require.NotEqual(t, 0, rule.ID, "expected to get the created id and not the zero value")
}) })
t.Run("alert rule creation should set the right provenance", func(t *testing.T) { t.Run("alert rule creation should set the right provenance", func(t *testing.T) {
var orgID int64 = 1 var orgID int64 = 1
rule, err := ruleService.CreateAlertRule(context.Background(), dummyRule("test#2", orgID), models.ProvenanceAPI) rule, err := ruleService.CreateAlertRule(context.Background(), dummyRule("test#2", orgID), models.ProvenanceAPI, 0)
require.NoError(t, err) require.NoError(t, err)
_, provenance, err := ruleService.GetAlertRule(context.Background(), orgID, rule.UID) _, provenance, err := ruleService.GetAlertRule(context.Background(), orgID, rule.UID)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, models.ProvenanceAPI, provenance) require.Equal(t, models.ProvenanceAPI, provenance)
}) })
t.Run("alert rule group should be updated correctly", func(t *testing.T) { t.Run("alert rule group should be updated correctly", func(t *testing.T) {
var orgID int64 = 1 var orgID int64 = 1
rule := dummyRule("test#3", orgID) rule := dummyRule("test#3", orgID)
rule.RuleGroup = "a" rule.RuleGroup = "a"
rule, err := ruleService.CreateAlertRule(context.Background(), rule, models.ProvenanceNone) rule, err := ruleService.CreateAlertRule(context.Background(), rule, models.ProvenanceNone, 0)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, int64(60), rule.IntervalSeconds) require.Equal(t, int64(60), rule.IntervalSeconds)
@ -46,11 +49,12 @@ func TestAlertRuleService(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, interval, rule.IntervalSeconds) require.Equal(t, interval, rule.IntervalSeconds)
}) })
t.Run("alert rule should get interval from existing rule group", func(t *testing.T) { t.Run("alert rule should get interval from existing rule group", func(t *testing.T) {
var orgID int64 = 1 var orgID int64 = 1
rule := dummyRule("test#4", orgID) rule := dummyRule("test#4", orgID)
rule.RuleGroup = "b" rule.RuleGroup = "b"
rule, err := ruleService.CreateAlertRule(context.Background(), rule, models.ProvenanceNone) rule, err := ruleService.CreateAlertRule(context.Background(), rule, models.ProvenanceNone, 0)
require.NoError(t, err) require.NoError(t, err)
var interval int64 = 120 var interval int64 = 120
@ -59,10 +63,11 @@ func TestAlertRuleService(t *testing.T) {
rule = dummyRule("test#4-1", orgID) rule = dummyRule("test#4-1", orgID)
rule.RuleGroup = "b" rule.RuleGroup = "b"
rule, err = ruleService.CreateAlertRule(context.Background(), rule, models.ProvenanceNone) rule, err = ruleService.CreateAlertRule(context.Background(), rule, models.ProvenanceNone, 0)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, interval, rule.IntervalSeconds) require.Equal(t, interval, rule.IntervalSeconds)
}) })
t.Run("updating a rule group should bump the version number", func(t *testing.T) { t.Run("updating a rule group should bump the version number", func(t *testing.T) {
const ( const (
orgID = 123 orgID = 123
@ -75,7 +80,7 @@ func TestAlertRuleService(t *testing.T) {
rule.UID = ruleUID rule.UID = ruleUID
rule.RuleGroup = ruleGroup rule.RuleGroup = ruleGroup
rule.NamespaceUID = namespaceUID rule.NamespaceUID = namespaceUID
_, err := ruleService.CreateAlertRule(context.Background(), rule, models.ProvenanceNone) _, err := ruleService.CreateAlertRule(context.Background(), rule, models.ProvenanceNone, 0)
require.NoError(t, err) require.NoError(t, err)
rule, _, err = ruleService.GetAlertRule(context.Background(), orgID, ruleUID) rule, _, err = ruleService.GetAlertRule(context.Background(), orgID, ruleUID)
@ -91,6 +96,7 @@ func TestAlertRuleService(t *testing.T) {
require.Equal(t, int64(2), rule.Version) require.Equal(t, int64(2), rule.Version)
require.Equal(t, newInterval, rule.IntervalSeconds) require.Equal(t, newInterval, rule.IntervalSeconds)
}) })
t.Run("alert rule provenace should be correctly checked", func(t *testing.T) { t.Run("alert rule provenace should be correctly checked", func(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
@ -139,7 +145,7 @@ func TestAlertRuleService(t *testing.T) {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
var orgID int64 = 1 var orgID int64 = 1
rule := dummyRule(t.Name(), orgID) rule := dummyRule(t.Name(), orgID)
rule, err := ruleService.CreateAlertRule(context.Background(), rule, test.from) rule, err := ruleService.CreateAlertRule(context.Background(), rule, test.from, 0)
require.NoError(t, err) require.NoError(t, err)
_, err = ruleService.UpdateAlertRule(context.Background(), rule, test.to) _, err = ruleService.UpdateAlertRule(context.Background(), rule, test.to)
@ -151,6 +157,17 @@ func TestAlertRuleService(t *testing.T) {
}) })
} }
}) })
t.Run("quota met causes create to be rejected", func(t *testing.T) {
ruleService := createAlertRuleService(t)
checker := &MockQuotaChecker{}
checker.EXPECT().LimitExceeded()
ruleService.quotas = checker
_, err := ruleService.CreateAlertRule(context.Background(), dummyRule("test#1", 1), models.ProvenanceNone, 0)
require.ErrorIs(t, err, models.ErrQuotaReached)
})
} }
func createAlertRuleService(t *testing.T) AlertRuleService { func createAlertRuleService(t *testing.T) AlertRuleService {
@ -160,9 +177,12 @@ func createAlertRuleService(t *testing.T) AlertRuleService {
SQLStore: sqlStore, SQLStore: sqlStore,
BaseInterval: time.Second * 10, BaseInterval: time.Second * 10,
} }
quotas := MockQuotaChecker{}
quotas.EXPECT().LimitOK()
return AlertRuleService{ return AlertRuleService{
ruleStore: store, ruleStore: store,
provenanceStore: store, provenanceStore: store,
quotas: &quotas,
xact: sqlStore, xact: sqlStore,
log: log.New("testing"), log: log.New("testing"),
baseIntervalSeconds: 10, baseIntervalSeconds: 10,

View File

@ -5,6 +5,7 @@ import (
"github.com/grafana/grafana/pkg/services/ngalert/models" "github.com/grafana/grafana/pkg/services/ngalert/models"
"github.com/grafana/grafana/pkg/services/ngalert/store" "github.com/grafana/grafana/pkg/services/ngalert/store"
"github.com/grafana/grafana/pkg/services/quota"
) )
// AMStore is a store of Alertmanager configurations. // AMStore is a store of Alertmanager configurations.
@ -37,3 +38,9 @@ type RuleStore interface {
UpdateAlertRules(ctx context.Context, rule []store.UpdateRule) error UpdateAlertRules(ctx context.Context, rule []store.UpdateRule) error
DeleteAlertRulesByUID(ctx context.Context, orgID int64, ruleUID ...string) error DeleteAlertRulesByUID(ctx context.Context, orgID int64, ruleUID ...string) error
} }
// QuotaChecker represents the ability to evaluate whether quotas are met.
//go:generate mockery --name QuotaChecker --structname MockQuotaChecker --inpackage --filename quota_checker_mock.go --with-expecter
type QuotaChecker interface {
CheckQuotaReached(ctx context.Context, target string, scopeParams *quota.ScopeParameters) (bool, error)
}

View File

@ -0,0 +1,81 @@
// Code generated by mockery v2.12.0. DO NOT EDIT.
package provisioning
import (
context "context"
quota "github.com/grafana/grafana/pkg/services/quota"
mock "github.com/stretchr/testify/mock"
testing "testing"
)
// MockQuotaChecker is an autogenerated mock type for the QuotaChecker type
type MockQuotaChecker struct {
mock.Mock
}
type MockQuotaChecker_Expecter struct {
mock *mock.Mock
}
func (_m *MockQuotaChecker) EXPECT() *MockQuotaChecker_Expecter {
return &MockQuotaChecker_Expecter{mock: &_m.Mock}
}
// CheckQuotaReached provides a mock function with given fields: ctx, target, scopeParams
func (_m *MockQuotaChecker) CheckQuotaReached(ctx context.Context, target string, scopeParams *quota.ScopeParameters) (bool, error) {
ret := _m.Called(ctx, target, scopeParams)
var r0 bool
if rf, ok := ret.Get(0).(func(context.Context, string, *quota.ScopeParameters) bool); ok {
r0 = rf(ctx, target, scopeParams)
} else {
r0 = ret.Get(0).(bool)
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, string, *quota.ScopeParameters) error); ok {
r1 = rf(ctx, target, scopeParams)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockQuotaChecker_CheckQuotaReached_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CheckQuotaReached'
type MockQuotaChecker_CheckQuotaReached_Call struct {
*mock.Call
}
// CheckQuotaReached is a helper method to define mock.On call
// - ctx context.Context
// - target string
// - scopeParams *quota.ScopeParameters
func (_e *MockQuotaChecker_Expecter) CheckQuotaReached(ctx interface{}, target interface{}, scopeParams interface{}) *MockQuotaChecker_CheckQuotaReached_Call {
return &MockQuotaChecker_CheckQuotaReached_Call{Call: _e.mock.On("CheckQuotaReached", ctx, target, scopeParams)}
}
func (_c *MockQuotaChecker_CheckQuotaReached_Call) Run(run func(ctx context.Context, target string, scopeParams *quota.ScopeParameters)) *MockQuotaChecker_CheckQuotaReached_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(string), args[2].(*quota.ScopeParameters))
})
return _c
}
func (_c *MockQuotaChecker_CheckQuotaReached_Call) Return(_a0 bool, _a1 error) *MockQuotaChecker_CheckQuotaReached_Call {
_c.Call.Return(_a0, _a1)
return _c
}
// NewMockQuotaChecker creates a new instance of MockQuotaChecker. It also registers the testing.TB interface on the mock and a cleanup function to assert the mocks expectations.
func NewMockQuotaChecker(t testing.TB) *MockQuotaChecker {
mock := &MockQuotaChecker{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}

View File

@ -170,3 +170,13 @@ func (m *MockProvisioningStore_Expecter) SaveSucceeds() *MockProvisioningStore_E
m.DeleteProvenance(mock.Anything, mock.Anything, mock.Anything).Return(nil) m.DeleteProvenance(mock.Anything, mock.Anything, mock.Anything).Return(nil)
return m return m
} }
func (m *MockQuotaChecker_Expecter) LimitOK() *MockQuotaChecker_Expecter {
m.CheckQuotaReached(mock.Anything, mock.Anything, mock.Anything).Return(false, nil)
return m
}
func (m *MockQuotaChecker_Expecter) LimitExceeded() *MockQuotaChecker_Expecter {
m.CheckQuotaReached(mock.Anything, mock.Anything, mock.Anything).Return(true, nil)
return m
}