diff --git a/crates/api_models/src/user.rs b/crates/api_models/src/user.rs index 7b5911cf1a..ea979bfe73 100644 --- a/crates/api_models/src/user.rs +++ b/crates/api_models/src/user.rs @@ -305,18 +305,26 @@ pub struct CreateUserAuthenticationMethodRequest { pub owner_type: common_enums::Owner, pub auth_method: AuthConfig, pub allow_signup: bool, + pub email_domain: Option, } #[derive(Debug, serde::Deserialize, serde::Serialize)] -pub struct UpdateUserAuthenticationMethodRequest { - pub id: String, - // TODO: When adding more fields make config and new fields option - pub auth_method: AuthConfig, +#[serde(rename_all = "snake_case")] +pub enum UpdateUserAuthenticationMethodRequest { + AuthMethod { + id: String, + auth_config: AuthConfig, + }, + EmailDomain { + owner_id: String, + email_domain: String, + }, } #[derive(Debug, serde::Deserialize, serde::Serialize)] pub struct GetUserAuthenticationMethodsRequest { - pub auth_id: String, + pub auth_id: Option, + pub email_domain: Option, } #[derive(Debug, serde::Deserialize, serde::Serialize)] diff --git a/crates/diesel_models/src/query/user_authentication_method.rs b/crates/diesel_models/src/query/user_authentication_method.rs index 14a28269ce..b9ed95e37c 100644 --- a/crates/diesel_models/src/query/user_authentication_method.rs +++ b/crates/diesel_models/src/query/user_authentication_method.rs @@ -64,4 +64,18 @@ impl UserAuthenticationMethod { ) .await } + + pub async fn list_user_authentication_methods_for_email_domain( + conn: &PgPooledConn, + email_domain: &str, + ) -> StorageResult> { + generics::generic_filter::<::Table, _, _, _>( + conn, + dsl::email_domain.eq(email_domain.to_owned()), + None, + None, + Some(dsl::last_modified_at.asc()), + ) + .await + } } diff --git a/crates/diesel_models/src/schema.rs b/crates/diesel_models/src/schema.rs index 95bb714cb7..37a39cb731 100644 --- a/crates/diesel_models/src/schema.rs +++ b/crates/diesel_models/src/schema.rs @@ -1405,6 +1405,8 @@ diesel::table! { allow_signup -> Bool, created_at -> Timestamp, last_modified_at -> Timestamp, + #[max_length = 64] + email_domain -> Varchar, } } diff --git a/crates/diesel_models/src/schema_v2.rs b/crates/diesel_models/src/schema_v2.rs index 8bbb4baf9d..3e27f29427 100644 --- a/crates/diesel_models/src/schema_v2.rs +++ b/crates/diesel_models/src/schema_v2.rs @@ -1352,6 +1352,8 @@ diesel::table! { allow_signup -> Bool, created_at -> Timestamp, last_modified_at -> Timestamp, + #[max_length = 64] + email_domain -> Varchar, } } diff --git a/crates/diesel_models/src/user_authentication_method.rs b/crates/diesel_models/src/user_authentication_method.rs index 76e1abe757..7f3799aa7b 100644 --- a/crates/diesel_models/src/user_authentication_method.rs +++ b/crates/diesel_models/src/user_authentication_method.rs @@ -17,6 +17,7 @@ pub struct UserAuthenticationMethod { pub allow_signup: bool, pub created_at: PrimitiveDateTime, pub last_modified_at: PrimitiveDateTime, + pub email_domain: String, } #[derive(router_derive::Setter, Clone, Debug, Insertable, router_derive::DebugAsDisplay)] @@ -32,6 +33,7 @@ pub struct UserAuthenticationMethodNew { pub allow_signup: bool, pub created_at: PrimitiveDateTime, pub last_modified_at: PrimitiveDateTime, + pub email_domain: String, } #[derive(Clone, Debug, AsChangeset, router_derive::DebugAsDisplay)] @@ -40,6 +42,7 @@ pub struct OrgAuthenticationMethodUpdateInternal { pub private_config: Option, pub public_config: Option, pub last_modified_at: PrimitiveDateTime, + pub email_domain: Option, } pub enum UserAuthenticationMethodUpdate { @@ -47,6 +50,9 @@ pub enum UserAuthenticationMethodUpdate { private_config: Option, public_config: Option, }, + EmailDomain { + email_domain: String, + }, } impl From for OrgAuthenticationMethodUpdateInternal { @@ -60,6 +66,13 @@ impl From for OrgAuthenticationMethodUpdateInter private_config, public_config, last_modified_at, + email_domain: None, + }, + UserAuthenticationMethodUpdate::EmailDomain { email_domain } => Self { + private_config: None, + public_config: None, + last_modified_at, + email_domain: Some(email_domain), }, } } diff --git a/crates/router/src/core/errors/user.rs b/crates/router/src/core/errors/user.rs index fa5f185ab8..6af269d916 100644 --- a/crates/router/src/core/errors/user.rs +++ b/crates/router/src/core/errors/user.rs @@ -108,6 +108,8 @@ pub enum UserErrors { InvalidThemeLineage(String), #[error("Missing required field: email_config")] MissingEmailConfig, + #[error("Invalid Auth Method Operation: {0}")] + InvalidAuthMethodOperationWithMessage(String), } impl common_utils::errors::ErrorSwitch for UserErrors { @@ -280,6 +282,9 @@ impl common_utils::errors::ErrorSwitch { AER::BadRequest(ApiError::new(sub_code, 56, self.get_error_message(), None)) } + Self::InvalidAuthMethodOperationWithMessage(_) => { + AER::BadRequest(ApiError::new(sub_code, 57, self.get_error_message(), None)) + } } } } @@ -347,6 +352,9 @@ impl UserErrors { format!("Invalid field: {} in lineage", field_name) } Self::MissingEmailConfig => "Missing required field: email_config".to_string(), + Self::InvalidAuthMethodOperationWithMessage(operation) => { + format!("Invalid Auth Method Operation: {}", operation) + } } } } diff --git a/crates/router/src/core/user.rs b/crates/router/src/core/user.rs index 19f257ef4f..a7c60f33ff 100644 --- a/crates/router/src/core/user.rs +++ b/crates/router/src/core/user.rs @@ -7,7 +7,7 @@ use api_models::{ payments::RedirectionResponse, user::{self as user_api, InviteMultipleUserResponse, NameIdUnit}, }; -use common_enums::EntityType; +use common_enums::{EntityType, UserAuthType}; use common_utils::{type_name, types::keymanager::Identifier}; #[cfg(feature = "email")] use diesel_models::user_role::UserRoleUpdate; @@ -22,6 +22,7 @@ use masking::{ExposeInterface, PeekInterface, Secret}; #[cfg(feature = "email")] use router_env::env; use router_env::logger; +use storage_impl::errors::StorageError; #[cfg(not(feature = "email"))] use user_api::dashboard_metadata::SetMetaDataRequest; @@ -152,6 +153,14 @@ pub async fn signup_token_only_flow( state: SessionState, request: user_api::SignUpRequest, ) -> UserResponse { + let user_email = domain::UserEmail::from_pii_email(request.email.clone())?; + utils::user::validate_email_domain_auth_type_using_db( + &state, + &user_email, + UserAuthType::Password, + ) + .await?; + let new_user = domain::NewUser::try_from(request)?; new_user .get_new_merchant() @@ -187,9 +196,18 @@ pub async fn signin_token_only_flow( state: SessionState, request: user_api::SignInRequest, ) -> UserResponse { + let user_email = domain::UserEmail::from_pii_email(request.email)?; + + utils::user::validate_email_domain_auth_type_using_db( + &state, + &user_email, + UserAuthType::Password, + ) + .await?; + let user_from_db: domain::UserFromStorage = state .global_store - .find_user_by_email(&domain::UserEmail::from_pii_email(request.email)?) + .find_user_by_email(&user_email) .await .to_not_found_response(UserErrors::InvalidCredentials)? .into(); @@ -215,10 +233,16 @@ pub async fn connect_account( auth_id: Option, theme_id: Option, ) -> UserResponse { - let find_user = state - .global_store - .find_user_by_email(&domain::UserEmail::from_pii_email(request.email.clone())?) - .await; + let user_email = domain::UserEmail::from_pii_email(request.email.clone())?; + + utils::user::validate_email_domain_auth_type_using_db( + &state, + &user_email, + UserAuthType::MagicLink, + ) + .await?; + + let find_user = state.global_store.find_user_by_email(&user_email).await; if let Ok(found_user) = find_user { let user_from_db: domain::UserFromStorage = found_user.into(); @@ -412,6 +436,13 @@ pub async fn forgot_password( ) -> UserResponse<()> { let user_email = domain::UserEmail::from_pii_email(request.email)?; + utils::user::validate_email_domain_auth_type_using_db( + &state, + &user_email, + UserAuthType::Password, + ) + .await?; + let user_from_db = state .global_store .find_user_by_email(&user_email) @@ -1757,7 +1788,15 @@ pub async fn send_verification_mail( auth_id: Option, theme_id: Option, ) -> UserResponse<()> { - let user_email = domain::UserEmail::try_from(req.email)?; + let user_email = domain::UserEmail::from_pii_email(req.email)?; + + utils::user::validate_email_domain_auth_type_using_db( + &state, + &user_email, + UserAuthType::MagicLink, + ) + .await?; + let user = state .global_store .find_user_by_email(&user_email) @@ -2317,10 +2356,30 @@ pub async fn create_user_authentication_method( .change_context(UserErrors::InternalServerError) .attach_printable("Failed to get list of auth methods for the owner id")?; - let auth_id = auth_methods - .first() - .map(|auth_method| auth_method.auth_id.clone()) - .unwrap_or(uuid::Uuid::new_v4().to_string()); + let (auth_id, email_domain) = if let Some(auth_method) = auth_methods.first() { + let email_domain = match req.email_domain { + Some(email_domain) => { + if email_domain != auth_method.email_domain { + return Err(report!(UserErrors::InvalidAuthMethodOperationWithMessage( + "Email domain mismatch".to_string() + ))); + } + + email_domain + } + None => auth_method.email_domain.clone(), + }; + + (auth_method.auth_id.clone(), email_domain) + } else { + let email_domain = + req.email_domain + .ok_or(UserErrors::InvalidAuthMethodOperationWithMessage( + "Email domain not found".to_string(), + ))?; + + (uuid::Uuid::new_v4().to_string(), email_domain) + }; for db_auth_method in auth_methods { let is_type_same = db_auth_method.auth_type == (&req.auth_method).foreign_into(); @@ -2360,6 +2419,7 @@ pub async fn create_user_authentication_method( allow_signup: req.allow_signup, created_at: now, last_modified_at: now, + email_domain, }) .await .to_duplicate_response(UserErrors::UserAuthMethodAlreadyExists)?; @@ -2383,25 +2443,71 @@ pub async fn update_user_authentication_method( .change_context(UserErrors::InternalServerError) .attach_printable("Failed to decode DEK")?; - let (private_config, public_config) = utils::user::construct_public_and_private_db_configs( - &state, - &req.auth_method, - &user_auth_encryption_key, - req.id.clone(), - ) - .await?; + match req { + user_api::UpdateUserAuthenticationMethodRequest::AuthMethod { + id, + auth_config: auth_method, + } => { + let (private_config, public_config) = + utils::user::construct_public_and_private_db_configs( + &state, + &auth_method, + &user_auth_encryption_key, + id.clone(), + ) + .await?; + + state + .store + .update_user_authentication_method( + &id, + UserAuthenticationMethodUpdate::UpdateConfig { + private_config, + public_config, + }, + ) + .await + .map_err(|error| { + let user_error = match error.current_context() { + StorageError::ValueNotFound(_) => { + UserErrors::InvalidAuthMethodOperationWithMessage( + "Auth method not found".to_string(), + ) + } + StorageError::DuplicateValue { .. } => { + UserErrors::UserAuthMethodAlreadyExists + } + _ => UserErrors::InternalServerError, + }; + error.change_context(user_error) + })?; + } + user_api::UpdateUserAuthenticationMethodRequest::EmailDomain { + owner_id, + email_domain, + } => { + let auth_methods = state + .store + .list_user_authentication_methods_for_owner_id(&owner_id) + .await + .change_context(UserErrors::InternalServerError)?; + + futures::future::try_join_all(auth_methods.iter().map(|auth_method| async { + state + .store + .update_user_authentication_method( + &auth_method.id, + UserAuthenticationMethodUpdate::EmailDomain { + email_domain: email_domain.clone(), + }, + ) + .await + .to_duplicate_response(UserErrors::UserAuthMethodAlreadyExists) + })) + .await?; + } + } - state - .store - .update_user_authentication_method( - &req.id, - UserAuthenticationMethodUpdate::UpdateConfig { - private_config, - public_config, - }, - ) - .await - .change_context(UserErrors::InvalidUserAuthMethodOperation)?; Ok(ApplicationResponse::StatusOk) } @@ -2409,18 +2515,28 @@ pub async fn list_user_authentication_methods( state: SessionState, req: user_api::GetUserAuthenticationMethodsRequest, ) -> UserResponse> { - let user_authentication_methods = state - .store - .list_user_authentication_methods_for_auth_id(&req.auth_id) - .await - .change_context(UserErrors::InternalServerError)?; + let user_authentication_methods = match (req.auth_id, req.email_domain) { + (Some(auth_id), None) => state + .store + .list_user_authentication_methods_for_auth_id(&auth_id) + .await + .change_context(UserErrors::InternalServerError)?, + (None, Some(email_domain)) => state + .store + .list_user_authentication_methods_for_email_domain(&email_domain) + .await + .change_context(UserErrors::InternalServerError)?, + (Some(_), Some(_)) | (None, None) => { + return Err(UserErrors::InvalidUserAuthMethodOperation.into()); + } + }; Ok(ApplicationResponse::Json( user_authentication_methods .into_iter() .map(|auth_method| { let auth_name = match (auth_method.auth_type, auth_method.public_config) { - (common_enums::UserAuthType::OpenIdConnect, config) => { + (UserAuthType::OpenIdConnect, config) => { let open_id_public_config: Option = config .map(|config| { @@ -2546,6 +2662,13 @@ pub async fn sso_sign( ) .await?; + utils::user::validate_email_domain_auth_type_using_db( + &state, + &email, + UserAuthType::OpenIdConnect, + ) + .await?; + // TODO: Use config to handle not found error let user_from_db: domain::UserFromStorage = state .global_store @@ -2594,14 +2717,20 @@ pub async fn terminate_auth_select( .change_context(UserErrors::InternalServerError)? .into(); - let user_authentication_method = if let Some(id) = &req.id { - state - .store - .get_user_authentication_method_by_id(id) - .await - .to_not_found_response(UserErrors::InvalidUserAuthMethodOperation)? - } else { - DEFAULT_USER_AUTH_METHOD.clone() + let user_email = domain::UserEmail::from_pii_email(user_from_db.get_email())?; + let auth_methods = state + .store + .list_user_authentication_methods_for_email_domain(user_email.extract_domain()?) + .await + .change_context(UserErrors::InternalServerError)?; + + let user_authentication_method = match (req.id, auth_methods.is_empty()) { + (Some(id), _) => auth_methods + .into_iter() + .find(|auth_method| auth_method.id == id) + .ok_or(UserErrors::InvalidUserAuthMethodOperation)?, + (None, true) => DEFAULT_USER_AUTH_METHOD.clone(), + (None, false) => return Err(UserErrors::InvalidUserAuthMethodOperation.into()), }; let current_flow = domain::CurrentFlow::new(user_token, domain::SPTFlow::AuthSelect.into())?; diff --git a/crates/router/src/db/kafka_store.rs b/crates/router/src/db/kafka_store.rs index 525c5f12dc..8eec7f0416 100644 --- a/crates/router/src/db/kafka_store.rs +++ b/crates/router/src/db/kafka_store.rs @@ -3816,6 +3816,18 @@ impl UserAuthenticationMethodInterface for KafkaStore { .update_user_authentication_method(id, user_authentication_method_update) .await } + + async fn list_user_authentication_methods_for_email_domain( + &self, + email_domain: &str, + ) -> CustomResult< + Vec, + errors::StorageError, + > { + self.diesel_store + .list_user_authentication_methods_for_email_domain(email_domain) + .await + } } #[async_trait::async_trait] diff --git a/crates/router/src/db/user_authentication_method.rs b/crates/router/src/db/user_authentication_method.rs index a02e7bdb11..cd918fe206 100644 --- a/crates/router/src/db/user_authentication_method.rs +++ b/crates/router/src/db/user_authentication_method.rs @@ -36,6 +36,11 @@ pub trait UserAuthenticationMethodInterface { id: &str, user_authentication_method_update: storage::UserAuthenticationMethodUpdate, ) -> CustomResult; + + async fn list_user_authentication_methods_for_email_domain( + &self, + email_domain: &str, + ) -> CustomResult, errors::StorageError>; } #[async_trait::async_trait] @@ -57,7 +62,7 @@ impl UserAuthenticationMethodInterface for Store { &self, id: &str, ) -> CustomResult { - let conn = connection::pg_connection_write(self).await?; + let conn = connection::pg_connection_read(self).await?; storage::UserAuthenticationMethod::get_user_authentication_method_by_id(&conn, id) .await .map_err(|error| report!(errors::StorageError::from(error))) @@ -68,7 +73,7 @@ impl UserAuthenticationMethodInterface for Store { &self, auth_id: &str, ) -> CustomResult, errors::StorageError> { - let conn = connection::pg_connection_write(self).await?; + let conn = connection::pg_connection_read(self).await?; storage::UserAuthenticationMethod::list_user_authentication_methods_for_auth_id( &conn, auth_id, ) @@ -81,7 +86,7 @@ impl UserAuthenticationMethodInterface for Store { &self, owner_id: &str, ) -> CustomResult, errors::StorageError> { - let conn = connection::pg_connection_write(self).await?; + let conn = connection::pg_connection_read(self).await?; storage::UserAuthenticationMethod::list_user_authentication_methods_for_owner_id( &conn, owner_id, ) @@ -104,6 +109,20 @@ impl UserAuthenticationMethodInterface for Store { .await .map_err(|error| report!(errors::StorageError::from(error))) } + + #[instrument(skip_all)] + async fn list_user_authentication_methods_for_email_domain( + &self, + email_domain: &str, + ) -> CustomResult, errors::StorageError> { + let conn = connection::pg_connection_read(self).await?; + storage::UserAuthenticationMethod::list_user_authentication_methods_for_email_domain( + &conn, + email_domain, + ) + .await + .map_err(|error| report!(errors::StorageError::from(error))) + } } #[async_trait::async_trait] @@ -130,6 +149,7 @@ impl UserAuthenticationMethodInterface for MockDb { allow_signup: user_authentication_method.allow_signup, created_at: user_authentication_method.created_at, last_modified_at: user_authentication_method.last_modified_at, + email_domain: user_authentication_method.email_domain, }; user_authentication_methods.push(user_authentication_method.clone()); @@ -222,6 +242,13 @@ impl UserAuthenticationMethodInterface for MockDb { last_modified_at: common_utils::date_time::now(), ..auth_method_inner.to_owned() }, + storage::UserAuthenticationMethodUpdate::EmailDomain { email_domain } => { + storage::UserAuthenticationMethod { + email_domain: email_domain.to_owned(), + last_modified_at: common_utils::date_time::now(), + ..auth_method_inner.to_owned() + } + } }; auth_method_inner.to_owned() }) @@ -232,4 +259,20 @@ impl UserAuthenticationMethodInterface for MockDb { .into(), ) } + + #[instrument(skip_all)] + async fn list_user_authentication_methods_for_email_domain( + &self, + email_domain: &str, + ) -> CustomResult, errors::StorageError> { + let user_authentication_methods = self.user_authentication_methods.lock().await; + + let user_authentication_methods_list: Vec<_> = user_authentication_methods + .iter() + .filter(|auth_method_inner| auth_method_inner.email_domain == email_domain) + .cloned() + .collect(); + + Ok(user_authentication_methods_list) + } } diff --git a/crates/router/src/types/domain/user.rs b/crates/router/src/types/domain/user.rs index 6efaac7bfe..569e7f9a99 100644 --- a/crates/router/src/types/domain/user.rs +++ b/crates/router/src/types/domain/user.rs @@ -138,6 +138,15 @@ impl UserEmail { pub fn get_secret(self) -> Secret { (*self.0).clone() } + + pub fn extract_domain(&self) -> UserResult<&str> { + let (_username, domain) = self + .peek() + .split_once('@') + .ok_or(UserErrors::InternalServerError)?; + + Ok(domain) + } } impl TryFrom for UserEmail { diff --git a/crates/router/src/types/domain/user/user_authentication_method.rs b/crates/router/src/types/domain/user/user_authentication_method.rs index 570e144961..29c588f15e 100644 --- a/crates/router/src/types/domain/user/user_authentication_method.rs +++ b/crates/router/src/types/domain/user/user_authentication_method.rs @@ -14,4 +14,5 @@ pub static DEFAULT_USER_AUTH_METHOD: Lazy = allow_signup: true, created_at: common_utils::date_time::now(), last_modified_at: common_utils::date_time::now(), + email_domain: String::from("hyperswitch"), }); diff --git a/crates/router/src/utils/user.rs b/crates/router/src/utils/user.rs index abd5968424..b8ffcf836d 100644 --- a/crates/router/src/utils/user.rs +++ b/crates/router/src/utils/user.rs @@ -5,7 +5,7 @@ use common_enums::UserAuthType; use common_utils::{ encryption::Encryption, errors::CustomResult, id_type, type_name, types::keymanager::Identifier, }; -use diesel_models::{organization, organization::OrganizationBridge}; +use diesel_models::organization::{self, OrganizationBridge}; use error_stack::ResultExt; use masking::{ExposeInterface, Secret}; use redis_interface::RedisConnectionPool; @@ -312,3 +312,23 @@ pub fn create_merchant_account_request_for_org( pm_collect_link_config: None, }) } + +pub async fn validate_email_domain_auth_type_using_db( + state: &SessionState, + email: &domain::UserEmail, + required_auth_type: UserAuthType, +) -> UserResult<()> { + let domain = email.extract_domain()?; + let user_auth_methods = state + .store + .list_user_authentication_methods_for_email_domain(domain) + .await + .change_context(UserErrors::InternalServerError)?; + + (user_auth_methods.is_empty() + || user_auth_methods + .iter() + .any(|auth_method| auth_method.auth_type == required_auth_type)) + .then_some(()) + .ok_or(UserErrors::InvalidUserAuthMethodOperation.into()) +} diff --git a/migrations/2024-12-11-092624_add-email-domain-in-auth-methods/down.sql b/migrations/2024-12-11-092624_add-email-domain-in-auth-methods/down.sql new file mode 100644 index 0000000000..9f3560069c --- /dev/null +++ b/migrations/2024-12-11-092624_add-email-domain-in-auth-methods/down.sql @@ -0,0 +1,3 @@ +-- This file should undo anything in `up.sql` +DROP INDEX email_domain_index; +ALTER TABLE user_authentication_methods DROP COLUMN email_domain; diff --git a/migrations/2024-12-11-092624_add-email-domain-in-auth-methods/up.sql b/migrations/2024-12-11-092624_add-email-domain-in-auth-methods/up.sql new file mode 100644 index 0000000000..831f10162c --- /dev/null +++ b/migrations/2024-12-11-092624_add-email-domain-in-auth-methods/up.sql @@ -0,0 +1,6 @@ +-- Your SQL goes here +ALTER TABLE user_authentication_methods ADD COLUMN email_domain VARCHAR(64); +UPDATE user_authentication_methods SET email_domain = auth_id WHERE email_domain IS NULL; +ALTER TABLE user_authentication_methods ALTER COLUMN email_domain SET NOT NULL; + +CREATE INDEX email_domain_index ON user_authentication_methods (email_domain);