package shared import ( "errors" "fmt" "time" "github.com/teamhanko/hanko/backend/v2/flowpilot" "github.com/teamhanko/hanko/backend/v2/persistence/models" "github.com/teamhanko/hanko/backend/v2/rate_limiter" ) type ExchangeToken struct { Action } func (a ExchangeToken) GetName() flowpilot.ActionName { return ActionExchangeToken } func (a ExchangeToken) GetDescription() string { return "Exchange a one time token." } func (a ExchangeToken) Initialize(c flowpilot.InitializationContext) { c.AddInputs( flowpilot.StringInput("token").Hidden(true).Required(true), flowpilot.StringInput("code_verifier").Hidden(true), ) } func (a ExchangeToken) Execute(c flowpilot.ExecutionContext) error { if valid := c.ValidateInputData(); !valid { return c.Error(flowpilot.ErrorFormDataInvalid) } deps := a.GetDeps(c) if deps.Cfg.RateLimiter.Enabled { rateLimitKey := rate_limiter.CreateRateLimitTokenExchangeKey(deps.HttpContext.RealIP()) retryAfterSeconds, ok, err := rate_limiter.Limit2(deps.TokenExchangeRateLimiter, rateLimitKey) if err != nil { return fmt.Errorf("rate limiter failed: %w", err) } if !ok { err = c.Payload().Set("retry_after", retryAfterSeconds) if err != nil { return fmt.Errorf("failed to set a value for retry_after to the payload: %w", err) } return c.Error(ErrorRateLimitExceeded.Wrap(fmt.Errorf("rate limit exceeded for: %s", rateLimitKey))) } } tokenModel, err := deps.Persister.GetTokenPersisterWithConnection(deps.Tx).GetByValue(c.Input().Get("token").String()) if err != nil { return fmt.Errorf("failed to fetch token from db: %w", err) } if tokenModel == nil { return errors.New("token not found") } if tokenModel.PKCECodeVerifier != nil && *tokenModel.PKCECodeVerifier != "" && *tokenModel.PKCECodeVerifier != c.Input().Get("code_verifier").String() { return c.Error(flowpilot.ErrorFormDataInvalid.Wrap(errors.New("code_verifier does not match"))) } if time.Now().UTC().After(tokenModel.ExpiresAt) { return errors.New("token expired") } identity, err := deps.Persister.GetIdentityPersisterWithConnection(deps.Tx).GetByID(*tokenModel.IdentityID) if err != nil { return fmt.Errorf("failed to fetch identity from db: %w", err) } // Set so the issue_session hook knows who to create the session for. if err := c.Stash().Set(StashPathUserID, tokenModel.UserID.String()); err != nil { return fmt.Errorf("failed to set user_id to stash: %w", err) } // Set because the thirdparty/callback endpoint already creates a user. if err := c.Stash().Set(StashPathSkipUserCreation, true); err != nil { return fmt.Errorf("failed to set skip_user_creation to stash: %w", err) } err = deps.Persister.GetTokenPersisterWithConnection(deps.Tx).Delete(*tokenModel) if err != nil { return fmt.Errorf("failed to delete token from db: %w", err) } isSaml := identity.SamlIdentity != nil var onboardingStates []flowpilot.StateName if isSaml { samlProvider, err := deps.SamlService.GetProviderByIssuer(identity.ProviderID) if err != nil { return fmt.Errorf("could not fetch saml provider for identity: %w", err) } mustDoEmailVerification := !samlProvider.GetConfig().SkipEmailVerification && identity.Email != nil && !identity.Email.Verified onboardingStates, err = a.determineOnboardingStates(c, identity, tokenModel.UserCreated, mustDoEmailVerification) } else { mustDoEmailVerification := deps.Cfg.Email.RequireVerification && identity.Email != nil && !identity.Email.Verified onboardingStates, err = a.determineOnboardingStates(c, identity, tokenModel.UserCreated, mustDoEmailVerification) } if err != nil { return fmt.Errorf("failed to determine onboarding states: %w", err) } if err = c.Stash().Set(StashPathLoginMethod, "third_party"); err != nil { return fmt.Errorf("failed to set login_method to the stash: %w", err) } if err = c.Stash().Set(StashPathThirdPartyProvider, identity.ProviderID); err != nil { return fmt.Errorf("failed to set third_party_provider to the stash: %w", err) } c.PreventRevert() return c.Continue(onboardingStates...) } func (a ExchangeToken) determineOnboardingStates(c flowpilot.ExecutionContext, identity *models.Identity, userCreated bool, mustDoEmailVerification bool) ([]flowpilot.StateName, error) { deps := a.GetDeps(c) result := make([]flowpilot.StateName, 0) if mustDoEmailVerification { if err := c.Stash().Set(StashPathEmail, identity.Email.Address); err != nil { return nil, fmt.Errorf("failed to stash email: %w", err) } if err := c.Stash().Set(StashPathPasscodeTemplate, PasscodeTemplateEmailVerification); err != nil { return nil, fmt.Errorf("failed to stash passcode_template: %w", err) } result = append(result, StatePasscodeConfirmation) } if deps.Cfg.Username.Enabled && identity.Email.User.GetUsername() == nil { if (!userCreated && deps.Cfg.Username.AcquireOnLogin) || (userCreated && deps.Cfg.Username.AcquireOnRegistration) { result = append(result, StateOnboardingUsername) } } return append(result, StateSuccess), nil }