mirror of
https://github.com/teamhanko/hanko.git
synced 2025-10-26 13:27:57 +08:00
147 lines
3.9 KiB
Go
147 lines
3.9 KiB
Go
package handler
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"time"
|
|
|
|
"github.com/gobuffalo/pop/v6"
|
|
"github.com/gofrs/uuid"
|
|
"github.com/labstack/echo/v4"
|
|
"github.com/sethvargo/go-limiter"
|
|
auditlog "github.com/teamhanko/hanko/backend/v2/audit_log"
|
|
"github.com/teamhanko/hanko/backend/v2/config"
|
|
"github.com/teamhanko/hanko/backend/v2/dto"
|
|
"github.com/teamhanko/hanko/backend/v2/persistence"
|
|
"github.com/teamhanko/hanko/backend/v2/persistence/models"
|
|
rateLimit "github.com/teamhanko/hanko/backend/v2/rate_limiter"
|
|
"github.com/teamhanko/hanko/backend/v2/session"
|
|
)
|
|
|
|
type TokenHandler struct {
|
|
persister persistence.Persister
|
|
sessionManager session.Manager
|
|
cfg *config.Config
|
|
auditLogger auditlog.Logger
|
|
rateLimiter limiter.Store
|
|
}
|
|
|
|
func NewTokenHandler(cfg *config.Config, persister persistence.Persister, sessionManager session.Manager, auditLogger auditlog.Logger) *TokenHandler {
|
|
var rateLimiter limiter.Store
|
|
if cfg.RateLimiter.Enabled {
|
|
rateLimiter = rateLimit.NewRateLimiter(cfg.RateLimiter, cfg.RateLimiter.TokenLimits)
|
|
}
|
|
|
|
return &TokenHandler{cfg: cfg,
|
|
persister: persister,
|
|
sessionManager: sessionManager,
|
|
auditLogger: auditLogger,
|
|
rateLimiter: rateLimiter,
|
|
}
|
|
}
|
|
|
|
type TokenValidationBody struct {
|
|
Value string `json:"value" validate:"required"`
|
|
}
|
|
|
|
func (h TokenHandler) Validate(c echo.Context) error {
|
|
if h.rateLimiter != nil {
|
|
err := rateLimit.Limit(h.rateLimiter, uuid.Nil, c)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
var userID uuid.UUID
|
|
err := h.persister.Transaction(func(tx *pop.Connection) error {
|
|
var body TokenValidationBody
|
|
if terr := (&echo.DefaultBinder{}).BindBody(c, &body); terr != nil {
|
|
return dto.ToHttpError(terr)
|
|
}
|
|
|
|
if terr := c.Validate(body); terr != nil {
|
|
return dto.ToHttpError(terr)
|
|
}
|
|
|
|
tokenPersister := h.persister.GetTokenPersisterWithConnection(tx)
|
|
token, terr := tokenPersister.GetByValue(body.Value)
|
|
if terr != nil {
|
|
return fmt.Errorf("failed to fetch token from db: %w", terr)
|
|
}
|
|
|
|
if token == nil {
|
|
return echo.NewHTTPError(http.StatusNotFound, "token not found")
|
|
}
|
|
|
|
if time.Now().UTC().After(token.ExpiresAt) {
|
|
return echo.NewHTTPError(http.StatusUnprocessableEntity, "token has expired")
|
|
}
|
|
|
|
terr = tokenPersister.Delete(*token)
|
|
if terr != nil {
|
|
return fmt.Errorf("failed to delete token from db: %w", terr)
|
|
}
|
|
|
|
emails, err := h.persister.GetEmailPersister().FindByUserId(token.UserID)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get emails from db: %w", err)
|
|
}
|
|
|
|
var emailJwt *dto.EmailJWT
|
|
if e := emails.GetPrimary(); e != nil {
|
|
emailJwt = dto.EmailJWTFromEmailModel(e)
|
|
}
|
|
|
|
jwtToken, rawToken, err := h.sessionManager.GenerateJWT(dto.UserJWT{
|
|
UserID: token.UserID.String(),
|
|
Email: emailJwt,
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("failed to generate jwt: %w", err)
|
|
}
|
|
|
|
cookie, err := h.sessionManager.GenerateCookie(jwtToken)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create session token: %w", err)
|
|
}
|
|
|
|
err = storeSession(h.cfg, h.persister, token.UserID, rawToken, c, h.persister.GetConnection())
|
|
if err != nil {
|
|
return fmt.Errorf("failed to store session in DB: %w", err)
|
|
}
|
|
|
|
c.Response().Header().Set("X-Session-Lifetime", fmt.Sprintf("%d", cookie.MaxAge))
|
|
|
|
if h.cfg.Session.EnableAuthTokenHeader {
|
|
c.Response().Header().Set("X-Auth-Token", jwtToken)
|
|
} else {
|
|
c.SetCookie(cookie)
|
|
}
|
|
|
|
userID = token.UserID
|
|
|
|
return nil
|
|
})
|
|
|
|
if err != nil {
|
|
var httpError *echo.HTTPError
|
|
if errors.As(err, &httpError) {
|
|
aerr := h.auditLogger.Create(c, models.AuditLogTokenExchangeFailed, nil, err)
|
|
if aerr != nil {
|
|
return fmt.Errorf("could not create audit log: %w", aerr)
|
|
}
|
|
}
|
|
return err
|
|
}
|
|
|
|
user := &models.User{ID: userID}
|
|
err = h.auditLogger.Create(c, models.AuditLogTokenExchangeSucceeded, user, nil)
|
|
if err != nil {
|
|
return fmt.Errorf("could not create audit log: %w", err)
|
|
}
|
|
|
|
return c.JSON(http.StatusOK, map[string]string{"user_id": userID.String()})
|
|
|
|
}
|