Chore: Add context to authinfo (#42096)

* Add context to authinfo

* Replace Dispatch with DispatchCtx
This commit is contained in:
idafurjes
2021-11-25 14:22:40 +01:00
committed by GitHub
parent d0c9564e1a
commit ac6867c3bb
13 changed files with 88 additions and 83 deletions

View File

@ -1,7 +1,6 @@
package api package api
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
@ -80,7 +79,7 @@ func AdminUpdateUserPassword(c *models.ReqContext, form dtos.AdminUpdateUserPass
NewPassword: passwordHashed, NewPassword: passwordHashed,
} }
if err := bus.Dispatch(&cmd); err != nil { if err := bus.DispatchCtx(c.Req.Context(), &cmd); err != nil {
return response.Error(500, "Failed to update user password", err) return response.Error(500, "Failed to update user password", err)
} }
@ -108,7 +107,7 @@ func AdminDeleteUser(c *models.ReqContext) response.Response {
cmd := models.DeleteUserCommand{UserId: userID} cmd := models.DeleteUserCommand{UserId: userID}
if err := bus.Dispatch(&cmd); err != nil { if err := bus.DispatchCtx(c.Req.Context(), &cmd); err != nil {
if errors.Is(err, models.ErrUserNotFound) { if errors.Is(err, models.ErrUserNotFound) {
return response.Error(404, models.ErrUserNotFound.Error(), nil) return response.Error(404, models.ErrUserNotFound.Error(), nil)
} }
@ -124,12 +123,12 @@ func (hs *HTTPServer) AdminDisableUser(c *models.ReqContext) response.Response {
// External users shouldn't be disabled from API // External users shouldn't be disabled from API
authInfoQuery := &models.GetAuthInfoQuery{UserId: userID} authInfoQuery := &models.GetAuthInfoQuery{UserId: userID}
if err := bus.DispatchCtx(context.TODO(), authInfoQuery); !errors.Is(err, models.ErrUserNotFound) { if err := bus.DispatchCtx(c.Req.Context(), authInfoQuery); !errors.Is(err, models.ErrUserNotFound) {
return response.Error(500, "Could not disable external user", nil) return response.Error(500, "Could not disable external user", nil)
} }
disableCmd := models.DisableUserCommand{UserId: userID, IsDisabled: true} disableCmd := models.DisableUserCommand{UserId: userID, IsDisabled: true}
if err := bus.Dispatch(&disableCmd); err != nil { if err := bus.DispatchCtx(c.Req.Context(), &disableCmd); err != nil {
if errors.Is(err, models.ErrUserNotFound) { if errors.Is(err, models.ErrUserNotFound) {
return response.Error(404, models.ErrUserNotFound.Error(), nil) return response.Error(404, models.ErrUserNotFound.Error(), nil)
} }
@ -150,12 +149,12 @@ func AdminEnableUser(c *models.ReqContext) response.Response {
// External users shouldn't be disabled from API // External users shouldn't be disabled from API
authInfoQuery := &models.GetAuthInfoQuery{UserId: userID} authInfoQuery := &models.GetAuthInfoQuery{UserId: userID}
if err := bus.DispatchCtx(context.TODO(), authInfoQuery); !errors.Is(err, models.ErrUserNotFound) { if err := bus.DispatchCtx(c.Req.Context(), authInfoQuery); !errors.Is(err, models.ErrUserNotFound) {
return response.Error(500, "Could not enable external user", nil) return response.Error(500, "Could not enable external user", nil)
} }
disableCmd := models.DisableUserCommand{UserId: userID, IsDisabled: false} disableCmd := models.DisableUserCommand{UserId: userID, IsDisabled: false}
if err := bus.Dispatch(&disableCmd); err != nil { if err := bus.DispatchCtx(c.Req.Context(), &disableCmd); err != nil {
if errors.Is(err, models.ErrUserNotFound) { if errors.Is(err, models.ErrUserNotFound) {
return response.Error(404, models.ErrUserNotFound.Error(), nil) return response.Error(404, models.ErrUserNotFound.Error(), nil)
} }

View File

@ -63,7 +63,7 @@ type LDAPServerDTO struct {
} }
// FetchOrgs fetches the organization(s) information by executing a single query to the database. Then, populating the DTO with the information retrieved. // FetchOrgs fetches the organization(s) information by executing a single query to the database. Then, populating the DTO with the information retrieved.
func (user *LDAPUserDTO) FetchOrgs() error { func (user *LDAPUserDTO) FetchOrgs(ctx context.Context) error {
orgIds := []int64{} orgIds := []int64{}
for _, or := range user.OrgRoles { for _, or := range user.OrgRoles {
@ -73,7 +73,7 @@ func (user *LDAPUserDTO) FetchOrgs() error {
q := &models.SearchOrgsQuery{} q := &models.SearchOrgsQuery{}
q.Ids = orgIds q.Ids = orgIds
if err := bus.Dispatch(q); err != nil { if err := bus.DispatchCtx(ctx, q); err != nil {
return err return err
} }
@ -196,7 +196,7 @@ func (hs *HTTPServer) PostSyncUserWithLDAP(c *models.ReqContext) response.Respon
} }
// Since the user was not in the LDAP server. Let's disable it. // Since the user was not in the LDAP server. Let's disable it.
err := login.DisableExternalUser(query.Result.Login) err := login.DisableExternalUser(c.Req.Context(), query.Result.Login)
if err != nil { if err != nil {
return response.Error(http.StatusInternalServerError, "Failed to disable the user", err) return response.Error(http.StatusInternalServerError, "Failed to disable the user", err)
} }
@ -219,7 +219,7 @@ func (hs *HTTPServer) PostSyncUserWithLDAP(c *models.ReqContext) response.Respon
SignupAllowed: hs.Cfg.LDAPAllowSignup, SignupAllowed: hs.Cfg.LDAPAllowSignup,
} }
err = bus.Dispatch(upsertCmd) err = bus.DispatchCtx(c.Req.Context(), upsertCmd)
if err != nil { if err != nil {
return response.Error(http.StatusInternalServerError, "Failed to update the user", err) return response.Error(http.StatusInternalServerError, "Failed to update the user", err)
} }
@ -302,13 +302,13 @@ func (hs *HTTPServer) GetUserFromLDAP(c *models.ReqContext) response.Response {
u.OrgRoles = orgRoles u.OrgRoles = orgRoles
ldapLogger.Debug("mapping org roles", "orgsRoles", u.OrgRoles) ldapLogger.Debug("mapping org roles", "orgsRoles", u.OrgRoles)
err = u.FetchOrgs() err = u.FetchOrgs(c.Req.Context())
if err != nil { if err != nil {
return response.Error(http.StatusBadRequest, "An organization was not found - Please verify your LDAP configuration", err) return response.Error(http.StatusBadRequest, "An organization was not found - Please verify your LDAP configuration", err)
} }
cmd := &models.GetTeamsForLDAPGroupCommand{Groups: user.Groups} cmd := &models.GetTeamsForLDAPGroupCommand{Groups: user.Groups}
err = bus.Dispatch(cmd) err = bus.DispatchCtx(c.Req.Context(), cmd)
if err != nil && !errors.Is(err, bus.ErrHandlerNotFound) { if err != nil && !errors.Is(err, bus.ErrHandlerNotFound) {
return response.Error(http.StatusBadRequest, "Unable to find the teams for this user", err) return response.Error(http.StatusBadRequest, "Unable to find the teams for this user", err)
} }

View File

@ -311,7 +311,7 @@ func syncUser(
ExternalUser: extUser, ExternalUser: extUser,
SignupAllowed: connect.IsSignupAllowed(), SignupAllowed: connect.IsSignupAllowed(),
} }
if err := bus.Dispatch(cmd); err != nil { if err := bus.DispatchCtx(ctx.Req.Context(), cmd); err != nil {
return nil, err return nil, err
} }

View File

@ -43,7 +43,7 @@ var loginUsingLDAP = func(ctx context.Context, query *models.LoginUserQuery) (bo
if err != nil { if err != nil {
if errors.Is(err, ldap.ErrCouldNotFindUser) { if errors.Is(err, ldap.ErrCouldNotFindUser) {
// Ignore the error since user might not be present anyway // Ignore the error since user might not be present anyway
if err := DisableExternalUser(query.Username); err != nil { if err := DisableExternalUser(ctx, query.Username); err != nil {
ldapLogger.Debug("Failed to disable external user", "err", err) ldapLogger.Debug("Failed to disable external user", "err", err)
} }
@ -68,13 +68,13 @@ var loginUsingLDAP = func(ctx context.Context, query *models.LoginUserQuery) (bo
} }
// DisableExternalUser marks external user as disabled in Grafana db // DisableExternalUser marks external user as disabled in Grafana db
func DisableExternalUser(username string) error { func DisableExternalUser(ctx context.Context, username string) error {
// Check if external user exist in Grafana // Check if external user exist in Grafana
userQuery := &models.GetExternalUserInfoByLoginQuery{ userQuery := &models.GetExternalUserInfoByLoginQuery{
LoginOrEmail: username, LoginOrEmail: username,
} }
if err := bus.Dispatch(userQuery); err != nil { if err := bus.DispatchCtx(ctx, userQuery); err != nil {
return err return err
} }
@ -92,7 +92,7 @@ func DisableExternalUser(username string) error {
IsDisabled: true, IsDisabled: true,
} }
if err := bus.Dispatch(disableUserCmd); err != nil { if err := bus.DispatchCtx(ctx, disableUserCmd); err != nil {
ldapLogger.Debug( ldapLogger.Debug(
"Error disabling external user", "Error disabling external user",
"user", "user",

View File

@ -248,7 +248,7 @@ func (auth *AuthProxy) LoginViaLDAP() (int64, error) {
SignupAllowed: auth.cfg.LDAPAllowSignup, SignupAllowed: auth.cfg.LDAPAllowSignup,
ExternalUser: extUser, ExternalUser: extUser,
} }
if err := bus.Dispatch(upsert); err != nil { if err := bus.DispatchCtx(auth.ctx.Req.Context(), upsert); err != nil {
return 0, err return 0, err
} }
@ -305,7 +305,7 @@ func (auth *AuthProxy) LoginViaHeader() (int64, error) {
ExternalUser: extUser, ExternalUser: extUser,
} }
err := bus.Dispatch(upsert) err := bus.DispatchCtx(auth.ctx.Req.Context(), upsert)
if err != nil { if err != nil {
return 0, err return 0, err
} }

View File

@ -1,7 +1,11 @@
package login package login
import "github.com/grafana/grafana/pkg/models" import (
"context"
"github.com/grafana/grafana/pkg/models"
)
type AuthInfoService interface { type AuthInfoService interface {
LookupAndUpdate(query *models.GetUserByAuthInfoQuery) (*models.User, error) LookupAndUpdate(ctx context.Context, query *models.GetUserByAuthInfoQuery) (*models.User, error)
} }

View File

@ -37,7 +37,7 @@ func (s *Implementation) GetExternalUserInfoByLogin(ctx context.Context, query *
return nil return nil
} }
func (s *Implementation) GetAuthInfo(query *models.GetAuthInfoQuery) error { func (s *Implementation) GetAuthInfo(ctx context.Context, query *models.GetAuthInfoQuery) error {
userAuth := &models.UserAuth{ userAuth := &models.UserAuth{
UserId: query.UserId, UserId: query.UserId,
AuthModule: query.AuthModule, AuthModule: query.AuthModule,
@ -79,7 +79,7 @@ func (s *Implementation) GetAuthInfo(query *models.GetAuthInfoQuery) error {
return nil return nil
} }
func (s *Implementation) SetAuthInfo(cmd *models.SetAuthInfoCommand) error { func (s *Implementation) SetAuthInfo(ctx context.Context, cmd *models.SetAuthInfoCommand) error {
authUser := &models.UserAuth{ authUser := &models.UserAuth{
UserId: cmd.UserId, UserId: cmd.UserId,
AuthModule: cmd.AuthModule, AuthModule: cmd.AuthModule,
@ -113,7 +113,7 @@ func (s *Implementation) SetAuthInfo(cmd *models.SetAuthInfoCommand) error {
}) })
} }
func (s *Implementation) UpdateAuthInfo(cmd *models.UpdateAuthInfoCommand) error { func (s *Implementation) UpdateAuthInfo(ctx context.Context, cmd *models.UpdateAuthInfoCommand) error {
authUser := &models.UserAuth{ authUser := &models.UserAuth{
UserId: cmd.UserId, UserId: cmd.UserId,
AuthModule: cmd.AuthModule, AuthModule: cmd.AuthModule,
@ -153,7 +153,7 @@ func (s *Implementation) UpdateAuthInfo(cmd *models.UpdateAuthInfoCommand) error
}) })
} }
func (s *Implementation) DeleteAuthInfo(cmd *models.DeleteAuthInfoCommand) error { func (s *Implementation) DeleteAuthInfo(ctx context.Context, cmd *models.DeleteAuthInfoCommand) error {
return s.SQLStore.WithTransactionalDbSession(context.Background(), func(sess *sqlstore.DBSession) error { return s.SQLStore.WithTransactionalDbSession(context.Background(), func(sess *sqlstore.DBSession) error {
_, err := sess.Delete(cmd.UserAuth) _, err := sess.Delete(cmd.UserAuth)
return err return err

View File

@ -33,10 +33,10 @@ func ProvideAuthInfoService(bus bus.Bus, store *sqlstore.SQLStore, userProtectio
} }
s.Bus.AddHandlerCtx(s.GetExternalUserInfoByLogin) s.Bus.AddHandlerCtx(s.GetExternalUserInfoByLogin)
s.Bus.AddHandler(s.GetAuthInfo) s.Bus.AddHandlerCtx(s.GetAuthInfo)
s.Bus.AddHandler(s.SetAuthInfo) s.Bus.AddHandlerCtx(s.SetAuthInfo)
s.Bus.AddHandler(s.UpdateAuthInfo) s.Bus.AddHandlerCtx(s.UpdateAuthInfo)
s.Bus.AddHandler(s.DeleteAuthInfo) s.Bus.AddHandlerCtx(s.DeleteAuthInfo)
return s return s
} }
@ -70,7 +70,7 @@ func (s *Implementation) getUser(user *models.User) (bool, error) {
return has, err return has, err
} }
func (s *Implementation) LookupAndFix(query *models.GetUserByAuthInfoQuery) (bool, *models.User, *models.UserAuth, error) { func (s *Implementation) LookupAndFix(ctx context.Context, query *models.GetUserByAuthInfoQuery) (bool, *models.User, *models.UserAuth, error) {
authQuery := &models.GetAuthInfoQuery{} authQuery := &models.GetAuthInfoQuery{}
// Try to find the user by auth module and id first // Try to find the user by auth module and id first
@ -78,7 +78,7 @@ func (s *Implementation) LookupAndFix(query *models.GetUserByAuthInfoQuery) (boo
authQuery.AuthModule = query.AuthModule authQuery.AuthModule = query.AuthModule
authQuery.AuthId = query.AuthId authQuery.AuthId = query.AuthId
err := s.GetAuthInfo(authQuery) err := s.GetAuthInfo(ctx, authQuery)
if !errors.Is(err, models.ErrUserNotFound) { if !errors.Is(err, models.ErrUserNotFound) {
if err != nil { if err != nil {
return false, nil, nil, err return false, nil, nil, err
@ -86,7 +86,7 @@ func (s *Implementation) LookupAndFix(query *models.GetUserByAuthInfoQuery) (boo
// if user id was specified and doesn't match the user_auth entry, remove it // if user id was specified and doesn't match the user_auth entry, remove it
if query.UserId != 0 && query.UserId != authQuery.Result.UserId { if query.UserId != 0 && query.UserId != authQuery.Result.UserId {
err := s.DeleteAuthInfo(&models.DeleteAuthInfoCommand{ err := s.DeleteAuthInfo(ctx, &models.DeleteAuthInfoCommand{
UserAuth: authQuery.Result, UserAuth: authQuery.Result,
}) })
if err != nil { if err != nil {
@ -102,7 +102,7 @@ func (s *Implementation) LookupAndFix(query *models.GetUserByAuthInfoQuery) (boo
if !has { if !has {
// if the user has been deleted then remove the entry // if the user has been deleted then remove the entry
err = s.DeleteAuthInfo(&models.DeleteAuthInfoCommand{ err = s.DeleteAuthInfo(ctx, &models.DeleteAuthInfoCommand{
UserAuth: authQuery.Result, UserAuth: authQuery.Result,
}) })
if err != nil { if err != nil {
@ -158,13 +158,13 @@ func (s *Implementation) LookupByOneOf(userId int64, email string, login string)
return foundUser, user, nil return foundUser, user, nil
} }
func (s *Implementation) GenericOAuthLookup(authModule string, authId string, userID int64) (*models.UserAuth, error) { func (s *Implementation) GenericOAuthLookup(ctx context.Context, authModule string, authId string, userID int64) (*models.UserAuth, error) {
if authModule == genericOAuthModule && userID != 0 { if authModule == genericOAuthModule && userID != 0 {
authQuery := &models.GetAuthInfoQuery{} authQuery := &models.GetAuthInfoQuery{}
authQuery.AuthModule = authModule authQuery.AuthModule = authModule
authQuery.AuthId = authId authQuery.AuthId = authId
authQuery.UserId = userID authQuery.UserId = userID
err := s.GetAuthInfo(authQuery) err := s.GetAuthInfo(ctx, authQuery)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -174,10 +174,10 @@ func (s *Implementation) GenericOAuthLookup(authModule string, authId string, us
return nil, nil return nil, nil
} }
func (s *Implementation) LookupAndUpdate(query *models.GetUserByAuthInfoQuery) (*models.User, error) { func (s *Implementation) LookupAndUpdate(ctx context.Context, query *models.GetUserByAuthInfoQuery) (*models.User, error) {
// 1. LookupAndFix = auth info, user, error // 1. LookupAndFix = auth info, user, error
// TODO: Not a big fan of the fact that we are deleting auth info here, might want to move that // TODO: Not a big fan of the fact that we are deleting auth info here, might want to move that
foundUser, user, authInfo, err := s.LookupAndFix(query) foundUser, user, authInfo, err := s.LookupAndFix(ctx, query)
if err != nil && !errors.Is(err, models.ErrUserNotFound) { if err != nil && !errors.Is(err, models.ErrUserNotFound) {
return nil, err return nil, err
} }
@ -195,7 +195,7 @@ func (s *Implementation) LookupAndUpdate(query *models.GetUserByAuthInfoQuery) (
} }
// Special case for generic oauth duplicates // Special case for generic oauth duplicates
ai, err := s.GenericOAuthLookup(query.AuthModule, query.AuthId, user.Id) ai, err := s.GenericOAuthLookup(ctx, query.AuthModule, query.AuthId, user.Id)
if !errors.Is(err, models.ErrUserNotFound) { if !errors.Is(err, models.ErrUserNotFound) {
if err != nil { if err != nil {
return nil, err return nil, err
@ -211,7 +211,7 @@ func (s *Implementation) LookupAndUpdate(query *models.GetUserByAuthInfoQuery) (
AuthModule: query.AuthModule, AuthModule: query.AuthModule,
AuthId: query.AuthId, AuthId: query.AuthId,
} }
if err := s.SetAuthInfo(cmd); err != nil { if err := s.SetAuthInfo(ctx, cmd); err != nil {
return nil, err return nil, err
} }
} }

View File

@ -40,7 +40,7 @@ func TestUserAuth(t *testing.T) {
login := "loginuser0" login := "loginuser0"
query := &models.GetUserByAuthInfoQuery{Login: login} query := &models.GetUserByAuthInfoQuery{Login: login}
user, err := srv.LookupAndUpdate(query) user, err := srv.LookupAndUpdate(context.Background(), query)
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, user.Login, login) require.Equal(t, user.Login, login)
@ -73,7 +73,7 @@ func TestUserAuth(t *testing.T) {
t.Run("Can set & locate by AuthModule and AuthId", func(t *testing.T) { t.Run("Can set & locate by AuthModule and AuthId", func(t *testing.T) {
// get nonexistent user_auth entry // get nonexistent user_auth entry
query := &models.GetUserByAuthInfoQuery{AuthModule: "test", AuthId: "test"} query := &models.GetUserByAuthInfoQuery{AuthModule: "test", AuthId: "test"}
user, err := srv.LookupAndUpdate(query) user, err := srv.LookupAndUpdate(context.Background(), query)
require.Equal(t, models.ErrUserNotFound, err) require.Equal(t, models.ErrUserNotFound, err)
require.Nil(t, user) require.Nil(t, user)
@ -82,14 +82,14 @@ func TestUserAuth(t *testing.T) {
login := "loginuser0" login := "loginuser0"
query.Login = login query.Login = login
user, err = srv.LookupAndUpdate(query) user, err = srv.LookupAndUpdate(context.Background(), query)
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, user.Login, login) require.Equal(t, user.Login, login)
// get via user_auth // get via user_auth
query = &models.GetUserByAuthInfoQuery{AuthModule: "test", AuthId: "test"} query = &models.GetUserByAuthInfoQuery{AuthModule: "test", AuthId: "test"}
user, err = srv.LookupAndUpdate(query) user, err = srv.LookupAndUpdate(context.Background(), query)
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, user.Login, login) require.Equal(t, user.Login, login)
@ -98,14 +98,14 @@ func TestUserAuth(t *testing.T) {
id := user.Id id := user.Id
query.UserId = id + 1 query.UserId = id + 1
user, err = srv.LookupAndUpdate(query) user, err = srv.LookupAndUpdate(context.Background(), query)
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, user.Login, "loginuser1") require.Equal(t, user.Login, "loginuser1")
// get via user_auth // get via user_auth
query = &models.GetUserByAuthInfoQuery{AuthModule: "test", AuthId: "test"} query = &models.GetUserByAuthInfoQuery{AuthModule: "test", AuthId: "test"}
user, err = srv.LookupAndUpdate(query) user, err = srv.LookupAndUpdate(context.Background(), query)
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, user.Login, "loginuser1") require.Equal(t, user.Login, "loginuser1")
@ -120,7 +120,7 @@ func TestUserAuth(t *testing.T) {
// get via user_auth for deleted user // get via user_auth for deleted user
query = &models.GetUserByAuthInfoQuery{AuthModule: "test", AuthId: "test"} query = &models.GetUserByAuthInfoQuery{AuthModule: "test", AuthId: "test"}
user, err = srv.LookupAndUpdate(query) user, err = srv.LookupAndUpdate(context.Background(), query)
require.Equal(t, err, models.ErrUserNotFound) require.Equal(t, err, models.ErrUserNotFound)
require.Nil(t, user) require.Nil(t, user)
@ -139,7 +139,7 @@ func TestUserAuth(t *testing.T) {
// Calling GetUserByAuthInfoQuery on an existing user will populate an entry in the user_auth table // Calling GetUserByAuthInfoQuery on an existing user will populate an entry in the user_auth table
query := &models.GetUserByAuthInfoQuery{Login: login, AuthModule: "test", AuthId: "test"} query := &models.GetUserByAuthInfoQuery{Login: login, AuthModule: "test", AuthId: "test"}
user, err := srv.LookupAndUpdate(query) user, err := srv.LookupAndUpdate(context.Background(), query)
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, user.Login, login) require.Equal(t, user.Login, login)
@ -150,7 +150,7 @@ func TestUserAuth(t *testing.T) {
AuthModule: query.AuthModule, AuthModule: query.AuthModule,
OAuthToken: token, OAuthToken: token,
} }
err = srv.UpdateAuthInfo(cmd) err = srv.UpdateAuthInfo(context.Background(), cmd)
require.Nil(t, err) require.Nil(t, err)
@ -158,7 +158,7 @@ func TestUserAuth(t *testing.T) {
UserId: user.Id, UserId: user.Id,
} }
err = srv.GetAuthInfo(getAuthQuery) err = srv.GetAuthInfo(context.Background(), getAuthQuery)
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, getAuthQuery.Result.OAuthAccessToken, token.AccessToken) require.Equal(t, getAuthQuery.Result.OAuthAccessToken, token.AccessToken)
@ -187,7 +187,7 @@ func TestUserAuth(t *testing.T) {
// Make the first log-in during the past // Make the first log-in during the past
getTime = func() time.Time { return time.Now().AddDate(0, 0, -2) } getTime = func() time.Time { return time.Now().AddDate(0, 0, -2) }
query := &models.GetUserByAuthInfoQuery{Login: login, AuthModule: "test1", AuthId: "test1"} query := &models.GetUserByAuthInfoQuery{Login: login, AuthModule: "test1", AuthId: "test1"}
user, err := srv.LookupAndUpdate(query) user, err := srv.LookupAndUpdate(context.Background(), query)
getTime = time.Now getTime = time.Now
require.Nil(t, err) require.Nil(t, err)
@ -197,7 +197,7 @@ func TestUserAuth(t *testing.T) {
// Have this module's last log-in be more recent // Have this module's last log-in be more recent
getTime = func() time.Time { return time.Now().AddDate(0, 0, -1) } getTime = func() time.Time { return time.Now().AddDate(0, 0, -1) }
query = &models.GetUserByAuthInfoQuery{Login: login, AuthModule: "test2", AuthId: "test2"} query = &models.GetUserByAuthInfoQuery{Login: login, AuthModule: "test2", AuthId: "test2"}
user, err = srv.LookupAndUpdate(query) user, err = srv.LookupAndUpdate(context.Background(), query)
getTime = time.Now getTime = time.Now
require.Nil(t, err) require.Nil(t, err)
@ -208,14 +208,14 @@ func TestUserAuth(t *testing.T) {
UserId: user.Id, UserId: user.Id,
} }
err = srv.GetAuthInfo(getAuthQuery) err = srv.GetAuthInfo(context.Background(), getAuthQuery)
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, getAuthQuery.Result.AuthModule, "test2") require.Equal(t, getAuthQuery.Result.AuthModule, "test2")
// "log in" again with the first auth module // "log in" again with the first auth module
updateAuthCmd := &models.UpdateAuthInfoCommand{UserId: user.Id, AuthModule: "test1", AuthId: "test1"} updateAuthCmd := &models.UpdateAuthInfoCommand{UserId: user.Id, AuthModule: "test1", AuthId: "test1"}
err = srv.UpdateAuthInfo(updateAuthCmd) err = srv.UpdateAuthInfo(context.Background(), updateAuthCmd)
require.Nil(t, err) require.Nil(t, err)
@ -224,7 +224,7 @@ func TestUserAuth(t *testing.T) {
UserId: user.Id, UserId: user.Id,
} }
err = srv.GetAuthInfo(getAuthQuery) err = srv.GetAuthInfo(context.Background(), getAuthQuery)
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, getAuthQuery.Result.AuthModule, "test1") require.Equal(t, getAuthQuery.Result.AuthModule, "test1")
@ -237,7 +237,7 @@ func TestUserAuth(t *testing.T) {
// Expect to pass since there's a matching login user // Expect to pass since there's a matching login user
getTime = func() time.Time { return time.Now().AddDate(0, 0, -2) } getTime = func() time.Time { return time.Now().AddDate(0, 0, -2) }
query := &models.GetUserByAuthInfoQuery{Login: login, AuthModule: genericOAuthModule, AuthId: ""} query := &models.GetUserByAuthInfoQuery{Login: login, AuthModule: genericOAuthModule, AuthId: ""}
user, err := srv.LookupAndUpdate(query) user, err := srv.LookupAndUpdate(context.Background(), query)
getTime = time.Now getTime = time.Now
require.Nil(t, err) require.Nil(t, err)
@ -246,7 +246,7 @@ func TestUserAuth(t *testing.T) {
// Should throw a "user not found" error since there's no matching login user // Should throw a "user not found" error since there's no matching login user
getTime = func() time.Time { return time.Now().AddDate(0, 0, -2) } getTime = func() time.Time { return time.Now().AddDate(0, 0, -2) }
query = &models.GetUserByAuthInfoQuery{Login: "aloginuser", AuthModule: genericOAuthModule, AuthId: ""} query = &models.GetUserByAuthInfoQuery{Login: "aloginuser", AuthModule: genericOAuthModule, AuthId: ""}
user, err = srv.LookupAndUpdate(query) user, err = srv.LookupAndUpdate(context.Background(), query)
getTime = time.Now getTime = time.Now
require.NotNil(t, err) require.NotNil(t, err)

View File

@ -1,6 +1,7 @@
package login package login
import ( import (
"context"
"errors" "errors"
"github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/models"
@ -16,6 +17,6 @@ type TeamSyncFunc func(user *models.User, externalUser *models.ExternalUserInfo)
type Service interface { type Service interface {
CreateUser(cmd models.CreateUserCommand) (*models.User, error) CreateUser(cmd models.CreateUserCommand) (*models.User, error)
UpsertUser(cmd *models.UpsertUserCommand) error UpsertUser(ctx context.Context, cmd *models.UpsertUserCommand) error
SetTeamSyncFunc(TeamSyncFunc) SetTeamSyncFunc(TeamSyncFunc)
} }

View File

@ -23,7 +23,7 @@ func ProvideService(sqlStore *sqlstore.SQLStore, bus bus.Bus, quotaService *quot
QuotaService: quotaService, QuotaService: quotaService,
AuthInfoService: authInfoService, AuthInfoService: authInfoService,
} }
bus.AddHandler(s.UpsertUser) bus.AddHandlerCtx(s.UpsertUser)
return s return s
} }
@ -41,10 +41,10 @@ func (ls *Implementation) CreateUser(cmd models.CreateUserCommand) (*models.User
} }
// UpsertUser updates an existing user, or if it doesn't exist, inserts a new one. // UpsertUser updates an existing user, or if it doesn't exist, inserts a new one.
func (ls *Implementation) UpsertUser(cmd *models.UpsertUserCommand) error { func (ls *Implementation) UpsertUser(ctx context.Context, cmd *models.UpsertUserCommand) error {
extUser := cmd.ExternalUser extUser := cmd.ExternalUser
user, err := ls.AuthInfoService.LookupAndUpdate(&models.GetUserByAuthInfoQuery{ user, err := ls.AuthInfoService.LookupAndUpdate(ctx, &models.GetUserByAuthInfoQuery{
AuthModule: extUser.AuthModule, AuthModule: extUser.AuthModule,
AuthId: extUser.AuthId, AuthId: extUser.AuthId,
UserId: extUser.UserId, UserId: extUser.UserId,
@ -81,21 +81,21 @@ func (ls *Implementation) UpsertUser(cmd *models.UpsertUserCommand) error {
AuthId: extUser.AuthId, AuthId: extUser.AuthId,
OAuthToken: extUser.OAuthToken, OAuthToken: extUser.OAuthToken,
} }
if err := ls.Bus.Dispatch(cmd2); err != nil { if err := ls.Bus.DispatchCtx(ctx, cmd2); err != nil {
return err return err
} }
} }
} else { } else {
cmd.Result = user cmd.Result = user
err = updateUser(cmd.Result, extUser) err = updateUser(ctx, cmd.Result, extUser)
if err != nil { if err != nil {
return err return err
} }
// Always persist the latest token at log-in // Always persist the latest token at log-in
if extUser.AuthModule != "" && extUser.OAuthToken != nil { if extUser.AuthModule != "" && extUser.OAuthToken != nil {
err = updateUserAuth(cmd.Result, extUser) err = updateUserAuth(ctx, cmd.Result, extUser)
if err != nil { if err != nil {
return err return err
} }
@ -103,13 +103,13 @@ func (ls *Implementation) UpsertUser(cmd *models.UpsertUserCommand) error {
if extUser.AuthModule == models.AuthModuleLDAP && user.IsDisabled { if extUser.AuthModule == models.AuthModuleLDAP && user.IsDisabled {
// Re-enable user when it found in LDAP // Re-enable user when it found in LDAP
if err := ls.Bus.Dispatch(&models.DisableUserCommand{UserId: cmd.Result.Id, IsDisabled: false}); err != nil { if err := ls.Bus.DispatchCtx(ctx, &models.DisableUserCommand{UserId: cmd.Result.Id, IsDisabled: false}); err != nil {
return err return err
} }
} }
} }
if err := syncOrgRoles(cmd.Result, extUser); err != nil { if err := syncOrgRoles(ctx, cmd.Result, extUser); err != nil {
return err return err
} }
@ -146,7 +146,7 @@ func (ls *Implementation) createUser(extUser *models.ExternalUserInfo) (*models.
return ls.CreateUser(cmd) return ls.CreateUser(cmd)
} }
func updateUser(user *models.User, extUser *models.ExternalUserInfo) error { func updateUser(ctx context.Context, user *models.User, extUser *models.ExternalUserInfo) error {
// sync user info // sync user info
updateCmd := &models.UpdateUserCommand{ updateCmd := &models.UpdateUserCommand{
UserId: user.Id, UserId: user.Id,
@ -176,10 +176,10 @@ func updateUser(user *models.User, extUser *models.ExternalUserInfo) error {
} }
logger.Debug("Syncing user info", "id", user.Id, "update", updateCmd) logger.Debug("Syncing user info", "id", user.Id, "update", updateCmd)
return bus.Dispatch(updateCmd) return bus.DispatchCtx(ctx, updateCmd)
} }
func updateUserAuth(user *models.User, extUser *models.ExternalUserInfo) error { func updateUserAuth(ctx context.Context, user *models.User, extUser *models.ExternalUserInfo) error {
updateCmd := &models.UpdateAuthInfoCommand{ updateCmd := &models.UpdateAuthInfoCommand{
AuthModule: extUser.AuthModule, AuthModule: extUser.AuthModule,
AuthId: extUser.AuthId, AuthId: extUser.AuthId,
@ -188,10 +188,10 @@ func updateUserAuth(user *models.User, extUser *models.ExternalUserInfo) error {
} }
logger.Debug("Updating user_auth info", "user_id", user.Id) logger.Debug("Updating user_auth info", "user_id", user.Id)
return bus.Dispatch(updateCmd) return bus.DispatchCtx(ctx, updateCmd)
} }
func syncOrgRoles(user *models.User, extUser *models.ExternalUserInfo) error { func syncOrgRoles(ctx context.Context, user *models.User, extUser *models.ExternalUserInfo) error {
logger.Debug("Syncing organization roles", "id", user.Id, "extOrgRoles", extUser.OrgRoles) logger.Debug("Syncing organization roles", "id", user.Id, "extOrgRoles", extUser.OrgRoles)
// don't sync org roles if none is specified // don't sync org roles if none is specified
@ -201,7 +201,7 @@ func syncOrgRoles(user *models.User, extUser *models.ExternalUserInfo) error {
} }
orgsQuery := &models.GetUserOrgListQuery{UserId: user.Id} orgsQuery := &models.GetUserOrgListQuery{UserId: user.Id}
if err := bus.Dispatch(orgsQuery); err != nil { if err := bus.DispatchCtx(ctx, orgsQuery); err != nil {
return err return err
} }
@ -218,7 +218,7 @@ func syncOrgRoles(user *models.User, extUser *models.ExternalUserInfo) error {
} else if extRole != org.Role { } else if extRole != org.Role {
// update role // update role
cmd := &models.UpdateOrgUserCommand{OrgId: org.OrgId, UserId: user.Id, Role: extRole} cmd := &models.UpdateOrgUserCommand{OrgId: org.OrgId, UserId: user.Id, Role: extRole}
if err := bus.Dispatch(cmd); err != nil { if err := bus.DispatchCtx(ctx, cmd); err != nil {
return err return err
} }
} }
@ -232,7 +232,7 @@ func syncOrgRoles(user *models.User, extUser *models.ExternalUserInfo) error {
// add role // add role
cmd := &models.AddOrgUserCommand{UserId: user.Id, Role: orgRole, OrgId: orgId} cmd := &models.AddOrgUserCommand{UserId: user.Id, Role: orgRole, OrgId: orgId}
err := bus.Dispatch(cmd) err := bus.DispatchCtx(ctx, cmd)
if err != nil && !errors.Is(err, models.ErrOrgNotFound) { if err != nil && !errors.Is(err, models.ErrOrgNotFound) {
return err return err
} }
@ -243,7 +243,7 @@ func syncOrgRoles(user *models.User, extUser *models.ExternalUserInfo) error {
logger.Debug("Removing user's organization membership as part of syncing with OAuth login", logger.Debug("Removing user's organization membership as part of syncing with OAuth login",
"userId", user.Id, "orgId", orgId) "userId", user.Id, "orgId", orgId)
cmd := &models.RemoveOrgUserCommand{OrgId: orgId, UserId: user.Id} cmd := &models.RemoveOrgUserCommand{OrgId: orgId, UserId: user.Id}
if err := bus.Dispatch(cmd); err != nil { if err := bus.DispatchCtx(ctx, cmd); err != nil {
if errors.Is(err, models.ErrLastOrgAdmin) { if errors.Is(err, models.ErrLastOrgAdmin) {
logger.Error(err.Error(), "userId", cmd.UserId, "orgId", cmd.OrgId) logger.Error(err.Error(), "userId", cmd.UserId, "orgId", cmd.OrgId)
continue continue
@ -260,7 +260,7 @@ func syncOrgRoles(user *models.User, extUser *models.ExternalUserInfo) error {
break break
} }
return bus.Dispatch(&models.SetUsingOrgCommand{ return bus.DispatchCtx(ctx, &models.SetUsingOrgCommand{
UserId: user.Id, UserId: user.Id,
OrgId: user.OrgId, OrgId: user.OrgId,
}) })

View File

@ -1,6 +1,7 @@
package loginservice package loginservice
import ( import (
"context"
"errors" "errors"
"testing" "testing"
@ -36,7 +37,7 @@ func Test_syncOrgRoles_doesNotBreakWhenTryingToRemoveLastOrgAdmin(t *testing.T)
return nil return nil
}) })
err := syncOrgRoles(&user, &externalUser) err := syncOrgRoles(context.Background(), &user, &externalUser)
require.Empty(t, remResp) require.Empty(t, remResp)
require.NoError(t, err) require.NoError(t, err)
} }
@ -71,7 +72,7 @@ func Test_syncOrgRoles_whenTryingToRemoveLastOrgLogsError(t *testing.T) {
return nil return nil
}) })
err := syncOrgRoles(&user, &externalUser) err := syncOrgRoles(context.Background(), &user, &externalUser)
require.NoError(t, err) require.NoError(t, err)
assert.Contains(t, logs, models.ErrLastOrgAdmin.Error()) assert.Contains(t, logs, models.ErrLastOrgAdmin.Error())
} }
@ -81,7 +82,7 @@ type authInfoServiceMock struct {
err error err error
} }
func (a *authInfoServiceMock) LookupAndUpdate(query *models.GetUserByAuthInfoQuery) (*models.User, error) { func (a *authInfoServiceMock) LookupAndUpdate(ctx context.Context, query *models.GetUserByAuthInfoQuery) (*models.User, error) {
return a.user, a.err return a.user, a.err
} }
@ -109,7 +110,7 @@ func Test_teamSync(t *testing.T) {
var actualExternalUser *models.ExternalUserInfo var actualExternalUser *models.ExternalUserInfo
t.Run("login.TeamSync should not be called when nil", func(t *testing.T) { t.Run("login.TeamSync should not be called when nil", func(t *testing.T) {
err := login.UpsertUser(upserCmd) err := login.UpsertUser(context.Background(), upserCmd)
require.Nil(t, err) require.Nil(t, err)
assert.Nil(t, actualUser) assert.Nil(t, actualUser)
assert.Nil(t, actualExternalUser) assert.Nil(t, actualExternalUser)
@ -121,7 +122,7 @@ func Test_teamSync(t *testing.T) {
return nil return nil
} }
login.TeamSync = teamSyncFunc login.TeamSync = teamSyncFunc
err := login.UpsertUser(upserCmd) err := login.UpsertUser(context.Background(), upserCmd)
require.Nil(t, err) require.Nil(t, err)
assert.Equal(t, actualUser, expectedUser) assert.Equal(t, actualUser, expectedUser)
assert.Equal(t, actualExternalUser, upserCmd.ExternalUser) assert.Equal(t, actualExternalUser, upserCmd.ExternalUser)
@ -132,7 +133,7 @@ func Test_teamSync(t *testing.T) {
return errors.New("teamsync test error") return errors.New("teamsync test error")
} }
login.TeamSync = teamSyncFunc login.TeamSync = teamSyncFunc
err := login.UpsertUser(upserCmd) err := login.UpsertUser(context.Background(), upserCmd)
require.Error(t, err) require.Error(t, err)
}) })
}) })

View File

@ -83,7 +83,7 @@ func (o *Service) GetCurrentOAuthToken(ctx context.Context, user *models.SignedI
AuthId: authInfoQuery.Result.AuthId, AuthId: authInfoQuery.Result.AuthId,
OAuthToken: token, OAuthToken: token,
} }
if err := bus.Dispatch(updateAuthCommand); err != nil { if err := bus.DispatchCtx(ctx, updateAuthCommand); err != nil {
logger.Error("failed to update auth info during token refresh", "userId", user.UserId, "username", user.Login, "error", err) logger.Error("failed to update auth info during token refresh", "userId", user.UserId, "username", user.Login, "error", err)
return nil return nil
} }