Chore: Remove bus from quota (#45143)

* Remove bus from quota

* workaround

* Change ExpectedOrg ot *models.Org
This commit is contained in:
idafurjes
2022-02-10 12:42:06 +01:00
committed by GitHub
parent f2795981c6
commit 923b62ecab
8 changed files with 52 additions and 116 deletions

View File

@ -1,12 +1,11 @@
package middleware
import (
"context"
"fmt"
"testing"
"github.com/grafana/grafana/pkg/bus"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/services/sqlstore/mockstore"
"github.com/grafana/grafana/pkg/setting"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -40,11 +39,9 @@ func TestMiddlewareAuth(t *testing.T) {
middlewareScenario(t, "ReqSignIn true and NoAnonynmous true", func(
t *testing.T, sc *scenarioContext) {
bus.AddHandler("test", func(ctx context.Context, query *models.GetOrgByNameQuery) error {
query.Result = &models.Org{Id: orgID, Name: "test"}
return nil
})
sqlStore := mockstore.NewSQLStoreMock()
sqlStore.ExpectedOrg = &models.Org{Id: orgID, Name: "test"}
sc.sqlStore = sqlStore
sc.m.Get("/api/secure", ReqSignedInNoAnonymous, sc.defaultHandler)
sc.fakeReq("GET", "/api/secure").exec()
@ -53,11 +50,9 @@ func TestMiddlewareAuth(t *testing.T) {
middlewareScenario(t, "ReqSignIn true and request with forceLogin in query string", func(
t *testing.T, sc *scenarioContext) {
bus.AddHandler("test", func(ctx context.Context, query *models.GetOrgByNameQuery) error {
query.Result = &models.Org{Id: orgID, Name: "test"}
return nil
})
sqlStore := mockstore.NewSQLStoreMock()
sqlStore.ExpectedOrg = &models.Org{Id: orgID, Name: "test"}
sc.sqlStore = sqlStore
sc.m.Get("/secure", reqSignIn, sc.defaultHandler)
sc.fakeReq("GET", "/secure?forceLogin=true").exec()
@ -82,11 +77,6 @@ func TestMiddlewareAuth(t *testing.T) {
middlewareScenario(t, "ReqSignIn true and request with different org provided in query string", func(
t *testing.T, sc *scenarioContext) {
bus.AddHandler("test", func(ctx context.Context, query *models.GetOrgByNameQuery) error {
query.Result = &models.Org{Id: orgID, Name: "test"}
return nil
})
sc.m.Get("/secure", reqSignIn, sc.defaultHandler)
sc.fakeReq("GET", "/secure?orgId=2").exec()

View File

@ -333,7 +333,6 @@ func TestMiddlewareContext(t *testing.T) {
middlewareScenario(t, "When anonymous access is enabled", func(t *testing.T, sc *scenarioContext) {
org, err := sc.sqlStore.CreateOrgWithMember(sc.cfg.AnonymousOrgName, 1)
require.NoError(t, err)
sc.fakeReq("GET", "/").exec()
assert.Equal(t, int64(0), sc.context.UserId)
@ -674,7 +673,6 @@ func middlewareScenario(t *testing.T, desc string, fn scenarioFunc, cbs ...func(
}
sc := &scenarioContext{t: t, cfg: cfg}
viewsPath, err := filepath.Abs("../../public/views")
require.NoError(t, err)
exists, err := fs.Exists(viewsPath)
@ -726,8 +724,8 @@ func getContextHandler(t *testing.T, cfg *setting.Cfg) *contexthandler.ContextHa
cfg.RemoteCacheOptions = &setting.RemoteCacheOptions{
Name: "database",
}
remoteCacheSvc, err := remotecache.ProvideService(cfg, sqlStore)
require.NoError(t, err)
remoteCacheSvc := remotecache.NewFakeStore(t)
userAuthTokenSvc := auth.NewFakeUserAuthTokenService()
renderSvc := &fakeRenderService{}
authJWTSvc := models.NewFakeJWTService()

View File

@ -9,7 +9,7 @@ import (
)
// Quota returns a function that returns a function used to call quotaservice based on target name
func Quota(quotaService *quota.QuotaService) func(string) web.Handler {
func Quota(quotaService quota.Service) func(string) web.Handler {
if quotaService == nil {
panic("quotaService is nil")
}

View File

@ -4,10 +4,7 @@ import (
"context"
"testing"
"github.com/grafana/grafana/pkg/bus"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/services/auth"
"github.com/grafana/grafana/pkg/services/quota"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/web"
"github.com/stretchr/testify/assert"
@ -16,16 +13,7 @@ import (
func TestMiddlewareQuota(t *testing.T) {
t.Run("With user not logged in", func(t *testing.T) {
middlewareScenario(t, "and global quota not reached", func(t *testing.T, sc *scenarioContext) {
bus.AddHandler("globalQuota", func(_ context.Context, query *models.GetGlobalQuotaByTargetQuery) error {
query.Result = &models.GlobalQuotaDTO{
Target: query.Target,
Limit: query.Default,
Used: 4,
}
return nil
})
quotaHandler := getQuotaHandler(sc, "user")
quotaHandler := getQuotaHandler(false, "user")
sc.m.Get("/user", quotaHandler, sc.defaultHandler)
sc.fakeReq("GET", "/user").exec()
@ -33,16 +21,7 @@ func TestMiddlewareQuota(t *testing.T) {
}, configure)
middlewareScenario(t, "and global quota reached", func(t *testing.T, sc *scenarioContext) {
bus.AddHandler("globalQuota", func(_ context.Context, query *models.GetGlobalQuotaByTargetQuery) error {
query.Result = &models.GlobalQuotaDTO{
Target: query.Target,
Limit: query.Default,
Used: 4,
}
return nil
})
quotaHandler := getQuotaHandler(sc, "user")
quotaHandler := getQuotaHandler(true, "user")
sc.m.Get("/user", quotaHandler, sc.defaultHandler)
sc.fakeReq("GET", "/user").exec()
assert.Equal(t, 403, sc.resp.Code)
@ -53,16 +32,7 @@ func TestMiddlewareQuota(t *testing.T) {
})
middlewareScenario(t, "and global session quota not reached", func(t *testing.T, sc *scenarioContext) {
bus.AddHandler("globalQuota", func(_ context.Context, query *models.GetGlobalQuotaByTargetQuery) error {
query.Result = &models.GlobalQuotaDTO{
Target: query.Target,
Limit: query.Default,
Used: 4,
}
return nil
})
quotaHandler := getQuotaHandler(sc, "session")
quotaHandler := getQuotaHandler(false, "session")
sc.m.Get("/user", quotaHandler, sc.defaultHandler)
sc.fakeReq("GET", "/user").exec()
assert.Equal(t, 200, sc.resp.Code)
@ -73,7 +43,7 @@ func TestMiddlewareQuota(t *testing.T) {
})
middlewareScenario(t, "and global session quota reached", func(t *testing.T, sc *scenarioContext) {
quotaHandler := getQuotaHandler(sc, "session")
quotaHandler := getQuotaHandler(true, "session")
sc.m.Get("/user", quotaHandler, sc.defaultHandler)
sc.fakeReq("GET", "/user").exec()
assert.Equal(t, 403, sc.resp.Code)
@ -86,53 +56,20 @@ func TestMiddlewareQuota(t *testing.T) {
t.Run("with user logged in", func(t *testing.T) {
const quotaUsed = 4
setUp := func(sc *scenarioContext) {
sc.withTokenSessionCookie("token")
bus.AddHandler("test", func(ctx context.Context, query *models.GetSignedInUserQuery) error {
query.Result = &models.SignedInUser{OrgId: 2, UserId: 12}
return nil
})
sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*models.UserToken, error) {
return &models.UserToken{
UserId: 12,
UnhashedToken: "",
}, nil
}
bus.AddHandler("globalQuota", func(_ context.Context, query *models.GetGlobalQuotaByTargetQuery) error {
query.Result = &models.GlobalQuotaDTO{
Target: query.Target,
Limit: query.Default,
Used: quotaUsed,
}
return nil
})
bus.AddHandler("userQuota", func(_ context.Context, query *models.GetUserQuotaByTargetQuery) error {
query.Result = &models.UserQuotaDTO{
Target: query.Target,
Limit: query.Default,
Used: quotaUsed,
}
return nil
})
bus.AddHandler("orgQuota", func(_ context.Context, query *models.GetOrgQuotaByTargetQuery) error {
query.Result = &models.OrgQuotaDTO{
Target: query.Target,
Limit: query.Default,
Used: quotaUsed,
}
return nil
})
}
middlewareScenario(t, "global datasource quota reached", func(t *testing.T, sc *scenarioContext) {
setUp(sc)
quotaHandler := getQuotaHandler(sc, "data_source")
quotaHandler := getQuotaHandler(true, "data_source")
sc.m.Get("/ds", quotaHandler, sc.defaultHandler)
sc.fakeReq("GET", "/ds").exec()
assert.Equal(t, 403, sc.resp.Code)
@ -145,7 +82,7 @@ func TestMiddlewareQuota(t *testing.T) {
middlewareScenario(t, "user Org quota not reached", func(t *testing.T, sc *scenarioContext) {
setUp(sc)
quotaHandler := getQuotaHandler(sc, "org")
quotaHandler := getQuotaHandler(false, "org")
sc.m.Get("/org", quotaHandler, sc.defaultHandler)
sc.fakeReq("GET", "/org").exec()
@ -159,7 +96,7 @@ func TestMiddlewareQuota(t *testing.T) {
middlewareScenario(t, "user Org quota reached", func(t *testing.T, sc *scenarioContext) {
setUp(sc)
quotaHandler := getQuotaHandler(sc, "org")
quotaHandler := getQuotaHandler(true, "org")
sc.m.Get("/org", quotaHandler, sc.defaultHandler)
sc.fakeReq("GET", "/org").exec()
assert.Equal(t, 403, sc.resp.Code)
@ -172,7 +109,7 @@ func TestMiddlewareQuota(t *testing.T) {
middlewareScenario(t, "org dashboard quota not reached", func(t *testing.T, sc *scenarioContext) {
setUp(sc)
quotaHandler := getQuotaHandler(sc, "dashboard")
quotaHandler := getQuotaHandler(false, "dashboard")
sc.m.Get("/dashboard", quotaHandler, sc.defaultHandler)
sc.fakeReq("GET", "/dashboard").exec()
assert.Equal(t, 200, sc.resp.Code)
@ -185,7 +122,7 @@ func TestMiddlewareQuota(t *testing.T) {
middlewareScenario(t, "org dashboard quota reached", func(t *testing.T, sc *scenarioContext) {
setUp(sc)
quotaHandler := getQuotaHandler(sc, "dashboard")
quotaHandler := getQuotaHandler(true, "dashboard")
sc.m.Get("/dashboard", quotaHandler, sc.defaultHandler)
sc.fakeReq("GET", "/dashboard").exec()
assert.Equal(t, 403, sc.resp.Code)
@ -198,7 +135,7 @@ func TestMiddlewareQuota(t *testing.T) {
middlewareScenario(t, "org dashboard quota reached, but quotas disabled", func(t *testing.T, sc *scenarioContext) {
setUp(sc)
quotaHandler := getQuotaHandler(sc, "dashboard")
quotaHandler := getQuotaHandler(false, "dashboard")
sc.m.Get("/dashboard", quotaHandler, sc.defaultHandler)
sc.fakeReq("GET", "/dashboard").exec()
assert.Equal(t, 200, sc.resp.Code)
@ -212,7 +149,7 @@ func TestMiddlewareQuota(t *testing.T) {
middlewareScenario(t, "org alert quota reached and unified alerting is enabled", func(t *testing.T, sc *scenarioContext) {
setUp(sc)
quotaHandler := getQuotaHandler(sc, "alert_rule")
quotaHandler := getQuotaHandler(true, "alert_rule")
sc.m.Get("/alert_rule", quotaHandler, sc.defaultHandler)
sc.fakeReq("GET", "/alert_rule").exec()
assert.Equal(t, 403, sc.resp.Code)
@ -227,7 +164,7 @@ func TestMiddlewareQuota(t *testing.T) {
middlewareScenario(t, "org alert quota not reached and unified alerting is enabled", func(t *testing.T, sc *scenarioContext) {
setUp(sc)
quotaHandler := getQuotaHandler(sc, "alert_rule")
quotaHandler := getQuotaHandler(false, "alert_rule")
sc.m.Get("/alert_rule", quotaHandler, sc.defaultHandler)
sc.fakeReq("GET", "/alert_rule").exec()
assert.Equal(t, 200, sc.resp.Code)
@ -243,7 +180,7 @@ func TestMiddlewareQuota(t *testing.T) {
// this scenario can only happen if the feature was enabled and later disabled
setUp(sc)
quotaHandler := getQuotaHandler(sc, "alert_rule")
quotaHandler := getQuotaHandler(true, "alert_rule")
sc.m.Get("/alert_rule", quotaHandler, sc.defaultHandler)
sc.fakeReq("GET", "/alert_rule").exec()
assert.Equal(t, 403, sc.resp.Code)
@ -256,7 +193,7 @@ func TestMiddlewareQuota(t *testing.T) {
middlewareScenario(t, "org alert quota not reached but ngalert disabled", func(t *testing.T, sc *scenarioContext) {
setUp(sc)
quotaHandler := getQuotaHandler(sc, "alert_rule")
quotaHandler := getQuotaHandler(false, "alert_rule")
sc.m.Get("/alert_rule", quotaHandler, sc.defaultHandler)
sc.fakeReq("GET", "/alert_rule").exec()
assert.Equal(t, 200, sc.resp.Code)
@ -268,13 +205,10 @@ func TestMiddlewareQuota(t *testing.T) {
})
}
func getQuotaHandler(sc *scenarioContext, target string) web.Handler {
fakeAuthTokenService := auth.NewFakeUserAuthTokenService()
qs := &quota.QuotaService{
AuthTokenService: fakeAuthTokenService,
Cfg: sc.cfg,
func getQuotaHandler(reached bool, target string) web.Handler {
qs := &mockQuotaService{
reached: reached,
}
return Quota(qs)(target)
}
@ -303,3 +237,12 @@ func configure(cfg *setting.Cfg) {
},
}
}
type mockQuotaService struct {
reached bool
err error
}
func (m *mockQuotaService) QuotaReached(c *models.ReqContext, target string) (bool, error) {
return m.reached, m.err
}

View File

@ -33,7 +33,7 @@ type scenarioContext struct {
jwtAuthService *models.FakeJWTService
remoteCacheService *remotecache.RemoteCache
cfg *setting.Cfg
sqlStore *sqlstore.SQLStore
sqlStore sqlstore.Store
contextHandler *contexthandler.ContextHandler
req *http.Request

View File

@ -55,7 +55,7 @@ type ContextHandler struct {
JWTAuthService models.JWTService
RemoteCache *remotecache.RemoteCache
RenderService rendering.Service
SQLStore *sqlstore.SQLStore
SQLStore sqlstore.Store
tracer tracing.Tracer
// GetTime returns the current time.
// Stubbable by tests.

View File

@ -3,23 +3,29 @@ package quota
import (
"errors"
"github.com/grafana/grafana/pkg/bus"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/services/sqlstore"
"github.com/grafana/grafana/pkg/setting"
)
var ErrInvalidQuotaTarget = errors.New("invalid quota target")
func ProvideService(cfg *setting.Cfg, tokenService models.UserTokenService) *QuotaService {
func ProvideService(cfg *setting.Cfg, tokenService models.UserTokenService, sqlStore *sqlstore.SQLStore) *QuotaService {
return &QuotaService{
Cfg: cfg,
AuthTokenService: tokenService,
SQLStore: sqlStore,
}
}
type QuotaService struct {
AuthTokenService models.UserTokenService
Cfg *setting.Cfg
SQLStore sqlstore.Store
}
type Service interface {
QuotaReached(c *models.ReqContext, target string) (bool, error)
}
func (qs *QuotaService) QuotaReached(c *models.ReqContext, target string) (bool, error) {
@ -38,7 +44,6 @@ func (qs *QuotaService) QuotaReached(c *models.ReqContext, target string) (bool,
if err != nil {
return false, err
}
for _, scope := range scopes {
c.Logger.Debug("Checking quota", "target", target, "scope", scope)
@ -63,7 +68,7 @@ func (qs *QuotaService) QuotaReached(c *models.ReqContext, target string) (bool,
continue
}
query := models.GetGlobalQuotaByTargetQuery{Target: scope.Target, UnifiedAlertingEnabled: qs.Cfg.UnifiedAlerting.IsEnabled()}
if err := bus.Dispatch(c.Req.Context(), &query); err != nil {
if err := qs.SQLStore.GetGlobalQuotaByTarget(c.Req.Context(), &query); err != nil {
return true, err
}
if query.Result.Used >= scope.DefaultLimit {
@ -79,7 +84,7 @@ func (qs *QuotaService) QuotaReached(c *models.ReqContext, target string) (bool,
Default: scope.DefaultLimit,
UnifiedAlertingEnabled: qs.Cfg.UnifiedAlerting.IsEnabled(),
}
if err := bus.Dispatch(c.Req.Context(), &query); err != nil {
if err := qs.SQLStore.GetOrgQuotaByTarget(c.Req.Context(), &query); err != nil {
return true, err
}
if query.Result.Limit < 0 {
@ -97,7 +102,7 @@ func (qs *QuotaService) QuotaReached(c *models.ReqContext, target string) (bool,
continue
}
query := models.GetUserQuotaByTargetQuery{UserId: c.UserId, Target: scope.Target, Default: scope.DefaultLimit, UnifiedAlertingEnabled: qs.Cfg.UnifiedAlerting.IsEnabled()}
if err := bus.Dispatch(c.Req.Context(), &query); err != nil {
if err := qs.SQLStore.GetUserQuotaByTarget(c.Req.Context(), &query); err != nil {
return true, err
}
if query.Result.Limit < 0 {
@ -112,7 +117,6 @@ func (qs *QuotaService) QuotaReached(c *models.ReqContext, target string) (bool,
}
}
}
return false, nil
}

View File

@ -29,6 +29,7 @@ type SQLStoreMock struct {
ExpectedTeamsByUser []*models.TeamDTO
ExpectedSearchOrgList []*models.OrgDTO
ExpectedDatasources []*models.DataSource
ExpectedOrg *models.Org
ExpectedError error
}
@ -67,11 +68,11 @@ func (m *SQLStoreMock) SearchDashboardSnapshots(query *models.GetDashboardSnapsh
}
func (m *SQLStoreMock) GetOrgByName(name string) (*models.Org, error) {
return nil, m.ExpectedError
return m.ExpectedOrg, m.ExpectedError
}
func (m *SQLStoreMock) CreateOrgWithMember(name string, userID int64) (models.Org, error) {
return models.Org{}, nil
return *m.ExpectedOrg, nil
}
func (m *SQLStoreMock) UpdateOrg(ctx context.Context, cmd *models.UpdateOrgCommand) error {