diff --git a/pkg/services/org/orgtest/fake.go b/pkg/services/org/orgtest/fake.go index 6c0c08e579e..5d3abd82e94 100644 --- a/pkg/services/org/orgtest/fake.go +++ b/pkg/services/org/orgtest/fake.go @@ -21,6 +21,7 @@ type FakeOrgService struct { ExpectedSearchOrgUsersResult *org.SearchOrgUsersQueryResult ExpectedOrgListResponse OrgListResponse SearchOrgUsersFn func(context.Context, *org.SearchOrgUsersQuery) (*org.SearchOrgUsersQueryResult, error) + InsertOrgUserFn func(context.Context, *org.OrgUser) (int64, error) } func NewOrgServiceFake() *FakeOrgService { @@ -36,6 +37,9 @@ func (f *FakeOrgService) Insert(ctx context.Context, cmd *org.OrgUser) (int64, e } func (f *FakeOrgService) InsertOrgUser(ctx context.Context, cmd *org.OrgUser) (int64, error) { + if f.InsertOrgUserFn != nil { + return f.InsertOrgUserFn(ctx, cmd) + } return f.ExpectedOrgUserID, f.ExpectedError } diff --git a/pkg/services/user/userimpl/user.go b/pkg/services/user/userimpl/user.go index 9a19eca19e3..b0642e08ad7 100644 --- a/pkg/services/user/userimpl/user.go +++ b/pkg/services/user/userimpl/user.go @@ -32,6 +32,7 @@ type Service struct { cacheService *localcache.CacheService cfg *setting.Cfg tracer tracing.Tracer + db db.DB } func ProvideService( @@ -50,6 +51,7 @@ func ProvideService( teamService: teamService, cacheService: cacheService, tracer: tracer, + db: db, } defaultLimits, err := readQuotaConfig(cfg) @@ -172,35 +174,35 @@ func (s *Service) Create(ctx context.Context, cmd *user.CreateUserCommand) (*use } } - _, err = s.store.Insert(ctx, usr) - if err != nil { - return nil, err - } - - // create org user link - if !cmd.SkipOrgSetup && !usr.IsProvisioned { - orgUser := org.OrgUser{ - OrgID: orgID, - UserID: usr.ID, - Role: org.RoleAdmin, - Created: time.Now(), - Updated: time.Now(), - } - - if s.cfg.AutoAssignOrg && !usr.IsAdmin { - if len(cmd.DefaultOrgRole) > 0 { - orgUser.Role = org.RoleType(cmd.DefaultOrgRole) - } else { - orgUser.Role = org.RoleType(s.cfg.AutoAssignOrgRole) - } - } - _, err = s.orgService.InsertOrgUser(ctx, &orgUser) + err = s.db.InTransaction(ctx, func(ctx context.Context) error { + _, err = s.store.Insert(ctx, usr) if err != nil { - err := s.store.Delete(ctx, usr.ID) - return usr, err + return err } - } - return usr, nil + + // create org user link + if !cmd.SkipOrgSetup && !usr.IsProvisioned { + orgUser := org.OrgUser{ + OrgID: orgID, + UserID: usr.ID, + Role: org.RoleAdmin, + Created: time.Now(), + Updated: time.Now(), + } + + if s.cfg.AutoAssignOrg && !usr.IsAdmin { + if len(cmd.DefaultOrgRole) > 0 { + orgUser.Role = org.RoleType(cmd.DefaultOrgRole) + } else { + orgUser.Role = org.RoleType(s.cfg.AutoAssignOrgRole) + } + } + _, err = s.orgService.InsertOrgUser(ctx, &orgUser) + return err + } + return nil + }) + return usr, err } func (s *Service) Delete(ctx context.Context, cmd *user.DeleteUserCommand) error { diff --git a/pkg/services/user/userimpl/user_test.go b/pkg/services/user/userimpl/user_test.go index c71431718de..536555fe5ff 100644 --- a/pkg/services/user/userimpl/user_test.go +++ b/pkg/services/user/userimpl/user_test.go @@ -9,7 +9,9 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/grafana/grafana/pkg/infra/db" "github.com/grafana/grafana/pkg/infra/localcache" + "github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/infra/tracing" "github.com/grafana/grafana/pkg/services/org" "github.com/grafana/grafana/pkg/services/org/orgtest" @@ -27,6 +29,7 @@ func TestUserService(t *testing.T) { cacheService: localcache.ProvideService(), teamService: &teamtest.FakeService{}, tracer: tracing.InitializeTracerForTest(), + db: db.InitTestDB(t), } userService.cfg = setting.NewCfg() @@ -266,6 +269,46 @@ func TestMetrics(t *testing.T) { }) } +func TestIntegrationCreateUser(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test") + } + + cfg := setting.NewCfg() + ss := db.InitTestDB(t) + userStore := &sqlStore{ + db: ss, + dialect: ss.GetDialect(), + logger: log.NewNopLogger(), + cfg: cfg, + } + + t.Run("create user should roll back created user if OrgUser cannot be created", func(t *testing.T) { + userService := Service{ + store: userStore, + orgService: &orgtest.FakeOrgService{InsertOrgUserFn: func(ctx context.Context, orgUser *org.OrgUser) (int64, error) { + return 0, errors.New("some error") + }}, + cacheService: localcache.ProvideService(), + teamService: &teamtest.FakeService{}, + tracer: tracing.InitializeTracerForTest(), + cfg: setting.NewCfg(), + db: ss, + } + _, err := userService.Create(context.Background(), &user.CreateUserCommand{ + Email: "email", + Login: "login", + Name: "name", + }) + require.Error(t, err) + + usr, err := userService.GetByLogin(context.Background(), &user.GetUserByLoginQuery{LoginOrEmail: "login"}) + require.Nil(t, usr) + require.Error(t, err) + require.ErrorIs(t, err, user.ErrUserNotFound) + }) +} + type FakeUserStore struct { ExpectedUser *user.User ExpectedSignedInUser *user.SignedInUser