diff --git a/ee/authn/callbackauthn/oidccallbackauthn/authn.go b/ee/authn/callbackauthn/oidccallbackauthn/authn.go index b1a048fbb5..de904928b1 100644 --- a/ee/authn/callbackauthn/oidccallbackauthn/authn.go +++ b/ee/authn/callbackauthn/oidccallbackauthn/authn.go @@ -129,6 +129,12 @@ func (a *AuthN) HandleCallback(ctx context.Context, query url.Values) (*authtype return authtypes.NewCallbackIdentity("", email, authDomain.StorableAuthDomain().OrgID, state), nil } +func (a *AuthN) ProviderInfo(ctx context.Context, authDomain *authtypes.AuthDomain) *authtypes.AuthNProviderInfo { + return &authtypes.AuthNProviderInfo{ + RelayStatePath: nil, + } +} + func (a *AuthN) oidcProviderAndoauth2Config(ctx context.Context, siteURL *url.URL, authDomain *authtypes.AuthDomain) (*oidc.Provider, *oauth2.Config, error) { if authDomain.AuthDomainConfig().OIDC.IssuerAlias != "" { ctx = oidc.InsecureIssuerURLContext(ctx, authDomain.AuthDomainConfig().OIDC.IssuerAlias) diff --git a/ee/authn/callbackauthn/samlcallbackauthn/authn.go b/ee/authn/callbackauthn/samlcallbackauthn/authn.go index 1fc99d0744..89714f305c 100644 --- a/ee/authn/callbackauthn/samlcallbackauthn/authn.go +++ b/ee/authn/callbackauthn/samlcallbackauthn/authn.go @@ -99,6 +99,14 @@ func (a *AuthN) HandleCallback(ctx context.Context, formValues url.Values) (*aut return authtypes.NewCallbackIdentity("", email, authDomain.StorableAuthDomain().OrgID, state), nil } +func (a *AuthN) ProviderInfo(ctx context.Context, authDomain *authtypes.AuthDomain) *authtypes.AuthNProviderInfo { + state := authtypes.NewState(&url.URL{Path: "login"}, authDomain.StorableAuthDomain().ID).URL.String() + + return &authtypes.AuthNProviderInfo{ + RelayStatePath: &state, + } +} + func (a *AuthN) serviceProvider(siteURL *url.URL, authDomain *authtypes.AuthDomain) (*saml2.SAMLServiceProvider, error) { certStore, err := a.getCertificateStore(authDomain) if err != nil { diff --git a/pkg/authn/authn.go b/pkg/authn/authn.go index b9aca5989e..645f13b7f1 100644 --- a/pkg/authn/authn.go +++ b/pkg/authn/authn.go @@ -22,4 +22,7 @@ type CallbackAuthN interface { // Handle the callback from the provider. HandleCallback(context.Context, url.Values) (*authtypes.CallbackIdentity, error) + + // Get provider info such as `relay state` + ProviderInfo(context.Context, *authtypes.AuthDomain) *authtypes.AuthNProviderInfo } diff --git a/pkg/authn/callbackauthn/googlecallbackauthn/authn.go b/pkg/authn/callbackauthn/googlecallbackauthn/authn.go index 2c68b48c6e..ca48db7975 100644 --- a/pkg/authn/callbackauthn/googlecallbackauthn/authn.go +++ b/pkg/authn/callbackauthn/googlecallbackauthn/authn.go @@ -117,6 +117,12 @@ func (a *AuthN) HandleCallback(ctx context.Context, query url.Values) (*authtype } +func (a *AuthN) ProviderInfo(ctx context.Context, authDomain *authtypes.AuthDomain) *authtypes.AuthNProviderInfo { + return &authtypes.AuthNProviderInfo{ + RelayStatePath: nil, + } +} + func (a *AuthN) oauth2Config(siteURL *url.URL, authDomain *authtypes.AuthDomain, provider *oidc.Provider) *oauth2.Config { return &oauth2.Config{ ClientID: authDomain.AuthDomainConfig().Google.ClientID, diff --git a/pkg/modules/authdomain/authdomain.go b/pkg/modules/authdomain/authdomain.go index ebebfbe9ca..010175da67 100644 --- a/pkg/modules/authdomain/authdomain.go +++ b/pkg/modules/authdomain/authdomain.go @@ -29,6 +29,9 @@ type Module interface { // Delete an existing auth domain by id. Delete(context.Context, valuer.UUID, valuer.UUID) error + + // Get the IDP info of the domain provided. + GetAuthNProviderInfo(context.Context, *authtypes.AuthDomain) (*authtypes.AuthNProviderInfo) } type Handler interface { diff --git a/pkg/modules/authdomain/implauthdomain/handler.go b/pkg/modules/authdomain/implauthdomain/handler.go index cd271dff78..4b0ceab48d 100644 --- a/pkg/modules/authdomain/implauthdomain/handler.go +++ b/pkg/modules/authdomain/implauthdomain/handler.go @@ -95,7 +95,7 @@ func (handler *handler) List(rw http.ResponseWriter, r *http.Request) { authDomains := make([]*authtypes.GettableAuthDomain, len(domains)) for i, domain := range domains { - authDomains[i] = authtypes.NewGettableAuthDomainFromAuthDomain(domain) + authDomains[i] = authtypes.NewGettableAuthDomainFromAuthDomain(domain, handler.module.GetAuthNProviderInfo(ctx, domain)) } render.Success(rw, http.StatusOK, authDomains) diff --git a/pkg/modules/authdomain/implauthdomain/module.go b/pkg/modules/authdomain/implauthdomain/module.go index 2532ac0f5b..08d2486670 100644 --- a/pkg/modules/authdomain/implauthdomain/module.go +++ b/pkg/modules/authdomain/implauthdomain/module.go @@ -3,17 +3,19 @@ package implauthdomain import ( "context" + "github.com/SigNoz/signoz/pkg/authn" "github.com/SigNoz/signoz/pkg/modules/authdomain" "github.com/SigNoz/signoz/pkg/types/authtypes" "github.com/SigNoz/signoz/pkg/valuer" ) type module struct { - store authtypes.AuthDomainStore + store authtypes.AuthDomainStore + authNs map[authtypes.AuthNProvider]authn.AuthN } -func NewModule(store authtypes.AuthDomainStore) authdomain.Module { - return &module{store: store} +func NewModule(store authtypes.AuthDomainStore, authNs map[authtypes.AuthNProvider]authn.AuthN) authdomain.Module { + return &module{store: store, authNs: authNs} } func (module *module) Create(ctx context.Context, domain *authtypes.AuthDomain) error { @@ -24,6 +26,13 @@ func (module *module) Get(ctx context.Context, id valuer.UUID) (*authtypes.AuthD return module.store.Get(ctx, id) } +func (module *module) GetAuthNProviderInfo(ctx context.Context, domain *authtypes.AuthDomain) *authtypes.AuthNProviderInfo { + if callbackAuthN, ok := module.authNs[domain.AuthDomainConfig().AuthNProvider].(authn.CallbackAuthN); ok { + return callbackAuthN.ProviderInfo(ctx, domain) + } + return &authtypes.AuthNProviderInfo{} +} + func (module *module) GetByOrgIDAndID(ctx context.Context, orgID valuer.UUID, id valuer.UUID) (*authtypes.AuthDomain, error) { return module.store.GetByOrgIDAndID(ctx, orgID, id) } diff --git a/pkg/signoz/module.go b/pkg/signoz/module.go index 77d518d413..2f4f8ce6d3 100644 --- a/pkg/signoz/module.go +++ b/pkg/signoz/module.go @@ -90,8 +90,8 @@ func NewModules( QuickFilter: quickfilter, TraceFunnel: impltracefunnel.NewModule(impltracefunnel.NewStore(sqlstore)), RawDataExport: implrawdataexport.NewModule(querier), - AuthDomain: implauthdomain.NewModule(implauthdomain.NewStore(sqlstore)), - Session: implsession.NewModule(providerSettings, authNs, user, userGetter, implauthdomain.NewModule(implauthdomain.NewStore(sqlstore)), tokenizer, orgGetter), + AuthDomain: implauthdomain.NewModule(implauthdomain.NewStore(sqlstore), authNs), + Session: implsession.NewModule(providerSettings, authNs, user, userGetter, implauthdomain.NewModule(implauthdomain.NewStore(sqlstore), authNs), tokenizer, orgGetter), SpanPercentile: implspanpercentile.NewModule(querier, providerSettings), Services: implservices.NewModule(querier, telemetryStore), } diff --git a/pkg/types/authtypes/domain.go b/pkg/types/authtypes/domain.go index 2caad1e197..a33ad4c6fb 100644 --- a/pkg/types/authtypes/domain.go +++ b/pkg/types/authtypes/domain.go @@ -31,6 +31,11 @@ var ( type GettableAuthDomain struct { *StorableAuthDomain *AuthDomainConfig + AuthNProviderInfo *AuthNProviderInfo `json:"authNProviderInfo"` +} + +type AuthNProviderInfo struct { + RelayStatePath *string `json:"relayStatePath"` } type PostableAuthDomain struct { @@ -103,10 +108,11 @@ func NewAuthDomainFromStorableAuthDomain(storableAuthDomain *StorableAuthDomain) }, nil } -func NewGettableAuthDomainFromAuthDomain(authDomain *AuthDomain) *GettableAuthDomain { +func NewGettableAuthDomainFromAuthDomain(authDomain *AuthDomain, authNProviderInfo *AuthNProviderInfo) *GettableAuthDomain { return &GettableAuthDomain{ StorableAuthDomain: authDomain.StorableAuthDomain(), AuthDomainConfig: authDomain.AuthDomainConfig(), + AuthNProviderInfo: authNProviderInfo, } } diff --git a/tests/integration/fixtures/idputils.py b/tests/integration/fixtures/idputils.py index 45ad84428c..3849061dcd 100644 --- a/tests/integration/fixtures/idputils.py +++ b/tests/integration/fixtures/idputils.py @@ -1,4 +1,4 @@ -from typing import Callable +from typing import Callable, Dict, Any from urllib.parse import urljoin from xml.etree import ElementTree @@ -121,6 +121,31 @@ def create_saml_client( return _create_saml_client +@pytest.fixture(name="update_saml_client_attributes", scope="function") +def update_saml_client_attributes( + idp: types.TestContainerIDP +) -> Callable[[str, Dict[str, Any]], None]: + def _update_saml_client_attributes(client_id: str, attributes: Dict[str, Any]) -> None: + client = KeycloakAdmin( + server_url=idp.container.host_configs["6060"].base(), + username=IDP_ROOT_USERNAME, + password=IDP_ROOT_PASSWORD, + realm_name="master", + ) + + kc_client_id = client.get_client_id(client_id=client_id) + print("kc_client_id: " + kc_client_id) + + payload = client.get_client(client_id=kc_client_id) + + for attr_key, attr_value in attributes.items(): + payload["attributes"][attr_key] = attr_value + + client.update_client(client_id=kc_client_id, payload=payload) + + return _update_saml_client_attributes + + @pytest.fixture(name="create_oidc_client", scope="function") def create_oidc_client( idp: types.TestContainerIDP, signoz: types.SigNoz diff --git a/tests/integration/src/callbackauthn/b_saml.py b/tests/integration/src/callbackauthn/b_saml.py index b3c056a796..cd22710cf7 100644 --- a/tests/integration/src/callbackauthn/b_saml.py +++ b/tests/integration/src/callbackauthn/b_saml.py @@ -1,5 +1,5 @@ from http import HTTPStatus -from typing import Callable, List +from typing import Callable, List, Dict, Any import requests from selenium import webdriver @@ -26,6 +26,7 @@ def test_create_auth_domain( signoz: SigNoz, idp: TestContainerIDP, # pylint: disable=unused-argument create_saml_client: Callable[[str, str], None], + update_saml_client_attributes: Callable[[str, Dict[str, Any]], None], get_saml_settings: Callable[[], dict], create_user_admin: Callable[[], None], # pylint: disable=unused-argument get_token: Callable[[str, str], str], @@ -59,6 +60,43 @@ def test_create_auth_domain( assert response.status_code == HTTPStatus.CREATED + # Get the domains from signoz + response = requests.get( + signoz.self.host_configs["8080"].get("/api/v1/domains"), + headers={"Authorization": f"Bearer {admin_token}"}, + timeout=2, + ) + + assert response.status_code == HTTPStatus.OK + + found_domain = None + + if len(response.json()["data"]) > 0: + found_domain = next( + ( + domain + for domain in response.json()["data"] + if domain["name"] == "saml.integration.test" + ), + None, + ) + + relay_state_path = found_domain["authNProviderInfo"]["relayStatePath"] + + assert relay_state_path is not None + + # Get the relay state url from domains API + relay_state_url = signoz.self.host_configs["8080"].base() + "/" + relay_state_path + + # Update the saml client with new attributes + update_saml_client_attributes( + f"{signoz.self.host_configs['8080'].address}:{signoz.self.host_configs['8080'].port}", + { + "saml_idp_initiated_sso_url_name": "idp-initiated-saml-test", + "saml_idp_initiated_sso_relay_state": relay_state_url, + "saml_assertion_consumer_url_post": signoz.self.host_configs["8080"].get("/api/v1/complete/saml") + } + ) def test_saml_authn( signoz: SigNoz, @@ -106,3 +144,51 @@ def test_saml_authn( assert found_user is not None assert found_user["role"] == "VIEWER" + + +def test_idp_initiated_saml_authn( + signoz: SigNoz, + idp: TestContainerIDP, # pylint: disable=unused-argument + driver: webdriver.Chrome, + create_user_idp: Callable[[str, str], None], + idp_login: Callable[[str, str], None], + get_token: Callable[[str, str], str], + get_session_context: Callable[[str], str], +) -> None: + # Create a user in the idp. + create_user_idp("viewer.idp.initiated@saml.integration.test", "password", True) + + # Get the session context from signoz which will give the SAML login URL. + session_context = get_session_context("viewer.idp.initiated@saml.integration.test") + + assert len(session_context["orgs"]) == 1 + assert len(session_context["orgs"][0]["authNSupport"]["callback"]) == 1 + + idp_initiated_login_url = idp.container.host_configs["6060"].base() + "/realms/master/protocol/saml/clients/idp-initiated-saml-test" + + driver.get(idp_initiated_login_url) + idp_login("viewer.idp.initiated@saml.integration.test", "password") + + admin_token = get_token(USER_ADMIN_EMAIL, USER_ADMIN_PASSWORD) + + # Assert that the user was created in signoz. + response = requests.get( + signoz.self.host_configs["8080"].get("/api/v1/user"), + timeout=2, + headers={"Authorization": f"Bearer {admin_token}"}, + ) + + assert response.status_code == HTTPStatus.OK + + user_response = response.json()["data"] + found_user = next( + ( + user + for user in user_response + if user["email"] == "viewer.idp.initiated@saml.integration.test" + ), + None, + ) + + assert found_user is not None + assert found_user["role"] == "VIEWER"