Files
hanko/backend/flow_api/flow/shared/hook_issue_session.go
Lennart Fleischmann efeca4a76f feat: custom jwt claims
2025-04-16 15:17:48 +02:00

176 lines
5.2 KiB
Go

package shared
import (
"errors"
"fmt"
"time"
"github.com/gobuffalo/nulls"
"github.com/gofrs/uuid"
auditlog "github.com/teamhanko/hanko/backend/audit_log"
"github.com/teamhanko/hanko/backend/dto"
"github.com/teamhanko/hanko/backend/flowpilot"
"github.com/teamhanko/hanko/backend/persistence/models"
)
type IssueSession struct {
Action
}
func (h IssueSession) Execute(c flowpilot.HookExecutionContext) error {
deps := h.GetDeps(c)
var userId uuid.UUID
var err error
if c.Stash().Get(StashPathUserID).Exists() {
userId, err = uuid.FromString(c.Stash().Get(StashPathUserID).String())
if err != nil {
return fmt.Errorf("failed to parse stashed user_id into a uuid: %w", err)
}
} else {
return errors.New("user_id not found in stash")
}
userModel, err := deps.Persister.GetUserPersisterWithConnection(deps.Tx).Get(userId)
if err != nil {
return fmt.Errorf("failed to fetch user from db: %w", err)
}
userJWT := dto.UserJWTFromUserModel(userModel)
signedSessionToken, rawToken, err := deps.SessionManager.GenerateJWT(userJWT)
if err != nil {
return fmt.Errorf("failed to generate JWT: %w", err)
}
claims, err := dto.GetClaimsFromToken(rawToken)
if err != nil {
return fmt.Errorf("failed to get token claims: %w", err)
}
err = c.Payload().Set("claims", claims)
if err != nil {
return fmt.Errorf("failed to set token claims to payload: %w", err)
}
activeSessions, err := deps.Persister.GetSessionPersisterWithConnection(deps.Tx).ListActive(userId)
if err != nil {
return fmt.Errorf("failed to list active sessions: %w", err)
}
// remove all server side sessions that exceed the limit
if len(activeSessions) >= deps.Cfg.Session.Limit {
for i := deps.Cfg.Session.Limit - 1; i < len(activeSessions); i++ {
err = deps.Persister.GetSessionPersisterWithConnection(deps.Tx).Delete(activeSessions[i])
if err != nil {
return fmt.Errorf("failed to remove latest session: %w", err)
}
}
}
sessionID, _ := rawToken.Get("session_id")
expirationTime := rawToken.Expiration()
sessionModel := models.Session{
ID: uuid.FromStringOrNil(sessionID.(string)),
UserID: userId,
CreatedAt: rawToken.IssuedAt(),
UpdatedAt: rawToken.IssuedAt(),
ExpiresAt: &expirationTime,
LastUsed: rawToken.IssuedAt(),
}
if deps.Cfg.Session.AcquireIPAddress {
sessionModel.IpAddress = nulls.NewString(deps.HttpContext.RealIP())
}
if deps.Cfg.Session.AcquireUserAgent {
sessionModel.UserAgent = nulls.NewString(deps.HttpContext.Request().UserAgent())
}
err = deps.Persister.GetSessionPersisterWithConnection(deps.Tx).Create(sessionModel)
if err != nil {
return fmt.Errorf("failed to store session: %w", err)
}
rememberMeSelected := c.Stash().Get(StashPathRememberMeSelected).Bool()
cookie, err := deps.SessionManager.GenerateCookie(signedSessionToken)
if err != nil {
return fmt.Errorf("failed to generate auth cookie, %w", err)
}
lifespan, err := time.ParseDuration(deps.Cfg.Session.Lifespan)
if err != nil {
return fmt.Errorf("failed to parse session lifespan: %w", err)
}
sessionRetention := "persistent"
if deps.Cfg.Session.Cookie.Retention == "session" ||
(deps.Cfg.Session.Cookie.Retention == "prompt" && !rememberMeSelected) {
// Issue a session cookie.
cookie.MaxAge = 0
sessionRetention = "session"
}
deps.HttpContext.Response().Header().Set("X-Session-Lifetime", fmt.Sprintf("%d", int(lifespan.Seconds())))
deps.HttpContext.Response().Header().Set("X-Session-Retention", fmt.Sprintf("%s", sessionRetention))
if deps.Cfg.Session.EnableAuthTokenHeader {
deps.HttpContext.Response().Header().Set("X-Auth-Token", signedSessionToken)
} else {
deps.HttpContext.SetCookie(cookie)
}
loginMethod := c.Stash().Get(StashPathLoginMethod)
mfaMethod := c.Stash().Get(StashPathMFAUsageMethod)
thirdPartyProvider := c.Stash().Get(StashPathThirdPartyProvider)
// Audit log logins only, because user creation on registration implies that the user is logged
// in after a registration. Only login actions should set the "login_method" stash entry.
if loginMethod.Exists() {
auditLogDetails := []auditlog.DetailOption{
auditlog.Detail("login_method", loginMethod.String()),
auditlog.Detail("flow_id", c.GetFlowID()),
}
if mfaMethod.Exists() {
auditLogDetails = append(
auditLogDetails,
auditlog.Detail("mfa_method", mfaMethod.String()),
)
}
err = deps.AuditLogger.CreateWithConnection(
deps.Tx,
deps.HttpContext,
models.AuditLogLoginSuccess,
&models.User{ID: userId},
err,
auditLogDetails...)
if err != nil {
return fmt.Errorf("could not create audit log: %w", err)
}
}
if loginMethod.Exists() {
if err := c.Payload().Set("last_login.login_method", loginMethod.String()); err != nil {
return fmt.Errorf("failed to set login_method to the payload: %w", err)
}
if thirdPartyProvider.Exists() {
if err := c.Payload().Set("last_login.third_party_provider", thirdPartyProvider.String()); err != nil {
return fmt.Errorf("failed to set third_party_provider to the payload: %w", err)
}
}
if mfaMethod.Exists() {
if err := c.Payload().Set("last_login.mfa_method", mfaMethod.String()); err != nil {
return fmt.Errorf("failed to set mfa_method to the payload: %w", err)
}
}
}
return nil
}