diff --git a/pkg/api/admin_users.go b/pkg/api/admin_users.go index fd77a184804..120bb17f416 100644 --- a/pkg/api/admin_users.go +++ b/pkg/api/admin_users.go @@ -6,10 +6,13 @@ import ( "net/http" "strconv" + "golang.org/x/sync/errgroup" + "github.com/grafana/grafana/pkg/api/dtos" "github.com/grafana/grafana/pkg/api/response" "github.com/grafana/grafana/pkg/infra/metrics" "github.com/grafana/grafana/pkg/models" + "github.com/grafana/grafana/pkg/services/accesscontrol" "github.com/grafana/grafana/pkg/services/user" "github.com/grafana/grafana/pkg/util" "github.com/grafana/grafana/pkg/web" @@ -156,7 +159,7 @@ func (hs *HTTPServer) AdminUpdateUserPermissions(c *models.ReqContext) response. return response.Error(http.StatusBadRequest, "id is invalid", err) } - err = hs.SQLStore.UpdateUserPermissions(userID, form.IsGrafanaAdmin) + err = hs.userService.UpdatePermissions(userID, form.IsGrafanaAdmin) if err != nil { if errors.Is(err, user.ErrLastGrafanaAdmin) { return response.Error(400, user.ErrLastGrafanaAdmin.Error(), nil) @@ -198,6 +201,65 @@ func (hs *HTTPServer) AdminDeleteUser(c *models.ReqContext) response.Response { return response.Error(500, "Failed to delete user", err) } + g, ctx := errgroup.WithContext(c.Req.Context()) + g.Go(func() error { + if err := hs.starService.DeleteByUser(ctx, cmd.UserID); err != nil { + return err + } + return nil + }) + g.Go(func() error { + if err := hs.orgService.DeleteUserFromAll(ctx, cmd.UserID); err != nil { + return err + } + return nil + }) + g.Go(func() error { + if err := hs.DashboardService.DeleteACLByUser(ctx, cmd.UserID); err != nil { + return err + } + return nil + }) + g.Go(func() error { + if err := hs.preferenceService.DeleteByUser(ctx, cmd.UserID); err != nil { + return err + } + return nil + }) + g.Go(func() error { + if err := hs.teamGuardian.DeleteByUser(ctx, cmd.UserID); err != nil { + return err + } + return nil + }) + g.Go(func() error { + if err := hs.userAuthService.Delete(ctx, cmd.UserID); err != nil { + return err + } + return nil + }) + g.Go(func() error { + if err := hs.userAuthService.DeleteToken(ctx, cmd.UserID); err != nil { + return err + } + return nil + }) + g.Go(func() error { + if err := hs.QuotaService.DeleteByUser(ctx, cmd.UserID); err != nil { + return err + } + return nil + }) + g.Go(func() error { + if err := hs.accesscontrolService.DeleteUserPermissions(ctx, accesscontrol.GlobalOrgID, cmd.UserID); err != nil { + return err + } + return nil + }) + if err := g.Wait(); err != nil { + return response.Error(500, "Failed to delete user", err) + } + return response.Success("User deleted") } diff --git a/pkg/api/admin_users_test.go b/pkg/api/admin_users_test.go index c1e394d72e8..41f79831f74 100644 --- a/pkg/api/admin_users_test.go +++ b/pkg/api/admin_users_test.go @@ -37,14 +37,12 @@ func TestAdminAPIEndpoint(t *testing.T) { updateCmd := dtos.AdminUpdateUserPermissionsForm{ IsGrafanaAdmin: false, } - mock := &mockstore.SQLStoreMock{ - ExpectedError: user.ErrLastGrafanaAdmin, - } + userService := usertest.FakeUserService{ExpectedError: user.ErrLastGrafanaAdmin} putAdminScenario(t, "When calling PUT on", "/api/admin/users/1/permissions", "/api/admin/users/:id/permissions", role, updateCmd, func(sc *scenarioContext) { sc.fakeReqWithParams("PUT", sc.url, map[string]string{}).exec() assert.Equal(t, 400, sc.resp.Code) - }, mock) + }, nil, &userService) }) t.Run("When a server admin attempts to logout himself from all devices", func(t *testing.T) { @@ -235,12 +233,13 @@ func TestAdminAPIEndpoint(t *testing.T) { } func putAdminScenario(t *testing.T, desc string, url string, routePattern string, role org.RoleType, - cmd dtos.AdminUpdateUserPermissionsForm, fn scenarioFunc, sqlStore sqlstore.Store) { + cmd dtos.AdminUpdateUserPermissionsForm, fn scenarioFunc, sqlStore sqlstore.Store, userSvc user.Service) { t.Run(fmt.Sprintf("%s %s", desc, url), func(t *testing.T) { hs := &HTTPServer{ Cfg: setting.NewCfg(), SQLStore: sqlStore, authInfoService: &logintest.AuthInfoServiceFake{}, + userService: userSvc, } sc := setupScenarioContext(t, url) diff --git a/pkg/api/common_test.go b/pkg/api/common_test.go index 36e982e0a5b..7d96dd721fc 100644 --- a/pkg/api/common_test.go +++ b/pkg/api/common_test.go @@ -56,6 +56,7 @@ import ( "github.com/grafana/grafana/pkg/services/team/teamimpl" "github.com/grafana/grafana/pkg/services/team/teamtest" "github.com/grafana/grafana/pkg/services/user" + "github.com/grafana/grafana/pkg/services/user/userimpl" "github.com/grafana/grafana/pkg/services/user/usertest" "github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/web" @@ -209,7 +210,7 @@ func getContextHandler(t *testing.T, cfg *setting.Cfg) *contexthandler.ContextHa renderSvc := &fakeRenderService{} authJWTSvc := models.NewFakeJWTService() tracer := tracing.InitializeTracerForTest() - authProxy := authproxy.ProvideAuthProxy(cfg, remoteCacheSvc, loginservice.LoginServiceMock{}, sqlStore) + authProxy := authproxy.ProvideAuthProxy(cfg, remoteCacheSvc, loginservice.LoginServiceMock{}, &usertest.FakeUserService{}, sqlStore) loginService := &logintest.LoginServiceFake{} authenticator := &logintest.AuthenticatorFake{} ctxHdlr := contexthandler.ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, renderSvc, sqlStore, tracer, authProxy, loginService, nil, authenticator, usertest.NewUserServiceFake()) @@ -292,8 +293,6 @@ type accessControlScenarioContext struct { // acmock is an accesscontrol mock used to fake users rights. acmock *accesscontrolmock.Mock - usermock *usertest.FakeUserService - // db is a test database initialized with InitTestDB db *sqlstore.SQLStore @@ -302,6 +301,7 @@ type accessControlScenarioContext struct { dashboardsStore dashboards.Store teamService team.Service + userService user.Service } func setAccessControlPermissions(acmock *accesscontrolmock.Mock, perms []accesscontrol.Permission, org int64) { @@ -378,6 +378,8 @@ func setupHTTPServerWithCfgDb( var ac accesscontrol.AccessControl var acService accesscontrol.Service + var userSvc user.Service + // Defining the accesscontrol service has to be done before registering routes if useFakeAccessControl { acmock = accesscontrolmock.New() @@ -386,19 +388,18 @@ func setupHTTPServerWithCfgDb( } ac = acmock acService = acmock + userSvc = &usertest.FakeUserService{} } else { var err error acService, err = acimpl.ProvideService(cfg, db, routeRegister, localcache.ProvideService()) require.NoError(t, err) ac = acimpl.ProvideAccessControl(cfg) + userSvc = userimpl.ProvideService(db, nil, cfg, db) } - - teamPermissionService, err := ossaccesscontrol.ProvideTeamPermissions(cfg, routeRegister, db, ac, license, acService, teamService) + teamPermissionService, err := ossaccesscontrol.ProvideTeamPermissions(cfg, routeRegister, db, ac, license, acService, teamService, userSvc) require.NoError(t, err) // Create minimal HTTP Server - userMock := usertest.NewUserServiceFake() - userMock.ExpectedUser = &user.User{ID: 1} hs := &HTTPServer{ Cfg: cfg, Features: features, @@ -416,7 +417,7 @@ func setupHTTPServerWithCfgDb( accesscontrolmock.NewMockedPermissionsService(), accesscontrolmock.NewMockedPermissionsService(), ac, ), preferenceService: preftest.NewPreferenceServiceFake(), - userService: userMock, + userService: userSvc, orgService: orgtest.NewOrgServiceFake(), teamService: teamService, annotationsRepo: annotationstest.NewFakeAnnotationsRepo(), @@ -455,7 +456,7 @@ func setupHTTPServerWithCfgDb( cfg: cfg, dashboardsStore: dashboardsStore, teamService: teamService, - usermock: userMock, + userService: userSvc, } } diff --git a/pkg/api/http_server.go b/pkg/api/http_server.go index 3d515a2df56..cc7d81452ba 100644 --- a/pkg/api/http_server.go +++ b/pkg/api/http_server.go @@ -16,6 +16,7 @@ import ( "github.com/grafana/grafana/pkg/bus" "github.com/grafana/grafana/pkg/middleware/csrf" "github.com/grafana/grafana/pkg/services/searchV2" + "github.com/grafana/grafana/pkg/services/userauth" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" @@ -199,6 +200,7 @@ type HTTPServer struct { accesscontrolService accesscontrol.Service annotationsRepo annotations.Repository tagService tag.Service + userAuthService userauth.Service } type ServerOptions struct { @@ -240,6 +242,7 @@ func ProvideHTTPServer(opts ServerOptions, cfg *setting.Cfg, routeRegister routi loginAttemptService loginAttempt.Service, orgService org.Service, teamService team.Service, accesscontrolService accesscontrol.Service, dashboardThumbsService dashboardThumbs.Service, navTreeService navtree.Service, annotationRepo annotations.Repository, tagService tag.Service, searchv2HTTPService searchV2.SearchHTTPService, + userAuthService userauth.Service, ) (*HTTPServer, error) { web.Env = cfg.Env m := web.New() @@ -340,6 +343,7 @@ func ProvideHTTPServer(opts ServerOptions, cfg *setting.Cfg, routeRegister routi accesscontrolService: accesscontrolService, annotationsRepo: annotationRepo, tagService: tagService, + userAuthService: userAuthService, } if hs.Listener != nil { hs.log.Debug("Using provided listener") @@ -592,7 +596,7 @@ func (hs *HTTPServer) addMiddlewaresAndStaticRoutes() { m.Use(hs.frontendLogEndpoints()) m.UseMiddleware(hs.ContextHandler.Middleware) - m.Use(middleware.OrgRedirect(hs.Cfg, hs.SQLStore)) + m.Use(middleware.OrgRedirect(hs.Cfg, hs.userService)) m.Use(accesscontrol.LoadPermissionsMiddleware(hs.accesscontrolService)) // needs to be after context handler diff --git a/pkg/api/org_invite.go b/pkg/api/org_invite.go index 729ffc2e0ad..53ff2490e8d 100644 --- a/pkg/api/org_invite.go +++ b/pkg/api/org_invite.go @@ -302,7 +302,7 @@ func (hs *HTTPServer) applyUserInvite(ctx context.Context, usr *user.User, invit if setActive { // set org to active - if err := hs.SQLStore.SetUsingOrg(ctx, &models.SetUsingOrgCommand{OrgId: invite.OrgId, UserId: usr.ID}); err != nil { + if err := hs.userService.SetUsingOrg(ctx, &user.SetUsingOrgCommand{OrgID: invite.OrgId, UserID: usr.ID}); err != nil { return false, response.Error(500, "Failed to set org as active", err) } } diff --git a/pkg/api/org_test.go b/pkg/api/org_test.go index dfc2aeadc91..028724686f1 100644 --- a/pkg/api/org_test.go +++ b/pkg/api/org_test.go @@ -12,6 +12,7 @@ import ( "github.com/grafana/grafana/pkg/services/accesscontrol" "github.com/grafana/grafana/pkg/services/org/orgimpl" "github.com/grafana/grafana/pkg/services/sqlstore" + "github.com/grafana/grafana/pkg/services/user/usertest" "github.com/grafana/grafana/pkg/setting" ) @@ -198,7 +199,7 @@ func setupOrgsDBForAccessControlTests(t *testing.T, db sqlstore.Store, c accessC setInitCtxSignedInViewer(c.initCtx) u := *c.initCtx.SignedInUser u.OrgID = orgID - c.usermock.ExpectedSignedInUser = &u + c.userService.(*usertest.FakeUserService).ExpectedSignedInUser = &u // Create `orgsCount` orgs for i := 1; i <= int(orgID); i++ { diff --git a/pkg/api/org_users_test.go b/pkg/api/org_users_test.go index b948ebd6f55..74820816843 100644 --- a/pkg/api/org_users_test.go +++ b/pkg/api/org_users_test.go @@ -388,8 +388,7 @@ func TestGetOrgUsersAPIEndpoint_AccessControlMetadata(t *testing.T) { cfg.RBACEnabled = tc.enableAccessControl sc := setupHTTPServerWithCfg(t, false, cfg, func(hs *HTTPServer) { hs.userService = userimpl.ProvideService( - hs.SQLStore, nil, nil, nil, nil, - nil, nil, nil, nil, nil, hs.SQLStore.(*sqlstore.SQLStore), + hs.SQLStore, nil, cfg, hs.SQLStore.(*sqlstore.SQLStore), ) hs.orgService = orgimpl.ProvideService(hs.SQLStore, cfg) }) @@ -493,8 +492,7 @@ func TestGetOrgUsersAPIEndpoint_AccessControl(t *testing.T) { cfg.RBACEnabled = tc.enableAccessControl sc := setupHTTPServerWithCfg(t, false, cfg, func(hs *HTTPServer) { hs.userService = userimpl.ProvideService( - hs.SQLStore, nil, nil, nil, nil, - nil, nil, nil, nil, nil, hs.SQLStore.(*sqlstore.SQLStore), + hs.SQLStore, nil, cfg, hs.SQLStore.(*sqlstore.SQLStore), ) hs.orgService = orgimpl.ProvideService(hs.SQLStore, cfg) }) @@ -599,10 +597,10 @@ func TestPostOrgUsersAPIEndpoint_AccessControl(t *testing.T) { cfg.RBACEnabled = tc.enableAccessControl sc := setupHTTPServerWithCfg(t, false, cfg, func(hs *HTTPServer) { hs.userService = userimpl.ProvideService( - hs.SQLStore, nil, nil, nil, nil, - nil, nil, nil, nil, nil, hs.SQLStore.(*sqlstore.SQLStore), + hs.SQLStore, nil, cfg, hs.SQLStore.(*sqlstore.SQLStore), ) }) + setupOrgUsersDBForAccessControlTests(t, sc.db) setInitCtxSignedInUser(sc.initCtx, tc.user) @@ -718,8 +716,7 @@ func TestOrgUsersAPIEndpointWithSetPerms_AccessControl(t *testing.T) { sc := setupHTTPServer(t, true, func(hs *HTTPServer) { hs.tempUserService = tempuserimpl.ProvideService(hs.SQLStore) hs.userService = userimpl.ProvideService( - hs.SQLStore, nil, nil, nil, nil, - nil, nil, nil, nil, nil, hs.SQLStore.(*sqlstore.SQLStore), + hs.SQLStore, nil, nil, hs.SQLStore.(*sqlstore.SQLStore), ) }) setInitCtxSignedInViewer(sc.initCtx) @@ -837,8 +834,7 @@ func TestPatchOrgUsersAPIEndpoint_AccessControl(t *testing.T) { cfg.RBACEnabled = tc.enableAccessControl sc := setupHTTPServerWithCfg(t, false, cfg, func(hs *HTTPServer) { hs.userService = userimpl.ProvideService( - hs.SQLStore, nil, nil, nil, nil, - nil, nil, nil, nil, nil, hs.SQLStore.(*sqlstore.SQLStore), + hs.SQLStore, nil, cfg, hs.SQLStore.(*sqlstore.SQLStore), ) hs.orgService = orgimpl.ProvideService(hs.SQLStore, cfg) }) @@ -858,13 +854,13 @@ func TestPatchOrgUsersAPIEndpoint_AccessControl(t *testing.T) { require.NoError(t, err) assert.Equal(t, tc.expectedMessage, message) - getUserQuery := models.GetSignedInUserQuery{ - UserId: tc.targetUserId, - OrgId: tc.targetOrg, + getUserQuery := user.GetSignedInUserQuery{ + UserID: tc.targetUserId, + OrgID: tc.targetOrg, } - err = sc.db.GetSignedInUser(context.Background(), &getUserQuery) + usr, err := sc.userService.GetSignedInUser(context.Background(), &getUserQuery) require.NoError(t, err) - assert.Equal(t, tc.expectedUserRole, getUserQuery.Result.OrgRole) + assert.Equal(t, tc.expectedUserRole, usr.OrgRole) } }) } @@ -965,8 +961,7 @@ func TestDeleteOrgUsersAPIEndpoint_AccessControl(t *testing.T) { cfg.RBACEnabled = tc.enableAccessControl sc := setupHTTPServerWithCfg(t, false, cfg, func(hs *HTTPServer) { hs.userService = userimpl.ProvideService( - hs.SQLStore, nil, nil, nil, nil, - nil, nil, nil, nil, nil, hs.SQLStore.(*sqlstore.SQLStore), + hs.SQLStore, nil, cfg, hs.SQLStore.(*sqlstore.SQLStore), ) hs.orgService = orgimpl.ProvideService(hs.SQLStore, cfg) }) diff --git a/pkg/api/user.go b/pkg/api/user.go index 8e870eb0d57..09e05ee0e13 100644 --- a/pkg/api/user.go +++ b/pkg/api/user.go @@ -49,9 +49,10 @@ func (hs *HTTPServer) GetUserByID(c *models.ReqContext) response.Response { } func (hs *HTTPServer) getUserUserProfile(c *models.ReqContext, userID int64) response.Response { - query := models.GetUserProfileQuery{UserId: userID} + query := user.GetUserProfileQuery{UserID: userID} - if err := hs.SQLStore.GetUserProfile(c.Req.Context(), &query); err != nil { + userProfile, err := hs.userService.GetUserProfile(c.Req.Context(), &query) + if err != nil { if errors.Is(err, user.ErrUserNotFound) { return response.Error(404, user.ErrUserNotFound.Error(), nil) } @@ -59,17 +60,17 @@ func (hs *HTTPServer) getUserUserProfile(c *models.ReqContext, userID int64) res } getAuthQuery := models.GetAuthInfoQuery{UserId: userID} - query.Result.AuthLabels = []string{} + userProfile.AuthLabels = []string{} if err := hs.authInfoService.GetAuthInfo(c.Req.Context(), &getAuthQuery); err == nil { authLabel := login.GetAuthProviderLabel(getAuthQuery.Result.AuthModule) - query.Result.AuthLabels = append(query.Result.AuthLabels, authLabel) - query.Result.IsExternal = true + userProfile.AuthLabels = append(userProfile.AuthLabels, authLabel) + userProfile.IsExternal = true } - query.Result.AccessControl = hs.getAccessControlMetadata(c, c.OrgID, "global.users:id:", strconv.FormatInt(userID, 10)) - query.Result.AvatarUrl = dtos.GetGravatarUrl(query.Result.Email) + userProfile.AccessControl = hs.getAccessControlMetadata(c, c.OrgID, "global.users:id:", strconv.FormatInt(userID, 10)) + userProfile.AvatarUrl = dtos.GetGravatarUrl(userProfile.Email) - return response.JSON(http.StatusOK, query.Result) + return response.JSON(http.StatusOK, userProfile) } // swagger:route GET /users/lookup users getUserByLoginOrEmail @@ -171,9 +172,9 @@ func (hs *HTTPServer) UpdateUserActiveOrg(c *models.ReqContext) response.Respons return response.Error(401, "Not a valid organization", nil) } - cmd := models.SetUsingOrgCommand{UserId: userID, OrgId: orgID} + cmd := user.SetUsingOrgCommand{UserID: userID, OrgID: orgID} - if err := hs.SQLStore.SetUsingOrg(c.Req.Context(), &cmd); err != nil { + if err := hs.userService.SetUsingOrg(c.Req.Context(), &cmd); err != nil { return response.Error(500, "Failed to change active organization", err) } @@ -334,9 +335,9 @@ func (hs *HTTPServer) UserSetUsingOrg(c *models.ReqContext) response.Response { return response.Error(401, "Not a valid organization", nil) } - cmd := models.SetUsingOrgCommand{UserId: c.UserID, OrgId: orgID} + cmd := user.SetUsingOrgCommand{UserID: c.UserID, OrgID: orgID} - if err := hs.SQLStore.SetUsingOrg(c.Req.Context(), &cmd); err != nil { + if err := hs.userService.SetUsingOrg(c.Req.Context(), &cmd); err != nil { return response.Error(500, "Failed to change active organization", err) } @@ -355,9 +356,9 @@ func (hs *HTTPServer) ChangeActiveOrgAndRedirectToHome(c *models.ReqContext) { hs.NotFoundHandler(c) } - cmd := models.SetUsingOrgCommand{UserId: c.UserID, OrgId: orgID} + cmd := user.SetUsingOrgCommand{UserID: c.UserID, OrgID: orgID} - if err := hs.SQLStore.SetUsingOrg(c.Req.Context(), &cmd); err != nil { + if err := hs.userService.SetUsingOrg(c.Req.Context(), &cmd); err != nil { hs.NotFoundHandler(c) } @@ -449,12 +450,12 @@ func (hs *HTTPServer) SetHelpFlag(c *models.ReqContext) response.Response { bitmask := &c.HelpFlags1 bitmask.AddFlag(user.HelpFlags1(flag)) - cmd := models.SetUserHelpFlagCommand{ - UserId: c.UserID, + cmd := user.SetUserHelpFlagCommand{ + UserID: c.UserID, HelpFlags1: *bitmask, } - if err := hs.SQLStore.SetUserHelpFlag(c.Req.Context(), &cmd); err != nil { + if err := hs.userService.SetUserHelpFlag(c.Req.Context(), &cmd); err != nil { return response.Error(500, "Failed to update help flag", err) } @@ -471,12 +472,12 @@ func (hs *HTTPServer) SetHelpFlag(c *models.ReqContext) response.Response { // 403: forbiddenError // 500: internalServerError func (hs *HTTPServer) ClearHelpFlags(c *models.ReqContext) response.Response { - cmd := models.SetUserHelpFlagCommand{ - UserId: c.UserID, + cmd := user.SetUserHelpFlagCommand{ + UserID: c.UserID, HelpFlags1: user.HelpFlags1(0), } - if err := hs.SQLStore.SetUserHelpFlag(c.Req.Context(), &cmd); err != nil { + if err := hs.userService.SetUserHelpFlag(c.Req.Context(), &cmd); err != nil { return response.Error(500, "Failed to update help flag", err) } diff --git a/pkg/api/user_test.go b/pkg/api/user_test.go index d19cd4841be..2bf71390b10 100644 --- a/pkg/api/user_test.go +++ b/pkg/api/user_test.go @@ -26,6 +26,7 @@ import ( "github.com/grafana/grafana/pkg/services/sqlstore" "github.com/grafana/grafana/pkg/services/sqlstore/mockstore" "github.com/grafana/grafana/pkg/services/user" + "github.com/grafana/grafana/pkg/services/user/userimpl" "github.com/grafana/grafana/pkg/services/user/usertest" "github.com/grafana/grafana/pkg/setting" ) @@ -67,6 +68,7 @@ func TestUserAPIEndpoint_userLoggedIn(t *testing.T) { } user, err := sqlStore.CreateUser(context.Background(), createUserCmd) require.Nil(t, err) + hs.userService = userimpl.ProvideService(sqlStore, nil, sc.cfg, sqlStore) sc.handlerFunc = hs.GetUserByID diff --git a/pkg/middleware/middleware_test.go b/pkg/middleware/middleware_test.go index d264a65f80a..0a68fce668a 100644 --- a/pkg/middleware/middleware_test.go +++ b/pkg/middleware/middleware_test.go @@ -13,6 +13,9 @@ import ( "time" "github.com/grafana/grafana-plugin-sdk-go/backend/gtime" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/grafana/grafana/pkg/api/dtos" "github.com/grafana/grafana/pkg/infra/fs" "github.com/grafana/grafana/pkg/infra/log" @@ -36,8 +39,6 @@ import ( "github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/util" "github.com/grafana/grafana/pkg/web" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func fakeGetTime() func() time.Time { @@ -372,8 +373,7 @@ func TestMiddlewareContext(t *testing.T) { const group = "grafana-core-team" middlewareScenario(t, "Should not sync the user if it's in the cache", func(t *testing.T, sc *scenarioContext) { - sc.mockSQLStore.ExpectedSignedInUser = &user.SignedInUser{OrgID: orgID, UserID: userID} - + sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: orgID, UserID: userID} h, err := authproxy.HashCacheKey(hdrName + "-" + group) require.NoError(t, err) key := fmt.Sprintf(authproxy.CachePrefix, h) @@ -412,9 +412,8 @@ func TestMiddlewareContext(t *testing.T) { }) middlewareScenario(t, "Should create an user from a header", func(t *testing.T, sc *scenarioContext) { - sc.mockSQLStore.ExpectedSignedInUser = &user.SignedInUser{OrgID: orgID, UserID: userID} sc.loginService.ExpectedUser = &user.User{ID: userID} - + sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: orgID, UserID: userID} sc.fakeReq("GET", "/") sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName) sc.exec() @@ -432,7 +431,7 @@ func TestMiddlewareContext(t *testing.T) { var storedRoleInfo map[int64]org.RoleType = nil sc.loginService.ExpectedUserFunc = func(cmd *models.UpsertUserCommand) *user.User { storedRoleInfo = cmd.ExternalUser.OrgRoles - sc.mockSQLStore.ExpectedSignedInUser = &user.SignedInUser{OrgID: defaultOrgId, UserID: userID, OrgRole: storedRoleInfo[defaultOrgId]} + sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: defaultOrgId, UserID: userID, OrgRole: storedRoleInfo[defaultOrgId]} return &user.User{ID: userID} } @@ -455,7 +454,7 @@ func TestMiddlewareContext(t *testing.T) { var storedRoleInfo map[int64]org.RoleType = nil sc.loginService.ExpectedUserFunc = func(cmd *models.UpsertUserCommand) *user.User { storedRoleInfo = cmd.ExternalUser.OrgRoles - sc.mockSQLStore.ExpectedSignedInUser = &user.SignedInUser{OrgID: orgID, UserID: userID, OrgRole: storedRoleInfo[orgID]} + sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: orgID, UserID: userID, OrgRole: storedRoleInfo[orgID]} return &user.User{ID: userID} } @@ -479,7 +478,7 @@ func TestMiddlewareContext(t *testing.T) { middlewareScenario(t, "Should use organisation specified by targetOrgId parameter", func(t *testing.T, sc *scenarioContext) { var targetOrgID int64 = 123 - sc.mockSQLStore.ExpectedSignedInUser = &user.SignedInUser{OrgID: targetOrgID, UserID: userID} + sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: targetOrgID, UserID: userID} sc.loginService.ExpectedUser = &user.User{ID: userID} sc.fakeReq("GET", fmt.Sprintf("/?targetOrgId=%d", targetOrgID)) @@ -553,7 +552,7 @@ func TestMiddlewareContext(t *testing.T) { const userID int64 = 12 const orgID int64 = 2 - sc.mockSQLStore.ExpectedSignedInUser = &user.SignedInUser{OrgID: orgID, UserID: userID} + sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: orgID, UserID: userID} sc.loginService.ExpectedUser = &user.User{ID: userID} sc.fakeReq("GET", "/") @@ -569,7 +568,7 @@ func TestMiddlewareContext(t *testing.T) { }) middlewareScenario(t, "Should allow the request from whitelist IP", func(t *testing.T, sc *scenarioContext) { - sc.mockSQLStore.ExpectedSignedInUser = &user.SignedInUser{OrgID: orgID, UserID: userID} + sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: orgID, UserID: userID} sc.loginService.ExpectedUser = &user.User{ID: userID} sc.fakeReq("GET", "/") @@ -659,7 +658,7 @@ func middlewareScenario(t *testing.T, desc string, fn scenarioFunc, cbs ...func( sc.sqlStore = ctxHdlr.SQLStore sc.contextHandler = ctxHdlr sc.m.Use(ctxHdlr.Middleware) - sc.m.Use(OrgRedirect(sc.cfg, sc.mockSQLStore)) + sc.m.Use(OrgRedirect(sc.cfg, sc.userService)) sc.userAuthTokenService = ctxHdlr.AuthTokenService.(*auth.FakeUserAuthTokenService) sc.jwtAuthService = ctxHdlr.JWTAuthService.(*models.FakeJWTService) @@ -703,7 +702,7 @@ func getContextHandler(t *testing.T, cfg *setting.Cfg, mockSQLStore *mockstore.S renderSvc := &fakeRenderService{} authJWTSvc := models.NewFakeJWTService() tracer := tracing.InitializeTracerForTest() - authProxy := authproxy.ProvideAuthProxy(cfg, remoteCacheSvc, loginService, mockSQLStore) + authProxy := authproxy.ProvideAuthProxy(cfg, remoteCacheSvc, loginService, userService, mockSQLStore) authenticator := &logintest.AuthenticatorFake{ExpectedUser: &user.User{}} return contexthandler.ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, renderSvc, mockSQLStore, tracer, authProxy, loginService, apiKeyService, authenticator, userService) } diff --git a/pkg/middleware/org_redirect.go b/pkg/middleware/org_redirect.go index 798800bdfe1..1199f1852cc 100644 --- a/pkg/middleware/org_redirect.go +++ b/pkg/middleware/org_redirect.go @@ -6,16 +6,15 @@ import ( "strconv" "strings" - "github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/services/contexthandler" - "github.com/grafana/grafana/pkg/services/sqlstore" + "github.com/grafana/grafana/pkg/services/user" "github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/web" ) // OrgRedirect changes org and redirects users if the // querystring `orgId` doesn't match the active org. -func OrgRedirect(cfg *setting.Cfg, store sqlstore.Store) web.Handler { +func OrgRedirect(cfg *setting.Cfg, userSvc user.Service) web.Handler { return func(res http.ResponseWriter, req *http.Request, c *web.Context) { orgIdValue := req.URL.Query().Get("orgId") orgId, err := strconv.ParseInt(orgIdValue, 10, 64) @@ -33,8 +32,8 @@ func OrgRedirect(cfg *setting.Cfg, store sqlstore.Store) web.Handler { return } - cmd := models.SetUsingOrgCommand{UserId: ctx.UserID, OrgId: orgId} - if err := store.SetUsingOrg(ctx.Req.Context(), &cmd); err != nil { + cmd := user.SetUsingOrgCommand{UserID: ctx.UserID, OrgID: orgId} + if err := userSvc.SetUsingOrg(ctx.Req.Context(), &cmd); err != nil { if ctx.IsApiRequest() { ctx.JsonApiErr(404, "Not found", nil) } else { diff --git a/pkg/middleware/org_redirect_test.go b/pkg/middleware/org_redirect_test.go index a8bb1559516..8d1460983df 100644 --- a/pkg/middleware/org_redirect_test.go +++ b/pkg/middleware/org_redirect_test.go @@ -5,9 +5,10 @@ import ( "fmt" "testing" + "github.com/stretchr/testify/require" + "github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/services/user" - "github.com/stretchr/testify/require" ) func TestOrgRedirectMiddleware(t *testing.T) { @@ -46,7 +47,7 @@ func TestOrgRedirectMiddleware(t *testing.T) { for _, tc := range testCases { middlewareScenario(t, tc.desc, func(t *testing.T, sc *scenarioContext) { sc.withTokenSessionCookie("token") - sc.mockSQLStore.ExpectedSignedInUser = &user.SignedInUser{OrgID: 1, UserID: 12} + sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: 1, UserID: 12} sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*models.UserToken, error) { return &models.UserToken{ UserId: 0, @@ -64,8 +65,8 @@ func TestOrgRedirectMiddleware(t *testing.T) { middlewareScenario(t, "when setting an invalid org for user", func(t *testing.T, sc *scenarioContext) { sc.withTokenSessionCookie("token") - sc.mockSQLStore.ExpectedSetUsingOrgError = fmt.Errorf("") - sc.mockSQLStore.ExpectedSignedInUser = &user.SignedInUser{OrgID: 1, UserID: 12} + sc.userService.ExpectedSetUsingOrgError = fmt.Errorf("") + sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: 1, UserID: 12} sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*models.UserToken, error) { return &models.UserToken{ diff --git a/pkg/middleware/recovery_test.go b/pkg/middleware/recovery_test.go index c66161cb889..0bbec1aca2d 100644 --- a/pkg/middleware/recovery_test.go +++ b/pkg/middleware/recovery_test.go @@ -5,13 +5,14 @@ import ( "strings" "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/grafana/grafana/pkg/infra/remotecache" "github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/services/auth" "github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/web" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func TestRecoveryMiddleware(t *testing.T) { @@ -70,7 +71,7 @@ func recoveryScenario(t *testing.T, desc string, url string, fn scenarioFunc) { contextHandler := getContextHandler(t, nil, nil, nil, nil, nil) sc.m.Use(contextHandler.Middleware) // mock out gc goroutine - sc.m.Use(OrgRedirect(cfg, sc.mockSQLStore)) + sc.m.Use(OrgRedirect(cfg, sc.userService)) sc.defaultHandler = func(c *models.ReqContext) { sc.context = c diff --git a/pkg/plugins/manager/manager_integration_test.go b/pkg/plugins/manager/manager_integration_test.go index beffba0c4ba..e3bcde12281 100644 --- a/pkg/plugins/manager/manager_integration_test.go +++ b/pkg/plugins/manager/manager_integration_test.go @@ -9,13 +9,12 @@ import ( "testing" "time" - "gopkg.in/ini.v1" - "github.com/grafana/grafana-azure-sdk-go/azsettings" "github.com/grafana/grafana-plugin-sdk-go/backend" "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel/trace" + "gopkg.in/ini.v1" "github.com/grafana/grafana/pkg/infra/tracing" "github.com/grafana/grafana/pkg/plugins" @@ -103,7 +102,7 @@ func TestIntegrationPluginManager(t *testing.T) { pg := postgres.ProvideService(cfg) my := mysql.ProvideService(cfg, hcp) ms := mssql.ProvideService(cfg) - sv2 := searchV2.ProvideService(cfg, sqlstore.InitTestDB(t), nil, nil, tracing.InitializeTracerForTest(), featuremgmt.WithFeatures(), nil) + sv2 := searchV2.ProvideService(cfg, sqlstore.InitTestDB(t), nil, nil, tracing.InitializeTracerForTest(), featuremgmt.WithFeatures(), nil, nil) graf := grafanads.ProvideService(cfg, sv2, nil) coreRegistry := coreplugin.ProvideCoreRegistry(am, cw, cm, es, grap, idb, lk, otsdb, pr, tmpo, td, pg, my, ms, graf) diff --git a/pkg/server/wire.go b/pkg/server/wire.go index a982bcaae81..8a56e474611 100644 --- a/pkg/server/wire.go +++ b/pkg/server/wire.go @@ -6,13 +6,6 @@ package server import ( "github.com/google/wire" sdkhttpclient "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" - "github.com/grafana/grafana/pkg/services/annotations" - "github.com/grafana/grafana/pkg/services/annotations/annotationsimpl" - - "github.com/grafana/grafana/pkg/services/auth" - "github.com/grafana/grafana/pkg/services/playlist/playlistimpl" - "github.com/grafana/grafana/pkg/services/store/sanitizer" - "github.com/grafana/grafana/pkg/api" "github.com/grafana/grafana/pkg/api/avatar" "github.com/grafana/grafana/pkg/api/routing" @@ -51,7 +44,10 @@ import ( "github.com/grafana/grafana/pkg/services/accesscontrol/acimpl" "github.com/grafana/grafana/pkg/services/accesscontrol/ossaccesscontrol" "github.com/grafana/grafana/pkg/services/alerting" + "github.com/grafana/grafana/pkg/services/annotations" + "github.com/grafana/grafana/pkg/services/annotations/annotationsimpl" "github.com/grafana/grafana/pkg/services/apikey/apikeyimpl" + "github.com/grafana/grafana/pkg/services/auth" "github.com/grafana/grafana/pkg/services/auth/jwt" "github.com/grafana/grafana/pkg/services/cleanup" "github.com/grafana/grafana/pkg/services/comments" @@ -95,6 +91,7 @@ import ( "github.com/grafana/grafana/pkg/services/notifications" "github.com/grafana/grafana/pkg/services/oauthtoken" "github.com/grafana/grafana/pkg/services/org/orgimpl" + "github.com/grafana/grafana/pkg/services/playlist/playlistimpl" "github.com/grafana/grafana/pkg/services/plugindashboards" plugindashboardsservice "github.com/grafana/grafana/pkg/services/plugindashboards/service" "github.com/grafana/grafana/pkg/services/pluginsettings" @@ -125,6 +122,7 @@ import ( "github.com/grafana/grafana/pkg/services/sqlstore/mockstore" "github.com/grafana/grafana/pkg/services/star/starimpl" "github.com/grafana/grafana/pkg/services/store" + "github.com/grafana/grafana/pkg/services/store/sanitizer" "github.com/grafana/grafana/pkg/services/tag" "github.com/grafana/grafana/pkg/services/tag/tagimpl" "github.com/grafana/grafana/pkg/services/team/teamimpl" diff --git a/pkg/services/accesscontrol/ossaccesscontrol/permissions_services.go b/pkg/services/accesscontrol/ossaccesscontrol/permissions_services.go index 9aeefa77710..60f97ef326d 100644 --- a/pkg/services/accesscontrol/ossaccesscontrol/permissions_services.go +++ b/pkg/services/accesscontrol/ossaccesscontrol/permissions_services.go @@ -40,7 +40,7 @@ var ( func ProvideTeamPermissions( cfg *setting.Cfg, router routing.RouteRegister, sql *sqlstore.SQLStore, ac accesscontrol.AccessControl, license models.Licensing, service accesscontrol.Service, - teamService team.Service, + teamService team.Service, userService user.Service, ) (*TeamPermissionsService, error) { options := resourcepermissions.Options{ Resource: "teams", @@ -96,7 +96,7 @@ func ProvideTeamPermissions( }, } - srv, err := resourcepermissions.New(options, cfg, router, license, ac, service, sql, teamService) + srv, err := resourcepermissions.New(options, cfg, router, license, ac, service, sql, teamService, userService) if err != nil { return nil, err } @@ -114,7 +114,7 @@ var DashboardAdminActions = append(DashboardEditActions, []string{dashboards.Act func ProvideDashboardPermissions( cfg *setting.Cfg, router routing.RouteRegister, sql *sqlstore.SQLStore, ac accesscontrol.AccessControl, license models.Licensing, dashboardStore dashboards.Store, service accesscontrol.Service, - teamService team.Service, + teamService team.Service, userService user.Service, ) (*DashboardPermissionsService, error) { getDashboard := func(ctx context.Context, orgID int64, resourceID string) (*models.Dashboard, error) { query := &models.GetDashboardQuery{Uid: resourceID, OrgId: orgID} @@ -168,7 +168,7 @@ func ProvideDashboardPermissions( RoleGroup: "Dashboards", } - srv, err := resourcepermissions.New(options, cfg, router, license, ac, service, sql, teamService) + srv, err := resourcepermissions.New(options, cfg, router, license, ac, service, sql, teamService, userService) if err != nil { return nil, err } @@ -193,7 +193,7 @@ var FolderAdminActions = append(FolderEditActions, []string{dashboards.ActionFol func ProvideFolderPermissions( cfg *setting.Cfg, router routing.RouteRegister, sql *sqlstore.SQLStore, accesscontrol accesscontrol.AccessControl, license models.Licensing, dashboardStore dashboards.Store, service accesscontrol.Service, - teamService team.Service, + teamService team.Service, userService user.Service, ) (*FolderPermissionsService, error) { options := resourcepermissions.Options{ Resource: "folders", @@ -224,7 +224,7 @@ func ProvideFolderPermissions( WriterRoleName: "Folder permission writer", RoleGroup: "Folders", } - srv, err := resourcepermissions.New(options, cfg, router, license, accesscontrol, service, sql, teamService) + srv, err := resourcepermissions.New(options, cfg, router, license, accesscontrol, service, sql, teamService, userService) if err != nil { return nil, err } @@ -284,7 +284,7 @@ type ServiceAccountPermissionsService struct { func ProvideServiceAccountPermissions( cfg *setting.Cfg, router routing.RouteRegister, sql *sqlstore.SQLStore, ac accesscontrol.AccessControl, license models.Licensing, serviceAccountStore serviceaccounts.Store, service accesscontrol.Service, - teamService team.Service, + teamService team.Service, userService user.Service, ) (*ServiceAccountPermissionsService, error) { options := resourcepermissions.Options{ Resource: "serviceaccounts", @@ -311,7 +311,7 @@ func ProvideServiceAccountPermissions( RoleGroup: "Service accounts", } - srv, err := resourcepermissions.New(options, cfg, router, license, ac, service, sql, teamService) + srv, err := resourcepermissions.New(options, cfg, router, license, ac, service, sql, teamService, userService) if err != nil { return nil, err } diff --git a/pkg/services/accesscontrol/resourcepermissions/service.go b/pkg/services/accesscontrol/resourcepermissions/service.go index b96ed0bf192..3f2717d2dd8 100644 --- a/pkg/services/accesscontrol/resourcepermissions/service.go +++ b/pkg/services/accesscontrol/resourcepermissions/service.go @@ -51,7 +51,7 @@ type Store interface { func New( options Options, cfg *setting.Cfg, router routing.RouteRegister, license models.Licensing, ac accesscontrol.AccessControl, service accesscontrol.Service, sqlStore *sqlstore.SQLStore, - teamService team.Service, + teamService team.Service, userService user.Service, ) (*Service, error) { var permissions []string actionSet := make(map[string]struct{}) @@ -83,6 +83,7 @@ func New( sqlStore: sqlStore, service: service, teamService: teamService, + userService: userService, } s.api = newApi(ac, router, s) @@ -110,6 +111,7 @@ type Service struct { actions []string sqlStore *sqlstore.SQLStore teamService team.Service + userService user.Service } func (s *Service) GetPermissions(ctx context.Context, user *user.SignedInUser, resourceID string) ([]accesscontrol.ResourcePermission, error) { @@ -286,10 +288,8 @@ func (s *Service) validateUser(ctx context.Context, orgID, userID int64) error { return ErrInvalidAssignment } - if err := s.sqlStore.GetSignedInUser(ctx, &models.GetSignedInUserQuery{OrgId: orgID, UserId: userID}); err != nil { - return err - } - return nil + _, err := s.userService.GetSignedInUser(ctx, &user.GetSignedInUserQuery{OrgID: orgID, UserID: userID}) + return err } func (s *Service) validateTeam(ctx context.Context, orgID, teamID int64) error { diff --git a/pkg/services/accesscontrol/resourcepermissions/service_test.go b/pkg/services/accesscontrol/resourcepermissions/service_test.go index 770a1f692c2..b069a4fac4c 100644 --- a/pkg/services/accesscontrol/resourcepermissions/service_test.go +++ b/pkg/services/accesscontrol/resourcepermissions/service_test.go @@ -15,6 +15,7 @@ import ( "github.com/grafana/grafana/pkg/services/team" "github.com/grafana/grafana/pkg/services/team/teamimpl" "github.com/grafana/grafana/pkg/services/user" + "github.com/grafana/grafana/pkg/services/user/userimpl" "github.com/grafana/grafana/pkg/setting" ) @@ -223,12 +224,13 @@ func setupTestEnvironment(t *testing.T, permissions []accesscontrol.Permission, sql := sqlstore.InitTestDB(t) cfg := setting.NewCfg() teamSvc := teamimpl.ProvideService(sql, cfg) + userSvc := userimpl.ProvideService(sql, nil, cfg, sql) license := licensingtest.NewFakeLicensing() license.On("FeatureEnabled", "accesscontrol.enforcement").Return(true).Maybe() mock := accesscontrolmock.New().WithPermissions(permissions) service, err := New( ops, cfg, routing.NewRouteRegister(), license, - accesscontrolmock.New().WithPermissions(permissions), mock, sql, teamSvc, + accesscontrolmock.New().WithPermissions(permissions), mock, sql, teamSvc, userSvc, ) require.NoError(t, err) diff --git a/pkg/services/contexthandler/auth_proxy_test.go b/pkg/services/contexthandler/auth_proxy_test.go index 6483b1e0883..1e968dc4bab 100644 --- a/pkg/services/contexthandler/auth_proxy_test.go +++ b/pkg/services/contexthandler/auth_proxy_test.go @@ -6,6 +6,8 @@ import ( "net/http" "testing" + "github.com/stretchr/testify/require" + "github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/infra/remotecache" "github.com/grafana/grafana/pkg/infra/tracing" @@ -19,7 +21,6 @@ import ( "github.com/grafana/grafana/pkg/services/user/usertest" "github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/web" - "github.com/stretchr/testify/require" ) const userID = int64(1) @@ -84,10 +85,23 @@ func getContextHandler(t *testing.T) *ContextHandler { tracer := tracing.InitializeTracerForTest() loginService := loginservice.LoginServiceMock{ExpectedUser: &user.User{ID: userID}} - authProxy := authproxy.ProvideAuthProxy(cfg, remoteCacheSvc, loginService, &FakeGetSignUserStore{}) + userService := usertest.FakeUserService{ + GetSignedInUserFn: func(ctx context.Context, query *user.GetSignedInUserQuery) (*user.SignedInUser, error) { + if query.UserID != userID { + return &user.SignedInUser{}, user.ErrUserNotFound + } + return &user.SignedInUser{ + UserID: userID, + OrgID: orgID, + }, nil + }, + } + + authProxy := authproxy.ProvideAuthProxy(cfg, remoteCacheSvc, loginService, &userService, &FakeGetSignUserStore{}) authenticator := &fakeAuthenticator{} - return ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, renderSvc, sqlStore, tracer, authProxy, loginService, nil, authenticator, &usertest.FakeUserService{}) + return ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, + renderSvc, sqlStore, tracer, authProxy, loginService, nil, authenticator, &userService) } type FakeGetSignUserStore struct { diff --git a/pkg/services/contexthandler/authproxy/authproxy.go b/pkg/services/contexthandler/authproxy/authproxy.go index 184c12ab234..7585c38133f 100644 --- a/pkg/services/contexthandler/authproxy/authproxy.go +++ b/pkg/services/contexthandler/authproxy/authproxy.go @@ -56,16 +56,18 @@ type AuthProxy struct { remoteCache *remotecache.RemoteCache loginService login.Service sqlStore sqlstore.Store + userService user.Service logger log.Logger } -func ProvideAuthProxy(cfg *setting.Cfg, remoteCache *remotecache.RemoteCache, loginService login.Service, sqlStore sqlstore.Store) *AuthProxy { +func ProvideAuthProxy(cfg *setting.Cfg, remoteCache *remotecache.RemoteCache, loginService login.Service, userService user.Service, sqlStore sqlstore.Store) *AuthProxy { return &AuthProxy{ cfg: cfg, remoteCache: remoteCache, loginService: loginService, sqlStore: sqlStore, + userService: userService, logger: log.New("auth.proxy"), } } @@ -347,16 +349,10 @@ func (auth *AuthProxy) headersIterator(reqCtx *models.ReqContext, fn func(field // GetSignedInUser gets full signed in user info. func (auth *AuthProxy) GetSignedInUser(userID int64, orgID int64) (*user.SignedInUser, error) { - query := &models.GetSignedInUserQuery{ - OrgId: orgID, - UserId: userID, - } - - if err := auth.sqlStore.GetSignedInUser(context.Background(), query); err != nil { - return nil, err - } - - return query.Result, nil + return auth.userService.GetSignedInUser(context.Background(), &user.GetSignedInUserQuery{ + OrgID: orgID, + UserID: userID, + }) } // Remember user in cache diff --git a/pkg/services/contexthandler/authproxy/authproxy_test.go b/pkg/services/contexthandler/authproxy/authproxy_test.go index 8f49041aa88..736776cc94d 100644 --- a/pkg/services/contexthandler/authproxy/authproxy_test.go +++ b/pkg/services/contexthandler/authproxy/authproxy_test.go @@ -7,6 +7,9 @@ import ( "net/http" "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/grafana/grafana/pkg/infra/remotecache" "github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/services/ldap" @@ -15,9 +18,6 @@ import ( "github.com/grafana/grafana/pkg/services/user" "github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/web" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) const hdrName = "markelog" @@ -48,7 +48,7 @@ func prepareMiddleware(t *testing.T, remoteCache *remotecache.RemoteCache, confi }, } - return ProvideAuthProxy(cfg, remoteCache, loginService, nil), ctx + return ProvideAuthProxy(cfg, remoteCache, loginService, nil, nil), ctx } func TestMiddlewareContext(t *testing.T) { diff --git a/pkg/services/guardian/accesscontrol_guardian_test.go b/pkg/services/guardian/accesscontrol_guardian_test.go index 9fe80861032..2965dd54ce8 100644 --- a/pkg/services/guardian/accesscontrol_guardian_test.go +++ b/pkg/services/guardian/accesscontrol_guardian_test.go @@ -9,6 +9,7 @@ import ( "github.com/grafana/grafana/pkg/services/tag/tagimpl" "github.com/grafana/grafana/pkg/services/team/teamimpl" "github.com/grafana/grafana/pkg/services/user" + "github.com/grafana/grafana/pkg/services/user/userimpl" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" @@ -603,12 +604,13 @@ func setupAccessControlGuardianTest(t *testing.T, uid string, permissions []acce license := licensingtest.NewFakeLicensing() license.On("FeatureEnabled", "accesscontrol.enforcement").Return(true).Maybe() teamSvc := teamimpl.ProvideService(store, store.Cfg) + userSvc := userimpl.ProvideService(store, nil, store.Cfg, store) folderPermissions, err := ossaccesscontrol.ProvideFolderPermissions( - setting.NewCfg(), routing.NewRouteRegister(), store, ac, license, &dashboards.FakeDashboardStore{}, ac, teamSvc) + setting.NewCfg(), routing.NewRouteRegister(), store, ac, license, &dashboards.FakeDashboardStore{}, ac, teamSvc, userSvc) require.NoError(t, err) dashboardPermissions, err := ossaccesscontrol.ProvideDashboardPermissions( - setting.NewCfg(), routing.NewRouteRegister(), store, ac, license, &dashboards.FakeDashboardStore{}, ac, teamSvc) + setting.NewCfg(), routing.NewRouteRegister(), store, ac, license, &dashboards.FakeDashboardStore{}, ac, teamSvc, userSvc) require.NoError(t, err) if dashboardSvc == nil { dashboardSvc = &dashboards.FakeDashboardService{} diff --git a/pkg/services/login/loginservice/loginservice.go b/pkg/services/login/loginservice/loginservice.go index 20b70931e22..cc7d05711dc 100644 --- a/pkg/services/login/loginservice/loginservice.go +++ b/pkg/services/login/loginservice/loginservice.go @@ -148,7 +148,7 @@ func (ls *Implementation) UpsertUser(ctx context.Context, cmd *models.UpsertUser // Sync isGrafanaAdmin permission if extUser.IsGrafanaAdmin != nil && *extUser.IsGrafanaAdmin != cmd.Result.IsAdmin { - if errPerms := ls.SQLStore.UpdateUserPermissions(cmd.Result.ID, *extUser.IsGrafanaAdmin); errPerms != nil { + if errPerms := ls.userService.UpdatePermissions(cmd.Result.ID, *extUser.IsGrafanaAdmin); errPerms != nil { return errPerms } } @@ -334,9 +334,9 @@ func (ls *Implementation) syncOrgRoles(ctx context.Context, usr *user.User, extU break } - return ls.SQLStore.SetUsingOrg(ctx, &models.SetUsingOrgCommand{ - UserId: usr.ID, - OrgId: usr.OrgID, + return ls.userService.SetUsingOrg(ctx, &user.SetUsingOrgCommand{ + UserID: usr.ID, + OrgID: usr.OrgID, }) } diff --git a/pkg/services/searchV2/allowed_actions_test.go b/pkg/services/searchV2/allowed_actions_test.go index 9f69da83b29..7bff7320821 100644 --- a/pkg/services/searchV2/allowed_actions_test.go +++ b/pkg/services/searchV2/allowed_actions_test.go @@ -9,6 +9,8 @@ import ( "github.com/grafana/grafana-plugin-sdk-go/backend" "github.com/grafana/grafana-plugin-sdk-go/data" "github.com/grafana/grafana-plugin-sdk-go/experimental" + "github.com/stretchr/testify/require" + "github.com/grafana/grafana/pkg/infra/tracing" ac "github.com/grafana/grafana/pkg/services/accesscontrol" accesscontrolmock "github.com/grafana/grafana/pkg/services/accesscontrol/mock" @@ -17,7 +19,6 @@ import ( "github.com/grafana/grafana/pkg/services/featuremgmt" "github.com/grafana/grafana/pkg/services/user" "github.com/grafana/grafana/pkg/setting" - "github.com/stretchr/testify/require" ) var ( @@ -83,7 +84,9 @@ var ( ) func service(t *testing.T) *StandardSearchService { - service, ok := ProvideService(&setting.Cfg{Search: setting.SearchSettings{}}, nil, nil, accesscontrolmock.New(), tracing.InitializeTracerForTest(), featuremgmt.WithFeatures(), nil).(*StandardSearchService) + service, ok := ProvideService(&setting.Cfg{Search: setting.SearchSettings{}}, + nil, nil, accesscontrolmock.New(), tracing.InitializeTracerForTest(), featuremgmt.WithFeatures(), + nil, nil).(*StandardSearchService) require.True(t, ok) return service } diff --git a/pkg/services/searchV2/service.go b/pkg/services/searchV2/service.go index daf68125dd8..8887c2ddac2 100644 --- a/pkg/services/searchV2/service.go +++ b/pkg/services/searchV2/service.go @@ -6,6 +6,10 @@ import ( "fmt" "time" + "github.com/grafana/grafana-plugin-sdk-go/backend" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + "github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/infra/tracing" "github.com/grafana/grafana/pkg/models" @@ -17,10 +21,6 @@ import ( "github.com/grafana/grafana/pkg/services/store" "github.com/grafana/grafana/pkg/services/user" "github.com/grafana/grafana/pkg/setting" - "github.com/prometheus/client_golang/prometheus" - "github.com/prometheus/client_golang/prometheus/promauto" - - "github.com/grafana/grafana-plugin-sdk-go/backend" ) var ( @@ -63,11 +63,12 @@ var ( type StandardSearchService struct { registry.BackgroundService - cfg *setting.Cfg - sql *sqlstore.SQLStore - auth FutureAuthService // eventually injected from elsewhere - ac accesscontrol.Service - orgService org.Service + cfg *setting.Cfg + sql *sqlstore.SQLStore + auth FutureAuthService // eventually injected from elsewhere + ac accesscontrol.Service + orgService org.Service + userService user.Service logger log.Logger dashboardIndex *searchIndex @@ -79,7 +80,9 @@ func (s *StandardSearchService) IsReady(ctx context.Context, orgId int64) IsSear return s.dashboardIndex.isInitialized(ctx, orgId) } -func ProvideService(cfg *setting.Cfg, sql *sqlstore.SQLStore, entityEventStore store.EntityEventsService, ac accesscontrol.Service, tracer tracing.Tracer, features featuremgmt.FeatureToggles, orgService org.Service) SearchService { +func ProvideService(cfg *setting.Cfg, sql *sqlstore.SQLStore, entityEventStore store.EntityEventsService, + ac accesscontrol.Service, tracer tracing.Tracer, features featuremgmt.FeatureToggles, orgService org.Service, + userService user.Service) SearchService { extender := &NoopExtender{} s := &StandardSearchService{ cfg: cfg, @@ -98,10 +101,11 @@ func ProvideService(cfg *setting.Cfg, sql *sqlstore.SQLStore, entityEventStore s features, cfg.Search, ), - logger: log.New("searchV2"), - extender: extender, - reIndexCh: make(chan struct{}, 1), - orgService: orgService, + logger: log.New("searchV2"), + extender: extender, + reIndexCh: make(chan struct{}, 1), + orgService: orgService, + userService: userService, } return s } @@ -157,23 +161,22 @@ func (s *StandardSearchService) getUser(ctx context.Context, backendUser *backen IsAnonymous: true, } } else { - getSignedInUserQuery := &models.GetSignedInUserQuery{ + getSignedInUserQuery := &user.GetSignedInUserQuery{ Login: backendUser.Login, Email: backendUser.Email, - OrgId: orgId, + OrgID: orgId, } - err := s.sql.GetSignedInUser(ctx, getSignedInUserQuery) + var err error + usr, err = s.userService.GetSignedInUser(ctx, getSignedInUserQuery) if err != nil { s.logger.Error("Error while retrieving user", "error", err, "email", backendUser.Email, "login", getSignedInUserQuery.Login) return nil, errors.New("auth error") } - if getSignedInUserQuery.Result == nil { + if usr == nil { s.logger.Error("No user found", "email", backendUser.Email) return nil, errors.New("auth error") } - - usr = getSignedInUserQuery.Result } if s.ac.IsDisabled() { diff --git a/pkg/services/serviceaccounts/api/api_test.go b/pkg/services/serviceaccounts/api/api_test.go index eae1918f855..a77782f3c83 100644 --- a/pkg/services/serviceaccounts/api/api_test.go +++ b/pkg/services/serviceaccounts/api/api_test.go @@ -11,6 +11,9 @@ import ( "strconv" "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/grafana/grafana/pkg/api/routing" "github.com/grafana/grafana/pkg/infra/kvstore" "github.com/grafana/grafana/pkg/infra/log" @@ -29,10 +32,9 @@ import ( "github.com/grafana/grafana/pkg/services/sqlstore" "github.com/grafana/grafana/pkg/services/team/teamimpl" "github.com/grafana/grafana/pkg/services/user" + "github.com/grafana/grafana/pkg/services/user/userimpl" "github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/web" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) var ( @@ -282,7 +284,9 @@ func setupTestServer(t *testing.T, svc *tests.ServiceAccountMock, sqlStore *sqlstore.SQLStore, saStore serviceaccounts.Store) (*web.Mux, *ServiceAccountsAPI) { cfg := setting.NewCfg() teamSvc := teamimpl.ProvideService(sqlStore, cfg) - saPermissionService, err := ossaccesscontrol.ProvideServiceAccountPermissions(cfg, routing.NewRouteRegister(), sqlStore, acmock, &licensing.OSSLicensingService{}, saStore, acmock, teamSvc) + userSvc := userimpl.ProvideService(sqlStore, nil, cfg, sqlStore) + saPermissionService, err := ossaccesscontrol.ProvideServiceAccountPermissions( + cfg, routing.NewRouteRegister(), sqlStore, acmock, &licensing.OSSLicensingService{}, saStore, acmock, teamSvc, userSvc) require.NoError(t, err) a := NewServiceAccountsAPI(cfg, svc, acmock, routerRegister, saStore, saPermissionService) diff --git a/pkg/services/sqlstore/stats_test.go b/pkg/services/sqlstore/stats_test.go index 3bf6d22aabe..3ff9bd3bd51 100644 --- a/pkg/services/sqlstore/stats_test.go +++ b/pkg/services/sqlstore/stats_test.go @@ -5,11 +5,12 @@ import ( "fmt" "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/services/org" "github.com/grafana/grafana/pkg/services/user" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func TestIntegrationStatsDataAccess(t *testing.T) { @@ -118,13 +119,6 @@ func populateDB(t *testing.T, sqlStore *SQLStore) { err = sqlStore.AddOrgUser(context.Background(), cmd) require.NoError(t, err) - // update 1st user last seen at - updateUserLastSeenAtCmd := &models.UpdateUserLastSeenAtCommand{ - UserId: users[0].ID, - } - err = sqlStore.UpdateUserLastSeenAt(context.Background(), updateUserLastSeenAtCmd) - require.NoError(t, err) - // force renewal of user stats err = sqlStore.updateUserRoleCountsIfNecessary(context.Background(), true) require.NoError(t, err) diff --git a/pkg/services/sqlstore/user.go b/pkg/services/sqlstore/user.go index a0fb2b21376..c4edd6fc0cb 100644 --- a/pkg/services/sqlstore/user.go +++ b/pkg/services/sqlstore/user.go @@ -17,27 +17,6 @@ import ( "github.com/grafana/grafana/pkg/util" ) -type ErrCaseInsensitiveLoginConflict struct { - users []user.User -} - -func (e *ErrCaseInsensitiveLoginConflict) Unwrap() error { - return user.ErrCaseInsensitive -} - -func (e *ErrCaseInsensitiveLoginConflict) Error() string { - n := len(e.users) - - userStrings := make([]string, 0, n) - for _, v := range e.users { - userStrings = append(userStrings, fmt.Sprintf("%s (email:%s, id:%d)", v.Login, v.Email, v.ID)) - } - - return fmt.Sprintf( - "Found a conflict in user login information. %d users already exist with either the same login or email: [%s].", - n, strings.Join(userStrings, ", ")) -} - func (ss *SQLStore) getOrgIDForNewUser(sess *DBSession, args user.CreateUserCommand) (int64, error) { if ss.Cfg.AutoAssignOrg && args.OrgID != 0 { if err := verifyExistingOrg(sess, args.OrgID); err != nil { @@ -63,7 +42,7 @@ func (ss *SQLStore) userCaseInsensitiveLoginConflict(ctx context.Context, sess * } if len(users) > 1 { - return &ErrCaseInsensitiveLoginConflict{users} + return &user.ErrCaseInsensitiveLoginConflict{Users: users} } return nil diff --git a/pkg/services/user/userimpl/store.go b/pkg/services/user/userimpl/store.go index 2ccce72a595..f215789a8da 100644 --- a/pkg/services/user/userimpl/store.go +++ b/pkg/services/user/userimpl/store.go @@ -5,10 +5,12 @@ import ( "fmt" "github.com/grafana/grafana/pkg/events" + "github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/services/sqlstore" "github.com/grafana/grafana/pkg/services/sqlstore/db" "github.com/grafana/grafana/pkg/services/sqlstore/migrator" "github.com/grafana/grafana/pkg/services/user" + "github.com/grafana/grafana/pkg/setting" ) type store interface { @@ -23,6 +25,17 @@ type store interface { type sqlStore struct { db db.DB dialect migrator.Dialect + logger log.Logger + cfg *setting.Cfg +} + +func ProvideStore(db db.DB, cfg *setting.Cfg) sqlStore { + return sqlStore{ + db: db, + dialect: db.GetDialect(), + cfg: cfg, + logger: log.New("user.store"), + } } func (ss *sqlStore) Insert(ctx context.Context, cmd *user.User) (int64, error) { diff --git a/pkg/services/user/userimpl/store_test.go b/pkg/services/user/userimpl/store_test.go index 8b1bbe7e7e7..3df8e61c9b8 100644 --- a/pkg/services/user/userimpl/store_test.go +++ b/pkg/services/user/userimpl/store_test.go @@ -5,10 +5,11 @@ import ( "testing" "time" + "github.com/stretchr/testify/require" + "github.com/grafana/grafana/pkg/services/sqlstore" "github.com/grafana/grafana/pkg/services/user" - - "github.com/stretchr/testify/require" + "github.com/grafana/grafana/pkg/setting" ) func TestIntegrationUserDataAccess(t *testing.T) { @@ -17,7 +18,7 @@ func TestIntegrationUserDataAccess(t *testing.T) { } ss := sqlstore.InitTestDB(t) - userStore := sqlStore{db: ss} + userStore := ProvideStore(ss, setting.NewCfg()) t.Run("user not found", func(t *testing.T) { _, err := userStore.Get(context.Background(), diff --git a/pkg/services/user/userimpl/user.go b/pkg/services/user/userimpl/user.go index 90eb27714b8..47824d9421e 100644 --- a/pkg/services/user/userimpl/user.go +++ b/pkg/services/user/userimpl/user.go @@ -6,67 +6,34 @@ import ( "time" "github.com/grafana/grafana/pkg/models" - "github.com/grafana/grafana/pkg/services/accesscontrol" - "github.com/grafana/grafana/pkg/services/dashboards" "github.com/grafana/grafana/pkg/services/org" - pref "github.com/grafana/grafana/pkg/services/preference" - "github.com/grafana/grafana/pkg/services/quota" "github.com/grafana/grafana/pkg/services/sqlstore" "github.com/grafana/grafana/pkg/services/sqlstore/db" - "github.com/grafana/grafana/pkg/services/star" - "github.com/grafana/grafana/pkg/services/teamguardian" "github.com/grafana/grafana/pkg/services/user" - "github.com/grafana/grafana/pkg/services/userauth" "github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/util" - - "golang.org/x/sync/errgroup" ) type Service struct { - store store - orgService org.Service - starService star.Service - dashboardService dashboards.DashboardService - preferenceService pref.Service - teamMemberService teamguardian.TeamGuardian - userAuthService userauth.Service - quotaService quota.Service - accessControlStore accesscontrol.Service + store store + orgService org.Service // TODO remove sqlstore sqlStore *sqlstore.SQLStore - - cfg *setting.Cfg + cfg *setting.Cfg } func ProvideService( db db.DB, orgService org.Service, - starService star.Service, - dashboardService dashboards.DashboardService, - preferenceService pref.Service, - teamMemberService teamguardian.TeamGuardian, - userAuthService userauth.Service, - quotaService quota.Service, - accessControlStore accesscontrol.Service, cfg *setting.Cfg, ss *sqlstore.SQLStore, ) user.Service { + store := ProvideStore(db, cfg) return &Service{ - store: &sqlStore{ - db: db, - dialect: db.GetDialect(), - }, - orgService: orgService, - starService: starService, - dashboardService: dashboardService, - preferenceService: preferenceService, - teamMemberService: teamMemberService, - userAuthService: userAuthService, - quotaService: quotaService, - accessControlStore: accessControlStore, - cfg: cfg, - sqlStore: ss, + store: &store, + orgService: orgService, + cfg: cfg, + sqlStore: ss, } } @@ -169,70 +136,7 @@ func (s *Service) Delete(ctx context.Context, cmd *user.DeleteUserCommand) error return err } // delete from all the stores - if err := s.store.Delete(ctx, cmd.UserID); err != nil { - return err - } - - g, ctx := errgroup.WithContext(ctx) - g.Go(func() error { - if err := s.starService.DeleteByUser(ctx, cmd.UserID); err != nil { - return err - } - return nil - }) - g.Go(func() error { - if err := s.orgService.DeleteUserFromAll(ctx, cmd.UserID); err != nil { - return err - } - return nil - }) - g.Go(func() error { - if err := s.dashboardService.DeleteACLByUser(ctx, cmd.UserID); err != nil { - return err - } - return nil - }) - g.Go(func() error { - if err := s.preferenceService.DeleteByUser(ctx, cmd.UserID); err != nil { - return err - } - return nil - }) - g.Go(func() error { - if err := s.teamMemberService.DeleteByUser(ctx, cmd.UserID); err != nil { - return err - } - return nil - }) - g.Go(func() error { - if err := s.userAuthService.Delete(ctx, cmd.UserID); err != nil { - return err - } - return nil - }) - g.Go(func() error { - if err := s.userAuthService.DeleteToken(ctx, cmd.UserID); err != nil { - return err - } - return nil - }) - g.Go(func() error { - if err := s.quotaService.DeleteByUser(ctx, cmd.UserID); err != nil { - return err - } - return nil - }) - g.Go(func() error { - if err := s.accessControlStore.DeleteUserPermissions(ctx, accesscontrol.GlobalOrgID, cmd.UserID); err != nil { - return err - } - return nil - }) - if err := g.Wait(); err != nil { - return err - } - - return nil + return s.store.Delete(ctx, cmd.UserID) } func (s *Service) GetByID(ctx context.Context, query *user.GetUserByIDQuery) (*user.User, error) { @@ -331,10 +235,7 @@ func (s *Service) GetSignedInUser(ctx context.Context, query *user.GetSignedInUs OrgId: query.OrgID, } err := s.sqlStore.GetSignedInUser(ctx, q) - if err != nil { - return nil, err - } - return q.Result, nil + return q.Result, err } // TODO: remove wrapper around sqlstore diff --git a/pkg/services/user/userimpl/user_test.go b/pkg/services/user/userimpl/user_test.go index 7cb904b4e98..f8a634d96bb 100644 --- a/pkg/services/user/userimpl/user_test.go +++ b/pkg/services/user/userimpl/user_test.go @@ -2,18 +2,10 @@ package userimpl import ( "context" - "errors" "testing" - "github.com/grafana/grafana/pkg/services/accesscontrol/mock" - "github.com/grafana/grafana/pkg/services/dashboards" "github.com/grafana/grafana/pkg/services/org/orgtest" - "github.com/grafana/grafana/pkg/services/preference/preftest" - "github.com/grafana/grafana/pkg/services/quota/quotatest" - "github.com/grafana/grafana/pkg/services/star/startest" - "github.com/grafana/grafana/pkg/services/teamguardian/manager" "github.com/grafana/grafana/pkg/services/user" - "github.com/grafana/grafana/pkg/services/userauth/userauthtest" "github.com/grafana/grafana/pkg/setting" "github.com/stretchr/testify/require" @@ -22,23 +14,9 @@ import ( func TestUserService(t *testing.T) { userStore := newUserStoreFake() orgService := orgtest.NewOrgServiceFake() - starService := startest.NewStarServiceFake() - dashboardService := dashboards.NewFakeDashboardService(t) - preferenceService := preftest.NewPreferenceServiceFake() - teamMemberService := manager.NewTeamGuardianMock() - userAuthService := userauthtest.NewFakeUserAuthService() - quotaService := quotatest.NewQuotaServiceFake() - accessControlStore := mock.New() userService := Service{ - store: userStore, - orgService: orgService, - starService: starService, - dashboardService: dashboardService, - preferenceService: preferenceService, - teamMemberService: teamMemberService, - userAuthService: userAuthService, - quotaService: quotaService, - accessControlStore: accessControlStore, + store: userStore, + orgService: orgService, } t.Run("create user", func(t *testing.T) { @@ -81,26 +59,6 @@ func TestUserService(t *testing.T) { require.Error(t, err, user.ErrUserNotFound) }) - t.Run("delete user returns from team", func(t *testing.T) { - teamMemberService.ExpectedError = errors.New("some error") - t.Cleanup(func() { - teamMemberService.ExpectedError = nil - }) - err := userService.Delete(context.Background(), &user.DeleteUserCommand{UserID: 1}) - require.Error(t, err) - }) - - t.Run("delete user returns from team and pref", func(t *testing.T) { - teamMemberService.ExpectedError = errors.New("some error") - preferenceService.ExpectedError = errors.New("some error 2") - t.Cleanup(func() { - teamMemberService.ExpectedError = nil - preferenceService.ExpectedError = nil - }) - err := userService.Delete(context.Background(), &user.DeleteUserCommand{UserID: 1}) - require.Error(t, err) - }) - t.Run("delete user successfully", func(t *testing.T) { err := userService.Delete(context.Background(), &user.DeleteUserCommand{UserID: 1}) require.NoError(t, err) @@ -115,26 +73,6 @@ func TestUserService(t *testing.T) { require.Error(t, err, user.ErrUserNotFound) }) - t.Run("delete user returns from team", func(t *testing.T) { - teamMemberService.ExpectedError = errors.New("some error") - t.Cleanup(func() { - teamMemberService.ExpectedError = nil - }) - err := userService.Delete(context.Background(), &user.DeleteUserCommand{UserID: 1}) - require.Error(t, err) - }) - - t.Run("delete user returns from team and pref", func(t *testing.T) { - teamMemberService.ExpectedError = errors.New("some error") - preferenceService.ExpectedError = errors.New("some error 2") - t.Cleanup(func() { - teamMemberService.ExpectedError = nil - preferenceService.ExpectedError = nil - }) - err := userService.Delete(context.Background(), &user.DeleteUserCommand{UserID: 1}) - require.Error(t, err) - }) - t.Run("delete user successfully", func(t *testing.T) { err := userService.Delete(context.Background(), &user.DeleteUserCommand{UserID: 1}) require.NoError(t, err) diff --git a/pkg/services/user/usertest/fake.go b/pkg/services/user/usertest/fake.go index 0705854a59f..47ce819f436 100644 --- a/pkg/services/user/usertest/fake.go +++ b/pkg/services/user/usertest/fake.go @@ -13,6 +13,8 @@ type FakeUserService struct { ExpectedSetUsingOrgError error ExpectedSearchUsers user.SearchUserQueryResult ExpectedUSerProfileDTO user.UserProfileDTO + + GetSignedInUserFn func(ctx context.Context, query *user.GetSignedInUserQuery) (*user.SignedInUser, error) } func NewUserServiceFake() *FakeUserService { @@ -60,6 +62,9 @@ func (f *FakeUserService) GetSignedInUserWithCacheCtx(ctx context.Context, query } func (f *FakeUserService) GetSignedInUser(ctx context.Context, query *user.GetSignedInUserQuery) (*user.SignedInUser, error) { + if f.GetSignedInUserFn != nil { + return f.GetSignedInUserFn(ctx, query) + } if f.ExpectedSignedInUser == nil { return &user.SignedInUser{}, f.ExpectedError }