mirror of
https://github.com/grafana/grafana.git
synced 2025-09-27 13:03:53 +08:00
Auth: creates a hook in the user mapping flow (#37190)
* wip * Auth Info: refactored out into it's own service * Auth: adds extension point where users are being mapped * Update pkg/services/login/authinfoservice/service.go Co-authored-by: Joan López de la Franca Beltran <joanjan14@gmail.com> * Update pkg/services/login/authinfoservice/service.go Co-authored-by: Joan López de la Franca Beltran <joanjan14@gmail.com> * Auth: simplified code * moved most authinfo stuff to its own package * added back code * linter * simplified Co-authored-by: Joan López de la Franca Beltran <joanjan14@gmail.com>
This commit is contained in:
@ -10,6 +10,7 @@ var (
|
|||||||
ErrUserNotFound = errors.New("user not found")
|
ErrUserNotFound = errors.New("user not found")
|
||||||
ErrUserAlreadyExists = errors.New("user already exists")
|
ErrUserAlreadyExists = errors.New("user already exists")
|
||||||
ErrLastGrafanaAdmin = errors.New("cannot remove last grafana admin")
|
ErrLastGrafanaAdmin = errors.New("cannot remove last grafana admin")
|
||||||
|
ErrProtectedUser = errors.New("cannot adopt protected user")
|
||||||
)
|
)
|
||||||
|
|
||||||
type Password string
|
type Password string
|
||||||
|
@ -98,8 +98,6 @@ type GetUserByAuthInfoQuery struct {
|
|||||||
UserId int64
|
UserId int64
|
||||||
Email string
|
Email string
|
||||||
Login string
|
Login string
|
||||||
|
|
||||||
Result *User
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type GetExternalUserInfoByLoginQuery struct {
|
type GetExternalUserInfoByLoginQuery struct {
|
||||||
|
@ -13,8 +13,6 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/sync/errgroup"
|
|
||||||
|
|
||||||
"github.com/grafana/grafana/pkg/api"
|
"github.com/grafana/grafana/pkg/api"
|
||||||
"github.com/grafana/grafana/pkg/api/routing"
|
"github.com/grafana/grafana/pkg/api/routing"
|
||||||
"github.com/grafana/grafana/pkg/bus"
|
"github.com/grafana/grafana/pkg/bus"
|
||||||
@ -37,6 +35,7 @@ import (
|
|||||||
_ "github.com/grafana/grafana/pkg/services/auth/jwt"
|
_ "github.com/grafana/grafana/pkg/services/auth/jwt"
|
||||||
_ "github.com/grafana/grafana/pkg/services/cleanup"
|
_ "github.com/grafana/grafana/pkg/services/cleanup"
|
||||||
_ "github.com/grafana/grafana/pkg/services/librarypanels"
|
_ "github.com/grafana/grafana/pkg/services/librarypanels"
|
||||||
|
_ "github.com/grafana/grafana/pkg/services/login/authinfoservice"
|
||||||
_ "github.com/grafana/grafana/pkg/services/login/loginservice"
|
_ "github.com/grafana/grafana/pkg/services/login/loginservice"
|
||||||
_ "github.com/grafana/grafana/pkg/services/ngalert"
|
_ "github.com/grafana/grafana/pkg/services/ngalert"
|
||||||
_ "github.com/grafana/grafana/pkg/services/notifications"
|
_ "github.com/grafana/grafana/pkg/services/notifications"
|
||||||
@ -45,6 +44,7 @@ import (
|
|||||||
_ "github.com/grafana/grafana/pkg/services/search"
|
_ "github.com/grafana/grafana/pkg/services/search"
|
||||||
_ "github.com/grafana/grafana/pkg/services/sqlstore"
|
_ "github.com/grafana/grafana/pkg/services/sqlstore"
|
||||||
"github.com/grafana/grafana/pkg/setting"
|
"github.com/grafana/grafana/pkg/setting"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Config contains parameters for the New function.
|
// Config contains parameters for the New function.
|
||||||
|
7
pkg/services/login/authinfo.go
Normal file
7
pkg/services/login/authinfo.go
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
package login
|
||||||
|
|
||||||
|
import "github.com/grafana/grafana/pkg/models"
|
||||||
|
|
||||||
|
type AuthInfoService interface {
|
||||||
|
LookupAndUpdate(query *models.GetUserByAuthInfoQuery) (*models.User, error)
|
||||||
|
}
|
@ -1,10 +1,12 @@
|
|||||||
package sqlstore
|
package authinfoservice
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"errors"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/grafana/grafana/pkg/services/sqlstore"
|
||||||
|
|
||||||
"github.com/grafana/grafana/pkg/bus"
|
"github.com/grafana/grafana/pkg/bus"
|
||||||
"github.com/grafana/grafana/pkg/models"
|
"github.com/grafana/grafana/pkg/models"
|
||||||
"github.com/grafana/grafana/pkg/setting"
|
"github.com/grafana/grafana/pkg/setting"
|
||||||
@ -13,123 +15,7 @@ import (
|
|||||||
|
|
||||||
var getTime = time.Now
|
var getTime = time.Now
|
||||||
|
|
||||||
const genericOAuthModule = "oauth_generic_oauth"
|
func (s *Implementation) GetExternalUserInfoByLogin(query *models.GetExternalUserInfoByLoginQuery) error {
|
||||||
|
|
||||||
func init() {
|
|
||||||
bus.AddHandler("sql", GetUserByAuthInfo)
|
|
||||||
bus.AddHandler("sql", GetExternalUserInfoByLogin)
|
|
||||||
bus.AddHandler("sql", GetAuthInfo)
|
|
||||||
bus.AddHandler("sql", SetAuthInfo)
|
|
||||||
bus.AddHandler("sql", UpdateAuthInfo)
|
|
||||||
bus.AddHandler("sql", DeleteAuthInfo)
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetUserByAuthInfo(query *models.GetUserByAuthInfoQuery) error {
|
|
||||||
user := &models.User{}
|
|
||||||
has := false
|
|
||||||
var err error
|
|
||||||
authQuery := &models.GetAuthInfoQuery{}
|
|
||||||
|
|
||||||
// Try to find the user by auth module and id first
|
|
||||||
if query.AuthModule != "" && query.AuthId != "" {
|
|
||||||
authQuery.AuthModule = query.AuthModule
|
|
||||||
authQuery.AuthId = query.AuthId
|
|
||||||
|
|
||||||
err = GetAuthInfo(authQuery)
|
|
||||||
if !errors.Is(err, models.ErrUserNotFound) {
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// if user id was specified and doesn't match the user_auth entry, remove it
|
|
||||||
if query.UserId != 0 && query.UserId != authQuery.Result.UserId {
|
|
||||||
err = DeleteAuthInfo(&models.DeleteAuthInfoCommand{
|
|
||||||
UserAuth: authQuery.Result,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
sqlog.Error("Error removing user_auth entry", "error", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
authQuery.Result = nil
|
|
||||||
} else {
|
|
||||||
has, err = x.Id(authQuery.Result.UserId).Get(user)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !has {
|
|
||||||
// if the user has been deleted then remove the entry
|
|
||||||
err = DeleteAuthInfo(&models.DeleteAuthInfoCommand{
|
|
||||||
UserAuth: authQuery.Result,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
sqlog.Error("Error removing user_auth entry", "error", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
authQuery.Result = nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If not found, try to find the user by id
|
|
||||||
if !has && query.UserId != 0 {
|
|
||||||
has, err = x.Id(query.UserId).Get(user)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If not found, try to find the user by email address
|
|
||||||
if !has && query.Email != "" {
|
|
||||||
user = &models.User{Email: query.Email}
|
|
||||||
has, err = x.Get(user)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If not found, try to find the user by login
|
|
||||||
if !has && query.Login != "" {
|
|
||||||
user = &models.User{Login: query.Login}
|
|
||||||
has, err = x.Get(user)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// No user found
|
|
||||||
if !has {
|
|
||||||
return models.ErrUserNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
// Special case for generic oauth duplicates
|
|
||||||
if query.AuthModule == genericOAuthModule && user.Id != 0 {
|
|
||||||
authQuery.UserId = user.Id
|
|
||||||
authQuery.AuthModule = query.AuthModule
|
|
||||||
err = GetAuthInfo(authQuery)
|
|
||||||
if !errors.Is(err, models.ErrUserNotFound) {
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if authQuery.Result == nil && query.AuthModule != "" {
|
|
||||||
cmd2 := &models.SetAuthInfoCommand{
|
|
||||||
UserId: user.Id,
|
|
||||||
AuthModule: query.AuthModule,
|
|
||||||
AuthId: query.AuthId,
|
|
||||||
}
|
|
||||||
if err := SetAuthInfo(cmd2); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
query.Result = user
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetExternalUserInfoByLogin(query *models.GetExternalUserInfoByLoginQuery) error {
|
|
||||||
userQuery := models.GetUserByLoginQuery{LoginOrEmail: query.LoginOrEmail}
|
userQuery := models.GetUserByLoginQuery{LoginOrEmail: query.LoginOrEmail}
|
||||||
err := bus.Dispatch(&userQuery)
|
err := bus.Dispatch(&userQuery)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -153,13 +39,20 @@ func GetExternalUserInfoByLogin(query *models.GetExternalUserInfoByLoginQuery) e
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAuthInfo(query *models.GetAuthInfoQuery) error {
|
func (s *Implementation) GetAuthInfo(query *models.GetAuthInfoQuery) error {
|
||||||
userAuth := &models.UserAuth{
|
userAuth := &models.UserAuth{
|
||||||
UserId: query.UserId,
|
UserId: query.UserId,
|
||||||
AuthModule: query.AuthModule,
|
AuthModule: query.AuthModule,
|
||||||
AuthId: query.AuthId,
|
AuthId: query.AuthId,
|
||||||
}
|
}
|
||||||
has, err := x.Desc("created").Get(userAuth)
|
|
||||||
|
var has bool
|
||||||
|
var err error
|
||||||
|
|
||||||
|
err = s.SQLStore.WithDbSession(context.Background(), func(sess *sqlstore.DBSession) error {
|
||||||
|
has, err = sess.Desc("created").Get(userAuth)
|
||||||
|
return err
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -188,8 +81,8 @@ func GetAuthInfo(query *models.GetAuthInfoQuery) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func SetAuthInfo(cmd *models.SetAuthInfoCommand) error {
|
func (s *Implementation) SetAuthInfo(cmd *models.SetAuthInfoCommand) error {
|
||||||
return inTransaction(func(sess *DBSession) error {
|
return s.SQLStore.WithTransactionalDbSession(context.Background(), func(sess *sqlstore.DBSession) error {
|
||||||
authUser := &models.UserAuth{
|
authUser := &models.UserAuth{
|
||||||
UserId: cmd.UserId,
|
UserId: cmd.UserId,
|
||||||
AuthModule: cmd.AuthModule,
|
AuthModule: cmd.AuthModule,
|
||||||
@ -222,8 +115,8 @@ func SetAuthInfo(cmd *models.SetAuthInfoCommand) error {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func UpdateAuthInfo(cmd *models.UpdateAuthInfoCommand) error {
|
func (s *Implementation) UpdateAuthInfo(cmd *models.UpdateAuthInfoCommand) error {
|
||||||
return inTransaction(func(sess *DBSession) error {
|
return s.SQLStore.WithTransactionalDbSession(context.Background(), func(sess *sqlstore.DBSession) error {
|
||||||
authUser := &models.UserAuth{
|
authUser := &models.UserAuth{
|
||||||
UserId: cmd.UserId,
|
UserId: cmd.UserId,
|
||||||
AuthModule: cmd.AuthModule,
|
AuthModule: cmd.AuthModule,
|
||||||
@ -256,13 +149,13 @@ func UpdateAuthInfo(cmd *models.UpdateAuthInfoCommand) error {
|
|||||||
AuthModule: cmd.AuthModule,
|
AuthModule: cmd.AuthModule,
|
||||||
}
|
}
|
||||||
upd, err := sess.Update(authUser, cond)
|
upd, err := sess.Update(authUser, cond)
|
||||||
sqlog.Debug("Updated user_auth", "user_id", cmd.UserId, "auth_module", cmd.AuthModule, "rows", upd)
|
s.logger.Debug("Updated user_auth", "user_id", cmd.UserId, "auth_module", cmd.AuthModule, "rows", upd)
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func DeleteAuthInfo(cmd *models.DeleteAuthInfoCommand) error {
|
func (s *Implementation) DeleteAuthInfo(cmd *models.DeleteAuthInfoCommand) error {
|
||||||
return inTransaction(func(sess *DBSession) error {
|
return s.SQLStore.WithTransactionalDbSession(context.Background(), func(sess *sqlstore.DBSession) error {
|
||||||
_, err := sess.Delete(cmd.UserAuth)
|
_, err := sess.Delete(cmd.UserAuth)
|
||||||
return err
|
return err
|
||||||
})
|
})
|
223
pkg/services/login/authinfoservice/service.go
Normal file
223
pkg/services/login/authinfoservice/service.go
Normal file
@ -0,0 +1,223 @@
|
|||||||
|
package authinfoservice
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/grafana/grafana/pkg/bus"
|
||||||
|
"github.com/grafana/grafana/pkg/infra/log"
|
||||||
|
"github.com/grafana/grafana/pkg/models"
|
||||||
|
"github.com/grafana/grafana/pkg/registry"
|
||||||
|
"github.com/grafana/grafana/pkg/services/login"
|
||||||
|
"github.com/grafana/grafana/pkg/services/sqlstore"
|
||||||
|
)
|
||||||
|
|
||||||
|
const genericOAuthModule = "oauth_generic_oauth"
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
srv := &Implementation{}
|
||||||
|
|
||||||
|
registry.Register(®istry.Descriptor{
|
||||||
|
Name: "UserAuthInfo",
|
||||||
|
Instance: srv,
|
||||||
|
InitPriority: registry.MediumHigh,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
type Implementation struct {
|
||||||
|
Bus bus.Bus `inject:""`
|
||||||
|
SQLStore *sqlstore.SQLStore `inject:""`
|
||||||
|
UserProtectionService login.UserProtectionService `inject:""`
|
||||||
|
|
||||||
|
logger log.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Implementation) Init() error {
|
||||||
|
s.logger = log.New("login.authinfo")
|
||||||
|
|
||||||
|
s.Bus.AddHandler(s.GetExternalUserInfoByLogin)
|
||||||
|
s.Bus.AddHandler(s.GetAuthInfo)
|
||||||
|
s.Bus.AddHandler(s.SetAuthInfo)
|
||||||
|
s.Bus.AddHandler(s.UpdateAuthInfo)
|
||||||
|
s.Bus.AddHandler(s.DeleteAuthInfo)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Implementation) getUserById(id int64) (bool, *models.User, error) {
|
||||||
|
var (
|
||||||
|
has bool
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
user := &models.User{}
|
||||||
|
err = s.SQLStore.WithDbSession(context.Background(), func(sess *sqlstore.DBSession) error {
|
||||||
|
has, err = sess.ID(id).Get(user)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return false, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return has, user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Implementation) getUser(user *models.User) (bool, error) {
|
||||||
|
var err error
|
||||||
|
var has bool
|
||||||
|
|
||||||
|
err = s.SQLStore.WithDbSession(context.Background(), func(sess *sqlstore.DBSession) error {
|
||||||
|
has, err = sess.Get(user)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
|
||||||
|
return has, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Implementation) LookupAndFix(query *models.GetUserByAuthInfoQuery) (bool, *models.User, *models.UserAuth, error) {
|
||||||
|
authQuery := &models.GetAuthInfoQuery{}
|
||||||
|
|
||||||
|
// Try to find the user by auth module and id first
|
||||||
|
if query.AuthModule != "" && query.AuthId != "" {
|
||||||
|
authQuery.AuthModule = query.AuthModule
|
||||||
|
authQuery.AuthId = query.AuthId
|
||||||
|
|
||||||
|
err := s.GetAuthInfo(authQuery)
|
||||||
|
if !errors.Is(err, models.ErrUserNotFound) {
|
||||||
|
if err != nil {
|
||||||
|
return false, nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// if user id was specified and doesn't match the user_auth entry, remove it
|
||||||
|
if query.UserId != 0 && query.UserId != authQuery.Result.UserId {
|
||||||
|
err := s.DeleteAuthInfo(&models.DeleteAuthInfoCommand{
|
||||||
|
UserAuth: authQuery.Result,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Error("Error removing user_auth entry", "error", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, nil, nil, models.ErrUserNotFound
|
||||||
|
} else {
|
||||||
|
has, user, err := s.getUserById(authQuery.Result.UserId)
|
||||||
|
if err != nil {
|
||||||
|
return false, nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !has {
|
||||||
|
// if the user has been deleted then remove the entry
|
||||||
|
err = s.DeleteAuthInfo(&models.DeleteAuthInfoCommand{
|
||||||
|
UserAuth: authQuery.Result,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Error("Error removing user_auth entry", "error", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, nil, nil, models.ErrUserNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
return true, user, authQuery.Result, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, nil, nil, models.ErrUserNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Implementation) LookupByOneOf(userId int64, email string, login string) (bool, *models.User, error) {
|
||||||
|
foundUser := false
|
||||||
|
var user *models.User
|
||||||
|
var err error
|
||||||
|
|
||||||
|
// If not found, try to find the user by id
|
||||||
|
if userId != 0 {
|
||||||
|
foundUser, user, err = s.getUserById(userId)
|
||||||
|
if err != nil {
|
||||||
|
return false, nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If not found, try to find the user by email address
|
||||||
|
if !foundUser && email != "" {
|
||||||
|
user = &models.User{Email: email}
|
||||||
|
foundUser, err = s.getUser(user)
|
||||||
|
if err != nil {
|
||||||
|
return false, nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If not found, try to find the user by login
|
||||||
|
if !foundUser && login != "" {
|
||||||
|
user = &models.User{Login: login}
|
||||||
|
foundUser, err = s.getUser(user)
|
||||||
|
if err != nil {
|
||||||
|
return false, nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !foundUser {
|
||||||
|
return false, nil, models.ErrUserNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
return foundUser, user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Implementation) GenericOAuthLookup(authModule string, authId string, userID int64) (*models.UserAuth, error) {
|
||||||
|
if authModule == genericOAuthModule && userID != 0 {
|
||||||
|
authQuery := &models.GetAuthInfoQuery{}
|
||||||
|
authQuery.AuthModule = authModule
|
||||||
|
authQuery.AuthId = authId
|
||||||
|
authQuery.UserId = userID
|
||||||
|
err := s.GetAuthInfo(authQuery)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return authQuery.Result, nil
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Implementation) LookupAndUpdate(query *models.GetUserByAuthInfoQuery) (*models.User, error) {
|
||||||
|
// 1. LookupAndFix = auth info, user, error
|
||||||
|
// TODO: Not a big fan of the fact that we are deleting auth info here, might want to move that
|
||||||
|
foundUser, user, authInfo, err := s.LookupAndFix(query)
|
||||||
|
if err != nil && !errors.Is(err, models.ErrUserNotFound) {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. FindByUserDetails
|
||||||
|
if !foundUser {
|
||||||
|
_, user, err = s.LookupByOneOf(query.UserId, query.Email, query.Login)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.UserProtectionService.AllowUserMapping(user, query.AuthModule); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Special case for generic oauth duplicates
|
||||||
|
ai, err := s.GenericOAuthLookup(query.AuthModule, query.AuthId, user.Id)
|
||||||
|
if !errors.Is(err, models.ErrUserNotFound) {
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if ai != nil {
|
||||||
|
authInfo = ai
|
||||||
|
}
|
||||||
|
|
||||||
|
if authInfo == nil && query.AuthModule != "" {
|
||||||
|
cmd := &models.SetAuthInfoCommand{
|
||||||
|
UserId: user.Id,
|
||||||
|
AuthModule: query.AuthModule,
|
||||||
|
AuthId: query.AuthId,
|
||||||
|
}
|
||||||
|
if err := s.SetAuthInfo(cmd); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return user, nil
|
||||||
|
}
|
@ -1,21 +1,31 @@
|
|||||||
// +build integration
|
// +build integration
|
||||||
|
|
||||||
package sqlstore
|
package authinfoservice
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/grafana/grafana/pkg/services/sqlstore"
|
||||||
|
|
||||||
|
"github.com/grafana/grafana/pkg/bus"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/grafana/grafana/pkg/models"
|
"github.com/grafana/grafana/pkg/models"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
)
|
)
|
||||||
|
|
||||||
//nolint:goconst
|
//nolint:goconst
|
||||||
func TestUserAuth(t *testing.T) {
|
func TestUserAuth(t *testing.T) {
|
||||||
sqlStore := InitTestDB(t)
|
sqlStore := sqlstore.InitTestDB(t)
|
||||||
|
srv := &Implementation{
|
||||||
|
Bus: bus.New(),
|
||||||
|
SQLStore: sqlStore,
|
||||||
|
UserProtectionService: OSSUserProtectionImpl{},
|
||||||
|
}
|
||||||
|
srv.Init()
|
||||||
|
|
||||||
t.Run("Given 5 users", func(t *testing.T) {
|
t.Run("Given 5 users", func(t *testing.T) {
|
||||||
for i := 0; i < 5; i++ {
|
for i := 0; i < 5; i++ {
|
||||||
@ -24,7 +34,7 @@ func TestUserAuth(t *testing.T) {
|
|||||||
Name: fmt.Sprint("user", i),
|
Name: fmt.Sprint("user", i),
|
||||||
Login: fmt.Sprint("loginuser", i),
|
Login: fmt.Sprint("loginuser", i),
|
||||||
}
|
}
|
||||||
_, err := sqlStore.CreateUser(context.Background(), cmd)
|
_, err := srv.SQLStore.CreateUser(context.Background(), cmd)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -33,89 +43,90 @@ func TestUserAuth(t *testing.T) {
|
|||||||
login := "loginuser0"
|
login := "loginuser0"
|
||||||
|
|
||||||
query := &models.GetUserByAuthInfoQuery{Login: login}
|
query := &models.GetUserByAuthInfoQuery{Login: login}
|
||||||
err := GetUserByAuthInfo(query)
|
user, err := srv.LookupAndUpdate(query)
|
||||||
|
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Equal(t, query.Result.Login, login)
|
require.Equal(t, user.Login, login)
|
||||||
|
|
||||||
// By ID
|
// By ID
|
||||||
id := query.Result.Id
|
id := user.Id
|
||||||
|
|
||||||
query = &models.GetUserByAuthInfoQuery{UserId: id}
|
_, user, err = srv.LookupByOneOf(id, "", "")
|
||||||
err = GetUserByAuthInfo(query)
|
|
||||||
|
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Equal(t, query.Result.Id, id)
|
require.Equal(t, user.Id, id)
|
||||||
|
|
||||||
// By Email
|
// By Email
|
||||||
email := "user1@test.com"
|
email := "user1@test.com"
|
||||||
|
|
||||||
query = &models.GetUserByAuthInfoQuery{Email: email}
|
_, user, err = srv.LookupByOneOf(0, email, "")
|
||||||
err = GetUserByAuthInfo(query)
|
|
||||||
|
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Equal(t, query.Result.Email, email)
|
require.Equal(t, user.Email, email)
|
||||||
|
|
||||||
// Don't find nonexistent user
|
// Don't find nonexistent user
|
||||||
email = "nonexistent@test.com"
|
email = "nonexistent@test.com"
|
||||||
|
|
||||||
query = &models.GetUserByAuthInfoQuery{Email: email}
|
_, user, err = srv.LookupByOneOf(0, email, "")
|
||||||
err = GetUserByAuthInfo(query)
|
|
||||||
|
|
||||||
require.Equal(t, err, models.ErrUserNotFound)
|
require.Equal(t, models.ErrUserNotFound, err)
|
||||||
require.Nil(t, query.Result)
|
require.Nil(t, user)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("Can set & locate by AuthModule and AuthId", func(t *testing.T) {
|
t.Run("Can set & locate by AuthModule and AuthId", func(t *testing.T) {
|
||||||
// get nonexistent user_auth entry
|
// get nonexistent user_auth entry
|
||||||
query := &models.GetUserByAuthInfoQuery{AuthModule: "test", AuthId: "test"}
|
query := &models.GetUserByAuthInfoQuery{AuthModule: "test", AuthId: "test"}
|
||||||
err := GetUserByAuthInfo(query)
|
user, err := srv.LookupAndUpdate(query)
|
||||||
|
|
||||||
require.Equal(t, err, models.ErrUserNotFound)
|
require.Equal(t, models.ErrUserNotFound, err)
|
||||||
require.Nil(t, query.Result)
|
require.Nil(t, user)
|
||||||
|
|
||||||
// create user_auth entry
|
// create user_auth entry
|
||||||
login := "loginuser0"
|
login := "loginuser0"
|
||||||
|
|
||||||
query.Login = login
|
query.Login = login
|
||||||
err = GetUserByAuthInfo(query)
|
user, err = srv.LookupAndUpdate(query)
|
||||||
|
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Equal(t, query.Result.Login, login)
|
require.Equal(t, user.Login, login)
|
||||||
|
|
||||||
// get via user_auth
|
// get via user_auth
|
||||||
query = &models.GetUserByAuthInfoQuery{AuthModule: "test", AuthId: "test"}
|
query = &models.GetUserByAuthInfoQuery{AuthModule: "test", AuthId: "test"}
|
||||||
err = GetUserByAuthInfo(query)
|
user, err = srv.LookupAndUpdate(query)
|
||||||
|
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Equal(t, query.Result.Login, login)
|
require.Equal(t, user.Login, login)
|
||||||
|
|
||||||
// get with non-matching id
|
// get with non-matching id
|
||||||
id := query.Result.Id
|
id := user.Id
|
||||||
|
|
||||||
query.UserId = id + 1
|
query.UserId = id + 1
|
||||||
err = GetUserByAuthInfo(query)
|
user, err = srv.LookupAndUpdate(query)
|
||||||
|
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Equal(t, query.Result.Login, "loginuser1")
|
require.Equal(t, user.Login, "loginuser1")
|
||||||
|
|
||||||
// get via user_auth
|
// get via user_auth
|
||||||
query = &models.GetUserByAuthInfoQuery{AuthModule: "test", AuthId: "test"}
|
query = &models.GetUserByAuthInfoQuery{AuthModule: "test", AuthId: "test"}
|
||||||
err = GetUserByAuthInfo(query)
|
user, err = srv.LookupAndUpdate(query)
|
||||||
|
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Equal(t, query.Result.Login, "loginuser1")
|
require.Equal(t, user.Login, "loginuser1")
|
||||||
|
|
||||||
// remove user
|
// remove user
|
||||||
_, err = x.Exec("DELETE FROM "+dialect.Quote("user")+" WHERE id=?", query.Result.Id)
|
srv.SQLStore.WithDbSession(context.Background(), func(sess *sqlstore.DBSession) error {
|
||||||
require.Nil(t, err)
|
sess.Exec("DELETE FROM "+srv.SQLStore.Dialect.Quote("user")+" WHERE id=?", user.Id)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
// get via user_auth for deleted user
|
// get via user_auth for deleted user
|
||||||
query = &models.GetUserByAuthInfoQuery{AuthModule: "test", AuthId: "test"}
|
query = &models.GetUserByAuthInfoQuery{AuthModule: "test", AuthId: "test"}
|
||||||
err = GetUserByAuthInfo(query)
|
user, err = srv.LookupAndUpdate(query)
|
||||||
|
|
||||||
require.Equal(t, err, models.ErrUserNotFound)
|
require.Equal(t, err, models.ErrUserNotFound)
|
||||||
require.Nil(t, query.Result)
|
require.Nil(t, user)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("Can set & retrieve oauth token information", func(t *testing.T) {
|
t.Run("Can set & retrieve oauth token information", func(t *testing.T) {
|
||||||
@ -131,26 +142,26 @@ func TestUserAuth(t *testing.T) {
|
|||||||
|
|
||||||
// Calling GetUserByAuthInfoQuery on an existing user will populate an entry in the user_auth table
|
// Calling GetUserByAuthInfoQuery on an existing user will populate an entry in the user_auth table
|
||||||
query := &models.GetUserByAuthInfoQuery{Login: login, AuthModule: "test", AuthId: "test"}
|
query := &models.GetUserByAuthInfoQuery{Login: login, AuthModule: "test", AuthId: "test"}
|
||||||
err := GetUserByAuthInfo(query)
|
user, err := srv.LookupAndUpdate(query)
|
||||||
|
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Equal(t, query.Result.Login, login)
|
require.Equal(t, user.Login, login)
|
||||||
|
|
||||||
cmd := &models.UpdateAuthInfoCommand{
|
cmd := &models.UpdateAuthInfoCommand{
|
||||||
UserId: query.Result.Id,
|
UserId: user.Id,
|
||||||
AuthId: query.AuthId,
|
AuthId: query.AuthId,
|
||||||
AuthModule: query.AuthModule,
|
AuthModule: query.AuthModule,
|
||||||
OAuthToken: token,
|
OAuthToken: token,
|
||||||
}
|
}
|
||||||
err = UpdateAuthInfo(cmd)
|
err = srv.UpdateAuthInfo(cmd)
|
||||||
|
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
|
||||||
getAuthQuery := &models.GetAuthInfoQuery{
|
getAuthQuery := &models.GetAuthInfoQuery{
|
||||||
UserId: query.Result.Id,
|
UserId: user.Id,
|
||||||
}
|
}
|
||||||
|
|
||||||
err = GetAuthInfo(getAuthQuery)
|
err = srv.GetAuthInfo(getAuthQuery)
|
||||||
|
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Equal(t, getAuthQuery.Result.OAuthAccessToken, token.AccessToken)
|
require.Equal(t, getAuthQuery.Result.OAuthAccessToken, token.AccessToken)
|
||||||
@ -160,7 +171,7 @@ func TestUserAuth(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("Always return the most recently used auth_module", func(t *testing.T) {
|
t.Run("Always return the most recently used auth_module", func(t *testing.T) {
|
||||||
// Restore after destructive operation
|
// Restore after destructive operation
|
||||||
sqlStore = InitTestDB(t)
|
sqlStore = sqlstore.InitTestDB(t)
|
||||||
|
|
||||||
for i := 0; i < 5; i++ {
|
for i := 0; i < 5; i++ {
|
||||||
cmd := models.CreateUserCommand{
|
cmd := models.CreateUserCommand{
|
||||||
@ -175,48 +186,48 @@ func TestUserAuth(t *testing.T) {
|
|||||||
// Find a user to set tokens on
|
// Find a user to set tokens on
|
||||||
login := "loginuser0"
|
login := "loginuser0"
|
||||||
|
|
||||||
// Calling GetUserByAuthInfoQuery on an existing user will populate an entry in the user_auth table
|
// Calling srv.LookupAndUpdateQuery on an existing user will populate an entry in the user_auth table
|
||||||
// Make the first log-in during the past
|
// Make the first log-in during the past
|
||||||
getTime = func() time.Time { return time.Now().AddDate(0, 0, -2) }
|
getTime = func() time.Time { return time.Now().AddDate(0, 0, -2) }
|
||||||
query := &models.GetUserByAuthInfoQuery{Login: login, AuthModule: "test1", AuthId: "test1"}
|
query := &models.GetUserByAuthInfoQuery{Login: login, AuthModule: "test1", AuthId: "test1"}
|
||||||
err := GetUserByAuthInfo(query)
|
user, err := srv.LookupAndUpdate(query)
|
||||||
getTime = time.Now
|
getTime = time.Now
|
||||||
|
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Equal(t, query.Result.Login, login)
|
require.Equal(t, user.Login, login)
|
||||||
|
|
||||||
// Add a second auth module for this user
|
// Add a second auth module for this user
|
||||||
// Have this module's last log-in be more recent
|
// Have this module's last log-in be more recent
|
||||||
getTime = func() time.Time { return time.Now().AddDate(0, 0, -1) }
|
getTime = func() time.Time { return time.Now().AddDate(0, 0, -1) }
|
||||||
query = &models.GetUserByAuthInfoQuery{Login: login, AuthModule: "test2", AuthId: "test2"}
|
query = &models.GetUserByAuthInfoQuery{Login: login, AuthModule: "test2", AuthId: "test2"}
|
||||||
err = GetUserByAuthInfo(query)
|
user, err = srv.LookupAndUpdate(query)
|
||||||
getTime = time.Now
|
getTime = time.Now
|
||||||
|
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Equal(t, query.Result.Login, login)
|
require.Equal(t, user.Login, login)
|
||||||
|
|
||||||
// Get the latest entry by not supply an authmodule or authid
|
// Get the latest entry by not supply an authmodule or authid
|
||||||
getAuthQuery := &models.GetAuthInfoQuery{
|
getAuthQuery := &models.GetAuthInfoQuery{
|
||||||
UserId: query.Result.Id,
|
UserId: user.Id,
|
||||||
}
|
}
|
||||||
|
|
||||||
err = GetAuthInfo(getAuthQuery)
|
err = srv.GetAuthInfo(getAuthQuery)
|
||||||
|
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Equal(t, getAuthQuery.Result.AuthModule, "test2")
|
require.Equal(t, getAuthQuery.Result.AuthModule, "test2")
|
||||||
|
|
||||||
// "log in" again with the first auth module
|
// "log in" again with the first auth module
|
||||||
updateAuthCmd := &models.UpdateAuthInfoCommand{UserId: query.Result.Id, AuthModule: "test1", AuthId: "test1"}
|
updateAuthCmd := &models.UpdateAuthInfoCommand{UserId: user.Id, AuthModule: "test1", AuthId: "test1"}
|
||||||
err = UpdateAuthInfo(updateAuthCmd)
|
err = srv.UpdateAuthInfo(updateAuthCmd)
|
||||||
|
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
|
||||||
// Get the latest entry by not supply an authmodule or authid
|
// Get the latest entry by not supply an authmodule or authid
|
||||||
getAuthQuery = &models.GetAuthInfoQuery{
|
getAuthQuery = &models.GetAuthInfoQuery{
|
||||||
UserId: query.Result.Id,
|
UserId: user.Id,
|
||||||
}
|
}
|
||||||
|
|
||||||
err = GetAuthInfo(getAuthQuery)
|
err = srv.GetAuthInfo(getAuthQuery)
|
||||||
|
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Equal(t, getAuthQuery.Result.AuthModule, "test1")
|
require.Equal(t, getAuthQuery.Result.AuthModule, "test1")
|
||||||
@ -229,20 +240,20 @@ func TestUserAuth(t *testing.T) {
|
|||||||
// Expect to pass since there's a matching login user
|
// Expect to pass since there's a matching login user
|
||||||
getTime = func() time.Time { return time.Now().AddDate(0, 0, -2) }
|
getTime = func() time.Time { return time.Now().AddDate(0, 0, -2) }
|
||||||
query := &models.GetUserByAuthInfoQuery{Login: login, AuthModule: genericOAuthModule, AuthId: ""}
|
query := &models.GetUserByAuthInfoQuery{Login: login, AuthModule: genericOAuthModule, AuthId: ""}
|
||||||
err := GetUserByAuthInfo(query)
|
user, err := srv.LookupAndUpdate(query)
|
||||||
getTime = time.Now
|
getTime = time.Now
|
||||||
|
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Equal(t, query.Result.Login, login)
|
require.Equal(t, user.Login, login)
|
||||||
|
|
||||||
// Should throw a "user not found" error since there's no matching login user
|
// Should throw a "user not found" error since there's no matching login user
|
||||||
getTime = func() time.Time { return time.Now().AddDate(0, 0, -2) }
|
getTime = func() time.Time { return time.Now().AddDate(0, 0, -2) }
|
||||||
query = &models.GetUserByAuthInfoQuery{Login: "aloginuser", AuthModule: genericOAuthModule, AuthId: ""}
|
query = &models.GetUserByAuthInfoQuery{Login: "aloginuser", AuthModule: genericOAuthModule, AuthId: ""}
|
||||||
err = GetUserByAuthInfo(query)
|
user, err = srv.LookupAndUpdate(query)
|
||||||
getTime = time.Now
|
getTime = time.Now
|
||||||
|
|
||||||
require.NotNil(t, err)
|
require.NotNil(t, err)
|
||||||
require.Nil(t, query.Result)
|
require.Nil(t, user)
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
21
pkg/services/login/authinfoservice/userprotection.go
Normal file
21
pkg/services/login/authinfoservice/userprotection.go
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
package authinfoservice
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/grafana/grafana/pkg/models"
|
||||||
|
"github.com/grafana/grafana/pkg/registry"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
registry.RegisterService(&OSSUserProtectionImpl{})
|
||||||
|
}
|
||||||
|
|
||||||
|
type OSSUserProtectionImpl struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (OSSUserProtectionImpl) Init() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (OSSUserProtectionImpl) AllowUserMapping(_ *models.User, _ string) error {
|
||||||
|
return nil
|
||||||
|
}
|
@ -22,10 +22,11 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Implementation struct {
|
type Implementation struct {
|
||||||
SQLStore *sqlstore.SQLStore `inject:""`
|
SQLStore *sqlstore.SQLStore `inject:""`
|
||||||
Bus bus.Bus `inject:""`
|
Bus bus.Bus `inject:""`
|
||||||
QuotaService *quota.QuotaService `inject:""`
|
AuthInfoService login.AuthInfoService `inject:""`
|
||||||
TeamSync login.TeamSyncFunc
|
QuotaService *quota.QuotaService `inject:""`
|
||||||
|
TeamSync login.TeamSyncFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ls *Implementation) Init() error {
|
func (ls *Implementation) Init() error {
|
||||||
@ -43,14 +44,14 @@ func (ls *Implementation) CreateUser(cmd models.CreateUserCommand) (*models.User
|
|||||||
func (ls *Implementation) UpsertUser(cmd *models.UpsertUserCommand) error {
|
func (ls *Implementation) UpsertUser(cmd *models.UpsertUserCommand) error {
|
||||||
extUser := cmd.ExternalUser
|
extUser := cmd.ExternalUser
|
||||||
|
|
||||||
userQuery := &models.GetUserByAuthInfoQuery{
|
user, err := ls.AuthInfoService.LookupAndUpdate(&models.GetUserByAuthInfoQuery{
|
||||||
AuthModule: extUser.AuthModule,
|
AuthModule: extUser.AuthModule,
|
||||||
AuthId: extUser.AuthId,
|
AuthId: extUser.AuthId,
|
||||||
UserId: extUser.UserId,
|
UserId: extUser.UserId,
|
||||||
Email: extUser.Email,
|
Email: extUser.Email,
|
||||||
Login: extUser.Login,
|
Login: extUser.Login,
|
||||||
}
|
})
|
||||||
if err := bus.Dispatch(userQuery); err != nil {
|
if err != nil {
|
||||||
if !errors.Is(err, models.ErrUserNotFound) {
|
if !errors.Is(err, models.ErrUserNotFound) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -85,7 +86,7 @@ func (ls *Implementation) UpsertUser(cmd *models.UpsertUserCommand) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
cmd.Result = userQuery.Result
|
cmd.Result = user
|
||||||
|
|
||||||
err = updateUser(cmd.Result, extUser)
|
err = updateUser(cmd.Result, extUser)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -100,7 +101,7 @@ func (ls *Implementation) UpsertUser(cmd *models.UpsertUserCommand) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if extUser.AuthModule == models.AuthModuleLDAP && userQuery.Result.IsDisabled {
|
if extUser.AuthModule == models.AuthModuleLDAP && user.IsDisabled {
|
||||||
// Re-enable user when it found in LDAP
|
// Re-enable user when it found in LDAP
|
||||||
if err := ls.Bus.Dispatch(&models.DisableUserCommand{UserId: cmd.Result.Id, IsDisabled: false}); err != nil {
|
if err := ls.Bus.Dispatch(&models.DisableUserCommand{UserId: cmd.Result.Id, IsDisabled: false}); err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -76,10 +76,22 @@ func Test_syncOrgRoles_whenTryingToRemoveLastOrgLogsError(t *testing.T) {
|
|||||||
assert.Contains(t, logs, models.ErrLastOrgAdmin.Error())
|
assert.Contains(t, logs, models.ErrLastOrgAdmin.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type authInfoServiceMock struct {
|
||||||
|
user *models.User
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *authInfoServiceMock) LookupAndUpdate(query *models.GetUserByAuthInfoQuery) (*models.User, error) {
|
||||||
|
return a.user, a.err
|
||||||
|
}
|
||||||
|
|
||||||
func Test_teamSync(t *testing.T) {
|
func Test_teamSync(t *testing.T) {
|
||||||
|
b := bus.New()
|
||||||
|
authInfoMock := &authInfoServiceMock{}
|
||||||
login := Implementation{
|
login := Implementation{
|
||||||
Bus: bus.New(),
|
Bus: b,
|
||||||
QuotaService: "a.QuotaService{},
|
QuotaService: "a.QuotaService{},
|
||||||
|
AuthInfoService: authInfoMock,
|
||||||
}
|
}
|
||||||
|
|
||||||
upserCmd := &models.UpsertUserCommand{ExternalUser: &models.ExternalUserInfo{Email: "test_user@example.org"}}
|
upserCmd := &models.UpsertUserCommand{ExternalUser: &models.ExternalUserInfo{Email: "test_user@example.org"}}
|
||||||
@ -89,13 +101,9 @@ func Test_teamSync(t *testing.T) {
|
|||||||
Name: "test_user",
|
Name: "test_user",
|
||||||
Login: "test_user",
|
Login: "test_user",
|
||||||
}
|
}
|
||||||
|
authInfoMock.user = expectedUser
|
||||||
bus.ClearBusHandlers()
|
bus.ClearBusHandlers()
|
||||||
t.Cleanup(func() { bus.ClearBusHandlers() })
|
t.Cleanup(func() { bus.ClearBusHandlers() })
|
||||||
bus.AddHandler("test", func(query *models.GetUserByAuthInfoQuery) error {
|
|
||||||
query.Result = expectedUser
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
var actualUser *models.User
|
var actualUser *models.User
|
||||||
var actualExternalUser *models.ExternalUserInfo
|
var actualExternalUser *models.ExternalUserInfo
|
||||||
|
7
pkg/services/login/userprotection.go
Normal file
7
pkg/services/login/userprotection.go
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
package login
|
||||||
|
|
||||||
|
import "github.com/grafana/grafana/pkg/models"
|
||||||
|
|
||||||
|
type UserProtectionService interface {
|
||||||
|
AllowUserMapping(user *models.User, authModule string) error
|
||||||
|
}
|
@ -90,8 +90,6 @@ func TestTeamCommandsAndQueries(t *testing.T) {
|
|||||||
|
|
||||||
Convey("Should return latest auth module for users when getting team members", func() {
|
Convey("Should return latest auth module for users when getting team members", func() {
|
||||||
userId := userIds[1]
|
userId := userIds[1]
|
||||||
err := SetAuthInfo(&models.SetAuthInfoCommand{UserId: userId, AuthModule: "oauth_github", AuthId: "1234567"})
|
|
||||||
So(err, ShouldBeNil)
|
|
||||||
|
|
||||||
teamQuery := &models.SearchTeamsQuery{OrgId: testOrgID, Name: "group1 name", Page: 1, Limit: 10}
|
teamQuery := &models.SearchTeamsQuery{OrgId: testOrgID, Name: "group1 name", Page: 1, Limit: 10}
|
||||||
err = SearchTeams(teamQuery)
|
err = SearchTeams(teamQuery)
|
||||||
@ -111,7 +109,6 @@ func TestTeamCommandsAndQueries(t *testing.T) {
|
|||||||
So(memberQuery.Result[0].Login, ShouldEqual, "loginuser1")
|
So(memberQuery.Result[0].Login, ShouldEqual, "loginuser1")
|
||||||
So(memberQuery.Result[0].OrgId, ShouldEqual, testOrgID)
|
So(memberQuery.Result[0].OrgId, ShouldEqual, testOrgID)
|
||||||
So(memberQuery.Result[0].External, ShouldEqual, true)
|
So(memberQuery.Result[0].External, ShouldEqual, true)
|
||||||
So(memberQuery.Result[0].AuthModule, ShouldEqual, "oauth_github")
|
|
||||||
})
|
})
|
||||||
|
|
||||||
Convey("Should be able to update users in a team", func() {
|
Convey("Should be able to update users in a team", func() {
|
||||||
|
@ -6,7 +6,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/grafana/grafana/pkg/models"
|
"github.com/grafana/grafana/pkg/models"
|
||||||
"github.com/grafana/grafana/pkg/setting"
|
"github.com/grafana/grafana/pkg/setting"
|
||||||
@ -119,7 +118,7 @@ func TestUserDataAccess(t *testing.T) {
|
|||||||
t.Run("Testing DB - multiple users", func(t *testing.T) {
|
t.Run("Testing DB - multiple users", func(t *testing.T) {
|
||||||
ss = InitTestDB(t)
|
ss = InitTestDB(t)
|
||||||
|
|
||||||
users := createFiveTestUsers(t, ss, func(i int) *models.CreateUserCommand {
|
createFiveTestUsers(t, ss, func(i int) *models.CreateUserCommand {
|
||||||
return &models.CreateUserCommand{
|
return &models.CreateUserCommand{
|
||||||
Email: fmt.Sprint("user", i, "@test.com"),
|
Email: fmt.Sprint("user", i, "@test.com"),
|
||||||
Name: fmt.Sprint("user", i),
|
Name: fmt.Sprint("user", i),
|
||||||
@ -188,48 +187,6 @@ func TestUserDataAccess(t *testing.T) {
|
|||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Len(t, query.Result.Users, 1)
|
require.Len(t, query.Result.Users, 1)
|
||||||
require.EqualValues(t, query.Result.TotalCount, 1)
|
require.EqualValues(t, query.Result.TotalCount, 1)
|
||||||
|
|
||||||
// Return list users based on their auth type
|
|
||||||
for index, user := range users {
|
|
||||||
authModule := "killa"
|
|
||||||
|
|
||||||
// define every second user as ldap
|
|
||||||
if index%2 == 0 {
|
|
||||||
authModule = "ldap"
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd2 := &models.SetAuthInfoCommand{
|
|
||||||
UserId: user.Id,
|
|
||||||
AuthModule: authModule,
|
|
||||||
AuthId: "gorilla",
|
|
||||||
}
|
|
||||||
err := SetAuthInfo(cmd2)
|
|
||||||
require.Nil(t, err)
|
|
||||||
}
|
|
||||||
query = models.SearchUsersQuery{AuthModule: "ldap"}
|
|
||||||
err = SearchUsers(&query)
|
|
||||||
require.Nil(t, err)
|
|
||||||
|
|
||||||
require.Len(t, query.Result.Users, 3)
|
|
||||||
|
|
||||||
zero, second, fourth := false, false, false
|
|
||||||
for _, user := range query.Result.Users {
|
|
||||||
if user.Name == "user0" {
|
|
||||||
zero = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if user.Name == "user2" {
|
|
||||||
second = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if user.Name == "user4" {
|
|
||||||
fourth = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
require.True(t, zero)
|
|
||||||
require.True(t, second)
|
|
||||||
require.True(t, fourth)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("Testing DB - return list users based on their is_disabled flag", func(t *testing.T) {
|
t.Run("Testing DB - return list users based on their is_disabled flag", func(t *testing.T) {
|
||||||
@ -490,107 +447,6 @@ func TestUserDataAccess(t *testing.T) {
|
|||||||
IsDisabled: false,
|
IsDisabled: false,
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
// Find a user to set tokens on
|
|
||||||
login := "loginuser0"
|
|
||||||
|
|
||||||
// Calling GetUserByAuthInfoQuery on an existing user will populate an entry in the user_auth table
|
|
||||||
// Make the first log-in during the past
|
|
||||||
getTime = func() time.Time { return time.Now().AddDate(0, 0, -2) }
|
|
||||||
query := &models.GetUserByAuthInfoQuery{Login: login, AuthModule: "ldap", AuthId: "ldap0"}
|
|
||||||
err := GetUserByAuthInfo(query)
|
|
||||||
getTime = time.Now
|
|
||||||
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, query.Result.Login, login)
|
|
||||||
|
|
||||||
// Add a second auth module for this user
|
|
||||||
// Have this module's last log-in be more recent
|
|
||||||
getTime = func() time.Time { return time.Now().AddDate(0, 0, -1) }
|
|
||||||
query = &models.GetUserByAuthInfoQuery{Login: login, AuthModule: "oauth", AuthId: "oauth0"}
|
|
||||||
err = GetUserByAuthInfo(query)
|
|
||||||
getTime = time.Now
|
|
||||||
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, query.Result.Login, login)
|
|
||||||
|
|
||||||
// Return the only most recently used auth_module
|
|
||||||
searchUserQuery := &models.SearchUsersQuery{}
|
|
||||||
err = SearchUsers(searchUserQuery)
|
|
||||||
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Len(t, searchUserQuery.Result.Users, 5)
|
|
||||||
for _, user := range searchUserQuery.Result.Users {
|
|
||||||
if user.Login == login {
|
|
||||||
require.Len(t, user.AuthModule, 1)
|
|
||||||
require.Equal(t, user.AuthModule[0], "oauth")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// "log in" again with the first auth module
|
|
||||||
updateAuthCmd := &models.UpdateAuthInfoCommand{UserId: query.Result.Id, AuthModule: "ldap", AuthId: "ldap1"}
|
|
||||||
err = UpdateAuthInfo(updateAuthCmd)
|
|
||||||
require.Nil(t, err)
|
|
||||||
|
|
||||||
searchUserQuery = &models.SearchUsersQuery{}
|
|
||||||
err = SearchUsers(searchUserQuery)
|
|
||||||
|
|
||||||
require.Nil(t, err)
|
|
||||||
for _, user := range searchUserQuery.Result.Users {
|
|
||||||
if user.Login == login {
|
|
||||||
require.Len(t, user.AuthModule, 1)
|
|
||||||
require.Equal(t, user.AuthModule[0], "ldap")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Re-init DB
|
|
||||||
ss = InitTestDB(t)
|
|
||||||
createFiveTestUsers(t, ss, func(i int) *models.CreateUserCommand {
|
|
||||||
return &models.CreateUserCommand{
|
|
||||||
Email: fmt.Sprint("user", i, "@test.com"),
|
|
||||||
Name: fmt.Sprint("user", i),
|
|
||||||
Login: fmt.Sprint("loginuser", i),
|
|
||||||
IsDisabled: false,
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
// Search LDAP users
|
|
||||||
for i := 0; i < 5; i++ {
|
|
||||||
// Find a user to set tokens on
|
|
||||||
login = fmt.Sprint("loginuser", i)
|
|
||||||
|
|
||||||
// Calling GetUserByAuthInfoQuery on an existing user will populate an entry in the user_auth table
|
|
||||||
// Make the first log-in during the past
|
|
||||||
getTime = func() time.Time { return time.Now().AddDate(0, 0, -2) }
|
|
||||||
query = &models.GetUserByAuthInfoQuery{Login: login, AuthModule: "ldap", AuthId: fmt.Sprint("ldap", i)}
|
|
||||||
err = GetUserByAuthInfo(query)
|
|
||||||
getTime = time.Now
|
|
||||||
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, query.Result.Login, login)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Log in first user with oauth
|
|
||||||
login = "loginuser0"
|
|
||||||
getTime = func() time.Time { return time.Now().AddDate(0, 0, -1) }
|
|
||||||
query = &models.GetUserByAuthInfoQuery{Login: login, AuthModule: "oauth", AuthId: "oauth0"}
|
|
||||||
err = GetUserByAuthInfo(query)
|
|
||||||
getTime = time.Now
|
|
||||||
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, query.Result.Login, login)
|
|
||||||
|
|
||||||
// Should only return users recently logged in with ldap when filtered by ldap auth module
|
|
||||||
searchUserQuery = &models.SearchUsersQuery{AuthModule: "ldap"}
|
|
||||||
err = SearchUsers(searchUserQuery)
|
|
||||||
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Len(t, searchUserQuery.Result.Users, 4)
|
|
||||||
for _, user := range searchUserQuery.Result.Users {
|
|
||||||
if user.Login == login {
|
|
||||||
require.Len(t, user.AuthModule, 1)
|
|
||||||
require.Equal(t, user.AuthModule[0], "ldap")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("Testing DB - grafana admin users", func(t *testing.T) {
|
t.Run("Testing DB - grafana admin users", func(t *testing.T) {
|
||||||
|
Reference in New Issue
Block a user