mirror of
https://github.com/teamhanko/hanko.git
synced 2025-10-26 21:57:14 +08:00
203 lines
6.2 KiB
Go
203 lines
6.2 KiB
Go
package flow_api
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"github.com/gobuffalo/pop/v6"
|
|
"github.com/gofrs/uuid"
|
|
echojwt "github.com/labstack/echo-jwt/v4"
|
|
"github.com/labstack/echo/v4"
|
|
"github.com/rs/zerolog"
|
|
zeroLogger "github.com/rs/zerolog/log"
|
|
"github.com/sethvargo/go-limiter"
|
|
auditlog "github.com/teamhanko/hanko/backend/v2/audit_log"
|
|
"github.com/teamhanko/hanko/backend/v2/config"
|
|
"github.com/teamhanko/hanko/backend/v2/dto"
|
|
"github.com/teamhanko/hanko/backend/v2/ee/saml"
|
|
"github.com/teamhanko/hanko/backend/v2/flow_api/flow"
|
|
"github.com/teamhanko/hanko/backend/v2/flow_api/flow/shared"
|
|
"github.com/teamhanko/hanko/backend/v2/flow_api/services"
|
|
"github.com/teamhanko/hanko/backend/v2/flowpilot"
|
|
"github.com/teamhanko/hanko/backend/v2/mapper"
|
|
"github.com/teamhanko/hanko/backend/v2/persistence"
|
|
"github.com/teamhanko/hanko/backend/v2/session"
|
|
"strconv"
|
|
"time"
|
|
)
|
|
|
|
type FlowPilotHandler struct {
|
|
Persister persistence.Persister
|
|
Cfg config.Config
|
|
PasscodeService services.Passcode
|
|
PasswordService services.Password
|
|
WebauthnService services.WebauthnService
|
|
SamlService saml.Service
|
|
SessionManager session.Manager
|
|
OTPRateLimiter limiter.Store
|
|
PasscodeRateLimiter limiter.Store
|
|
PasswordRateLimiter limiter.Store
|
|
TokenExchangeRateLimiter limiter.Store
|
|
AuthenticatorMetadata mapper.AuthenticatorMetadata
|
|
AuditLogger auditlog.Logger
|
|
}
|
|
|
|
func (h *FlowPilotHandler) RegistrationFlowHandler(c echo.Context) error {
|
|
registrationFlow := flow.NewRegistrationFlow(h.Cfg.Debug)
|
|
return h.executeFlow(c, registrationFlow)
|
|
}
|
|
|
|
func (h *FlowPilotHandler) LoginFlowHandler(c echo.Context) error {
|
|
loginFlow := flow.NewLoginFlow(h.Cfg.Debug)
|
|
return h.executeFlow(c, loginFlow)
|
|
}
|
|
|
|
func (h *FlowPilotHandler) ProfileFlowHandler(c echo.Context) error {
|
|
profileFlow := flow.NewProfileFlow(h.Cfg.Debug)
|
|
|
|
if err := h.validateSession(c); err != nil {
|
|
flowResult := profileFlow.ResultFromError(err)
|
|
return c.JSON(flowResult.GetStatus(), flowResult.GetResponse())
|
|
}
|
|
|
|
return h.executeFlow(c, profileFlow)
|
|
}
|
|
|
|
func (h *FlowPilotHandler) TokenExchangeFlowHandler(c echo.Context) error {
|
|
samlIdPInitiatedLoginFlow := flow.NewTokenExchangeFlow(h.Cfg.Debug)
|
|
return h.executeFlow(c, samlIdPInitiatedLoginFlow)
|
|
}
|
|
|
|
func (h *FlowPilotHandler) validateSession(c echo.Context) error {
|
|
lookup := fmt.Sprintf("header:Authorization:Bearer,cookie:%s", h.Cfg.Session.Cookie.GetName())
|
|
extractors, err := echojwt.CreateExtractors(lookup)
|
|
|
|
if err != nil {
|
|
return flowpilot.ErrorTechnical.Wrap(err)
|
|
}
|
|
|
|
var lastExtractorErr, lastTokenErr error
|
|
for _, extractor := range extractors {
|
|
auths, extractorErr := extractor(c)
|
|
if extractorErr != nil {
|
|
lastExtractorErr = extractorErr
|
|
continue
|
|
}
|
|
for _, auth := range auths {
|
|
token, tokenErr := h.SessionManager.Verify(auth)
|
|
if tokenErr != nil {
|
|
lastTokenErr = tokenErr
|
|
continue
|
|
}
|
|
|
|
// check that the session id is stored in the database
|
|
sessionId, ok := token.Get("session_id")
|
|
if !ok {
|
|
lastTokenErr = errors.New("no session id found in token")
|
|
continue
|
|
}
|
|
sessionID, err := uuid.FromString(sessionId.(string))
|
|
if err != nil {
|
|
lastTokenErr = errors.New("session id has wrong format")
|
|
continue
|
|
}
|
|
|
|
sessionModel, err := h.Persister.GetSessionPersister().Get(sessionID)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get session from database: %w", err)
|
|
}
|
|
if sessionModel == nil {
|
|
lastTokenErr = fmt.Errorf("session id not found in database")
|
|
continue
|
|
}
|
|
|
|
// Update lastUsed field
|
|
sessionModel.LastUsed = time.Now().UTC()
|
|
err = h.Persister.GetSessionPersister().Update(*sessionModel)
|
|
if err != nil {
|
|
return dto.ToHttpError(err)
|
|
}
|
|
|
|
c.Set("session", token)
|
|
|
|
return nil
|
|
}
|
|
}
|
|
|
|
if lastTokenErr != nil {
|
|
return shared.ErrorUnauthorized.Wrap(lastTokenErr)
|
|
} else if lastExtractorErr != nil {
|
|
return shared.ErrorUnauthorized.Wrap(lastExtractorErr)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (h *FlowPilotHandler) executeFlow(c echo.Context, flow flowpilot.Flow) error {
|
|
const queryParamKey = "action"
|
|
|
|
var err error
|
|
var inputData flowpilot.InputData
|
|
var flowResult flowpilot.FlowResult
|
|
|
|
txFunc := func(tx *pop.Connection) error {
|
|
deps := &shared.Dependencies{
|
|
Cfg: h.Cfg,
|
|
OTPRateLimiter: h.OTPRateLimiter,
|
|
PasscodeRateLimiter: h.PasscodeRateLimiter,
|
|
PasswordRateLimiter: h.PasswordRateLimiter,
|
|
TokenExchangeRateLimiter: h.TokenExchangeRateLimiter,
|
|
Tx: tx,
|
|
Persister: h.Persister,
|
|
HttpContext: c,
|
|
SessionManager: h.SessionManager,
|
|
PasscodeService: h.PasscodeService,
|
|
PasswordService: h.PasswordService,
|
|
WebauthnService: h.WebauthnService,
|
|
SamlService: h.SamlService,
|
|
AuthenticatorMetadata: h.AuthenticatorMetadata,
|
|
AuditLogger: h.AuditLogger,
|
|
}
|
|
|
|
flow.Set("deps", deps)
|
|
|
|
flowResult, err = flow.Execute(persistence.NewFlowPersister(tx),
|
|
flowpilot.WithQueryParamKey(queryParamKey),
|
|
flowpilot.WithQueryParamValue(c.QueryParam(queryParamKey)),
|
|
flowpilot.WithInputData(inputData),
|
|
flowpilot.UseCompression(!h.Cfg.Debug))
|
|
|
|
return err
|
|
}
|
|
|
|
err = c.Bind(&inputData)
|
|
if err != nil {
|
|
flowResult = flow.ResultFromError(flowpilot.ErrorTechnical.Wrap(err))
|
|
} else {
|
|
err = h.Persister.Transaction(txFunc)
|
|
if err != nil {
|
|
flowResult = flow.ResultFromError(err)
|
|
}
|
|
}
|
|
|
|
log := zeroLogger.Info().
|
|
Str("time_unix", strconv.FormatInt(time.Now().Unix(), 10)).
|
|
Str("id", c.Response().Header().Get(echo.HeaderXRequestID)).
|
|
Str("remote_ip", c.RealIP()).Str("host", c.Request().Host).
|
|
Str("method", c.Request().Method).Str("uri", c.Request().RequestURI).
|
|
Str("user_agent", c.Request().UserAgent()).Int("status", flowResult.GetStatus()).
|
|
Str("referer", c.Request().Referer())
|
|
if flowResult.GetResponse().Error != nil {
|
|
log.Str("error", fmt.Sprintf("%s", flowResult.GetResponse().Error.Code))
|
|
if flowResult.GetResponse().Error.Internal != nil {
|
|
log.Str("error_internal", *flowResult.GetResponse().Error.Internal)
|
|
}
|
|
}
|
|
log.Send()
|
|
|
|
return c.JSON(flowResult.GetStatus(), flowResult.GetResponse())
|
|
}
|
|
|
|
func init() {
|
|
zerolog.TimeFieldFormat = time.RFC3339Nano
|
|
}
|