mirror of
https://github.com/grafana/grafana.git
synced 2025-08-02 23:53:10 +08:00
User: email verification completion (#85259)
* TempUser: Include InvitedById in TempUserDTO * Extract email verfication completion flow to service
This commit is contained in:
@ -4,11 +4,9 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/mail"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/grafana/grafana/pkg/api/dtos"
|
||||
"github.com/grafana/grafana/pkg/api/response"
|
||||
@ -17,7 +15,6 @@ import (
|
||||
"github.com/grafana/grafana/pkg/services/login"
|
||||
"github.com/grafana/grafana/pkg/services/org"
|
||||
"github.com/grafana/grafana/pkg/services/team"
|
||||
tempuser "github.com/grafana/grafana/pkg/services/temp_user"
|
||||
"github.com/grafana/grafana/pkg/services/user"
|
||||
"github.com/grafana/grafana/pkg/util"
|
||||
"github.com/grafana/grafana/pkg/web"
|
||||
@ -275,7 +272,7 @@ func (hs *HTTPServer) handleUpdateUser(ctx context.Context, cmd user.UpdateUserC
|
||||
}
|
||||
|
||||
func (hs *HTTPServer) verifyEmailUpdate(ctx context.Context, email string, field user.UpdateEmailActionType, usr *user.User) response.Response {
|
||||
if err := hs.userVerifier.VerifyEmail(ctx, user.VerifyEmailCommand{
|
||||
if err := hs.userVerifier.Start(ctx, user.StartVerifyEmailCommand{
|
||||
User: *usr,
|
||||
Email: email,
|
||||
Action: field,
|
||||
@ -295,37 +292,15 @@ func (hs *HTTPServer) verifyEmailUpdate(ctx context.Context, email string, field
|
||||
// Responses:
|
||||
// 302: okResponse
|
||||
func (hs *HTTPServer) UpdateUserEmail(c *contextmodel.ReqContext) response.Response {
|
||||
var err error
|
||||
|
||||
q := c.Req.URL.Query()
|
||||
code, err := url.QueryUnescape(q.Get("code"))
|
||||
code, err := url.QueryUnescape(c.Req.URL.Query().Get("code"))
|
||||
if err != nil || code == "" {
|
||||
return hs.RedirectResponseWithError(c, errors.New("bad request data"))
|
||||
}
|
||||
|
||||
tempUser, err := hs.validateEmailCode(c.Req.Context(), code)
|
||||
if err != nil {
|
||||
if err := hs.userVerifier.Complete(c.Req.Context(), user.CompleteEmailVerifyCommand{Code: code}); err != nil {
|
||||
return hs.RedirectResponseWithError(c, err)
|
||||
}
|
||||
|
||||
cmd, err := hs.updateCmdFromEmailVerification(c.Req.Context(), tempUser)
|
||||
if err != nil {
|
||||
return hs.RedirectResponseWithError(c, err)
|
||||
}
|
||||
|
||||
if err := hs.userService.Update(c.Req.Context(), cmd); err != nil {
|
||||
if errors.Is(err, user.ErrCaseInsensitive) {
|
||||
return hs.RedirectResponseWithError(c, errors.New("update would result in user login conflict"))
|
||||
}
|
||||
return hs.RedirectResponseWithError(c, errors.New("failed to update user"))
|
||||
}
|
||||
|
||||
// Mark temp user as completed
|
||||
updateTmpUserCmd := tempuser.UpdateTempUserStatusCommand{Code: code, Status: tempuser.TmpUserEmailUpdateCompleted}
|
||||
if err := hs.tempUserService.UpdateTempUserStatus(c.Req.Context(), &updateTmpUserCmd); err != nil {
|
||||
return hs.RedirectResponseWithError(c, errors.New("failed to update verification status"))
|
||||
}
|
||||
|
||||
return response.Redirect(hs.Cfg.AppSubURL + "/profile")
|
||||
}
|
||||
|
||||
@ -694,57 +669,6 @@ func getUserID(c *contextmodel.ReqContext) (int64, *response.NormalResponse) {
|
||||
return userID, nil
|
||||
}
|
||||
|
||||
func (hs *HTTPServer) updateCmdFromEmailVerification(ctx context.Context, tempUser *tempuser.TempUserDTO) (*user.UpdateUserCommand, error) {
|
||||
userQuery := user.GetUserByLoginQuery{LoginOrEmail: tempUser.InvitedByLogin}
|
||||
usr, err := hs.userService.GetByLogin(ctx, &userQuery)
|
||||
if err != nil {
|
||||
if errors.Is(err, user.ErrUserNotFound) {
|
||||
return nil, user.ErrUserNotFound
|
||||
}
|
||||
return nil, errors.New("failed to get user")
|
||||
}
|
||||
|
||||
cmd := &user.UpdateUserCommand{UserID: usr.ID, Email: tempUser.Email}
|
||||
|
||||
switch tempUser.Name {
|
||||
case string(user.EmailUpdateAction):
|
||||
// User updated the email field
|
||||
if _, err := mail.ParseAddress(usr.Login); err == nil {
|
||||
// If username was also an email, we update it to keep it in sync with the email field
|
||||
cmd.Login = tempUser.Email
|
||||
}
|
||||
case string(user.LoginUpdateAction):
|
||||
// User updated the username field with a new email
|
||||
cmd.Login = tempUser.Email
|
||||
default:
|
||||
return nil, errors.New("trying to update email on unknown field")
|
||||
}
|
||||
return cmd, nil
|
||||
}
|
||||
|
||||
func (hs *HTTPServer) validateEmailCode(ctx context.Context, code string) (*tempuser.TempUserDTO, error) {
|
||||
tempUserQuery := tempuser.GetTempUserByCodeQuery{Code: code}
|
||||
tempUser, err := hs.tempUserService.GetTempUserByCode(ctx, &tempUserQuery)
|
||||
if err != nil {
|
||||
if errors.Is(err, tempuser.ErrTempUserNotFound) {
|
||||
return nil, errors.New("invalid email verification code")
|
||||
}
|
||||
return nil, errors.New("failed to read temp user")
|
||||
}
|
||||
|
||||
if tempUser.Status != tempuser.TmpUserEmailUpdateStarted {
|
||||
return nil, errors.New("invalid email verification code")
|
||||
}
|
||||
if !tempUser.EmailSent {
|
||||
return nil, errors.New("verification email was not recorded as sent")
|
||||
}
|
||||
if tempUser.EmailSentOn.Add(hs.Cfg.VerificationEmailMaxLifetime).Before(time.Now()) {
|
||||
return nil, errors.New("invalid email verification code")
|
||||
}
|
||||
|
||||
return tempUser, nil
|
||||
}
|
||||
|
||||
// swagger:parameters searchUsers
|
||||
type SearchUsersParams struct {
|
||||
// Limit the maximum number of users to return per page
|
||||
|
@ -397,7 +397,7 @@ func setupUpdateEmailTests(t *testing.T, cfg *setting.Cfg) (*user.User, *HTTPSer
|
||||
require.NoError(t, err)
|
||||
|
||||
nsMock := notifications.MockNotificationService()
|
||||
verifier := userimpl.ProvideVerifier(userSvc, tempUserService, nsMock)
|
||||
verifier := userimpl.ProvideVerifier(cfg, userSvc, tempUserService, nsMock)
|
||||
|
||||
hs := &HTTPServer{
|
||||
Cfg: cfg,
|
||||
@ -620,7 +620,7 @@ func TestUser_UpdateEmail(t *testing.T) {
|
||||
hs.tempUserService = tempUserSvc
|
||||
hs.NotificationService = nsMock
|
||||
hs.SecretsService = fakes.NewFakeSecretsService()
|
||||
hs.userVerifier = userimpl.ProvideVerifier(userSvc, tempUserSvc, nsMock)
|
||||
hs.userVerifier = userimpl.ProvideVerifier(settings, userSvc, tempUserSvc, nsMock)
|
||||
// User is internal
|
||||
hs.authInfoService = &authinfotest.FakeService{ExpectedError: user.ErrUserNotFound}
|
||||
})
|
||||
|
@ -96,6 +96,7 @@ type TempUserDTO struct {
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
Role org.RoleType `json:"role"`
|
||||
InvitedByID int64 `json:"-" xorm:"invited_by_id"`
|
||||
InvitedByLogin string `json:"invitedByLogin"`
|
||||
InvitedByEmail string `json:"invitedByEmail"`
|
||||
InvitedByName string `json:"invitedByName"`
|
||||
|
@ -129,18 +129,19 @@ func (ss *xormStore) GetTempUserByCode(ctx context.Context, query *tempuser.GetT
|
||||
tu.id as id,
|
||||
tu.org_id as org_id,
|
||||
tu.email as email,
|
||||
tu.name as name,
|
||||
tu.role as role,
|
||||
tu.code as code,
|
||||
tu.status as status,
|
||||
tu.email_sent as email_sent,
|
||||
tu.email_sent_on as email_sent_on,
|
||||
tu.created as created,
|
||||
u.login as invited_by_login,
|
||||
u.name as invited_by_name,
|
||||
u.email as invited_by_email
|
||||
tu.name as name,
|
||||
tu.role as role,
|
||||
tu.code as code,
|
||||
tu.status as status,
|
||||
tu.email_sent as email_sent,
|
||||
tu.email_sent_on as email_sent_on,
|
||||
tu.created as created,
|
||||
tu.invited_by_user_id as invited_by_id,
|
||||
u.login as invited_by_login,
|
||||
u.name as invited_by_name,
|
||||
u.email as invited_by_email
|
||||
FROM ` + ss.db.GetDialect().Quote("temp_user") + ` as tu
|
||||
LEFT OUTER JOIN ` + ss.db.GetDialect().Quote("user") + ` as u on u.id = tu.invited_by_user_id
|
||||
LEFT OUTER JOIN ` + ss.db.GetDialect().Quote("user") + ` as u on u.id = tu.invited_by_user_id
|
||||
WHERE tu.code=?`
|
||||
|
||||
var tempUser tempuser.TempUserDTO
|
||||
|
@ -10,11 +10,27 @@ var _ tempuser.Service = (*FakeTempUserService)(nil)
|
||||
|
||||
type FakeTempUserService struct {
|
||||
tempuser.Service
|
||||
GetTempUserByCodeFN func(ctx context.Context, query *tempuser.GetTempUserByCodeQuery) (*tempuser.TempUserDTO, error)
|
||||
UpdateTempUserStatusFN func(ctx context.Context, cmd *tempuser.UpdateTempUserStatusCommand) error
|
||||
CreateTempUserFN func(ctx context.Context, cmd *tempuser.CreateTempUserCommand) (*tempuser.TempUser, error)
|
||||
ExpirePreviousVerificationsFN func(ctx context.Context, cmd *tempuser.ExpirePreviousVerificationsCommand) error
|
||||
UpdateTempUserWithEmailSentFN func(ctx context.Context, cmd *tempuser.UpdateTempUserWithEmailSentCommand) error
|
||||
}
|
||||
|
||||
func (f *FakeTempUserService) GetTempUserByCode(ctx context.Context, query *tempuser.GetTempUserByCodeQuery) (*tempuser.TempUserDTO, error) {
|
||||
if f.GetTempUserByCodeFN != nil {
|
||||
return f.GetTempUserByCodeFN(ctx, query)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f *FakeTempUserService) UpdateTempUserStatus(ctx context.Context, cmd *tempuser.UpdateTempUserStatusCommand) error {
|
||||
if f.UpdateTempUserStatusFN != nil {
|
||||
return f.UpdateTempUserStatusFN(ctx, cmd)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *FakeTempUserService) CreateTempUser(ctx context.Context, cmd *tempuser.CreateTempUserCommand) (*tempuser.TempUser, error) {
|
||||
if f.CreateTempUserFN != nil {
|
||||
return f.CreateTempUserFN(ctx, cmd)
|
||||
|
@ -82,7 +82,8 @@ type UpdateUserCommand struct {
|
||||
Login string `json:"login"`
|
||||
Theme string `json:"theme"`
|
||||
|
||||
UserID int64 `json:"-"`
|
||||
UserID int64 `json:"-"`
|
||||
EmailVerified *bool `json:"-"`
|
||||
}
|
||||
|
||||
type ChangeUserPasswordCommand struct {
|
||||
@ -220,12 +221,16 @@ type GetUserByIDQuery struct {
|
||||
ID int64
|
||||
}
|
||||
|
||||
type VerifyEmailCommand struct {
|
||||
type StartVerifyEmailCommand struct {
|
||||
User User
|
||||
Email string
|
||||
Action UpdateEmailActionType
|
||||
}
|
||||
|
||||
type CompleteEmailVerifyCommand struct {
|
||||
Code string
|
||||
}
|
||||
|
||||
type ErrCaseInsensitiveLoginConflict struct {
|
||||
Users []User
|
||||
}
|
||||
|
@ -31,5 +31,6 @@ type Service interface {
|
||||
}
|
||||
|
||||
type Verifier interface {
|
||||
VerifyEmail(ctx context.Context, cmd VerifyEmailCommand) error
|
||||
Start(ctx context.Context, cmd StartVerifyEmailCommand) error
|
||||
Complete(ctx context.Context, cmd CompleteEmailVerifyCommand) error
|
||||
}
|
||||
|
@ -315,7 +315,14 @@ func (ss *sqlStore) Update(ctx context.Context, cmd *user.UpdateUserCommand) err
|
||||
Updated: time.Now(),
|
||||
}
|
||||
|
||||
if _, err := sess.ID(cmd.UserID).Where(ss.notServiceAccountFilter()).Update(&user); err != nil {
|
||||
q := sess.ID(cmd.UserID).Where(ss.notServiceAccountFilter())
|
||||
|
||||
if cmd.EmailVerified != nil {
|
||||
q.UseBool("email_verified")
|
||||
user.EmailVerified = *cmd.EmailVerified
|
||||
}
|
||||
|
||||
if _, err := q.Update(&user); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -4,26 +4,36 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/mail"
|
||||
"time"
|
||||
|
||||
"github.com/grafana/grafana/pkg/services/notifications"
|
||||
tempuser "github.com/grafana/grafana/pkg/services/temp_user"
|
||||
"github.com/grafana/grafana/pkg/services/user"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/util"
|
||||
"github.com/grafana/grafana/pkg/util/errutil"
|
||||
)
|
||||
|
||||
var (
|
||||
errInvalidCode = errutil.BadRequest("user.code.invalid", errutil.WithPublicMessage("Invalid verification code"))
|
||||
errExpiredCode = errutil.BadRequest("user.code.expired", errutil.WithPublicMessage("Verification code has expired"))
|
||||
)
|
||||
|
||||
var _ user.Verifier = (*Verifier)(nil)
|
||||
|
||||
func ProvideVerifier(us user.Service, ts tempuser.Service, ns notifications.Service) *Verifier {
|
||||
return &Verifier{us, ts, ns}
|
||||
func ProvideVerifier(cfg *setting.Cfg, us user.Service, ts tempuser.Service, ns notifications.Service) *Verifier {
|
||||
return &Verifier{cfg, us, ts, ns}
|
||||
}
|
||||
|
||||
type Verifier struct {
|
||||
us user.Service
|
||||
ts tempuser.Service
|
||||
ns notifications.Service
|
||||
cfg *setting.Cfg
|
||||
us user.Service
|
||||
ts tempuser.Service
|
||||
ns notifications.Service
|
||||
}
|
||||
|
||||
func (s *Verifier) VerifyEmail(ctx context.Context, cmd user.VerifyEmailCommand) error {
|
||||
func (s *Verifier) Start(ctx context.Context, cmd user.StartVerifyEmailCommand) error {
|
||||
usr, err := s.us.GetByLogin(ctx, &user.GetUserByLoginQuery{
|
||||
LoginOrEmail: cmd.Email,
|
||||
})
|
||||
@ -80,3 +90,60 @@ func (s *Verifier) VerifyEmail(ctx context.Context, cmd user.VerifyEmailCommand)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Verifier) Complete(ctx context.Context, cmd user.CompleteEmailVerifyCommand) error {
|
||||
tmpUsr, err := s.ts.GetTempUserByCode(ctx, &tempuser.GetTempUserByCodeQuery{Code: cmd.Code})
|
||||
if err != nil {
|
||||
return errInvalidCode.Errorf("failed to verify code: %w", err)
|
||||
}
|
||||
|
||||
if tmpUsr.Status != tempuser.TmpUserEmailUpdateStarted {
|
||||
return errInvalidCode.Errorf("wrong status for verification code: %s", tmpUsr.Status)
|
||||
}
|
||||
|
||||
if !tmpUsr.EmailSent {
|
||||
return errInvalidCode.Errorf("email was not marked as sent")
|
||||
}
|
||||
|
||||
if tmpUsr.EmailSentOn.Add(s.cfg.VerificationEmailMaxLifetime).Before(time.Now()) {
|
||||
return errExpiredCode.Errorf("verification code has expired")
|
||||
}
|
||||
|
||||
usr, err := s.us.GetByID(ctx, &user.GetUserByIDQuery{ID: tmpUsr.InvitedByID})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
verified := true
|
||||
update := &user.UpdateUserCommand{
|
||||
Email: tmpUsr.Email,
|
||||
UserID: tmpUsr.InvitedByID,
|
||||
EmailVerified: &verified,
|
||||
}
|
||||
switch tmpUsr.Name {
|
||||
case string(user.EmailUpdateAction):
|
||||
// User updated the email field
|
||||
if _, err := mail.ParseAddress(usr.Login); err == nil {
|
||||
// If username was also an email, we update it to keep it in sync with the email field
|
||||
update.Login = tmpUsr.Email
|
||||
}
|
||||
case string(user.LoginUpdateAction):
|
||||
// User updated the username field with a new email
|
||||
update.Login = tmpUsr.Email
|
||||
default:
|
||||
return errors.New("trying to update email on unknown field")
|
||||
}
|
||||
|
||||
if err := s.us.Update(ctx, update); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.ts.UpdateTempUserStatus(
|
||||
ctx,
|
||||
&tempuser.UpdateTempUserStatusCommand{Code: cmd.Code, Status: tempuser.TmpUserEmailUpdateCompleted},
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -3,6 +3,7 @@ package userimpl
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
@ -11,9 +12,10 @@ import (
|
||||
"github.com/grafana/grafana/pkg/services/temp_user/tempusertest"
|
||||
"github.com/grafana/grafana/pkg/services/user"
|
||||
"github.com/grafana/grafana/pkg/services/user/usertest"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
)
|
||||
|
||||
func TestVerifier_VerifyEmail(t *testing.T) {
|
||||
func TestVerifier_Start(t *testing.T) {
|
||||
ts := &tempusertest.FakeTempUserService{}
|
||||
us := &usertest.FakeUserService{}
|
||||
ns := notifications.MockNotificationService()
|
||||
@ -24,10 +26,10 @@ func TestVerifier_VerifyEmail(t *testing.T) {
|
||||
updateCalled bool
|
||||
}
|
||||
|
||||
verifier := ProvideVerifier(us, ts, ns)
|
||||
verifier := ProvideVerifier(setting.NewCfg(), us, ts, ns)
|
||||
t.Run("should error if email already exist for other user", func(t *testing.T) {
|
||||
us.ExpectedUser = &user.User{ID: 1}
|
||||
err := verifier.VerifyEmail(context.Background(), user.VerifyEmailCommand{
|
||||
err := verifier.Start(context.Background(), user.StartVerifyEmailCommand{
|
||||
User: user.User{ID: 2},
|
||||
Email: "some@email.com",
|
||||
Action: user.EmailUpdateAction,
|
||||
@ -59,13 +61,13 @@ func TestVerifier_VerifyEmail(t *testing.T) {
|
||||
c.updateCalled = true
|
||||
return nil
|
||||
}
|
||||
err := verifier.VerifyEmail(context.Background(), user.VerifyEmailCommand{
|
||||
err := verifier.Start(context.Background(), user.StartVerifyEmailCommand{
|
||||
User: user.User{ID: 2},
|
||||
Email: "some@email.com",
|
||||
Action: user.EmailUpdateAction,
|
||||
})
|
||||
|
||||
assert.ErrorIs(t, err, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, c.expireCalled)
|
||||
assert.True(t, c.createCalled)
|
||||
assert.True(t, c.updateCalled)
|
||||
@ -94,7 +96,7 @@ func TestVerifier_VerifyEmail(t *testing.T) {
|
||||
c.updateCalled = true
|
||||
return nil
|
||||
}
|
||||
err := verifier.VerifyEmail(context.Background(), user.VerifyEmailCommand{
|
||||
err := verifier.Start(context.Background(), user.StartVerifyEmailCommand{
|
||||
User: user.User{ID: 2},
|
||||
Email: "some@email.com",
|
||||
Action: user.EmailUpdateAction,
|
||||
@ -106,3 +108,142 @@ func TestVerifier_VerifyEmail(t *testing.T) {
|
||||
assert.True(t, c.updateCalled)
|
||||
})
|
||||
}
|
||||
|
||||
func TestVerifier_Complete(t *testing.T) {
|
||||
ts := &tempusertest.FakeTempUserService{}
|
||||
us := &usertest.FakeUserService{}
|
||||
ns := notifications.MockNotificationService()
|
||||
|
||||
type calls struct {
|
||||
updateCalled bool
|
||||
updateStatusCalled bool
|
||||
}
|
||||
|
||||
cfg := setting.NewCfg()
|
||||
cfg.VerificationEmailMaxLifetime = 1 * time.Hour
|
||||
verifier := ProvideVerifier(cfg, us, ts, ns)
|
||||
t.Run("should return error for invalid code", func(t *testing.T) {
|
||||
ts.GetTempUserByCodeFN = func(ctx context.Context, query *tempuser.GetTempUserByCodeQuery) (*tempuser.TempUserDTO, error) {
|
||||
return nil, tempuser.ErrTempUserNotFound
|
||||
}
|
||||
err := verifier.Complete(context.Background(), user.CompleteEmailVerifyCommand{Code: "some-code"})
|
||||
assert.ErrorIs(t, err, errInvalidCode)
|
||||
})
|
||||
|
||||
t.Run("should return error when verification has wrong status", func(t *testing.T) {
|
||||
ts.GetTempUserByCodeFN = func(ctx context.Context, query *tempuser.GetTempUserByCodeQuery) (*tempuser.TempUserDTO, error) {
|
||||
return &tempuser.TempUserDTO{
|
||||
Status: tempuser.TmpUserEmailUpdateCompleted,
|
||||
}, nil
|
||||
}
|
||||
err := verifier.Complete(context.Background(), user.CompleteEmailVerifyCommand{Code: "some-code"})
|
||||
assert.ErrorIs(t, err, errInvalidCode)
|
||||
})
|
||||
|
||||
t.Run("should return error when verification email was never sent", func(t *testing.T) {
|
||||
ts.GetTempUserByCodeFN = func(ctx context.Context, query *tempuser.GetTempUserByCodeQuery) (*tempuser.TempUserDTO, error) {
|
||||
return &tempuser.TempUserDTO{
|
||||
Status: tempuser.TmpUserEmailUpdateStarted,
|
||||
EmailSent: false,
|
||||
}, nil
|
||||
}
|
||||
err := verifier.Complete(context.Background(), user.CompleteEmailVerifyCommand{Code: "some-code"})
|
||||
assert.ErrorIs(t, err, errInvalidCode)
|
||||
})
|
||||
|
||||
t.Run("should return error when verification code has expired", func(t *testing.T) {
|
||||
ts.GetTempUserByCodeFN = func(ctx context.Context, query *tempuser.GetTempUserByCodeQuery) (*tempuser.TempUserDTO, error) {
|
||||
return &tempuser.TempUserDTO{
|
||||
Status: tempuser.TmpUserEmailUpdateStarted,
|
||||
EmailSent: true,
|
||||
EmailSentOn: time.Now().Add(-10 * time.Hour),
|
||||
}, nil
|
||||
}
|
||||
err := verifier.Complete(context.Background(), user.CompleteEmailVerifyCommand{Code: "some-code"})
|
||||
assert.ErrorIs(t, err, errExpiredCode)
|
||||
})
|
||||
|
||||
t.Run("should return error user connect to code don't exists", func(t *testing.T) {
|
||||
ts.GetTempUserByCodeFN = func(ctx context.Context, query *tempuser.GetTempUserByCodeQuery) (*tempuser.TempUserDTO, error) {
|
||||
return &tempuser.TempUserDTO{
|
||||
Status: tempuser.TmpUserEmailUpdateStarted,
|
||||
EmailSent: true,
|
||||
EmailSentOn: time.Now(),
|
||||
}, nil
|
||||
}
|
||||
us.ExpectedError = user.ErrUserNotFound
|
||||
|
||||
err := verifier.Complete(context.Background(), user.CompleteEmailVerifyCommand{Code: "some-code"})
|
||||
assert.ErrorIs(t, err, user.ErrUserNotFound)
|
||||
})
|
||||
|
||||
t.Run("should update user email on valid code", func(t *testing.T) {
|
||||
var c calls
|
||||
ts.GetTempUserByCodeFN = func(ctx context.Context, query *tempuser.GetTempUserByCodeQuery) (*tempuser.TempUserDTO, error) {
|
||||
return &tempuser.TempUserDTO{
|
||||
Status: tempuser.TmpUserEmailUpdateStarted,
|
||||
Name: string(user.EmailUpdateAction),
|
||||
InvitedByID: 1,
|
||||
Email: "updated@email.com",
|
||||
EmailSent: true,
|
||||
EmailSentOn: time.Now(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
ts.UpdateTempUserStatusFN = func(ctx context.Context, cmd *tempuser.UpdateTempUserStatusCommand) error {
|
||||
c.updateStatusCalled = true
|
||||
return nil
|
||||
}
|
||||
|
||||
us.ExpectedUser = &user.User{Email: "initial@email.com"}
|
||||
us.ExpectedError = nil
|
||||
us.UpdateFn = func(ctx context.Context, cmd *user.UpdateUserCommand) error {
|
||||
c.updateCalled = true
|
||||
assert.True(t, *cmd.EmailVerified)
|
||||
assert.Equal(t, int64(1), cmd.UserID)
|
||||
assert.Equal(t, "", cmd.Login)
|
||||
assert.Equal(t, "updated@email.com", cmd.Email)
|
||||
return nil
|
||||
}
|
||||
|
||||
err := verifier.Complete(context.Background(), user.CompleteEmailVerifyCommand{Code: "some-code"})
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, c.updateCalled)
|
||||
assert.True(t, c.updateStatusCalled)
|
||||
})
|
||||
|
||||
t.Run("should update user email and login if login is an email on valid code", func(t *testing.T) {
|
||||
var c calls
|
||||
ts.GetTempUserByCodeFN = func(ctx context.Context, query *tempuser.GetTempUserByCodeQuery) (*tempuser.TempUserDTO, error) {
|
||||
return &tempuser.TempUserDTO{
|
||||
Status: tempuser.TmpUserEmailUpdateStarted,
|
||||
Name: string(user.EmailUpdateAction),
|
||||
InvitedByID: 1,
|
||||
Email: "updated@email.com",
|
||||
EmailSent: true,
|
||||
EmailSentOn: time.Now(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
ts.UpdateTempUserStatusFN = func(ctx context.Context, cmd *tempuser.UpdateTempUserStatusCommand) error {
|
||||
c.updateStatusCalled = true
|
||||
return nil
|
||||
}
|
||||
|
||||
us.ExpectedUser = &user.User{Email: "initial@email.com", Login: "other@email.com"}
|
||||
us.ExpectedError = nil
|
||||
us.UpdateFn = func(ctx context.Context, cmd *user.UpdateUserCommand) error {
|
||||
c.updateCalled = true
|
||||
assert.True(t, *cmd.EmailVerified)
|
||||
assert.Equal(t, int64(1), cmd.UserID)
|
||||
assert.Equal(t, "updated@email.com", cmd.Email)
|
||||
assert.Equal(t, "updated@email.com", cmd.Login)
|
||||
return nil
|
||||
}
|
||||
|
||||
err := verifier.Complete(context.Background(), user.CompleteEmailVerifyCommand{Code: "some-code"})
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, c.updateCalled)
|
||||
assert.True(t, c.updateStatusCalled)
|
||||
})
|
||||
}
|
||||
|
@ -16,6 +16,7 @@ type FakeUserService struct {
|
||||
ExpectedUserProfileDTOs []*user.UserProfileDTO
|
||||
ExpectedUsageStats map[string]any
|
||||
|
||||
UpdateFn func(ctx context.Context, cmd *user.UpdateUserCommand) error
|
||||
GetSignedInUserFn func(ctx context.Context, query *user.GetSignedInUserQuery) (*user.SignedInUser, error)
|
||||
CreateFn func(ctx context.Context, cmd *user.CreateUserCommand) (*user.User, error)
|
||||
DisableFn func(ctx context.Context, cmd *user.DisableUserCommand) error
|
||||
@ -61,6 +62,9 @@ func (f *FakeUserService) GetByEmail(ctx context.Context, query *user.GetUserByE
|
||||
}
|
||||
|
||||
func (f *FakeUserService) Update(ctx context.Context, cmd *user.UpdateUserCommand) error {
|
||||
if f.UpdateFn != nil {
|
||||
return f.UpdateFn(ctx, cmd)
|
||||
}
|
||||
return f.ExpectedError
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user