mirror of
https://github.com/teamhanko/hanko.git
synced 2025-10-30 16:16:05 +08:00
419 lines
14 KiB
Go
419 lines
14 KiB
Go
package saml
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"github.com/gobuffalo/pop/v6"
|
|
"github.com/labstack/echo/v4"
|
|
saml2 "github.com/russellhaering/gosaml2"
|
|
auditlog "github.com/teamhanko/hanko/backend/audit_log"
|
|
"github.com/teamhanko/hanko/backend/ee/saml/dto"
|
|
"github.com/teamhanko/hanko/backend/ee/saml/provider"
|
|
samlUtils "github.com/teamhanko/hanko/backend/ee/saml/utils"
|
|
"github.com/teamhanko/hanko/backend/persistence/models"
|
|
"github.com/teamhanko/hanko/backend/session"
|
|
"github.com/teamhanko/hanko/backend/thirdparty"
|
|
"github.com/teamhanko/hanko/backend/utils"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
type Handler struct {
|
|
auditLogger auditlog.Logger
|
|
sessionManager session.Manager
|
|
samlService Service
|
|
}
|
|
|
|
func NewSamlHandler(sessionManager session.Manager, auditLogger auditlog.Logger, samlService Service) *Handler {
|
|
return &Handler{
|
|
auditLogger: auditLogger,
|
|
sessionManager: sessionManager,
|
|
samlService: samlService,
|
|
}
|
|
}
|
|
|
|
func (handler *Handler) Metadata(c echo.Context) error {
|
|
var request dto.SamlMetadataRequest
|
|
err := c.Bind(&request)
|
|
if err != nil {
|
|
return c.JSON(http.StatusBadRequest, thirdparty.ErrorInvalidRequest("domain is missing"))
|
|
}
|
|
|
|
foundProvider, err := handler.samlService.GetProviderByDomain(request.Domain)
|
|
if err != nil {
|
|
return c.NoContent(http.StatusNotFound)
|
|
}
|
|
|
|
if request.CertOnly {
|
|
cert, err := handler.samlService.Persister().GetSamlCertificatePersister().GetFirst()
|
|
if err != nil {
|
|
return c.JSON(http.StatusInternalServerError, thirdparty.ErrorServer("unable to provide metadata").WithCause(err))
|
|
}
|
|
|
|
if cert == nil {
|
|
return c.NoContent(http.StatusNotFound)
|
|
}
|
|
|
|
c.Response().Header().Set(echo.HeaderContentDisposition, fmt.Sprintf("attachment; filename=%s-service-provider.pem", handler.samlService.Config().Service.Name))
|
|
return c.Blob(http.StatusOK, echo.MIMEOctetStream, []byte(cert.CertData))
|
|
}
|
|
|
|
xmlMetadata, err := foundProvider.ProvideMetadataAsXml()
|
|
if err != nil {
|
|
return c.JSON(http.StatusInternalServerError, thirdparty.ErrorServer("unable to provide metadata").WithCause(err))
|
|
}
|
|
|
|
c.Response().Header().Set(echo.HeaderContentDisposition, fmt.Sprintf("attachment; filename=%s-metadata.xml", handler.samlService.Config().Service.Name))
|
|
return c.Blob(http.StatusOK, echo.MIMEOctetStream, xmlMetadata)
|
|
}
|
|
|
|
func (handler *Handler) Auth(c echo.Context) error {
|
|
errorRedirectTo := c.Request().Header.Get("Referer")
|
|
if errorRedirectTo == "" {
|
|
errorRedirectTo = handler.samlService.Config().Saml.DefaultRedirectUrl
|
|
}
|
|
|
|
var request dto.SamlAuthRequest
|
|
err := c.Bind(&request)
|
|
if err != nil {
|
|
return handler.redirectError(c, thirdparty.ErrorInvalidRequest(err.Error()).WithCause(err), errorRedirectTo)
|
|
}
|
|
|
|
err = c.Validate(request)
|
|
if err != nil {
|
|
return handler.redirectError(c, thirdparty.ErrorInvalidRequest(err.Error()).WithCause(err), errorRedirectTo)
|
|
}
|
|
|
|
foundProvider, err := handler.samlService.GetProviderByDomain(request.Domain)
|
|
if err != nil {
|
|
return handler.redirectError(c, thirdparty.ErrorInvalidRequest(err.Error()).WithCause(err), errorRedirectTo)
|
|
}
|
|
|
|
redirectUrl, err := handler.samlService.GetAuthUrl(foundProvider, request.RedirectTo, false)
|
|
if err != nil {
|
|
return handler.redirectError(c, thirdparty.ErrorServer("could not generate auth url").WithCause(err), errorRedirectTo)
|
|
}
|
|
|
|
return c.Redirect(http.StatusTemporaryRedirect, redirectUrl)
|
|
}
|
|
|
|
func (handler *Handler) callbackPostIdPInitiated(c echo.Context, samlResponse string) error {
|
|
// ignore URL parse error because config validation already ensures it is a parseable URL
|
|
redirectTo, _ := url.Parse(handler.samlService.Config().Saml.DefaultRedirectUrl)
|
|
|
|
// We need to already parse the response to be able to extract information (a response's ID, Issuer, InResponseTo
|
|
// nodes/values) to ensure protection against replaying IDP initiated responses as well as using service provider
|
|
// issued responses as IDP initiated responses, even though we later also use the gosaml2 library to parse (and then
|
|
// also validate) the response _again_. The reason is that the gosaml2 library does not make this information
|
|
// easily/publicly accessible through its API.
|
|
parsedSamlResponseDocument, _, err := samlUtils.ParseSamlResponse(samlResponse)
|
|
if err != nil {
|
|
return handler.redirectError(
|
|
c,
|
|
thirdparty.ErrorInvalidRequest("could not parse saml response").WithCause(err),
|
|
redirectTo.String(),
|
|
)
|
|
}
|
|
|
|
responseElement := parsedSamlResponseDocument.FindElement("/Response")
|
|
if responseElement == nil {
|
|
return handler.redirectError(
|
|
c,
|
|
thirdparty.ErrorInvalidRequest("invalid saml response: no response node present"),
|
|
redirectTo.String(),
|
|
)
|
|
}
|
|
|
|
issuerElement := parsedSamlResponseDocument.FindElement("/Response/Issuer")
|
|
if issuerElement == nil || issuerElement.Text() == "" {
|
|
return handler.redirectError(
|
|
c,
|
|
thirdparty.ErrorInvalidRequest("invalid saml response: no issuer node present"),
|
|
redirectTo.String(),
|
|
)
|
|
}
|
|
|
|
issuer := issuerElement.Text()
|
|
|
|
serviceProvider, err := handler.samlService.GetProviderByIssuer(issuer)
|
|
if err != nil {
|
|
return handler.redirectError(
|
|
c,
|
|
thirdparty.ErrorInvalidRequest(
|
|
fmt.Sprintf("could not get provider for issuer %s", issuer)).
|
|
WithCause(err),
|
|
redirectTo.String(),
|
|
)
|
|
}
|
|
|
|
// We need to check whether this is an unsolicited request, otherwise SP initiated responses could
|
|
// be used as IDP initiated responses.
|
|
if responseElement.SelectAttr("InResponseTo") != nil {
|
|
return handler.redirectError(
|
|
c,
|
|
thirdparty.ErrorInvalidRequest("saml request is not unsolicited"),
|
|
redirectTo.String(),
|
|
)
|
|
}
|
|
|
|
assertionInfo, err := handler.getAssertionInfo(serviceProvider, samlResponse)
|
|
if err != nil {
|
|
return handler.redirectError(
|
|
c,
|
|
thirdparty.ErrorInvalidRequest("could not get assertion info").WithCause(err),
|
|
redirectTo.String(),
|
|
)
|
|
}
|
|
|
|
samlResponseIDAttr := responseElement.SelectAttr("ID")
|
|
if samlResponseIDAttr == nil {
|
|
return handler.redirectError(
|
|
c,
|
|
thirdparty.ErrorInvalidRequest("invalid saml response: no ID for response present"),
|
|
redirectTo.String(),
|
|
)
|
|
}
|
|
|
|
samlResponseID := samlResponseIDAttr.Value
|
|
|
|
samlIDPInitiatedRequestPersister := handler.samlService.Persister().GetSamlIDPInitiatedRequestPersister()
|
|
|
|
// We use the SAML response's ID to prevent replay attacks by persisting every IDP initiated request and
|
|
// checking whether an IDP initiated request already exists for this request.
|
|
existingSamlIDPInitiatedRequest, err := samlIDPInitiatedRequestPersister.GetByResponseIDAndIssuer(samlResponseID, issuer)
|
|
if existingSamlIDPInitiatedRequest != nil {
|
|
return handler.redirectError(
|
|
c,
|
|
thirdparty.ErrorInvalidRequest("attempting to replay unsolicited saml request"),
|
|
redirectTo.String(),
|
|
)
|
|
}
|
|
|
|
// We assume only one assertion, and we assume it is present because we already validated it using the gosaml2
|
|
// library (which also consumes only one/the first assertion). We also assume assertion conditions are present
|
|
// because validation assures it is not nil (or else it returns an error).
|
|
expiresAtString := assertionInfo.Assertions[0].Conditions.NotOnOrAfter
|
|
|
|
expiresAt, err := time.Parse(time.RFC3339, expiresAtString)
|
|
if err != nil {
|
|
return handler.redirectError(
|
|
c,
|
|
thirdparty.ErrorServer("could not parse saml assertion conditions' NotOnOrAfter value").WithCause(err),
|
|
redirectTo.String(),
|
|
)
|
|
}
|
|
|
|
// If no request exists we create a new IDP initiated request model and persist it.
|
|
samlIDPInitiatedRequest, err := models.NewSamlIDPInitiatedRequest(samlResponseID, issuer, expiresAt)
|
|
if err != nil {
|
|
return handler.redirectError(
|
|
c,
|
|
thirdparty.ErrorServer("could not instantiate saml idp initiated request model").WithCause(err),
|
|
redirectTo.String(),
|
|
)
|
|
}
|
|
|
|
err = samlIDPInitiatedRequestPersister.Create(*samlIDPInitiatedRequest)
|
|
if err != nil {
|
|
return handler.redirectError(
|
|
c,
|
|
thirdparty.ErrorServer("could not persist saml idp initiated request"),
|
|
redirectTo.String(),
|
|
)
|
|
}
|
|
|
|
redirectUrl, samlError := handler.linkAccount(c, redirectTo, true, serviceProvider, assertionInfo)
|
|
if samlError != nil {
|
|
return handler.redirectError(
|
|
c,
|
|
samlError,
|
|
redirectTo.String(),
|
|
)
|
|
}
|
|
|
|
// Add hint to the redirect URL that this is an IDP initiated request so that a token exchange can
|
|
// eventually be performed through the dedicated flow API handler.
|
|
values := redirectUrl.Query()
|
|
values.Add("saml_hint", "idp_initiated")
|
|
redirectUrl.RawQuery = values.Encode()
|
|
|
|
return c.Redirect(http.StatusFound, redirectUrl.String())
|
|
}
|
|
|
|
func (handler *Handler) CallbackPost(c echo.Context) error {
|
|
relayState := c.FormValue("RelayState")
|
|
samlResponse := c.FormValue("SAMLResponse")
|
|
|
|
if handler.isIDPInitiated(relayState) {
|
|
return handler.callbackPostIdPInitiated(c, samlResponse)
|
|
} else {
|
|
state, err := VerifyState(
|
|
handler.samlService.Config(),
|
|
handler.samlService.Persister().GetSamlStatePersister(),
|
|
strings.TrimPrefix(relayState, statePrefixServiceProviderInitiated),
|
|
)
|
|
|
|
if err != nil {
|
|
return handler.redirectError(
|
|
c,
|
|
thirdparty.ErrorInvalidRequest(err.Error()).WithCause(err),
|
|
handler.samlService.Config().Saml.DefaultRedirectUrl,
|
|
)
|
|
}
|
|
|
|
if strings.TrimSpace(state.RedirectTo) == "" {
|
|
state.RedirectTo = handler.samlService.Config().Saml.DefaultRedirectUrl
|
|
}
|
|
|
|
redirectTo, err := url.Parse(state.RedirectTo)
|
|
if err != nil {
|
|
return handler.redirectError(
|
|
c,
|
|
thirdparty.ErrorServer("unable to parse redirect url").WithCause(err),
|
|
handler.samlService.Config().Saml.DefaultRedirectUrl,
|
|
)
|
|
}
|
|
|
|
foundProvider, err := handler.samlService.GetProviderByDomain(state.Provider)
|
|
if err != nil {
|
|
return handler.redirectError(
|
|
c,
|
|
thirdparty.ErrorServer("unable to find provider by domain").WithCause(err),
|
|
redirectTo.String(),
|
|
)
|
|
}
|
|
|
|
assertionInfo, err := handler.getAssertionInfo(foundProvider, samlResponse)
|
|
if err != nil {
|
|
return handler.redirectError(
|
|
c,
|
|
thirdparty.ErrorServer("unable to parse saml response").WithCause(err),
|
|
redirectTo.String(),
|
|
)
|
|
}
|
|
|
|
redirectUrl, err := handler.linkAccount(c, redirectTo, state.IsFlow, foundProvider, assertionInfo)
|
|
if err != nil {
|
|
return handler.redirectError(
|
|
c,
|
|
err,
|
|
redirectTo.String(),
|
|
)
|
|
}
|
|
|
|
return c.Redirect(http.StatusFound, redirectUrl.String())
|
|
}
|
|
}
|
|
|
|
func (handler *Handler) isIDPInitiated(relayState string) bool {
|
|
return !strings.HasPrefix(relayState, statePrefixServiceProviderInitiated)
|
|
}
|
|
|
|
func (handler *Handler) linkAccount(c echo.Context, redirectTo *url.URL, isFlow bool, provider provider.ServiceProvider, assertionInfo *saml2.AssertionInfo) (*url.URL, error) {
|
|
var accountLinkingResult *thirdparty.AccountLinkingResult
|
|
var err error
|
|
err = handler.samlService.Persister().Transaction(func(tx *pop.Connection) error {
|
|
userdata := provider.GetUserData(assertionInfo)
|
|
identityProviderIssuer := assertionInfo.Assertions[0].Issuer
|
|
samlDomain := provider.GetDomain()
|
|
linkResult, errTx := thirdparty.LinkAccount(tx, handler.samlService.Config(), handler.samlService.Persister(), userdata, identityProviderIssuer.Value, true, &samlDomain, isFlow)
|
|
if errTx != nil {
|
|
return errTx
|
|
}
|
|
|
|
accountLinkingResult = linkResult
|
|
|
|
emailModel := linkResult.User.Emails.GetEmailByAddress(userdata.Metadata.Email)
|
|
identityModel := emailModel.Identities.GetIdentity(identityProviderIssuer.Value, userdata.Metadata.Subject)
|
|
|
|
token, errTx := models.NewToken(
|
|
linkResult.User.ID,
|
|
models.TokenWithIdentityID(identityModel.ID),
|
|
models.TokenForFlowAPI(isFlow),
|
|
models.TokenUserCreated(linkResult.UserCreated))
|
|
if errTx != nil {
|
|
return thirdparty.ErrorServer("could not create token").WithCause(errTx)
|
|
}
|
|
|
|
errTx = handler.samlService.Persister().GetTokenPersisterWithConnection(tx).Create(*token)
|
|
if errTx != nil {
|
|
return thirdparty.ErrorServer("could not save token to db").WithCause(errTx)
|
|
}
|
|
|
|
query := redirectTo.Query()
|
|
query.Add(utils.HankoTokenQuery, token.Value)
|
|
redirectTo.RawQuery = query.Encode()
|
|
|
|
return nil
|
|
})
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
err = handler.auditLogger.Create(c, accountLinkingResult.Type, accountLinkingResult.User, nil)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return redirectTo, nil
|
|
}
|
|
|
|
func (handler *Handler) getAssertionInfo(provider provider.ServiceProvider, samlResponse string) (*saml2.AssertionInfo, error) {
|
|
assertionInfo, err := provider.GetService().RetrieveAssertionInfo(samlResponse)
|
|
if err != nil {
|
|
return nil, thirdparty.ErrorServer("unable to parse SAML response").WithCause(err)
|
|
}
|
|
|
|
if assertionInfo.WarningInfo.InvalidTime {
|
|
return nil, thirdparty.ErrorServer("SAMLAssertion expired")
|
|
}
|
|
|
|
if assertionInfo.WarningInfo.NotInAudience {
|
|
return nil, thirdparty.ErrorServer("not in SAML audience")
|
|
}
|
|
|
|
return assertionInfo, nil
|
|
}
|
|
|
|
func (handler *Handler) redirectError(c echo.Context, error error, to string) error {
|
|
c.Logger().Error(error)
|
|
|
|
err := handler.auditError(c, error)
|
|
if err != nil {
|
|
error = err
|
|
}
|
|
|
|
redirectURL := thirdparty.GetErrorUrl(to, error)
|
|
return c.Redirect(http.StatusSeeOther, redirectURL)
|
|
}
|
|
|
|
func (handler *Handler) auditError(c echo.Context, err error) error {
|
|
var e *thirdparty.ThirdPartyError
|
|
ok := errors.As(err, &e)
|
|
|
|
var auditLogError error
|
|
if ok && e.Code != thirdparty.ErrorCodeServerError {
|
|
auditLogError = handler.auditLogger.Create(c, models.AuditLogThirdPartySignInSignUpFailed, nil, err)
|
|
}
|
|
return auditLogError
|
|
}
|
|
|
|
func (handler *Handler) GetProvider(c echo.Context) error {
|
|
var request dto.SamlRequest
|
|
err := c.Bind(&request)
|
|
if err != nil {
|
|
return c.JSON(http.StatusBadRequest, err)
|
|
}
|
|
|
|
foundProvider, err := handler.samlService.GetProviderByDomain(request.Domain)
|
|
if err != nil {
|
|
return c.NoContent(http.StatusNotFound)
|
|
}
|
|
|
|
return c.JSON(http.StatusOK, foundProvider.GetConfig())
|
|
}
|