mirror of
https://github.com/grafana/grafana.git
synced 2025-07-31 19:22:34 +08:00
Chore: Add context to authinfo (#42096)
* Add context to authinfo * Replace Dispatch with DispatchCtx
This commit is contained in:
@ -1,7 +1,6 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
@ -80,7 +79,7 @@ func AdminUpdateUserPassword(c *models.ReqContext, form dtos.AdminUpdateUserPass
|
||||
NewPassword: passwordHashed,
|
||||
}
|
||||
|
||||
if err := bus.Dispatch(&cmd); err != nil {
|
||||
if err := bus.DispatchCtx(c.Req.Context(), &cmd); err != nil {
|
||||
return response.Error(500, "Failed to update user password", err)
|
||||
}
|
||||
|
||||
@ -108,7 +107,7 @@ func AdminDeleteUser(c *models.ReqContext) response.Response {
|
||||
|
||||
cmd := models.DeleteUserCommand{UserId: userID}
|
||||
|
||||
if err := bus.Dispatch(&cmd); err != nil {
|
||||
if err := bus.DispatchCtx(c.Req.Context(), &cmd); err != nil {
|
||||
if errors.Is(err, models.ErrUserNotFound) {
|
||||
return response.Error(404, models.ErrUserNotFound.Error(), nil)
|
||||
}
|
||||
@ -124,12 +123,12 @@ func (hs *HTTPServer) AdminDisableUser(c *models.ReqContext) response.Response {
|
||||
|
||||
// External users shouldn't be disabled from API
|
||||
authInfoQuery := &models.GetAuthInfoQuery{UserId: userID}
|
||||
if err := bus.DispatchCtx(context.TODO(), authInfoQuery); !errors.Is(err, models.ErrUserNotFound) {
|
||||
if err := bus.DispatchCtx(c.Req.Context(), authInfoQuery); !errors.Is(err, models.ErrUserNotFound) {
|
||||
return response.Error(500, "Could not disable external user", nil)
|
||||
}
|
||||
|
||||
disableCmd := models.DisableUserCommand{UserId: userID, IsDisabled: true}
|
||||
if err := bus.Dispatch(&disableCmd); err != nil {
|
||||
if err := bus.DispatchCtx(c.Req.Context(), &disableCmd); err != nil {
|
||||
if errors.Is(err, models.ErrUserNotFound) {
|
||||
return response.Error(404, models.ErrUserNotFound.Error(), nil)
|
||||
}
|
||||
@ -150,12 +149,12 @@ func AdminEnableUser(c *models.ReqContext) response.Response {
|
||||
|
||||
// External users shouldn't be disabled from API
|
||||
authInfoQuery := &models.GetAuthInfoQuery{UserId: userID}
|
||||
if err := bus.DispatchCtx(context.TODO(), authInfoQuery); !errors.Is(err, models.ErrUserNotFound) {
|
||||
if err := bus.DispatchCtx(c.Req.Context(), authInfoQuery); !errors.Is(err, models.ErrUserNotFound) {
|
||||
return response.Error(500, "Could not enable external user", nil)
|
||||
}
|
||||
|
||||
disableCmd := models.DisableUserCommand{UserId: userID, IsDisabled: false}
|
||||
if err := bus.Dispatch(&disableCmd); err != nil {
|
||||
if err := bus.DispatchCtx(c.Req.Context(), &disableCmd); err != nil {
|
||||
if errors.Is(err, models.ErrUserNotFound) {
|
||||
return response.Error(404, models.ErrUserNotFound.Error(), nil)
|
||||
}
|
||||
|
@ -63,7 +63,7 @@ type LDAPServerDTO struct {
|
||||
}
|
||||
|
||||
// FetchOrgs fetches the organization(s) information by executing a single query to the database. Then, populating the DTO with the information retrieved.
|
||||
func (user *LDAPUserDTO) FetchOrgs() error {
|
||||
func (user *LDAPUserDTO) FetchOrgs(ctx context.Context) error {
|
||||
orgIds := []int64{}
|
||||
|
||||
for _, or := range user.OrgRoles {
|
||||
@ -73,7 +73,7 @@ func (user *LDAPUserDTO) FetchOrgs() error {
|
||||
q := &models.SearchOrgsQuery{}
|
||||
q.Ids = orgIds
|
||||
|
||||
if err := bus.Dispatch(q); err != nil {
|
||||
if err := bus.DispatchCtx(ctx, q); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -196,7 +196,7 @@ func (hs *HTTPServer) PostSyncUserWithLDAP(c *models.ReqContext) response.Respon
|
||||
}
|
||||
|
||||
// Since the user was not in the LDAP server. Let's disable it.
|
||||
err := login.DisableExternalUser(query.Result.Login)
|
||||
err := login.DisableExternalUser(c.Req.Context(), query.Result.Login)
|
||||
if err != nil {
|
||||
return response.Error(http.StatusInternalServerError, "Failed to disable the user", err)
|
||||
}
|
||||
@ -219,7 +219,7 @@ func (hs *HTTPServer) PostSyncUserWithLDAP(c *models.ReqContext) response.Respon
|
||||
SignupAllowed: hs.Cfg.LDAPAllowSignup,
|
||||
}
|
||||
|
||||
err = bus.Dispatch(upsertCmd)
|
||||
err = bus.DispatchCtx(c.Req.Context(), upsertCmd)
|
||||
if err != nil {
|
||||
return response.Error(http.StatusInternalServerError, "Failed to update the user", err)
|
||||
}
|
||||
@ -302,13 +302,13 @@ func (hs *HTTPServer) GetUserFromLDAP(c *models.ReqContext) response.Response {
|
||||
u.OrgRoles = orgRoles
|
||||
|
||||
ldapLogger.Debug("mapping org roles", "orgsRoles", u.OrgRoles)
|
||||
err = u.FetchOrgs()
|
||||
err = u.FetchOrgs(c.Req.Context())
|
||||
if err != nil {
|
||||
return response.Error(http.StatusBadRequest, "An organization was not found - Please verify your LDAP configuration", err)
|
||||
}
|
||||
|
||||
cmd := &models.GetTeamsForLDAPGroupCommand{Groups: user.Groups}
|
||||
err = bus.Dispatch(cmd)
|
||||
err = bus.DispatchCtx(c.Req.Context(), cmd)
|
||||
if err != nil && !errors.Is(err, bus.ErrHandlerNotFound) {
|
||||
return response.Error(http.StatusBadRequest, "Unable to find the teams for this user", err)
|
||||
}
|
||||
|
@ -311,7 +311,7 @@ func syncUser(
|
||||
ExternalUser: extUser,
|
||||
SignupAllowed: connect.IsSignupAllowed(),
|
||||
}
|
||||
if err := bus.Dispatch(cmd); err != nil {
|
||||
if err := bus.DispatchCtx(ctx.Req.Context(), cmd); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -43,7 +43,7 @@ var loginUsingLDAP = func(ctx context.Context, query *models.LoginUserQuery) (bo
|
||||
if err != nil {
|
||||
if errors.Is(err, ldap.ErrCouldNotFindUser) {
|
||||
// Ignore the error since user might not be present anyway
|
||||
if err := DisableExternalUser(query.Username); err != nil {
|
||||
if err := DisableExternalUser(ctx, query.Username); err != nil {
|
||||
ldapLogger.Debug("Failed to disable external user", "err", err)
|
||||
}
|
||||
|
||||
@ -68,13 +68,13 @@ var loginUsingLDAP = func(ctx context.Context, query *models.LoginUserQuery) (bo
|
||||
}
|
||||
|
||||
// DisableExternalUser marks external user as disabled in Grafana db
|
||||
func DisableExternalUser(username string) error {
|
||||
func DisableExternalUser(ctx context.Context, username string) error {
|
||||
// Check if external user exist in Grafana
|
||||
userQuery := &models.GetExternalUserInfoByLoginQuery{
|
||||
LoginOrEmail: username,
|
||||
}
|
||||
|
||||
if err := bus.Dispatch(userQuery); err != nil {
|
||||
if err := bus.DispatchCtx(ctx, userQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -92,7 +92,7 @@ func DisableExternalUser(username string) error {
|
||||
IsDisabled: true,
|
||||
}
|
||||
|
||||
if err := bus.Dispatch(disableUserCmd); err != nil {
|
||||
if err := bus.DispatchCtx(ctx, disableUserCmd); err != nil {
|
||||
ldapLogger.Debug(
|
||||
"Error disabling external user",
|
||||
"user",
|
||||
|
@ -248,7 +248,7 @@ func (auth *AuthProxy) LoginViaLDAP() (int64, error) {
|
||||
SignupAllowed: auth.cfg.LDAPAllowSignup,
|
||||
ExternalUser: extUser,
|
||||
}
|
||||
if err := bus.Dispatch(upsert); err != nil {
|
||||
if err := bus.DispatchCtx(auth.ctx.Req.Context(), upsert); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
@ -305,7 +305,7 @@ func (auth *AuthProxy) LoginViaHeader() (int64, error) {
|
||||
ExternalUser: extUser,
|
||||
}
|
||||
|
||||
err := bus.Dispatch(upsert)
|
||||
err := bus.DispatchCtx(auth.ctx.Req.Context(), upsert)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
@ -1,7 +1,11 @@
|
||||
package login
|
||||
|
||||
import "github.com/grafana/grafana/pkg/models"
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/grafana/grafana/pkg/models"
|
||||
)
|
||||
|
||||
type AuthInfoService interface {
|
||||
LookupAndUpdate(query *models.GetUserByAuthInfoQuery) (*models.User, error)
|
||||
LookupAndUpdate(ctx context.Context, query *models.GetUserByAuthInfoQuery) (*models.User, error)
|
||||
}
|
||||
|
@ -37,7 +37,7 @@ func (s *Implementation) GetExternalUserInfoByLogin(ctx context.Context, query *
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Implementation) GetAuthInfo(query *models.GetAuthInfoQuery) error {
|
||||
func (s *Implementation) GetAuthInfo(ctx context.Context, query *models.GetAuthInfoQuery) error {
|
||||
userAuth := &models.UserAuth{
|
||||
UserId: query.UserId,
|
||||
AuthModule: query.AuthModule,
|
||||
@ -79,7 +79,7 @@ func (s *Implementation) GetAuthInfo(query *models.GetAuthInfoQuery) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Implementation) SetAuthInfo(cmd *models.SetAuthInfoCommand) error {
|
||||
func (s *Implementation) SetAuthInfo(ctx context.Context, cmd *models.SetAuthInfoCommand) error {
|
||||
authUser := &models.UserAuth{
|
||||
UserId: cmd.UserId,
|
||||
AuthModule: cmd.AuthModule,
|
||||
@ -113,7 +113,7 @@ func (s *Implementation) SetAuthInfo(cmd *models.SetAuthInfoCommand) error {
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Implementation) UpdateAuthInfo(cmd *models.UpdateAuthInfoCommand) error {
|
||||
func (s *Implementation) UpdateAuthInfo(ctx context.Context, cmd *models.UpdateAuthInfoCommand) error {
|
||||
authUser := &models.UserAuth{
|
||||
UserId: cmd.UserId,
|
||||
AuthModule: cmd.AuthModule,
|
||||
@ -153,7 +153,7 @@ func (s *Implementation) UpdateAuthInfo(cmd *models.UpdateAuthInfoCommand) error
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Implementation) DeleteAuthInfo(cmd *models.DeleteAuthInfoCommand) error {
|
||||
func (s *Implementation) DeleteAuthInfo(ctx context.Context, cmd *models.DeleteAuthInfoCommand) error {
|
||||
return s.SQLStore.WithTransactionalDbSession(context.Background(), func(sess *sqlstore.DBSession) error {
|
||||
_, err := sess.Delete(cmd.UserAuth)
|
||||
return err
|
||||
|
@ -33,10 +33,10 @@ func ProvideAuthInfoService(bus bus.Bus, store *sqlstore.SQLStore, userProtectio
|
||||
}
|
||||
|
||||
s.Bus.AddHandlerCtx(s.GetExternalUserInfoByLogin)
|
||||
s.Bus.AddHandler(s.GetAuthInfo)
|
||||
s.Bus.AddHandler(s.SetAuthInfo)
|
||||
s.Bus.AddHandler(s.UpdateAuthInfo)
|
||||
s.Bus.AddHandler(s.DeleteAuthInfo)
|
||||
s.Bus.AddHandlerCtx(s.GetAuthInfo)
|
||||
s.Bus.AddHandlerCtx(s.SetAuthInfo)
|
||||
s.Bus.AddHandlerCtx(s.UpdateAuthInfo)
|
||||
s.Bus.AddHandlerCtx(s.DeleteAuthInfo)
|
||||
|
||||
return s
|
||||
}
|
||||
@ -70,7 +70,7 @@ func (s *Implementation) getUser(user *models.User) (bool, error) {
|
||||
return has, err
|
||||
}
|
||||
|
||||
func (s *Implementation) LookupAndFix(query *models.GetUserByAuthInfoQuery) (bool, *models.User, *models.UserAuth, error) {
|
||||
func (s *Implementation) LookupAndFix(ctx context.Context, query *models.GetUserByAuthInfoQuery) (bool, *models.User, *models.UserAuth, error) {
|
||||
authQuery := &models.GetAuthInfoQuery{}
|
||||
|
||||
// Try to find the user by auth module and id first
|
||||
@ -78,7 +78,7 @@ func (s *Implementation) LookupAndFix(query *models.GetUserByAuthInfoQuery) (boo
|
||||
authQuery.AuthModule = query.AuthModule
|
||||
authQuery.AuthId = query.AuthId
|
||||
|
||||
err := s.GetAuthInfo(authQuery)
|
||||
err := s.GetAuthInfo(ctx, authQuery)
|
||||
if !errors.Is(err, models.ErrUserNotFound) {
|
||||
if err != nil {
|
||||
return false, nil, nil, err
|
||||
@ -86,7 +86,7 @@ func (s *Implementation) LookupAndFix(query *models.GetUserByAuthInfoQuery) (boo
|
||||
|
||||
// if user id was specified and doesn't match the user_auth entry, remove it
|
||||
if query.UserId != 0 && query.UserId != authQuery.Result.UserId {
|
||||
err := s.DeleteAuthInfo(&models.DeleteAuthInfoCommand{
|
||||
err := s.DeleteAuthInfo(ctx, &models.DeleteAuthInfoCommand{
|
||||
UserAuth: authQuery.Result,
|
||||
})
|
||||
if err != nil {
|
||||
@ -102,7 +102,7 @@ func (s *Implementation) LookupAndFix(query *models.GetUserByAuthInfoQuery) (boo
|
||||
|
||||
if !has {
|
||||
// if the user has been deleted then remove the entry
|
||||
err = s.DeleteAuthInfo(&models.DeleteAuthInfoCommand{
|
||||
err = s.DeleteAuthInfo(ctx, &models.DeleteAuthInfoCommand{
|
||||
UserAuth: authQuery.Result,
|
||||
})
|
||||
if err != nil {
|
||||
@ -158,13 +158,13 @@ func (s *Implementation) LookupByOneOf(userId int64, email string, login string)
|
||||
return foundUser, user, nil
|
||||
}
|
||||
|
||||
func (s *Implementation) GenericOAuthLookup(authModule string, authId string, userID int64) (*models.UserAuth, error) {
|
||||
func (s *Implementation) GenericOAuthLookup(ctx context.Context, authModule string, authId string, userID int64) (*models.UserAuth, error) {
|
||||
if authModule == genericOAuthModule && userID != 0 {
|
||||
authQuery := &models.GetAuthInfoQuery{}
|
||||
authQuery.AuthModule = authModule
|
||||
authQuery.AuthId = authId
|
||||
authQuery.UserId = userID
|
||||
err := s.GetAuthInfo(authQuery)
|
||||
err := s.GetAuthInfo(ctx, authQuery)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -174,10 +174,10 @@ func (s *Implementation) GenericOAuthLookup(authModule string, authId string, us
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *Implementation) LookupAndUpdate(query *models.GetUserByAuthInfoQuery) (*models.User, error) {
|
||||
func (s *Implementation) LookupAndUpdate(ctx context.Context, query *models.GetUserByAuthInfoQuery) (*models.User, error) {
|
||||
// 1. LookupAndFix = auth info, user, error
|
||||
// TODO: Not a big fan of the fact that we are deleting auth info here, might want to move that
|
||||
foundUser, user, authInfo, err := s.LookupAndFix(query)
|
||||
foundUser, user, authInfo, err := s.LookupAndFix(ctx, query)
|
||||
if err != nil && !errors.Is(err, models.ErrUserNotFound) {
|
||||
return nil, err
|
||||
}
|
||||
@ -195,7 +195,7 @@ func (s *Implementation) LookupAndUpdate(query *models.GetUserByAuthInfoQuery) (
|
||||
}
|
||||
|
||||
// Special case for generic oauth duplicates
|
||||
ai, err := s.GenericOAuthLookup(query.AuthModule, query.AuthId, user.Id)
|
||||
ai, err := s.GenericOAuthLookup(ctx, query.AuthModule, query.AuthId, user.Id)
|
||||
if !errors.Is(err, models.ErrUserNotFound) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -211,7 +211,7 @@ func (s *Implementation) LookupAndUpdate(query *models.GetUserByAuthInfoQuery) (
|
||||
AuthModule: query.AuthModule,
|
||||
AuthId: query.AuthId,
|
||||
}
|
||||
if err := s.SetAuthInfo(cmd); err != nil {
|
||||
if err := s.SetAuthInfo(ctx, cmd); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
@ -40,7 +40,7 @@ func TestUserAuth(t *testing.T) {
|
||||
login := "loginuser0"
|
||||
|
||||
query := &models.GetUserByAuthInfoQuery{Login: login}
|
||||
user, err := srv.LookupAndUpdate(query)
|
||||
user, err := srv.LookupAndUpdate(context.Background(), query)
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, user.Login, login)
|
||||
@ -73,7 +73,7 @@ func TestUserAuth(t *testing.T) {
|
||||
t.Run("Can set & locate by AuthModule and AuthId", func(t *testing.T) {
|
||||
// get nonexistent user_auth entry
|
||||
query := &models.GetUserByAuthInfoQuery{AuthModule: "test", AuthId: "test"}
|
||||
user, err := srv.LookupAndUpdate(query)
|
||||
user, err := srv.LookupAndUpdate(context.Background(), query)
|
||||
|
||||
require.Equal(t, models.ErrUserNotFound, err)
|
||||
require.Nil(t, user)
|
||||
@ -82,14 +82,14 @@ func TestUserAuth(t *testing.T) {
|
||||
login := "loginuser0"
|
||||
|
||||
query.Login = login
|
||||
user, err = srv.LookupAndUpdate(query)
|
||||
user, err = srv.LookupAndUpdate(context.Background(), query)
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, user.Login, login)
|
||||
|
||||
// get via user_auth
|
||||
query = &models.GetUserByAuthInfoQuery{AuthModule: "test", AuthId: "test"}
|
||||
user, err = srv.LookupAndUpdate(query)
|
||||
user, err = srv.LookupAndUpdate(context.Background(), query)
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, user.Login, login)
|
||||
@ -98,14 +98,14 @@ func TestUserAuth(t *testing.T) {
|
||||
id := user.Id
|
||||
|
||||
query.UserId = id + 1
|
||||
user, err = srv.LookupAndUpdate(query)
|
||||
user, err = srv.LookupAndUpdate(context.Background(), query)
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, user.Login, "loginuser1")
|
||||
|
||||
// get via user_auth
|
||||
query = &models.GetUserByAuthInfoQuery{AuthModule: "test", AuthId: "test"}
|
||||
user, err = srv.LookupAndUpdate(query)
|
||||
user, err = srv.LookupAndUpdate(context.Background(), query)
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, user.Login, "loginuser1")
|
||||
@ -120,7 +120,7 @@ func TestUserAuth(t *testing.T) {
|
||||
|
||||
// get via user_auth for deleted user
|
||||
query = &models.GetUserByAuthInfoQuery{AuthModule: "test", AuthId: "test"}
|
||||
user, err = srv.LookupAndUpdate(query)
|
||||
user, err = srv.LookupAndUpdate(context.Background(), query)
|
||||
|
||||
require.Equal(t, err, models.ErrUserNotFound)
|
||||
require.Nil(t, user)
|
||||
@ -139,7 +139,7 @@ func TestUserAuth(t *testing.T) {
|
||||
|
||||
// Calling GetUserByAuthInfoQuery on an existing user will populate an entry in the user_auth table
|
||||
query := &models.GetUserByAuthInfoQuery{Login: login, AuthModule: "test", AuthId: "test"}
|
||||
user, err := srv.LookupAndUpdate(query)
|
||||
user, err := srv.LookupAndUpdate(context.Background(), query)
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, user.Login, login)
|
||||
@ -150,7 +150,7 @@ func TestUserAuth(t *testing.T) {
|
||||
AuthModule: query.AuthModule,
|
||||
OAuthToken: token,
|
||||
}
|
||||
err = srv.UpdateAuthInfo(cmd)
|
||||
err = srv.UpdateAuthInfo(context.Background(), cmd)
|
||||
|
||||
require.Nil(t, err)
|
||||
|
||||
@ -158,7 +158,7 @@ func TestUserAuth(t *testing.T) {
|
||||
UserId: user.Id,
|
||||
}
|
||||
|
||||
err = srv.GetAuthInfo(getAuthQuery)
|
||||
err = srv.GetAuthInfo(context.Background(), getAuthQuery)
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, getAuthQuery.Result.OAuthAccessToken, token.AccessToken)
|
||||
@ -187,7 +187,7 @@ func TestUserAuth(t *testing.T) {
|
||||
// Make the first log-in during the past
|
||||
getTime = func() time.Time { return time.Now().AddDate(0, 0, -2) }
|
||||
query := &models.GetUserByAuthInfoQuery{Login: login, AuthModule: "test1", AuthId: "test1"}
|
||||
user, err := srv.LookupAndUpdate(query)
|
||||
user, err := srv.LookupAndUpdate(context.Background(), query)
|
||||
getTime = time.Now
|
||||
|
||||
require.Nil(t, err)
|
||||
@ -197,7 +197,7 @@ func TestUserAuth(t *testing.T) {
|
||||
// Have this module's last log-in be more recent
|
||||
getTime = func() time.Time { return time.Now().AddDate(0, 0, -1) }
|
||||
query = &models.GetUserByAuthInfoQuery{Login: login, AuthModule: "test2", AuthId: "test2"}
|
||||
user, err = srv.LookupAndUpdate(query)
|
||||
user, err = srv.LookupAndUpdate(context.Background(), query)
|
||||
getTime = time.Now
|
||||
|
||||
require.Nil(t, err)
|
||||
@ -208,14 +208,14 @@ func TestUserAuth(t *testing.T) {
|
||||
UserId: user.Id,
|
||||
}
|
||||
|
||||
err = srv.GetAuthInfo(getAuthQuery)
|
||||
err = srv.GetAuthInfo(context.Background(), getAuthQuery)
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, getAuthQuery.Result.AuthModule, "test2")
|
||||
|
||||
// "log in" again with the first auth module
|
||||
updateAuthCmd := &models.UpdateAuthInfoCommand{UserId: user.Id, AuthModule: "test1", AuthId: "test1"}
|
||||
err = srv.UpdateAuthInfo(updateAuthCmd)
|
||||
err = srv.UpdateAuthInfo(context.Background(), updateAuthCmd)
|
||||
|
||||
require.Nil(t, err)
|
||||
|
||||
@ -224,7 +224,7 @@ func TestUserAuth(t *testing.T) {
|
||||
UserId: user.Id,
|
||||
}
|
||||
|
||||
err = srv.GetAuthInfo(getAuthQuery)
|
||||
err = srv.GetAuthInfo(context.Background(), getAuthQuery)
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, getAuthQuery.Result.AuthModule, "test1")
|
||||
@ -237,7 +237,7 @@ func TestUserAuth(t *testing.T) {
|
||||
// Expect to pass since there's a matching login user
|
||||
getTime = func() time.Time { return time.Now().AddDate(0, 0, -2) }
|
||||
query := &models.GetUserByAuthInfoQuery{Login: login, AuthModule: genericOAuthModule, AuthId: ""}
|
||||
user, err := srv.LookupAndUpdate(query)
|
||||
user, err := srv.LookupAndUpdate(context.Background(), query)
|
||||
getTime = time.Now
|
||||
|
||||
require.Nil(t, err)
|
||||
@ -246,7 +246,7 @@ func TestUserAuth(t *testing.T) {
|
||||
// Should throw a "user not found" error since there's no matching login user
|
||||
getTime = func() time.Time { return time.Now().AddDate(0, 0, -2) }
|
||||
query = &models.GetUserByAuthInfoQuery{Login: "aloginuser", AuthModule: genericOAuthModule, AuthId: ""}
|
||||
user, err = srv.LookupAndUpdate(query)
|
||||
user, err = srv.LookupAndUpdate(context.Background(), query)
|
||||
getTime = time.Now
|
||||
|
||||
require.NotNil(t, err)
|
||||
|
@ -1,6 +1,7 @@
|
||||
package login
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/grafana/grafana/pkg/models"
|
||||
@ -16,6 +17,6 @@ type TeamSyncFunc func(user *models.User, externalUser *models.ExternalUserInfo)
|
||||
|
||||
type Service interface {
|
||||
CreateUser(cmd models.CreateUserCommand) (*models.User, error)
|
||||
UpsertUser(cmd *models.UpsertUserCommand) error
|
||||
UpsertUser(ctx context.Context, cmd *models.UpsertUserCommand) error
|
||||
SetTeamSyncFunc(TeamSyncFunc)
|
||||
}
|
||||
|
@ -23,7 +23,7 @@ func ProvideService(sqlStore *sqlstore.SQLStore, bus bus.Bus, quotaService *quot
|
||||
QuotaService: quotaService,
|
||||
AuthInfoService: authInfoService,
|
||||
}
|
||||
bus.AddHandler(s.UpsertUser)
|
||||
bus.AddHandlerCtx(s.UpsertUser)
|
||||
return s
|
||||
}
|
||||
|
||||
@ -41,10 +41,10 @@ func (ls *Implementation) CreateUser(cmd models.CreateUserCommand) (*models.User
|
||||
}
|
||||
|
||||
// UpsertUser updates an existing user, or if it doesn't exist, inserts a new one.
|
||||
func (ls *Implementation) UpsertUser(cmd *models.UpsertUserCommand) error {
|
||||
func (ls *Implementation) UpsertUser(ctx context.Context, cmd *models.UpsertUserCommand) error {
|
||||
extUser := cmd.ExternalUser
|
||||
|
||||
user, err := ls.AuthInfoService.LookupAndUpdate(&models.GetUserByAuthInfoQuery{
|
||||
user, err := ls.AuthInfoService.LookupAndUpdate(ctx, &models.GetUserByAuthInfoQuery{
|
||||
AuthModule: extUser.AuthModule,
|
||||
AuthId: extUser.AuthId,
|
||||
UserId: extUser.UserId,
|
||||
@ -81,21 +81,21 @@ func (ls *Implementation) UpsertUser(cmd *models.UpsertUserCommand) error {
|
||||
AuthId: extUser.AuthId,
|
||||
OAuthToken: extUser.OAuthToken,
|
||||
}
|
||||
if err := ls.Bus.Dispatch(cmd2); err != nil {
|
||||
if err := ls.Bus.DispatchCtx(ctx, cmd2); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
cmd.Result = user
|
||||
|
||||
err = updateUser(cmd.Result, extUser)
|
||||
err = updateUser(ctx, cmd.Result, extUser)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Always persist the latest token at log-in
|
||||
if extUser.AuthModule != "" && extUser.OAuthToken != nil {
|
||||
err = updateUserAuth(cmd.Result, extUser)
|
||||
err = updateUserAuth(ctx, cmd.Result, extUser)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -103,13 +103,13 @@ func (ls *Implementation) UpsertUser(cmd *models.UpsertUserCommand) error {
|
||||
|
||||
if extUser.AuthModule == models.AuthModuleLDAP && user.IsDisabled {
|
||||
// Re-enable user when it found in LDAP
|
||||
if err := ls.Bus.Dispatch(&models.DisableUserCommand{UserId: cmd.Result.Id, IsDisabled: false}); err != nil {
|
||||
if err := ls.Bus.DispatchCtx(ctx, &models.DisableUserCommand{UserId: cmd.Result.Id, IsDisabled: false}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := syncOrgRoles(cmd.Result, extUser); err != nil {
|
||||
if err := syncOrgRoles(ctx, cmd.Result, extUser); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -146,7 +146,7 @@ func (ls *Implementation) createUser(extUser *models.ExternalUserInfo) (*models.
|
||||
return ls.CreateUser(cmd)
|
||||
}
|
||||
|
||||
func updateUser(user *models.User, extUser *models.ExternalUserInfo) error {
|
||||
func updateUser(ctx context.Context, user *models.User, extUser *models.ExternalUserInfo) error {
|
||||
// sync user info
|
||||
updateCmd := &models.UpdateUserCommand{
|
||||
UserId: user.Id,
|
||||
@ -176,10 +176,10 @@ func updateUser(user *models.User, extUser *models.ExternalUserInfo) error {
|
||||
}
|
||||
|
||||
logger.Debug("Syncing user info", "id", user.Id, "update", updateCmd)
|
||||
return bus.Dispatch(updateCmd)
|
||||
return bus.DispatchCtx(ctx, updateCmd)
|
||||
}
|
||||
|
||||
func updateUserAuth(user *models.User, extUser *models.ExternalUserInfo) error {
|
||||
func updateUserAuth(ctx context.Context, user *models.User, extUser *models.ExternalUserInfo) error {
|
||||
updateCmd := &models.UpdateAuthInfoCommand{
|
||||
AuthModule: extUser.AuthModule,
|
||||
AuthId: extUser.AuthId,
|
||||
@ -188,10 +188,10 @@ func updateUserAuth(user *models.User, extUser *models.ExternalUserInfo) error {
|
||||
}
|
||||
|
||||
logger.Debug("Updating user_auth info", "user_id", user.Id)
|
||||
return bus.Dispatch(updateCmd)
|
||||
return bus.DispatchCtx(ctx, updateCmd)
|
||||
}
|
||||
|
||||
func syncOrgRoles(user *models.User, extUser *models.ExternalUserInfo) error {
|
||||
func syncOrgRoles(ctx context.Context, user *models.User, extUser *models.ExternalUserInfo) error {
|
||||
logger.Debug("Syncing organization roles", "id", user.Id, "extOrgRoles", extUser.OrgRoles)
|
||||
|
||||
// don't sync org roles if none is specified
|
||||
@ -201,7 +201,7 @@ func syncOrgRoles(user *models.User, extUser *models.ExternalUserInfo) error {
|
||||
}
|
||||
|
||||
orgsQuery := &models.GetUserOrgListQuery{UserId: user.Id}
|
||||
if err := bus.Dispatch(orgsQuery); err != nil {
|
||||
if err := bus.DispatchCtx(ctx, orgsQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -218,7 +218,7 @@ func syncOrgRoles(user *models.User, extUser *models.ExternalUserInfo) error {
|
||||
} else if extRole != org.Role {
|
||||
// update role
|
||||
cmd := &models.UpdateOrgUserCommand{OrgId: org.OrgId, UserId: user.Id, Role: extRole}
|
||||
if err := bus.Dispatch(cmd); err != nil {
|
||||
if err := bus.DispatchCtx(ctx, cmd); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@ -232,7 +232,7 @@ func syncOrgRoles(user *models.User, extUser *models.ExternalUserInfo) error {
|
||||
|
||||
// add role
|
||||
cmd := &models.AddOrgUserCommand{UserId: user.Id, Role: orgRole, OrgId: orgId}
|
||||
err := bus.Dispatch(cmd)
|
||||
err := bus.DispatchCtx(ctx, cmd)
|
||||
if err != nil && !errors.Is(err, models.ErrOrgNotFound) {
|
||||
return err
|
||||
}
|
||||
@ -243,7 +243,7 @@ func syncOrgRoles(user *models.User, extUser *models.ExternalUserInfo) error {
|
||||
logger.Debug("Removing user's organization membership as part of syncing with OAuth login",
|
||||
"userId", user.Id, "orgId", orgId)
|
||||
cmd := &models.RemoveOrgUserCommand{OrgId: orgId, UserId: user.Id}
|
||||
if err := bus.Dispatch(cmd); err != nil {
|
||||
if err := bus.DispatchCtx(ctx, cmd); err != nil {
|
||||
if errors.Is(err, models.ErrLastOrgAdmin) {
|
||||
logger.Error(err.Error(), "userId", cmd.UserId, "orgId", cmd.OrgId)
|
||||
continue
|
||||
@ -260,7 +260,7 @@ func syncOrgRoles(user *models.User, extUser *models.ExternalUserInfo) error {
|
||||
break
|
||||
}
|
||||
|
||||
return bus.Dispatch(&models.SetUsingOrgCommand{
|
||||
return bus.DispatchCtx(ctx, &models.SetUsingOrgCommand{
|
||||
UserId: user.Id,
|
||||
OrgId: user.OrgId,
|
||||
})
|
||||
|
@ -1,6 +1,7 @@
|
||||
package loginservice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
@ -36,7 +37,7 @@ func Test_syncOrgRoles_doesNotBreakWhenTryingToRemoveLastOrgAdmin(t *testing.T)
|
||||
return nil
|
||||
})
|
||||
|
||||
err := syncOrgRoles(&user, &externalUser)
|
||||
err := syncOrgRoles(context.Background(), &user, &externalUser)
|
||||
require.Empty(t, remResp)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
@ -71,7 +72,7 @@ func Test_syncOrgRoles_whenTryingToRemoveLastOrgLogsError(t *testing.T) {
|
||||
return nil
|
||||
})
|
||||
|
||||
err := syncOrgRoles(&user, &externalUser)
|
||||
err := syncOrgRoles(context.Background(), &user, &externalUser)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, logs, models.ErrLastOrgAdmin.Error())
|
||||
}
|
||||
@ -81,7 +82,7 @@ type authInfoServiceMock struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (a *authInfoServiceMock) LookupAndUpdate(query *models.GetUserByAuthInfoQuery) (*models.User, error) {
|
||||
func (a *authInfoServiceMock) LookupAndUpdate(ctx context.Context, query *models.GetUserByAuthInfoQuery) (*models.User, error) {
|
||||
return a.user, a.err
|
||||
}
|
||||
|
||||
@ -109,7 +110,7 @@ func Test_teamSync(t *testing.T) {
|
||||
var actualExternalUser *models.ExternalUserInfo
|
||||
|
||||
t.Run("login.TeamSync should not be called when nil", func(t *testing.T) {
|
||||
err := login.UpsertUser(upserCmd)
|
||||
err := login.UpsertUser(context.Background(), upserCmd)
|
||||
require.Nil(t, err)
|
||||
assert.Nil(t, actualUser)
|
||||
assert.Nil(t, actualExternalUser)
|
||||
@ -121,7 +122,7 @@ func Test_teamSync(t *testing.T) {
|
||||
return nil
|
||||
}
|
||||
login.TeamSync = teamSyncFunc
|
||||
err := login.UpsertUser(upserCmd)
|
||||
err := login.UpsertUser(context.Background(), upserCmd)
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, actualUser, expectedUser)
|
||||
assert.Equal(t, actualExternalUser, upserCmd.ExternalUser)
|
||||
@ -132,7 +133,7 @@ func Test_teamSync(t *testing.T) {
|
||||
return errors.New("teamsync test error")
|
||||
}
|
||||
login.TeamSync = teamSyncFunc
|
||||
err := login.UpsertUser(upserCmd)
|
||||
err := login.UpsertUser(context.Background(), upserCmd)
|
||||
require.Error(t, err)
|
||||
})
|
||||
})
|
||||
|
@ -83,7 +83,7 @@ func (o *Service) GetCurrentOAuthToken(ctx context.Context, user *models.SignedI
|
||||
AuthId: authInfoQuery.Result.AuthId,
|
||||
OAuthToken: token,
|
||||
}
|
||||
if err := bus.Dispatch(updateAuthCommand); err != nil {
|
||||
if err := bus.DispatchCtx(ctx, updateAuthCommand); err != nil {
|
||||
logger.Error("failed to update auth info during token refresh", "userId", user.UserId, "username", user.Login, "error", err)
|
||||
return nil
|
||||
}
|
||||
|
Reference in New Issue
Block a user