mirror of
				https://github.com/teamhanko/hanko.git
				synced 2025-11-01 00:58:16 +08:00 
			
		
		
		
	feat(ee): saml idp initiated sso
This commit is contained in:
		 Lennart Fleischmann
					Lennart Fleischmann
				
			
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			 GitHub
						GitHub
					
				
			
						parent
						
							55d6efb879
						
					
				
				
					commit
					983000d94e
				
			| @ -9,6 +9,7 @@ import ( | ||||
| 	auditlog "github.com/teamhanko/hanko/backend/audit_log" | ||||
| 	"github.com/teamhanko/hanko/backend/ee/saml/dto" | ||||
| 	"github.com/teamhanko/hanko/backend/ee/saml/provider" | ||||
| 	samlUtils "github.com/teamhanko/hanko/backend/ee/saml/utils" | ||||
| 	"github.com/teamhanko/hanko/backend/persistence/models" | ||||
| 	"github.com/teamhanko/hanko/backend/session" | ||||
| 	"github.com/teamhanko/hanko/backend/thirdparty" | ||||
| @ -16,6 +17,7 @@ import ( | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| type Handler struct { | ||||
| @ -97,48 +99,132 @@ func (handler *Handler) Auth(c echo.Context) error { | ||||
| 	return c.Redirect(http.StatusTemporaryRedirect, redirectUrl) | ||||
| } | ||||
|  | ||||
| func (handler *Handler) CallbackPost(c echo.Context) error { | ||||
| 	state, samlError := VerifyState(handler.samlService.Config(), handler.samlService.Persister().GetSamlStatePersister(), c.FormValue("RelayState")) | ||||
| 	if samlError != nil { | ||||
| func (handler *Handler) callbackPostIdPInitiated(c echo.Context, samlResponse string) error { | ||||
| 	// ignore URL parse error because config validation already ensures it is a parseable URL | ||||
| 	redirectTo, _ := url.Parse(handler.samlService.Config().Saml.DefaultRedirectUrl) | ||||
|  | ||||
| 	// We need to already parse the response to be able to extract information (a response's ID, Issuer, InResponseTo | ||||
| 	// nodes/values) to ensure protection against replaying IDP initiated responses as well as using service provider | ||||
| 	// issued responses as IDP initiated responses, even though we later also use the gosaml2 library to parse (and then | ||||
| 	// also validate) the response _again_. The reason is that the gosaml2 library does not make this information | ||||
| 	// easily/publicly accessible through its API. | ||||
| 	parsedSamlResponseDocument, _, err := samlUtils.ParseSamlResponse(samlResponse) | ||||
| 	if err != nil { | ||||
| 		return handler.redirectError( | ||||
| 			c, | ||||
| 			thirdparty.ErrorInvalidRequest(samlError.Error()).WithCause(samlError), | ||||
| 			handler.samlService.Config().Saml.DefaultRedirectUrl, | ||||
| 		) | ||||
| 	} | ||||
|  | ||||
| 	if strings.TrimSpace(state.RedirectTo) == "" { | ||||
| 		state.RedirectTo = handler.samlService.Config().Saml.DefaultRedirectUrl | ||||
| 	} | ||||
|  | ||||
| 	redirectTo, samlError := url.Parse(state.RedirectTo) | ||||
| 	if samlError != nil { | ||||
| 		return handler.redirectError( | ||||
| 			c, | ||||
| 			thirdparty.ErrorServer("unable to parse redirect url").WithCause(samlError), | ||||
| 			handler.samlService.Config().Saml.DefaultRedirectUrl, | ||||
| 		) | ||||
| 	} | ||||
|  | ||||
| 	foundProvider, samlError := handler.samlService.GetProviderByDomain(state.Provider) | ||||
| 	if samlError != nil { | ||||
| 		return handler.redirectError( | ||||
| 			c, | ||||
| 			thirdparty.ErrorServer("unable to find provider by domain").WithCause(samlError), | ||||
| 			thirdparty.ErrorInvalidRequest("could not parse saml response").WithCause(err), | ||||
| 			redirectTo.String(), | ||||
| 		) | ||||
| 	} | ||||
|  | ||||
| 	assertionInfo, samlError := handler.parseSamlResponse(foundProvider, c.FormValue("SAMLResponse")) | ||||
| 	if samlError != nil { | ||||
| 	responseElement := parsedSamlResponseDocument.FindElement("/Response") | ||||
| 	if responseElement == nil { | ||||
| 		return handler.redirectError( | ||||
| 			c, | ||||
| 			thirdparty.ErrorServer("unable to parse saml response").WithCause(samlError), | ||||
| 			thirdparty.ErrorInvalidRequest("invalid saml response: no response node present"), | ||||
| 			redirectTo.String(), | ||||
| 		) | ||||
| 	} | ||||
|  | ||||
| 	redirectUrl, samlError := handler.linkAccount(c, redirectTo, state, foundProvider, assertionInfo) | ||||
| 	issuerElement := parsedSamlResponseDocument.FindElement("/Response/Issuer") | ||||
| 	if issuerElement == nil || issuerElement.Text() == "" { | ||||
| 		return handler.redirectError( | ||||
| 			c, | ||||
| 			thirdparty.ErrorInvalidRequest("invalid saml response: no issuer node present"), | ||||
| 			redirectTo.String(), | ||||
| 		) | ||||
| 	} | ||||
|  | ||||
| 	issuer := issuerElement.Text() | ||||
|  | ||||
| 	serviceProvider, err := handler.samlService.GetProviderByIssuer(issuer) | ||||
| 	if err != nil { | ||||
| 		return handler.redirectError( | ||||
| 			c, | ||||
| 			thirdparty.ErrorInvalidRequest( | ||||
| 				fmt.Sprintf("could not get provider for issuer %s", issuer)). | ||||
| 				WithCause(err), | ||||
| 			redirectTo.String(), | ||||
| 		) | ||||
| 	} | ||||
|  | ||||
| 	// We need to check whether this is an unsolicited request, otherwise SP initiated responses could | ||||
| 	// be used as IDP initiated responses. | ||||
| 	if responseElement.SelectAttr("InResponseTo") != nil { | ||||
| 		return handler.redirectError( | ||||
| 			c, | ||||
| 			thirdparty.ErrorInvalidRequest("saml request is not unsolicited"), | ||||
| 			redirectTo.String(), | ||||
| 		) | ||||
| 	} | ||||
|  | ||||
| 	assertionInfo, err := handler.getAssertionInfo(serviceProvider, samlResponse) | ||||
| 	if err != nil { | ||||
| 		return handler.redirectError( | ||||
| 			c, | ||||
| 			thirdparty.ErrorInvalidRequest("could not get assertion info").WithCause(err), | ||||
| 			redirectTo.String(), | ||||
| 		) | ||||
| 	} | ||||
|  | ||||
| 	samlResponseIDAttr := responseElement.SelectAttr("ID") | ||||
| 	if samlResponseIDAttr == nil { | ||||
| 		return handler.redirectError( | ||||
| 			c, | ||||
| 			thirdparty.ErrorInvalidRequest("invalid saml response: no ID for response present"), | ||||
| 			redirectTo.String(), | ||||
| 		) | ||||
| 	} | ||||
|  | ||||
| 	samlResponseID := samlResponseIDAttr.Value | ||||
|  | ||||
| 	samlIDPInitiatedRequestPersister := handler.samlService.Persister().GetSamlIDPInitiatedRequestPersister() | ||||
|  | ||||
| 	// We use the SAML response's ID to prevent replay attacks by persisting every IDP initiated request and | ||||
| 	// checking whether an IDP initiated request already exists for this request. | ||||
| 	existingSamlIDPInitiatedRequest, err := samlIDPInitiatedRequestPersister.GetByResponseIDAndIssuer(samlResponseID, issuer) | ||||
| 	if existingSamlIDPInitiatedRequest != nil { | ||||
| 		return handler.redirectError( | ||||
| 			c, | ||||
| 			thirdparty.ErrorInvalidRequest("attempting to replay unsolicited saml request"), | ||||
| 			redirectTo.String(), | ||||
| 		) | ||||
| 	} | ||||
|  | ||||
| 	// We assume only one assertion, and we assume it is present because we already validated it using the gosaml2 | ||||
| 	// library (which also consumes only one/the first assertion). We also assume assertion conditions are present | ||||
| 	// because validation assures it is not nil (or else it returns an error). | ||||
| 	expiresAtString := assertionInfo.Assertions[0].Conditions.NotOnOrAfter | ||||
|  | ||||
| 	expiresAt, err := time.Parse(time.RFC3339, expiresAtString) | ||||
| 	if err != nil { | ||||
| 		return handler.redirectError( | ||||
| 			c, | ||||
| 			thirdparty.ErrorServer("could not parse saml assertion conditions' NotOnOrAfter value").WithCause(err), | ||||
| 			redirectTo.String(), | ||||
| 		) | ||||
| 	} | ||||
|  | ||||
| 	// If no request exists we create a new IDP initiated request model and persist it. | ||||
| 	samlIDPInitiatedRequest, err := models.NewSamlIDPInitiatedRequest(samlResponseID, issuer, expiresAt) | ||||
| 	if err != nil { | ||||
| 		return handler.redirectError( | ||||
| 			c, | ||||
| 			thirdparty.ErrorServer("could not instantiate saml idp initiated request model").WithCause(err), | ||||
| 			redirectTo.String(), | ||||
| 		) | ||||
| 	} | ||||
|  | ||||
| 	err = samlIDPInitiatedRequestPersister.Create(*samlIDPInitiatedRequest) | ||||
| 	if err != nil { | ||||
| 		return handler.redirectError( | ||||
| 			c, | ||||
| 			thirdparty.ErrorServer("could not persist saml idp initiated request"), | ||||
| 			redirectTo.String(), | ||||
| 		) | ||||
| 	} | ||||
|  | ||||
| 	redirectUrl, samlError := handler.linkAccount(c, redirectTo, true, serviceProvider, assertionInfo) | ||||
| 	if samlError != nil { | ||||
| 		return handler.redirectError( | ||||
| 			c, | ||||
| @ -147,19 +233,94 @@ func (handler *Handler) CallbackPost(c echo.Context) error { | ||||
| 		) | ||||
| 	} | ||||
|  | ||||
| 	// Add hint to the redirect URL that this is an IDP initiated request so that a token exchange can | ||||
| 	// eventually be performed through the dedicated flow API handler. | ||||
| 	values := redirectUrl.Query() | ||||
| 	values.Add("saml_hint", "idp_initiated") | ||||
| 	redirectUrl.RawQuery = values.Encode() | ||||
|  | ||||
| 	return c.Redirect(http.StatusFound, redirectUrl.String()) | ||||
| } | ||||
|  | ||||
| func (handler *Handler) linkAccount(c echo.Context, redirectTo *url.URL, state *State, provider provider.ServiceProvider, assertionInfo *saml2.AssertionInfo) (*url.URL, error) { | ||||
| func (handler *Handler) CallbackPost(c echo.Context) error { | ||||
| 	relayState := c.FormValue("RelayState") | ||||
| 	samlResponse := c.FormValue("SAMLResponse") | ||||
|  | ||||
| 	if handler.isIDPInitiated(relayState) { | ||||
| 		return handler.callbackPostIdPInitiated(c, samlResponse) | ||||
| 	} else { | ||||
| 		state, err := VerifyState( | ||||
| 			handler.samlService.Config(), | ||||
| 			handler.samlService.Persister().GetSamlStatePersister(), | ||||
| 			strings.TrimPrefix(relayState, statePrefixServiceProviderInitiated), | ||||
| 		) | ||||
|  | ||||
| 		if err != nil { | ||||
| 			return handler.redirectError( | ||||
| 				c, | ||||
| 				thirdparty.ErrorInvalidRequest(err.Error()).WithCause(err), | ||||
| 				handler.samlService.Config().Saml.DefaultRedirectUrl, | ||||
| 			) | ||||
| 		} | ||||
|  | ||||
| 		if strings.TrimSpace(state.RedirectTo) == "" { | ||||
| 			state.RedirectTo = handler.samlService.Config().Saml.DefaultRedirectUrl | ||||
| 		} | ||||
|  | ||||
| 		redirectTo, err := url.Parse(state.RedirectTo) | ||||
| 		if err != nil { | ||||
| 			return handler.redirectError( | ||||
| 				c, | ||||
| 				thirdparty.ErrorServer("unable to parse redirect url").WithCause(err), | ||||
| 				handler.samlService.Config().Saml.DefaultRedirectUrl, | ||||
| 			) | ||||
| 		} | ||||
|  | ||||
| 		foundProvider, err := handler.samlService.GetProviderByDomain(state.Provider) | ||||
| 		if err != nil { | ||||
| 			return handler.redirectError( | ||||
| 				c, | ||||
| 				thirdparty.ErrorServer("unable to find provider by domain").WithCause(err), | ||||
| 				redirectTo.String(), | ||||
| 			) | ||||
| 		} | ||||
|  | ||||
| 		assertionInfo, err := handler.getAssertionInfo(foundProvider, samlResponse) | ||||
| 		if err != nil { | ||||
| 			return handler.redirectError( | ||||
| 				c, | ||||
| 				thirdparty.ErrorServer("unable to parse saml response").WithCause(err), | ||||
| 				redirectTo.String(), | ||||
| 			) | ||||
| 		} | ||||
|  | ||||
| 		redirectUrl, err := handler.linkAccount(c, redirectTo, state.IsFlow, foundProvider, assertionInfo) | ||||
| 		if err != nil { | ||||
| 			return handler.redirectError( | ||||
| 				c, | ||||
| 				err, | ||||
| 				redirectTo.String(), | ||||
| 			) | ||||
| 		} | ||||
|  | ||||
| 		return c.Redirect(http.StatusFound, redirectUrl.String()) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (handler *Handler) isIDPInitiated(relayState string) bool { | ||||
| 	return !strings.HasPrefix(relayState, statePrefixServiceProviderInitiated) | ||||
| } | ||||
|  | ||||
| func (handler *Handler) linkAccount(c echo.Context, redirectTo *url.URL, isFlow bool, provider provider.ServiceProvider, assertionInfo *saml2.AssertionInfo) (*url.URL, error) { | ||||
| 	var accountLinkingResult *thirdparty.AccountLinkingResult | ||||
| 	var samlError error | ||||
| 	samlError = handler.samlService.Persister().Transaction(func(tx *pop.Connection) error { | ||||
| 	var err error | ||||
| 	err = handler.samlService.Persister().Transaction(func(tx *pop.Connection) error { | ||||
| 		userdata := provider.GetUserData(assertionInfo) | ||||
| 		identityProviderIssuer := assertionInfo.Assertions[0].Issuer | ||||
| 		samlDomain := provider.GetDomain() | ||||
| 		linkResult, samlErrorTx := thirdparty.LinkAccount(tx, handler.samlService.Config(), handler.samlService.Persister(), userdata, identityProviderIssuer.Value, true, &samlDomain, state.IsFlow) | ||||
| 		if samlErrorTx != nil { | ||||
| 			return samlErrorTx | ||||
| 		linkResult, errTx := thirdparty.LinkAccount(tx, handler.samlService.Config(), handler.samlService.Persister(), userdata, identityProviderIssuer.Value, true, &samlDomain, isFlow) | ||||
| 		if errTx != nil { | ||||
| 			return errTx | ||||
| 		} | ||||
|  | ||||
| 		accountLinkingResult = linkResult | ||||
| @ -167,18 +328,18 @@ func (handler *Handler) linkAccount(c echo.Context, redirectTo *url.URL, state * | ||||
| 		emailModel := linkResult.User.Emails.GetEmailByAddress(userdata.Metadata.Email) | ||||
| 		identityModel := emailModel.Identities.GetIdentity(identityProviderIssuer.Value, userdata.Metadata.Subject) | ||||
|  | ||||
| 		token, tokenError := models.NewToken( | ||||
| 		token, errTx := models.NewToken( | ||||
| 			linkResult.User.ID, | ||||
| 			models.TokenWithIdentityID(identityModel.ID), | ||||
| 			models.TokenForFlowAPI(state.IsFlow), | ||||
| 			models.TokenForFlowAPI(isFlow), | ||||
| 			models.TokenUserCreated(linkResult.UserCreated)) | ||||
| 		if tokenError != nil { | ||||
| 			return thirdparty.ErrorServer("could not create token").WithCause(tokenError) | ||||
| 		if errTx != nil { | ||||
| 			return thirdparty.ErrorServer("could not create token").WithCause(errTx) | ||||
| 		} | ||||
|  | ||||
| 		tokenError = handler.samlService.Persister().GetTokenPersisterWithConnection(tx).Create(*token) | ||||
| 		if tokenError != nil { | ||||
| 			return thirdparty.ErrorServer("could not save token to db").WithCause(tokenError) | ||||
| 		errTx = handler.samlService.Persister().GetTokenPersisterWithConnection(tx).Create(*token) | ||||
| 		if errTx != nil { | ||||
| 			return thirdparty.ErrorServer("could not save token to db").WithCause(errTx) | ||||
| 		} | ||||
|  | ||||
| 		query := redirectTo.Query() | ||||
| @ -188,20 +349,20 @@ func (handler *Handler) linkAccount(c echo.Context, redirectTo *url.URL, state * | ||||
| 		return nil | ||||
| 	}) | ||||
|  | ||||
| 	if samlError != nil { | ||||
| 		return nil, samlError | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	samlError = handler.auditLogger.Create(c, accountLinkingResult.Type, accountLinkingResult.User, nil) | ||||
| 	err = handler.auditLogger.Create(c, accountLinkingResult.Type, accountLinkingResult.User, nil) | ||||
|  | ||||
| 	if samlError != nil { | ||||
| 		return nil, samlError | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	return redirectTo, nil | ||||
| } | ||||
|  | ||||
| func (handler *Handler) parseSamlResponse(provider provider.ServiceProvider, samlResponse string) (*saml2.AssertionInfo, error) { | ||||
| func (handler *Handler) getAssertionInfo(provider provider.ServiceProvider, samlResponse string) (*saml2.AssertionInfo, error) { | ||||
| 	assertionInfo, err := provider.GetService().RetrieveAssertionInfo(samlResponse) | ||||
| 	if err != nil { | ||||
| 		return nil, thirdparty.ErrorServer("unable to parse SAML response").WithCause(err) | ||||
|  | ||||
| @ -15,6 +15,7 @@ type Service interface { | ||||
| 	Persister() persistence.Persister | ||||
| 	Providers() []provider.ServiceProvider | ||||
| 	GetProviderByDomain(domain string) (provider.ServiceProvider, error) | ||||
| 	GetProviderByIssuer(issuer string) (provider.ServiceProvider, error) | ||||
| 	GetAuthUrl(provider provider.ServiceProvider, redirectTo string, isFlow bool) (string, error) | ||||
| } | ||||
|  | ||||
| @ -83,6 +84,16 @@ func (s *defaultService) GetProviderByDomain(domain string) (provider.ServicePro | ||||
| 	return nil, fmt.Errorf("unknown provider for domain %s", domain) | ||||
| } | ||||
|  | ||||
| func (s *defaultService) GetProviderByIssuer(issuer string) (provider.ServiceProvider, error) { | ||||
| 	for _, availableProvider := range s.providers { | ||||
| 		if availableProvider.GetService().IdentityProviderIssuer == issuer { | ||||
| 			return availableProvider, nil | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return nil, fmt.Errorf("unknown provider for issuer %s", issuer) | ||||
| } | ||||
|  | ||||
| func (s *defaultService) GetAuthUrl(provider provider.ServiceProvider, redirectTo string, isFlow bool) (string, error) { | ||||
| 	if ok := samlUtils.IsAllowedRedirect(s.config.Saml, redirectTo); !ok { | ||||
| 		return "", thirdparty.ErrorInvalidRequest(fmt.Sprintf("redirect to '%s' not allowed", redirectTo)) | ||||
|  | ||||
| @ -22,6 +22,8 @@ type State struct { | ||||
| 	IsFlow     bool      `json:"is_flow"` | ||||
| } | ||||
|  | ||||
| const statePrefixServiceProviderInitiated = "hanko_spi_" | ||||
|  | ||||
| func GenerateStateForFlowAPI(isFlow bool) func(*State) { | ||||
| 	return func(state *State) { | ||||
| 		state.IsFlow = isFlow | ||||
| @ -77,7 +79,9 @@ func GenerateState(config *config.Config, persister persistence.SamlStatePersist | ||||
| 		return nil, fmt.Errorf("could not save state to db: %w", err) | ||||
| 	} | ||||
|  | ||||
| 	return []byte(encryptedState), nil | ||||
| 	// Add prefix to distinguish between SP initiated and IDP initiated requests in callback handler. | ||||
| 	result := fmt.Sprintf("%s%s", statePrefixServiceProviderInitiated, encryptedState) | ||||
| 	return []byte(result), nil | ||||
| } | ||||
|  | ||||
| func VerifyState(config *config.Config, persister persistence.SamlStatePersister, state string) (*State, error) { | ||||
|  | ||||
							
								
								
									
										76
									
								
								backend/ee/saml/utils/response.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										76
									
								
								backend/ee/saml/utils/response.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,76 @@ | ||||
| package utils | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"compress/flate" | ||||
| 	"encoding/base64" | ||||
| 	"fmt" | ||||
| 	"github.com/beevik/etree" | ||||
| 	rtvalidator "github.com/mattermost/xml-roundtrip-validator" | ||||
| 	"io" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	defaultMaxDecompressedResponseSize = 5 * 1024 * 1024 | ||||
| ) | ||||
|  | ||||
| func maybeDeflate(data []byte, maxSize int64, decoder func([]byte) error) error { | ||||
| 	err := decoder(data) | ||||
| 	if err == nil { | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	// Default to 5MB max size | ||||
| 	if maxSize == 0 { | ||||
| 		maxSize = defaultMaxDecompressedResponseSize | ||||
| 	} | ||||
|  | ||||
| 	lr := io.LimitReader(flate.NewReader(bytes.NewReader(data)), maxSize+1) | ||||
|  | ||||
| 	deflated, err := io.ReadAll(lr) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	if int64(len(deflated)) > maxSize { | ||||
| 		return fmt.Errorf("deflated response exceeds maximum size of %d bytes", maxSize) | ||||
| 	} | ||||
|  | ||||
| 	return decoder(deflated) | ||||
| } | ||||
|  | ||||
| func ParseSamlResponse(samlResponse string) (*etree.Document, *etree.Element, error) { | ||||
| 	raw, err := base64.StdEncoding.DecodeString(samlResponse) | ||||
| 	if err != nil { | ||||
| 		return nil, nil, fmt.Errorf("could not decode saml response: %w", err) | ||||
| 	} | ||||
|  | ||||
| 	return parseResponseXml(raw) | ||||
| } | ||||
|  | ||||
| func parseResponseXml(xml []byte) (*etree.Document, *etree.Element, error) { | ||||
| 	var doc *etree.Document | ||||
| 	var rawXML []byte | ||||
|  | ||||
| 	err := maybeDeflate(xml, defaultMaxDecompressedResponseSize, func(xml []byte) error { | ||||
| 		doc = etree.NewDocument() | ||||
| 		rawXML = xml | ||||
| 		return doc.ReadFromBytes(xml) | ||||
| 	}) | ||||
| 	if err != nil { | ||||
| 		return nil, nil, err | ||||
| 	} | ||||
|  | ||||
| 	el := doc.Root() | ||||
| 	if el == nil { | ||||
| 		return nil, nil, fmt.Errorf("unable to parse response") | ||||
| 	} | ||||
|  | ||||
| 	// Examine the response for attempts to exploit weaknesses in Go's encoding/xml | ||||
| 	err = rtvalidator.Validate(bytes.NewReader(rawXML)) | ||||
| 	if err != nil { | ||||
| 		return nil, nil, err | ||||
| 	} | ||||
|  | ||||
| 	return doc, el, nil | ||||
| } | ||||
| @ -215,3 +215,22 @@ func NewProfileFlow(debug bool) flowpilot.Flow { | ||||
| 		Debug(debug). | ||||
| 		MustBuild() | ||||
| } | ||||
|  | ||||
| func NewTokenExchangeFlow(debug bool) flowpilot.Flow { | ||||
| 	return flowpilot.NewFlow("token_exchange"). | ||||
| 		State(shared.StateThirdParty, | ||||
| 			shared.ExchangeToken{}). | ||||
| 		State(shared.StateSuccess). | ||||
| 		BeforeState(shared.StateSuccess, | ||||
| 			shared.IssueSession{}, | ||||
| 			shared.GetUserData{}). | ||||
| 		SubFlows( | ||||
| 			CredentialUsageSubFlow, | ||||
| 			UserDetailsSubFlow). | ||||
| 		AfterState(shared.StatePasscodeConfirmation, | ||||
| 			shared.EmailPersistVerifiedStatus{}). | ||||
| 		InitialState(shared.StateThirdParty). | ||||
| 		ErrorState(shared.StateError). | ||||
| 		Debug(debug). | ||||
| 		MustBuild() | ||||
| } | ||||
|  | ||||
| @ -81,16 +81,30 @@ func (a ExchangeToken) Execute(c flowpilot.ExecutionContext) error { | ||||
| 		return fmt.Errorf("failed to delete token from db: %w", err) | ||||
| 	} | ||||
|  | ||||
| 	onboardingStates, err := a.determineOnboardingStates(c, identity, tokenModel.UserCreated) | ||||
| 	isSaml := identity.SamlIdentity != nil | ||||
|  | ||||
| 	var onboardingStates []flowpilot.StateName | ||||
| 	if isSaml { | ||||
| 		samlProvider, err := deps.SamlService.GetProviderByIssuer(identity.ProviderID) | ||||
| 		if err != nil { | ||||
| 		return fmt.Errorf("failed to determine onboarding stattes: %w", err) | ||||
| 			return fmt.Errorf("could not fetch saml provider for identity: %w", err) | ||||
| 		} | ||||
| 		mustDoEmailVerification := !samlProvider.GetConfig().SkipEmailVerification && identity.Email != nil && !identity.Email.Verified | ||||
| 		onboardingStates, err = a.determineOnboardingStates(c, identity, tokenModel.UserCreated, mustDoEmailVerification) | ||||
| 	} else { | ||||
| 		mustDoEmailVerification := deps.Cfg.Email.RequireVerification && identity.Email != nil && !identity.Email.Verified | ||||
| 		onboardingStates, err = a.determineOnboardingStates(c, identity, tokenModel.UserCreated, mustDoEmailVerification) | ||||
| 	} | ||||
|  | ||||
| 	if err := c.Stash().Set(StashPathLoginMethod, "third_party"); err != nil { | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("failed to determine onboarding states: %w", err) | ||||
| 	} | ||||
|  | ||||
| 	if err = c.Stash().Set(StashPathLoginMethod, "third_party"); err != nil { | ||||
| 		return fmt.Errorf("failed to set login_method to the stash: %w", err) | ||||
| 	} | ||||
|  | ||||
| 	if err := c.Stash().Set(StashPathThirdPartyProvider, identity.ProviderID); err != nil { | ||||
| 	if err = c.Stash().Set(StashPathThirdPartyProvider, identity.ProviderID); err != nil { | ||||
| 		return fmt.Errorf("failed to set third_party_provider to the stash: %w", err) | ||||
| 	} | ||||
|  | ||||
| @ -99,11 +113,11 @@ func (a ExchangeToken) Execute(c flowpilot.ExecutionContext) error { | ||||
| 	return c.Continue(onboardingStates...) | ||||
| } | ||||
|  | ||||
| func (a ExchangeToken) determineOnboardingStates(c flowpilot.ExecutionContext, identity *models.Identity, userCreated bool) ([]flowpilot.StateName, error) { | ||||
| func (a ExchangeToken) determineOnboardingStates(c flowpilot.ExecutionContext, identity *models.Identity, userCreated bool, mustDoEmailVerification bool) ([]flowpilot.StateName, error) { | ||||
| 	deps := a.GetDeps(c) | ||||
| 	result := make([]flowpilot.StateName, 0) | ||||
|  | ||||
| 	if deps.Cfg.Email.RequireVerification && identity.Email != nil && !identity.Email.Verified { | ||||
| 	if mustDoEmailVerification { | ||||
| 		if err := c.Stash().Set(StashPathEmail, identity.Email.Address); err != nil { | ||||
| 			return nil, fmt.Errorf("failed to stash email: %w", err) | ||||
| 		} | ||||
|  | ||||
| @ -62,6 +62,11 @@ func (h *FlowPilotHandler) ProfileFlowHandler(c echo.Context) error { | ||||
| 	return h.executeFlow(c, profileFlow) | ||||
| } | ||||
|  | ||||
| func (h *FlowPilotHandler) TokenExchangeFlowHandler(c echo.Context) error { | ||||
| 	samlIdPInitiatedLoginFlow := flow.NewTokenExchangeFlow(h.Cfg.Debug) | ||||
| 	return h.executeFlow(c, samlIdPInitiatedLoginFlow) | ||||
| } | ||||
|  | ||||
| func (h *FlowPilotHandler) validateSession(c echo.Context) error { | ||||
| 	lookup := fmt.Sprintf("header:Authorization:Bearer,cookie:%s", h.Cfg.Session.Cookie.GetName()) | ||||
| 	extractors, err := echojwt.CreateExtractors(lookup) | ||||
|  | ||||
| @ -3,6 +3,7 @@ module github.com/teamhanko/hanko/backend | ||||
| go 1.20 | ||||
|  | ||||
| require ( | ||||
| 	github.com/beevik/etree v1.1.0 | ||||
| 	github.com/brianvoe/gofakeit/v6 v6.28.0 | ||||
| 	github.com/coreos/go-oidc/v3 v3.9.0 | ||||
| 	github.com/fatih/structs v1.1.0 | ||||
| @ -28,6 +29,7 @@ require ( | ||||
| 	github.com/labstack/gommon v0.4.2 | ||||
| 	github.com/lestrrat-go/jwx/v2 v2.1.0 | ||||
| 	github.com/lib/pq v1.10.9 | ||||
| 	github.com/mattermost/xml-roundtrip-validator v0.1.0 | ||||
| 	github.com/mileusna/useragent v1.3.5 | ||||
| 	github.com/mitchellh/mapstructure v1.5.0 | ||||
| 	github.com/nicksnyder/go-i18n/v2 v2.4.0 | ||||
| @ -63,7 +65,6 @@ require ( | ||||
| 	github.com/andybalholm/brotli v1.0.5 // indirect | ||||
| 	github.com/aymerick/douceur v0.2.0 // indirect | ||||
| 	github.com/bahlo/generic-list-go v0.2.0 // indirect | ||||
| 	github.com/beevik/etree v1.1.0 // indirect | ||||
| 	github.com/beorn7/perks v1.0.1 // indirect | ||||
| 	github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect | ||||
| 	github.com/buger/jsonparser v1.1.1 // indirect | ||||
| @ -123,7 +124,6 @@ require ( | ||||
| 	github.com/lestrrat-go/option v1.0.1 // indirect | ||||
| 	github.com/luna-duclos/instrumentedsql v1.1.3 // indirect | ||||
| 	github.com/mailru/easyjson v0.7.7 // indirect | ||||
| 	github.com/mattermost/xml-roundtrip-validator v0.1.0 // indirect | ||||
| 	github.com/mattn/go-colorable v0.1.13 // indirect | ||||
| 	github.com/mattn/go-isatty v0.0.20 // indirect | ||||
| 	github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect | ||||
|  | ||||
| @ -86,6 +86,10 @@ func NewPublicRouter(cfg *config.Config, persister persistence.Persister, promet | ||||
| 	e.POST("/login", flowAPIHandler.LoginFlowHandler, webhookMiddleware) | ||||
| 	e.POST("/profile", flowAPIHandler.ProfileFlowHandler, webhookMiddleware) | ||||
|  | ||||
| 	if cfg.Saml.Enabled { | ||||
| 		e.POST("/token_exchange", flowAPIHandler.TokenExchangeFlowHandler, webhookMiddleware) | ||||
| 	} | ||||
|  | ||||
| 	e.HideBanner = true | ||||
| 	g := e.Group("") | ||||
|  | ||||
|  | ||||
| @ -23,7 +23,7 @@ type identityPersister struct { | ||||
|  | ||||
| func (p identityPersister) GetByID(identityID uuid.UUID) (*models.Identity, error) { | ||||
| 	identity := &models.Identity{} | ||||
| 	if err := p.db.EagerPreload("Email", "Email.User", "Email.User.Username").Find(identity, identityID); err != nil { | ||||
| 	if err := p.db.EagerPreload("Email", "Email.User", "Email.User.Username", "SamlIdentity").Find(identity, identityID); err != nil { | ||||
| 		if errors.Is(err, sql.ErrNoRows) { | ||||
| 			return nil, nil | ||||
| 		} | ||||
|  | ||||
| @ -0,0 +1 @@ | ||||
| drop_table("saml_idp_initiated_requests") | ||||
| @ -0,0 +1,9 @@ | ||||
| create_table("saml_idp_initiated_requests") { | ||||
| 	t.Column("id", "uuid", {primary: true}) | ||||
| 	t.Column("response_id", "string", { "null": false }) | ||||
| 	t.Column("issuer", "string", { "null": false }) | ||||
| 	t.Column("expires_at", "timestamp", { "null": false }) | ||||
|     t.Column("created_at", "timestamp", { "null": false }) | ||||
|     t.DisableTimestamps() | ||||
|     t.Index(["response_id", "issuer"], {"unique": true}) | ||||
| } | ||||
							
								
								
									
										43
									
								
								backend/persistence/models/saml_idp_initiated_request.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										43
									
								
								backend/persistence/models/saml_idp_initiated_request.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,43 @@ | ||||
| package models | ||||
|  | ||||
| import ( | ||||
| 	"github.com/gobuffalo/pop/v6" | ||||
| 	"github.com/gobuffalo/validate/v3" | ||||
| 	"github.com/gobuffalo/validate/v3/validators" | ||||
| 	"github.com/gofrs/uuid" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| type SamlIDPInitiatedRequest struct { | ||||
| 	ID         uuid.UUID `db:"id"` | ||||
| 	ResponseID string    `db:"response_id"` | ||||
| 	Issuer     string    `db:"issuer"` | ||||
| 	ExpiresAt  time.Time `db:"expires_at"` | ||||
| 	CreatedAt  time.Time `db:"created_at"` | ||||
| } | ||||
|  | ||||
| func NewSamlIDPInitiatedRequest(responseID, issuer string, expiresAt time.Time) (*SamlIDPInitiatedRequest, error) { | ||||
| 	id, _ := uuid.NewV4() | ||||
|  | ||||
| 	return &SamlIDPInitiatedRequest{ | ||||
| 		ID:         id, | ||||
| 		ResponseID: responseID, | ||||
| 		Issuer:     issuer, | ||||
| 		ExpiresAt:  expiresAt, | ||||
| 		CreatedAt:  time.Now().UTC(), | ||||
| 	}, nil | ||||
| } | ||||
|  | ||||
| func (samlIDPInitiatedRequest SamlIDPInitiatedRequest) TableName() string { | ||||
| 	return "saml_idp_initiated_requests" | ||||
| } | ||||
|  | ||||
| func (r *SamlIDPInitiatedRequest) Validate(tx *pop.Connection) (*validate.Errors, error) { | ||||
| 	return validate.Validate( | ||||
| 		&validators.UUIDIsPresent{Name: "ID", Field: r.ID}, | ||||
| 		&validators.StringIsPresent{Name: "ResponseID", Field: r.ResponseID}, | ||||
| 		&validators.StringIsPresent{Name: "Issuer", Field: r.Issuer}, | ||||
| 		&validators.TimeIsPresent{Name: "ExpiresAt", Field: r.ExpiresAt}, | ||||
| 		&validators.TimeIsPresent{Name: "CreatedAt", Field: r.CreatedAt}, | ||||
| 	), nil | ||||
| } | ||||
| @ -39,6 +39,8 @@ type Persister interface { | ||||
| 	GetSamlStatePersisterWithConnection(tx *pop.Connection) SamlStatePersister | ||||
| 	GetSamlIdentityPersister() SamlIdentityPersister | ||||
| 	GetSamlIdentityPersisterWithConnection(tx *pop.Connection) SamlIdentityPersister | ||||
| 	GetSamlIDPInitiatedRequestPersister() SamlIDPInitiatedRequestPersister | ||||
| 	GetSamlIDPInitiatedRequestPersisterWithConnection(tx *pop.Connection) SamlIDPInitiatedRequestPersister | ||||
| 	GetTokenPersister() TokenPersister | ||||
| 	GetTokenPersisterWithConnection(tx *pop.Connection) TokenPersister | ||||
| 	GetUserPersister() UserPersister | ||||
| @ -288,6 +290,14 @@ func (p *persister) GetSamlIdentityPersisterWithConnection(tx *pop.Connection) S | ||||
| 	return NewSamlIdentityPersister(tx) | ||||
| } | ||||
|  | ||||
| func (p *persister) GetSamlIDPInitiatedRequestPersister() SamlIDPInitiatedRequestPersister { | ||||
| 	return NewSamlIDPInitiatedRequestPersister(p.DB) | ||||
| } | ||||
|  | ||||
| func (p *persister) GetSamlIDPInitiatedRequestPersisterWithConnection(tx *pop.Connection) SamlIDPInitiatedRequestPersister { | ||||
| 	return NewSamlIDPInitiatedRequestPersister(tx) | ||||
| } | ||||
|  | ||||
| func (p *persister) GetWebhookPersister(tx *pop.Connection) WebhookPersister { | ||||
| 	if tx != nil { | ||||
| 		return NewWebhookPersister(tx) | ||||
|  | ||||
							
								
								
									
										48
									
								
								backend/persistence/saml_idp_inititated_request_persister.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										48
									
								
								backend/persistence/saml_idp_inititated_request_persister.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,48 @@ | ||||
| package persistence | ||||
|  | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/gobuffalo/pop/v6" | ||||
| 	"github.com/teamhanko/hanko/backend/persistence/models" | ||||
| ) | ||||
|  | ||||
| type SamlIDPInitiatedRequestPersister interface { | ||||
| 	Create(samlIDPInitiatedRequest models.SamlIDPInitiatedRequest) error | ||||
| 	GetByResponseIDAndIssuer(responseID, entityID string) (*models.SamlIDPInitiatedRequest, error) | ||||
| } | ||||
|  | ||||
| type samlIDPInitiatedRequestPersister struct { | ||||
| 	db *pop.Connection | ||||
| } | ||||
|  | ||||
| func (p samlIDPInitiatedRequestPersister) GetByResponseIDAndIssuer(responseID, entityID string) (*models.SamlIDPInitiatedRequest, error) { | ||||
| 	samlIDPInitiatedRequest := models.SamlIDPInitiatedRequest{} | ||||
| 	query := p.db.Where("response_id = ? AND idp_entity_id = ?", responseID, entityID) | ||||
| 	err := query.First(&samlIDPInitiatedRequest) | ||||
| 	if err != nil && errors.Is(err, sql.ErrNoRows) { | ||||
| 		return nil, nil | ||||
| 	} | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("failed to get credential: %w", err) | ||||
| 	} | ||||
| 	return &samlIDPInitiatedRequest, nil | ||||
| } | ||||
|  | ||||
| func NewSamlIDPInitiatedRequestPersister(db *pop.Connection) SamlIDPInitiatedRequestPersister { | ||||
| 	return &samlIDPInitiatedRequestPersister{db: db} | ||||
| } | ||||
|  | ||||
| func (p samlIDPInitiatedRequestPersister) Create(samlIDPInitiatedRequest models.SamlIDPInitiatedRequest) error { | ||||
| 	vErr, err := p.db.ValidateAndCreate(&samlIDPInitiatedRequest) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("failed to store saml idp initiated request: %w", err) | ||||
| 	} | ||||
|  | ||||
| 	if vErr != nil && vErr.HasAny() { | ||||
| 		return fmt.Errorf("saml idp initated request object validation failed: %w", vErr) | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
| @ -464,6 +464,7 @@ const AppProvider = ({ | ||||
|             .run(); | ||||
|  | ||||
|           searchParams.delete("hanko_token"); | ||||
|           searchParams.delete("saml_hint"); | ||||
|  | ||||
|           history.replaceState( | ||||
|             null, | ||||
| @ -524,7 +525,17 @@ const AppProvider = ({ | ||||
|         "hanko_token", | ||||
|       ); | ||||
|       const cachedState = localStorage.getItem(localStorageCacheStateKey); | ||||
|       if (cachedState && cachedState.length > 0 && token && token.length > 0) { | ||||
|       const samlHint = new URLSearchParams(window.location.search).get( | ||||
|         "saml_hint", | ||||
|       ); | ||||
|       if (samlHint === "idp_initiated") { | ||||
|         await hanko.flow.init("/token_exchange", { ...stateHandler }); | ||||
|       } else if ( | ||||
|         cachedState && | ||||
|         cachedState.length > 0 && | ||||
|         token && | ||||
|         token.length > 0 | ||||
|       ) { | ||||
|         await hanko.flow.fromString( | ||||
|           localStorage.getItem(localStorageCacheStateKey), | ||||
|           { ...stateHandler }, | ||||
|  | ||||
| @ -121,7 +121,11 @@ export interface Payloads { | ||||
|   readonly webauthn_credential_verification: OnboardingVerifyPasskeyAttestationPayload; | ||||
| } | ||||
|  | ||||
| export type FlowPath = "/login" | "/registration" | "/profile"; | ||||
| export type FlowPath = | ||||
|   | "/login" | ||||
|   | "/registration" | ||||
|   | "/profile" | ||||
|   | "/token_exchange"; | ||||
|  | ||||
| export type FetchNextState = ( | ||||
|   // eslint-disable-next-line no-unused-vars | ||||
|  | ||||
		Reference in New Issue
	
	Block a user