package handler import ( "errors" "fmt" "net/http" "time" "github.com/gobuffalo/pop/v6" "github.com/gofrs/uuid" "github.com/labstack/echo/v4" "github.com/sethvargo/go-limiter" auditlog "github.com/teamhanko/hanko/backend/v2/audit_log" "github.com/teamhanko/hanko/backend/v2/config" "github.com/teamhanko/hanko/backend/v2/dto" "github.com/teamhanko/hanko/backend/v2/persistence" "github.com/teamhanko/hanko/backend/v2/persistence/models" rateLimit "github.com/teamhanko/hanko/backend/v2/rate_limiter" "github.com/teamhanko/hanko/backend/v2/session" ) type TokenHandler struct { persister persistence.Persister sessionManager session.Manager cfg *config.Config auditLogger auditlog.Logger rateLimiter limiter.Store } func NewTokenHandler(cfg *config.Config, persister persistence.Persister, sessionManager session.Manager, auditLogger auditlog.Logger) *TokenHandler { var rateLimiter limiter.Store if cfg.RateLimiter.Enabled { rateLimiter = rateLimit.NewRateLimiter(cfg.RateLimiter, cfg.RateLimiter.TokenLimits) } return &TokenHandler{cfg: cfg, persister: persister, sessionManager: sessionManager, auditLogger: auditLogger, rateLimiter: rateLimiter, } } type TokenValidationBody struct { Value string `json:"value" validate:"required"` } func (h TokenHandler) Validate(c echo.Context) error { if h.rateLimiter != nil { err := rateLimit.Limit(h.rateLimiter, uuid.Nil, c) if err != nil { return err } } var userID uuid.UUID err := h.persister.Transaction(func(tx *pop.Connection) error { var body TokenValidationBody if terr := (&echo.DefaultBinder{}).BindBody(c, &body); terr != nil { return dto.ToHttpError(terr) } if terr := c.Validate(body); terr != nil { return dto.ToHttpError(terr) } tokenPersister := h.persister.GetTokenPersisterWithConnection(tx) token, terr := tokenPersister.GetByValue(body.Value) if terr != nil { return fmt.Errorf("failed to fetch token from db: %w", terr) } if token == nil { return echo.NewHTTPError(http.StatusNotFound, "token not found") } if time.Now().UTC().After(token.ExpiresAt) { return echo.NewHTTPError(http.StatusUnprocessableEntity, "token has expired") } terr = tokenPersister.Delete(*token) if terr != nil { return fmt.Errorf("failed to delete token from db: %w", terr) } emails, err := h.persister.GetEmailPersister().FindByUserId(token.UserID) if err != nil { return fmt.Errorf("failed to get emails from db: %w", err) } var emailJwt *dto.EmailJWT if e := emails.GetPrimary(); e != nil { emailJwt = dto.EmailJWTFromEmailModel(e) } jwtToken, rawToken, err := h.sessionManager.GenerateJWT(dto.UserJWT{ UserID: token.UserID.String(), Email: emailJwt, }) if err != nil { return fmt.Errorf("failed to generate jwt: %w", err) } cookie, err := h.sessionManager.GenerateCookie(jwtToken) if err != nil { return fmt.Errorf("failed to create session token: %w", err) } err = storeSession(h.cfg, h.persister, token.UserID, rawToken, c, h.persister.GetConnection()) if err != nil { return fmt.Errorf("failed to store session in DB: %w", err) } c.Response().Header().Set("X-Session-Lifetime", fmt.Sprintf("%d", cookie.MaxAge)) if h.cfg.Session.EnableAuthTokenHeader { c.Response().Header().Set("X-Auth-Token", jwtToken) } else { c.SetCookie(cookie) } userID = token.UserID return nil }) if err != nil { var httpError *echo.HTTPError if errors.As(err, &httpError) { aerr := h.auditLogger.Create(c, models.AuditLogTokenExchangeFailed, nil, err) if aerr != nil { return fmt.Errorf("could not create audit log: %w", aerr) } } return err } user := &models.User{ID: userID} err = h.auditLogger.Create(c, models.AuditLogTokenExchangeSucceeded, user, nil) if err != nil { return fmt.Errorf("could not create audit log: %w", err) } return c.JSON(http.StatusOK, map[string]string{"user_id": userID.String()}) }