From 77a4869fcadf13827d76d5767d4de74812d6dd6d Mon Sep 17 00:00:00 2001 From: Kristin Laemmert Date: Mon, 8 Jul 2024 10:00:13 -0400 Subject: [PATCH] accesscontrol service read replica (#89963) * accesscontrol service read replica * now using the ReplDB interface * ReadReplica for GetUser --- pkg/api/folder_bench_test.go | 20 +++++++++---------- .../commands/conflict_user_command.go | 14 +++++++++---- pkg/server/wire.go | 2 ++ pkg/services/accesscontrol/acimpl/service.go | 4 ++-- .../acimpl/service_bench_test.go | 2 +- .../accesscontrol/acimpl/service_test.go | 4 ++-- .../accesscontrol/database/database.go | 20 +++++++++---------- .../accesscontrol/database/database_test.go | 4 ++-- .../database/externalservices.go | 4 ++-- .../database/externalservices_test.go | 17 ++++++++-------- .../accesscontrol/migrator/migrator.go | 6 +++--- .../migrator/migrator_bench_test.go | 2 +- .../accesscontrol/migrator/migrator_test.go | 2 +- pkg/services/sqlstore/replstore.go | 12 +++++++++++ 14 files changed, 67 insertions(+), 46 deletions(-) diff --git a/pkg/api/folder_bench_test.go b/pkg/api/folder_bench_test.go index d1ae9f030ba..9a31a50fc23 100644 --- a/pkg/api/folder_bench_test.go +++ b/pkg/api/folder_bench_test.go @@ -69,7 +69,7 @@ const ( ) type benchScenario struct { - db db.DB + db db.ReplDB // signedInUser is the user that is signed in to the server cfg *setting.Cfg signedInUser *user.SignedInUser @@ -202,7 +202,7 @@ func BenchmarkFolderListAndSearch(b *testing.B) { func setupDB(b testing.TB) benchScenario { b.Helper() - db, cfg := sqlstore.InitTestDB(b) + db, cfg := sqlstore.InitTestReplDB(b) IDs := map[int64]struct{}{} opts := sqlstore.NativeSettingsForDialect(db.GetDialect()) @@ -451,26 +451,26 @@ func setupServer(b testing.TB, sc benchScenario, features featuremgmt.FeatureTog quotaSrv := quotatest.New(false, nil) - dashStore, err := database.ProvideDashboardStore(sc.db, sc.cfg, features, tagimpl.ProvideService(sc.db), quotaSrv) + dashStore, err := database.ProvideDashboardStore(sc.db.DB(), sc.cfg, features, tagimpl.ProvideService(sc.db.DB()), quotaSrv) require.NoError(b, err) - folderStore := folderimpl.ProvideDashboardFolderStore(sc.db) + folderStore := folderimpl.ProvideDashboardFolderStore(sc.db.DB()) ac := acimpl.ProvideAccessControl(featuremgmt.WithFeatures(), zanzana.NewNoopClient()) - folderServiceWithFlagOn := folderimpl.ProvideService(ac, bus.ProvideBus(tracing.InitializeTracerForTest()), dashStore, folderStore, sc.db, features, supportbundlestest.NewFakeBundleService(), nil) + folderServiceWithFlagOn := folderimpl.ProvideService(ac, bus.ProvideBus(tracing.InitializeTracerForTest()), dashStore, folderStore, sc.db.DB(), features, supportbundlestest.NewFakeBundleService(), nil) cfg := setting.NewCfg() actionSets := resourcepermissions.NewActionSetService() acSvc := acimpl.ProvideOSSService( sc.cfg, acdb.ProvideService(sc.db), actionSets, localcache.ProvideService(), - features, tracing.InitializeTracerForTest(), zanzana.NewNoopClient(), sc.db, + features, tracing.InitializeTracerForTest(), zanzana.NewNoopClient(), sc.db.DB(), ) folderPermissions, err := ossaccesscontrol.ProvideFolderPermissions( - cfg, features, routing.NewRouteRegister(), sc.db, ac, license, &dashboards.FakeDashboardStore{}, folderServiceWithFlagOn, acSvc, sc.teamSvc, sc.userSvc, actionSets) + cfg, features, routing.NewRouteRegister(), sc.db.DB(), ac, license, &dashboards.FakeDashboardStore{}, folderServiceWithFlagOn, acSvc, sc.teamSvc, sc.userSvc, actionSets) require.NoError(b, err) dashboardPermissions, err := ossaccesscontrol.ProvideDashboardPermissions( - cfg, features, routing.NewRouteRegister(), sc.db, ac, license, &dashboards.FakeDashboardStore{}, folderServiceWithFlagOn, acSvc, sc.teamSvc, sc.userSvc, actionSets) + cfg, features, routing.NewRouteRegister(), sc.db.DB(), ac, license, &dashboards.FakeDashboardStore{}, folderServiceWithFlagOn, acSvc, sc.teamSvc, sc.userSvc, actionSets) require.NoError(b, err) dashboardSvc, err := dashboardservice.ProvideDashboardServiceImpl( @@ -486,10 +486,10 @@ func setupServer(b testing.TB, sc benchScenario, features featuremgmt.FeatureTog hs := &HTTPServer{ CacheService: localcache.New(5*time.Minute, 10*time.Minute), Cfg: sc.cfg, - SQLStore: sc.db, + SQLStore: sc.db.DB(), Features: features, QuotaService: quotaSrv, - SearchService: search.ProvideService(sc.cfg, sc.db, starSvc, dashboardSvc), + SearchService: search.ProvideService(sc.cfg, sc.db.DB(), starSvc, dashboardSvc), folderService: folderServiceWithFlagOn, DashboardService: dashboardSvc, } diff --git a/pkg/cmd/grafana-cli/commands/conflict_user_command.go b/pkg/cmd/grafana-cli/commands/conflict_user_command.go index 0aad8adfec7..9ad5fe87611 100644 --- a/pkg/cmd/grafana-cli/commands/conflict_user_command.go +++ b/pkg/cmd/grafana-cli/commands/conflict_user_command.go @@ -70,7 +70,7 @@ func initializeConflictResolver(cmd *utils.ContextCommandLine, f Formatter, ctx if err != nil { return nil, fmt.Errorf("%v: %w", "failed to load configuration", err) } - s, err := getSqlStore(cfg, tracer, features) + s, replstore, err := getSqlStore(cfg, tracer, features) if err != nil { return nil, fmt.Errorf("%v: %w", "failed to get to sql", err) } @@ -90,7 +90,7 @@ func initializeConflictResolver(cmd *utils.ContextCommandLine, f Formatter, ctx if err != nil { return nil, fmt.Errorf("%v: %w", "failed to initialize tracer service", err) } - acService, err := acimpl.ProvideService(cfg, s, routing, nil, nil, nil, features, tracer, zanzana.NewNoopClient()) + acService, err := acimpl.ProvideService(cfg, replstore, routing, nil, nil, nil, features, tracer, zanzana.NewNoopClient()) if err != nil { return nil, fmt.Errorf("%v: %w", "failed to get access control", err) } @@ -99,9 +99,15 @@ func initializeConflictResolver(cmd *utils.ContextCommandLine, f Formatter, ctx return &resolver, nil } -func getSqlStore(cfg *setting.Cfg, tracer tracing.Tracer, features featuremgmt.FeatureToggles) (*sqlstore.SQLStore, error) { +func getSqlStore(cfg *setting.Cfg, tracer tracing.Tracer, features featuremgmt.FeatureToggles) (*sqlstore.SQLStore, *sqlstore.ReplStore, error) { bus := bus.ProvideBus(tracer) - return sqlstore.ProvideService(cfg, features, &migrations.OSSMigrations{}, bus, tracer) + ss, err := sqlstore.ProvideService(cfg, features, &migrations.OSSMigrations{}, bus, tracer) + if err != nil { + return nil, nil, err + } + + replStore, err := sqlstore.ProvideServiceWithReadReplica(ss, cfg, features, &migrations.OSSMigrations{}, bus, tracer) + return ss, replStore, err } func runListConflictUsers() func(context *cli.Context) error { diff --git a/pkg/server/wire.go b/pkg/server/wire.go index a8bf87483eb..430180993c3 100644 --- a/pkg/server/wire.go +++ b/pkg/server/wire.go @@ -396,6 +396,7 @@ var wireSet = wire.NewSet( wire.Bind(new(notifications.WebhookSender), new(*notifications.NotificationService)), wire.Bind(new(notifications.EmailSender), new(*notifications.NotificationService)), wire.Bind(new(db.DB), new(*sqlstore.SQLStore)), + wire.Bind(new(db.ReplDB), new(*sqlstore.ReplStore)), prefimpl.ProvideService, oauthtoken.ProvideService, wire.Bind(new(oauthtoken.OAuthTokenService), new(*oauthtoken.Service)), @@ -412,6 +413,7 @@ var wireCLISet = wire.NewSet( wire.Bind(new(notifications.WebhookSender), new(*notifications.NotificationService)), wire.Bind(new(notifications.EmailSender), new(*notifications.NotificationService)), wire.Bind(new(db.DB), new(*sqlstore.SQLStore)), + wire.Bind(new(db.ReplDB), new(*sqlstore.ReplStore)), prefimpl.ProvideService, oauthtoken.ProvideService, wire.Bind(new(oauthtoken.OAuthTokenService), new(*oauthtoken.Service)), diff --git a/pkg/services/accesscontrol/acimpl/service.go b/pkg/services/accesscontrol/acimpl/service.go index 69f392bb40f..ec22b8a81f2 100644 --- a/pkg/services/accesscontrol/acimpl/service.go +++ b/pkg/services/accesscontrol/acimpl/service.go @@ -48,11 +48,11 @@ var SharedWithMeFolderPermission = accesscontrol.Permission{ var OSSRolesPrefixes = []string{accesscontrol.ManagedRolePrefix, accesscontrol.ExternalServiceRolePrefix} func ProvideService( - cfg *setting.Cfg, db db.DB, routeRegister routing.RouteRegister, cache *localcache.CacheService, + cfg *setting.Cfg, db db.ReplDB, routeRegister routing.RouteRegister, cache *localcache.CacheService, accessControl accesscontrol.AccessControl, actionResolver accesscontrol.ActionResolver, features featuremgmt.FeatureToggles, tracer tracing.Tracer, zclient zanzana.Client, ) (*Service, error) { - service := ProvideOSSService(cfg, database.ProvideService(db), actionResolver, cache, features, tracer, zclient, db) + service := ProvideOSSService(cfg, database.ProvideService(db), actionResolver, cache, features, tracer, zclient, db.DB()) api.NewAccessControlAPI(routeRegister, accessControl, service, features).RegisterAPIEndpoints() if err := accesscontrol.DeclareFixedRoles(service, cfg); err != nil { diff --git a/pkg/services/accesscontrol/acimpl/service_bench_test.go b/pkg/services/accesscontrol/acimpl/service_bench_test.go index c3b1a103b84..78e153936c7 100644 --- a/pkg/services/accesscontrol/acimpl/service_bench_test.go +++ b/pkg/services/accesscontrol/acimpl/service_bench_test.go @@ -25,7 +25,7 @@ import ( // - each managed role will have 3 permissions {"resources:action2", "resources:id:x"} where x belongs to [1, 3] func setupBenchEnv(b *testing.B, usersCount, resourceCount int) (accesscontrol.Service, *user.SignedInUser) { now := time.Now() - sqlStore := db.InitTestDB(b) + sqlStore := db.InitTestReplDB(b) store := database.ProvideService(sqlStore) acService := &Service{ cfg: setting.NewCfg(), diff --git a/pkg/services/accesscontrol/acimpl/service_test.go b/pkg/services/accesscontrol/acimpl/service_test.go index 7e7a6cdd43c..74382c44a28 100644 --- a/pkg/services/accesscontrol/acimpl/service_test.go +++ b/pkg/services/accesscontrol/acimpl/service_test.go @@ -41,8 +41,8 @@ func setupTestEnv(t testing.TB) *Service { log: log.New("accesscontrol"), registrations: accesscontrol.RegistrationList{}, roles: accesscontrol.BuildBasicRoleDefinitions(), - store: database.ProvideService(db.InitTestDB(t)), tracer: tracing.InitializeTracerForTest(), + store: database.ProvideService(db.InitTestReplDB(t)), } require.NoError(t, ac.RegisterFixedRoles(context.Background())) return ac @@ -65,7 +65,7 @@ func TestUsageMetrics(t *testing.T) { s := ProvideOSSService( cfg, - database.ProvideService(db.InitTestDB(t)), + database.ProvideService(db.InitTestReplDB(t)), &resourcepermissions.FakeActionSetSvc{}, localcache.ProvideService(), featuremgmt.WithFeatures(), diff --git a/pkg/services/accesscontrol/database/database.go b/pkg/services/accesscontrol/database/database.go index 43e8564505f..4b3c6479de6 100644 --- a/pkg/services/accesscontrol/database/database.go +++ b/pkg/services/accesscontrol/database/database.go @@ -36,17 +36,17 @@ const ( WHERE br.role = ?` ) -func ProvideService(sql db.DB) *AccessControlStore { +func ProvideService(sql db.ReplDB) *AccessControlStore { return &AccessControlStore{sql} } type AccessControlStore struct { - sql db.DB + sql db.ReplDB } func (s *AccessControlStore) GetUserPermissions(ctx context.Context, query accesscontrol.GetUserPermissionsQuery) ([]accesscontrol.Permission, error) { result := make([]accesscontrol.Permission, 0) - err := s.sql.WithDbSession(ctx, func(sess *db.Session) error { + err := s.sql.ReadReplica().WithDbSession(ctx, func(sess *db.Session) error { if query.UserID == 0 && len(query.TeamIDs) == 0 && len(query.Roles) == 0 { // no permission to fetch return nil @@ -104,7 +104,7 @@ func (s *AccessControlStore) GetTeamsPermissions(ctx context.Context, query acce orgID := query.OrgID rolePrefixes := query.RolePrefixes result := make([]teamPermission, 0) - err := s.sql.WithDbSession(ctx, func(sess *db.Session) error { + err := s.sql.ReadReplica().WithDbSession(ctx, func(sess *db.Session) error { if len(teams) == 0 { // no permission to fetch return nil @@ -172,7 +172,7 @@ func (s *AccessControlStore) SearchUsersPermissions(ctx context.Context, orgID i } } - if err := s.sql.WithDbSession(ctx, func(sess *db.Session) error { + if err := s.sql.ReadReplica().WithDbSession(ctx, func(sess *db.Session) error { roleNameFilterJoin := "" if len(options.RolePrefixes) > 0 { roleNameFilterJoin = "INNER JOIN role AS r ON up.role_id = r.id" @@ -198,7 +198,7 @@ func (s *AccessControlStore) SearchUsersPermissions(ctx context.Context, orgID i params = append(params, userID) } - grafanaAdmin := fmt.Sprintf(grafanaAdminAssignsSQL, s.sql.Quote("user")) + grafanaAdmin := fmt.Sprintf(grafanaAdminAssignsSQL, s.sql.ReadReplica().Quote("user")) params = append(params, accesscontrol.RoleGrafanaAdmin) if options.NamespacedID != "" { grafanaAdmin += " AND sa.user_id = ?" @@ -284,11 +284,11 @@ func (s *AccessControlStore) GetUsersBasicRoles(ctx context.Context, userFilter IsAdmin bool `xorm:"is_admin"` } dbRoles := make([]UserOrgRole, 0) - if err := s.sql.WithDbSession(ctx, func(sess *db.Session) error { + if err := s.sql.ReadReplica().WithDbSession(ctx, func(sess *db.Session) error { // Find roles q := ` SELECT u.id, ou.role, u.is_admin - FROM ` + s.sql.GetDialect().Quote("user") + ` AS u + FROM ` + s.sql.ReadReplica().GetDialect().Quote("user") + ` AS u LEFT JOIN org_user AS ou ON u.id = ou.user_id WHERE (u.is_admin OR ou.org_id = ?) ` @@ -318,7 +318,7 @@ func (s *AccessControlStore) GetUsersBasicRoles(ctx context.Context, userFilter } func (s *AccessControlStore) DeleteUserPermissions(ctx context.Context, orgID, userID int64) error { - err := s.sql.WithDbSession(ctx, func(sess *db.Session) error { + err := s.sql.DB().WithDbSession(ctx, func(sess *db.Session) error { roleDeleteQuery := "DELETE FROM user_role WHERE user_id = ?" roleDeleteParams := []any{roleDeleteQuery, userID} if orgID != accesscontrol.GlobalOrgID { @@ -383,7 +383,7 @@ func (s *AccessControlStore) DeleteUserPermissions(ctx context.Context, orgID, u } func (s *AccessControlStore) DeleteTeamPermissions(ctx context.Context, orgID, teamID int64) error { - err := s.sql.WithDbSession(ctx, func(sess *db.Session) error { + err := s.sql.DB().WithDbSession(ctx, func(sess *db.Session) error { roleDeleteQuery := "DELETE FROM team_role WHERE team_id = ? AND org_id = ?" roleDeleteParams := []any{roleDeleteQuery, teamID, orgID} diff --git a/pkg/services/accesscontrol/database/database_test.go b/pkg/services/accesscontrol/database/database_test.go index 8f8b237cd35..39d25fa3029 100644 --- a/pkg/services/accesscontrol/database/database_test.go +++ b/pkg/services/accesscontrol/database/database_test.go @@ -470,8 +470,8 @@ func createUsersAndTeams(t *testing.T, store db.DB, svcs helperServices, orgID i return res } -func setupTestEnv(t testing.TB) (*database.AccessControlStore, rs.Store, user.Service, team.Service, org.Service, *sqlstore.SQLStore) { - sql, cfg := db.InitTestDBWithCfg(t) +func setupTestEnv(t testing.TB) (*database.AccessControlStore, rs.Store, user.Service, team.Service, org.Service, *sqlstore.ReplStore) { + sql, cfg := db.InitTestReplDBWithCfg(t) cfg.AutoAssignOrg = true cfg.AutoAssignOrgRole = "Viewer" cfg.AutoAssignOrgId = 1 diff --git a/pkg/services/accesscontrol/database/externalservices.go b/pkg/services/accesscontrol/database/externalservices.go index ea69ff79fc5..622db51684e 100644 --- a/pkg/services/accesscontrol/database/externalservices.go +++ b/pkg/services/accesscontrol/database/externalservices.go @@ -18,7 +18,7 @@ func extServiceRoleName(externalServiceID string) string { func (s *AccessControlStore) DeleteExternalServiceRole(ctx context.Context, externalServiceID string) error { uid := accesscontrol.PrefixedRoleUID(extServiceRoleName(externalServiceID)) - return s.sql.WithDbSession(ctx, func(sess *db.Session) error { + return s.sql.DB().WithDbSession(ctx, func(sess *db.Session) error { stored, errGet := getRoleByUID(ctx, sess, uid) if errGet != nil { // Role not found, nothing to do @@ -55,7 +55,7 @@ func (s *AccessControlStore) SaveExternalServiceRole(ctx context.Context, cmd ac role := genExternalServiceRole(cmd) assignment := genExternalServiceAssignment(cmd) - return s.sql.WithDbSession(ctx, func(sess *db.Session) error { + return s.sql.DB().WithDbSession(ctx, func(sess *db.Session) error { // Create or update the role existingRole, errSaveRole := s.saveRole(ctx, sess, &role) if errSaveRole != nil { diff --git a/pkg/services/accesscontrol/database/externalservices_test.go b/pkg/services/accesscontrol/database/externalservices_test.go index 21bbbda171a..0df0860d21b 100644 --- a/pkg/services/accesscontrol/database/externalservices_test.go +++ b/pkg/services/accesscontrol/database/externalservices_test.go @@ -7,9 +7,10 @@ import ( "errors" "testing" + "github.com/stretchr/testify/require" + "github.com/grafana/grafana/pkg/infra/db" "github.com/grafana/grafana/pkg/services/accesscontrol" - "github.com/stretchr/testify/require" ) func TestAccessControlStore_SaveExternalServiceRole(t *testing.T) { @@ -114,7 +115,7 @@ func TestAccessControlStore_SaveExternalServiceRole(t *testing.T) { t.Run(tt.name, func(t *testing.T) { ctx := context.Background() s := &AccessControlStore{ - sql: db.InitTestDB(t), + sql: db.InitTestReplDB(t), } for i := range tt.runs { @@ -125,7 +126,7 @@ func TestAccessControlStore_SaveExternalServiceRole(t *testing.T) { } require.NoError(t, err) - errDBSession := s.sql.WithDbSession(ctx, func(sess *db.Session) error { + errDBSession := s.sql.DB().WithDbSession(ctx, func(sess *db.Session) error { storedRole, err := getRoleByUID(ctx, sess, accesscontrol.PrefixedRoleUID(extServiceRoleName(tt.runs[i].cmd.ExternalServiceID))) require.NoError(t, err) require.NotNil(t, storedRole) @@ -187,13 +188,13 @@ func TestAccessControlStore_DeleteExternalServiceRole(t *testing.T) { t.Run(tt.name, func(t *testing.T) { ctx := context.Background() s := &AccessControlStore{ - sql: db.InitTestDB(t), + sql: db.InitTestReplDB(t), } if tt.init != nil { tt.init(t, ctx, s) } roleID := int64(-1) - err := s.sql.WithDbSession(ctx, func(sess *db.Session) error { + err := s.sql.DB().WithDbSession(ctx, func(sess *db.Session) error { role, err := getRoleByUID(ctx, sess, accesscontrol.PrefixedRoleUID(extServiceRoleName(tt.id))) if err != nil && !errors.Is(err, accesscontrol.ErrRoleNotFound) { return err @@ -217,7 +218,7 @@ func TestAccessControlStore_DeleteExternalServiceRole(t *testing.T) { } // Assignments should be deleted - _ = s.sql.WithDbSession(ctx, func(sess *db.Session) error { + _ = s.sql.DB().WithDbSession(ctx, func(sess *db.Session) error { var assignment accesscontrol.UserRole count, err := sess.Where("role_id = ?", roleID).Count(&assignment) require.NoError(t, err) @@ -226,7 +227,7 @@ func TestAccessControlStore_DeleteExternalServiceRole(t *testing.T) { }) // Permissions should be deleted - _ = s.sql.WithDbSession(ctx, func(sess *db.Session) error { + _ = s.sql.DB().WithDbSession(ctx, func(sess *db.Session) error { var permission accesscontrol.Permission count, err := sess.Where("role_id = ?", roleID).Count(&permission) require.NoError(t, err) @@ -235,7 +236,7 @@ func TestAccessControlStore_DeleteExternalServiceRole(t *testing.T) { }) // Role should be deleted - _ = s.sql.WithDbSession(ctx, func(sess *db.Session) error { + _ = s.sql.DB().WithDbSession(ctx, func(sess *db.Session) error { storedRole, err := getRoleByUID(ctx, sess, accesscontrol.PrefixedRoleUID(extServiceRoleName(tt.id))) require.ErrorIs(t, err, accesscontrol.ErrRoleNotFound) require.Nil(t, storedRole) diff --git a/pkg/services/accesscontrol/migrator/migrator.go b/pkg/services/accesscontrol/migrator/migrator.go index 8ecc20ad553..33317b4b5be 100644 --- a/pkg/services/accesscontrol/migrator/migrator.go +++ b/pkg/services/accesscontrol/migrator/migrator.go @@ -19,14 +19,14 @@ const ( maxLen = 40 ) -func MigrateScopeSplit(db db.DB, log log.Logger) error { +func MigrateScopeSplit(db db.ReplDB, log log.Logger) error { t := time.Now() ctx := context.Background() cnt := 0 // Search for the permissions to update var permissions []ac.Permission - if errFind := db.WithTransactionalDbSession(ctx, func(sess *sqlstore.DBSession) error { + if errFind := db.DB().WithTransactionalDbSession(ctx, func(sess *sqlstore.DBSession) error { return sess.SQL("SELECT * FROM permission WHERE NOT scope = '' AND identifier = ''").Find(&permissions) }); errFind != nil { log.Error("Could not search for permissions to update", "migration", "scopeSplit", "error", errFind) @@ -76,7 +76,7 @@ func MigrateScopeSplit(db db.DB, log log.Logger) error { delQuery = delQuery[:len(delQuery)-1] + ")" // Batch update the permissions - if errBatchUpdate := db.GetSqlxSession().WithTransaction(ctx, func(tx *session.SessionTx) error { + if errBatchUpdate := db.DB().GetSqlxSession().WithTransaction(ctx, func(tx *session.SessionTx) error { if _, errDel := tx.Exec(ctx, delQuery, delArgs...); errDel != nil { log.Error("Error deleting permissions", "migration", "scopeSplit", "error", errDel) return errDel diff --git a/pkg/services/accesscontrol/migrator/migrator_bench_test.go b/pkg/services/accesscontrol/migrator/migrator_bench_test.go index 5257a2fb2dd..7246b0812e3 100644 --- a/pkg/services/accesscontrol/migrator/migrator_bench_test.go +++ b/pkg/services/accesscontrol/migrator/migrator_bench_test.go @@ -10,7 +10,7 @@ import ( ) func benchScopeSplitConcurrent(b *testing.B, count int) { - store := db.InitTestDB(b) + store := db.InitTestReplDB(b) // Populate permissions require.NoError(b, batchInsertPermissions(count, store), "could not insert permissions") logger := log.New("migrator.test") diff --git a/pkg/services/accesscontrol/migrator/migrator_test.go b/pkg/services/accesscontrol/migrator/migrator_test.go index bfa35f17efc..0c98aa24a26 100644 --- a/pkg/services/accesscontrol/migrator/migrator_test.go +++ b/pkg/services/accesscontrol/migrator/migrator_test.go @@ -46,7 +46,7 @@ func batchInsertPermissions(cnt int, sqlStore db.DB) error { // TestIntegrationMigrateScopeSplit tests the scope split migration // also tests the scope split truncation logic func TestIntegrationMigrateScopeSplitTruncation(t *testing.T) { - sqlStore := db.InitTestDB(t) + sqlStore := db.InitTestReplDB(t) logger := log.New("accesscontrol.migrator.test") batchSize = 20 diff --git a/pkg/services/sqlstore/replstore.go b/pkg/services/sqlstore/replstore.go index 050341a3606..5cc1bda9ccf 100644 --- a/pkg/services/sqlstore/replstore.go +++ b/pkg/services/sqlstore/replstore.go @@ -192,3 +192,15 @@ func InitTestReplDB(t sqlutil.ITestDB, opts ...InitTestDBOpt) (*ReplStore, *sett } return &ReplStore{ss, ss}, cfg } + +// InitTestReplDBWithMigration initializes the test DB given custom migrations. +func InitTestReplDBWithMigration(t sqlutil.ITestDB, migration registry.DatabaseMigrator, opts ...InitTestDBOpt) *ReplStore { + t.Helper() + features := getFeaturesForTesting(opts...) + cfg := getCfgForTesting(opts...) + ss, err := initTestDB(t, cfg, features, migration, opts...) + if err != nil { + t.Fatalf("failed to initialize sql store: %s", err) + } + return &ReplStore{ss, ss} +}