feat(users): Add email domain based restriction for dashboard entry APIs (#6940)

This commit is contained in:
Mani Chandra
2024-12-30 12:39:16 +05:30
committed by GitHub
parent 3eb2eb1cf5
commit 227c274ece
14 changed files with 322 additions and 52 deletions

View File

@ -305,18 +305,26 @@ pub struct CreateUserAuthenticationMethodRequest {
pub owner_type: common_enums::Owner, pub owner_type: common_enums::Owner,
pub auth_method: AuthConfig, pub auth_method: AuthConfig,
pub allow_signup: bool, pub allow_signup: bool,
pub email_domain: Option<String>,
} }
#[derive(Debug, serde::Deserialize, serde::Serialize)] #[derive(Debug, serde::Deserialize, serde::Serialize)]
pub struct UpdateUserAuthenticationMethodRequest { #[serde(rename_all = "snake_case")]
pub id: String, pub enum UpdateUserAuthenticationMethodRequest {
// TODO: When adding more fields make config and new fields option AuthMethod {
pub auth_method: AuthConfig, id: String,
auth_config: AuthConfig,
},
EmailDomain {
owner_id: String,
email_domain: String,
},
} }
#[derive(Debug, serde::Deserialize, serde::Serialize)] #[derive(Debug, serde::Deserialize, serde::Serialize)]
pub struct GetUserAuthenticationMethodsRequest { pub struct GetUserAuthenticationMethodsRequest {
pub auth_id: String, pub auth_id: Option<String>,
pub email_domain: Option<String>,
} }
#[derive(Debug, serde::Deserialize, serde::Serialize)] #[derive(Debug, serde::Deserialize, serde::Serialize)]

View File

@ -64,4 +64,18 @@ impl UserAuthenticationMethod {
) )
.await .await
} }
pub async fn list_user_authentication_methods_for_email_domain(
conn: &PgPooledConn,
email_domain: &str,
) -> StorageResult<Vec<Self>> {
generics::generic_filter::<<Self as HasTable>::Table, _, _, _>(
conn,
dsl::email_domain.eq(email_domain.to_owned()),
None,
None,
Some(dsl::last_modified_at.asc()),
)
.await
}
} }

View File

@ -1405,6 +1405,8 @@ diesel::table! {
allow_signup -> Bool, allow_signup -> Bool,
created_at -> Timestamp, created_at -> Timestamp,
last_modified_at -> Timestamp, last_modified_at -> Timestamp,
#[max_length = 64]
email_domain -> Varchar,
} }
} }

View File

@ -1352,6 +1352,8 @@ diesel::table! {
allow_signup -> Bool, allow_signup -> Bool,
created_at -> Timestamp, created_at -> Timestamp,
last_modified_at -> Timestamp, last_modified_at -> Timestamp,
#[max_length = 64]
email_domain -> Varchar,
} }
} }

View File

@ -17,6 +17,7 @@ pub struct UserAuthenticationMethod {
pub allow_signup: bool, pub allow_signup: bool,
pub created_at: PrimitiveDateTime, pub created_at: PrimitiveDateTime,
pub last_modified_at: PrimitiveDateTime, pub last_modified_at: PrimitiveDateTime,
pub email_domain: String,
} }
#[derive(router_derive::Setter, Clone, Debug, Insertable, router_derive::DebugAsDisplay)] #[derive(router_derive::Setter, Clone, Debug, Insertable, router_derive::DebugAsDisplay)]
@ -32,6 +33,7 @@ pub struct UserAuthenticationMethodNew {
pub allow_signup: bool, pub allow_signup: bool,
pub created_at: PrimitiveDateTime, pub created_at: PrimitiveDateTime,
pub last_modified_at: PrimitiveDateTime, pub last_modified_at: PrimitiveDateTime,
pub email_domain: String,
} }
#[derive(Clone, Debug, AsChangeset, router_derive::DebugAsDisplay)] #[derive(Clone, Debug, AsChangeset, router_derive::DebugAsDisplay)]
@ -40,6 +42,7 @@ pub struct OrgAuthenticationMethodUpdateInternal {
pub private_config: Option<Encryption>, pub private_config: Option<Encryption>,
pub public_config: Option<serde_json::Value>, pub public_config: Option<serde_json::Value>,
pub last_modified_at: PrimitiveDateTime, pub last_modified_at: PrimitiveDateTime,
pub email_domain: Option<String>,
} }
pub enum UserAuthenticationMethodUpdate { pub enum UserAuthenticationMethodUpdate {
@ -47,6 +50,9 @@ pub enum UserAuthenticationMethodUpdate {
private_config: Option<Encryption>, private_config: Option<Encryption>,
public_config: Option<serde_json::Value>, public_config: Option<serde_json::Value>,
}, },
EmailDomain {
email_domain: String,
},
} }
impl From<UserAuthenticationMethodUpdate> for OrgAuthenticationMethodUpdateInternal { impl From<UserAuthenticationMethodUpdate> for OrgAuthenticationMethodUpdateInternal {
@ -60,6 +66,13 @@ impl From<UserAuthenticationMethodUpdate> for OrgAuthenticationMethodUpdateInter
private_config, private_config,
public_config, public_config,
last_modified_at, last_modified_at,
email_domain: None,
},
UserAuthenticationMethodUpdate::EmailDomain { email_domain } => Self {
private_config: None,
public_config: None,
last_modified_at,
email_domain: Some(email_domain),
}, },
} }
} }

View File

@ -108,6 +108,8 @@ pub enum UserErrors {
InvalidThemeLineage(String), InvalidThemeLineage(String),
#[error("Missing required field: email_config")] #[error("Missing required field: email_config")]
MissingEmailConfig, MissingEmailConfig,
#[error("Invalid Auth Method Operation: {0}")]
InvalidAuthMethodOperationWithMessage(String),
} }
impl common_utils::errors::ErrorSwitch<api_models::errors::types::ApiErrorResponse> for UserErrors { impl common_utils::errors::ErrorSwitch<api_models::errors::types::ApiErrorResponse> for UserErrors {
@ -280,6 +282,9 @@ impl common_utils::errors::ErrorSwitch<api_models::errors::types::ApiErrorRespon
Self::MissingEmailConfig => { Self::MissingEmailConfig => {
AER::BadRequest(ApiError::new(sub_code, 56, self.get_error_message(), None)) AER::BadRequest(ApiError::new(sub_code, 56, self.get_error_message(), None))
} }
Self::InvalidAuthMethodOperationWithMessage(_) => {
AER::BadRequest(ApiError::new(sub_code, 57, self.get_error_message(), None))
}
} }
} }
} }
@ -347,6 +352,9 @@ impl UserErrors {
format!("Invalid field: {} in lineage", field_name) format!("Invalid field: {} in lineage", field_name)
} }
Self::MissingEmailConfig => "Missing required field: email_config".to_string(), Self::MissingEmailConfig => "Missing required field: email_config".to_string(),
Self::InvalidAuthMethodOperationWithMessage(operation) => {
format!("Invalid Auth Method Operation: {}", operation)
}
} }
} }
} }

View File

@ -7,7 +7,7 @@ use api_models::{
payments::RedirectionResponse, payments::RedirectionResponse,
user::{self as user_api, InviteMultipleUserResponse, NameIdUnit}, user::{self as user_api, InviteMultipleUserResponse, NameIdUnit},
}; };
use common_enums::EntityType; use common_enums::{EntityType, UserAuthType};
use common_utils::{type_name, types::keymanager::Identifier}; use common_utils::{type_name, types::keymanager::Identifier};
#[cfg(feature = "email")] #[cfg(feature = "email")]
use diesel_models::user_role::UserRoleUpdate; use diesel_models::user_role::UserRoleUpdate;
@ -22,6 +22,7 @@ use masking::{ExposeInterface, PeekInterface, Secret};
#[cfg(feature = "email")] #[cfg(feature = "email")]
use router_env::env; use router_env::env;
use router_env::logger; use router_env::logger;
use storage_impl::errors::StorageError;
#[cfg(not(feature = "email"))] #[cfg(not(feature = "email"))]
use user_api::dashboard_metadata::SetMetaDataRequest; use user_api::dashboard_metadata::SetMetaDataRequest;
@ -152,6 +153,14 @@ pub async fn signup_token_only_flow(
state: SessionState, state: SessionState,
request: user_api::SignUpRequest, request: user_api::SignUpRequest,
) -> UserResponse<user_api::TokenResponse> { ) -> UserResponse<user_api::TokenResponse> {
let user_email = domain::UserEmail::from_pii_email(request.email.clone())?;
utils::user::validate_email_domain_auth_type_using_db(
&state,
&user_email,
UserAuthType::Password,
)
.await?;
let new_user = domain::NewUser::try_from(request)?; let new_user = domain::NewUser::try_from(request)?;
new_user new_user
.get_new_merchant() .get_new_merchant()
@ -187,9 +196,18 @@ pub async fn signin_token_only_flow(
state: SessionState, state: SessionState,
request: user_api::SignInRequest, request: user_api::SignInRequest,
) -> UserResponse<user_api::TokenResponse> { ) -> UserResponse<user_api::TokenResponse> {
let user_email = domain::UserEmail::from_pii_email(request.email)?;
utils::user::validate_email_domain_auth_type_using_db(
&state,
&user_email,
UserAuthType::Password,
)
.await?;
let user_from_db: domain::UserFromStorage = state let user_from_db: domain::UserFromStorage = state
.global_store .global_store
.find_user_by_email(&domain::UserEmail::from_pii_email(request.email)?) .find_user_by_email(&user_email)
.await .await
.to_not_found_response(UserErrors::InvalidCredentials)? .to_not_found_response(UserErrors::InvalidCredentials)?
.into(); .into();
@ -215,10 +233,16 @@ pub async fn connect_account(
auth_id: Option<String>, auth_id: Option<String>,
theme_id: Option<String>, theme_id: Option<String>,
) -> UserResponse<user_api::ConnectAccountResponse> { ) -> UserResponse<user_api::ConnectAccountResponse> {
let find_user = state let user_email = domain::UserEmail::from_pii_email(request.email.clone())?;
.global_store
.find_user_by_email(&domain::UserEmail::from_pii_email(request.email.clone())?) utils::user::validate_email_domain_auth_type_using_db(
.await; &state,
&user_email,
UserAuthType::MagicLink,
)
.await?;
let find_user = state.global_store.find_user_by_email(&user_email).await;
if let Ok(found_user) = find_user { if let Ok(found_user) = find_user {
let user_from_db: domain::UserFromStorage = found_user.into(); let user_from_db: domain::UserFromStorage = found_user.into();
@ -412,6 +436,13 @@ pub async fn forgot_password(
) -> UserResponse<()> { ) -> UserResponse<()> {
let user_email = domain::UserEmail::from_pii_email(request.email)?; let user_email = domain::UserEmail::from_pii_email(request.email)?;
utils::user::validate_email_domain_auth_type_using_db(
&state,
&user_email,
UserAuthType::Password,
)
.await?;
let user_from_db = state let user_from_db = state
.global_store .global_store
.find_user_by_email(&user_email) .find_user_by_email(&user_email)
@ -1757,7 +1788,15 @@ pub async fn send_verification_mail(
auth_id: Option<String>, auth_id: Option<String>,
theme_id: Option<String>, theme_id: Option<String>,
) -> UserResponse<()> { ) -> UserResponse<()> {
let user_email = domain::UserEmail::try_from(req.email)?; let user_email = domain::UserEmail::from_pii_email(req.email)?;
utils::user::validate_email_domain_auth_type_using_db(
&state,
&user_email,
UserAuthType::MagicLink,
)
.await?;
let user = state let user = state
.global_store .global_store
.find_user_by_email(&user_email) .find_user_by_email(&user_email)
@ -2317,10 +2356,30 @@ pub async fn create_user_authentication_method(
.change_context(UserErrors::InternalServerError) .change_context(UserErrors::InternalServerError)
.attach_printable("Failed to get list of auth methods for the owner id")?; .attach_printable("Failed to get list of auth methods for the owner id")?;
let auth_id = auth_methods let (auth_id, email_domain) = if let Some(auth_method) = auth_methods.first() {
.first() let email_domain = match req.email_domain {
.map(|auth_method| auth_method.auth_id.clone()) Some(email_domain) => {
.unwrap_or(uuid::Uuid::new_v4().to_string()); if email_domain != auth_method.email_domain {
return Err(report!(UserErrors::InvalidAuthMethodOperationWithMessage(
"Email domain mismatch".to_string()
)));
}
email_domain
}
None => auth_method.email_domain.clone(),
};
(auth_method.auth_id.clone(), email_domain)
} else {
let email_domain =
req.email_domain
.ok_or(UserErrors::InvalidAuthMethodOperationWithMessage(
"Email domain not found".to_string(),
))?;
(uuid::Uuid::new_v4().to_string(), email_domain)
};
for db_auth_method in auth_methods { for db_auth_method in auth_methods {
let is_type_same = db_auth_method.auth_type == (&req.auth_method).foreign_into(); let is_type_same = db_auth_method.auth_type == (&req.auth_method).foreign_into();
@ -2360,6 +2419,7 @@ pub async fn create_user_authentication_method(
allow_signup: req.allow_signup, allow_signup: req.allow_signup,
created_at: now, created_at: now,
last_modified_at: now, last_modified_at: now,
email_domain,
}) })
.await .await
.to_duplicate_response(UserErrors::UserAuthMethodAlreadyExists)?; .to_duplicate_response(UserErrors::UserAuthMethodAlreadyExists)?;
@ -2383,25 +2443,71 @@ pub async fn update_user_authentication_method(
.change_context(UserErrors::InternalServerError) .change_context(UserErrors::InternalServerError)
.attach_printable("Failed to decode DEK")?; .attach_printable("Failed to decode DEK")?;
let (private_config, public_config) = utils::user::construct_public_and_private_db_configs( match req {
user_api::UpdateUserAuthenticationMethodRequest::AuthMethod {
id,
auth_config: auth_method,
} => {
let (private_config, public_config) =
utils::user::construct_public_and_private_db_configs(
&state, &state,
&req.auth_method, &auth_method,
&user_auth_encryption_key, &user_auth_encryption_key,
req.id.clone(), id.clone(),
) )
.await?; .await?;
state state
.store .store
.update_user_authentication_method( .update_user_authentication_method(
&req.id, &id,
UserAuthenticationMethodUpdate::UpdateConfig { UserAuthenticationMethodUpdate::UpdateConfig {
private_config, private_config,
public_config, public_config,
}, },
) )
.await .await
.change_context(UserErrors::InvalidUserAuthMethodOperation)?; .map_err(|error| {
let user_error = match error.current_context() {
StorageError::ValueNotFound(_) => {
UserErrors::InvalidAuthMethodOperationWithMessage(
"Auth method not found".to_string(),
)
}
StorageError::DuplicateValue { .. } => {
UserErrors::UserAuthMethodAlreadyExists
}
_ => UserErrors::InternalServerError,
};
error.change_context(user_error)
})?;
}
user_api::UpdateUserAuthenticationMethodRequest::EmailDomain {
owner_id,
email_domain,
} => {
let auth_methods = state
.store
.list_user_authentication_methods_for_owner_id(&owner_id)
.await
.change_context(UserErrors::InternalServerError)?;
futures::future::try_join_all(auth_methods.iter().map(|auth_method| async {
state
.store
.update_user_authentication_method(
&auth_method.id,
UserAuthenticationMethodUpdate::EmailDomain {
email_domain: email_domain.clone(),
},
)
.await
.to_duplicate_response(UserErrors::UserAuthMethodAlreadyExists)
}))
.await?;
}
}
Ok(ApplicationResponse::StatusOk) Ok(ApplicationResponse::StatusOk)
} }
@ -2409,18 +2515,28 @@ pub async fn list_user_authentication_methods(
state: SessionState, state: SessionState,
req: user_api::GetUserAuthenticationMethodsRequest, req: user_api::GetUserAuthenticationMethodsRequest,
) -> UserResponse<Vec<user_api::UserAuthenticationMethodResponse>> { ) -> UserResponse<Vec<user_api::UserAuthenticationMethodResponse>> {
let user_authentication_methods = state let user_authentication_methods = match (req.auth_id, req.email_domain) {
(Some(auth_id), None) => state
.store .store
.list_user_authentication_methods_for_auth_id(&req.auth_id) .list_user_authentication_methods_for_auth_id(&auth_id)
.await .await
.change_context(UserErrors::InternalServerError)?; .change_context(UserErrors::InternalServerError)?,
(None, Some(email_domain)) => state
.store
.list_user_authentication_methods_for_email_domain(&email_domain)
.await
.change_context(UserErrors::InternalServerError)?,
(Some(_), Some(_)) | (None, None) => {
return Err(UserErrors::InvalidUserAuthMethodOperation.into());
}
};
Ok(ApplicationResponse::Json( Ok(ApplicationResponse::Json(
user_authentication_methods user_authentication_methods
.into_iter() .into_iter()
.map(|auth_method| { .map(|auth_method| {
let auth_name = match (auth_method.auth_type, auth_method.public_config) { let auth_name = match (auth_method.auth_type, auth_method.public_config) {
(common_enums::UserAuthType::OpenIdConnect, config) => { (UserAuthType::OpenIdConnect, config) => {
let open_id_public_config: Option<user_api::OpenIdConnectPublicConfig> = let open_id_public_config: Option<user_api::OpenIdConnectPublicConfig> =
config config
.map(|config| { .map(|config| {
@ -2546,6 +2662,13 @@ pub async fn sso_sign(
) )
.await?; .await?;
utils::user::validate_email_domain_auth_type_using_db(
&state,
&email,
UserAuthType::OpenIdConnect,
)
.await?;
// TODO: Use config to handle not found error // TODO: Use config to handle not found error
let user_from_db: domain::UserFromStorage = state let user_from_db: domain::UserFromStorage = state
.global_store .global_store
@ -2594,14 +2717,20 @@ pub async fn terminate_auth_select(
.change_context(UserErrors::InternalServerError)? .change_context(UserErrors::InternalServerError)?
.into(); .into();
let user_authentication_method = if let Some(id) = &req.id { let user_email = domain::UserEmail::from_pii_email(user_from_db.get_email())?;
state let auth_methods = state
.store .store
.get_user_authentication_method_by_id(id) .list_user_authentication_methods_for_email_domain(user_email.extract_domain()?)
.await .await
.to_not_found_response(UserErrors::InvalidUserAuthMethodOperation)? .change_context(UserErrors::InternalServerError)?;
} else {
DEFAULT_USER_AUTH_METHOD.clone() let user_authentication_method = match (req.id, auth_methods.is_empty()) {
(Some(id), _) => auth_methods
.into_iter()
.find(|auth_method| auth_method.id == id)
.ok_or(UserErrors::InvalidUserAuthMethodOperation)?,
(None, true) => DEFAULT_USER_AUTH_METHOD.clone(),
(None, false) => return Err(UserErrors::InvalidUserAuthMethodOperation.into()),
}; };
let current_flow = domain::CurrentFlow::new(user_token, domain::SPTFlow::AuthSelect.into())?; let current_flow = domain::CurrentFlow::new(user_token, domain::SPTFlow::AuthSelect.into())?;

View File

@ -3816,6 +3816,18 @@ impl UserAuthenticationMethodInterface for KafkaStore {
.update_user_authentication_method(id, user_authentication_method_update) .update_user_authentication_method(id, user_authentication_method_update)
.await .await
} }
async fn list_user_authentication_methods_for_email_domain(
&self,
email_domain: &str,
) -> CustomResult<
Vec<diesel_models::user_authentication_method::UserAuthenticationMethod>,
errors::StorageError,
> {
self.diesel_store
.list_user_authentication_methods_for_email_domain(email_domain)
.await
}
} }
#[async_trait::async_trait] #[async_trait::async_trait]

View File

@ -36,6 +36,11 @@ pub trait UserAuthenticationMethodInterface {
id: &str, id: &str,
user_authentication_method_update: storage::UserAuthenticationMethodUpdate, user_authentication_method_update: storage::UserAuthenticationMethodUpdate,
) -> CustomResult<storage::UserAuthenticationMethod, errors::StorageError>; ) -> CustomResult<storage::UserAuthenticationMethod, errors::StorageError>;
async fn list_user_authentication_methods_for_email_domain(
&self,
email_domain: &str,
) -> CustomResult<Vec<storage::UserAuthenticationMethod>, errors::StorageError>;
} }
#[async_trait::async_trait] #[async_trait::async_trait]
@ -57,7 +62,7 @@ impl UserAuthenticationMethodInterface for Store {
&self, &self,
id: &str, id: &str,
) -> CustomResult<storage::UserAuthenticationMethod, errors::StorageError> { ) -> CustomResult<storage::UserAuthenticationMethod, errors::StorageError> {
let conn = connection::pg_connection_write(self).await?; let conn = connection::pg_connection_read(self).await?;
storage::UserAuthenticationMethod::get_user_authentication_method_by_id(&conn, id) storage::UserAuthenticationMethod::get_user_authentication_method_by_id(&conn, id)
.await .await
.map_err(|error| report!(errors::StorageError::from(error))) .map_err(|error| report!(errors::StorageError::from(error)))
@ -68,7 +73,7 @@ impl UserAuthenticationMethodInterface for Store {
&self, &self,
auth_id: &str, auth_id: &str,
) -> CustomResult<Vec<storage::UserAuthenticationMethod>, errors::StorageError> { ) -> CustomResult<Vec<storage::UserAuthenticationMethod>, errors::StorageError> {
let conn = connection::pg_connection_write(self).await?; let conn = connection::pg_connection_read(self).await?;
storage::UserAuthenticationMethod::list_user_authentication_methods_for_auth_id( storage::UserAuthenticationMethod::list_user_authentication_methods_for_auth_id(
&conn, auth_id, &conn, auth_id,
) )
@ -81,7 +86,7 @@ impl UserAuthenticationMethodInterface for Store {
&self, &self,
owner_id: &str, owner_id: &str,
) -> CustomResult<Vec<storage::UserAuthenticationMethod>, errors::StorageError> { ) -> CustomResult<Vec<storage::UserAuthenticationMethod>, errors::StorageError> {
let conn = connection::pg_connection_write(self).await?; let conn = connection::pg_connection_read(self).await?;
storage::UserAuthenticationMethod::list_user_authentication_methods_for_owner_id( storage::UserAuthenticationMethod::list_user_authentication_methods_for_owner_id(
&conn, owner_id, &conn, owner_id,
) )
@ -104,6 +109,20 @@ impl UserAuthenticationMethodInterface for Store {
.await .await
.map_err(|error| report!(errors::StorageError::from(error))) .map_err(|error| report!(errors::StorageError::from(error)))
} }
#[instrument(skip_all)]
async fn list_user_authentication_methods_for_email_domain(
&self,
email_domain: &str,
) -> CustomResult<Vec<storage::UserAuthenticationMethod>, errors::StorageError> {
let conn = connection::pg_connection_read(self).await?;
storage::UserAuthenticationMethod::list_user_authentication_methods_for_email_domain(
&conn,
email_domain,
)
.await
.map_err(|error| report!(errors::StorageError::from(error)))
}
} }
#[async_trait::async_trait] #[async_trait::async_trait]
@ -130,6 +149,7 @@ impl UserAuthenticationMethodInterface for MockDb {
allow_signup: user_authentication_method.allow_signup, allow_signup: user_authentication_method.allow_signup,
created_at: user_authentication_method.created_at, created_at: user_authentication_method.created_at,
last_modified_at: user_authentication_method.last_modified_at, last_modified_at: user_authentication_method.last_modified_at,
email_domain: user_authentication_method.email_domain,
}; };
user_authentication_methods.push(user_authentication_method.clone()); user_authentication_methods.push(user_authentication_method.clone());
@ -222,6 +242,13 @@ impl UserAuthenticationMethodInterface for MockDb {
last_modified_at: common_utils::date_time::now(), last_modified_at: common_utils::date_time::now(),
..auth_method_inner.to_owned() ..auth_method_inner.to_owned()
}, },
storage::UserAuthenticationMethodUpdate::EmailDomain { email_domain } => {
storage::UserAuthenticationMethod {
email_domain: email_domain.to_owned(),
last_modified_at: common_utils::date_time::now(),
..auth_method_inner.to_owned()
}
}
}; };
auth_method_inner.to_owned() auth_method_inner.to_owned()
}) })
@ -232,4 +259,20 @@ impl UserAuthenticationMethodInterface for MockDb {
.into(), .into(),
) )
} }
#[instrument(skip_all)]
async fn list_user_authentication_methods_for_email_domain(
&self,
email_domain: &str,
) -> CustomResult<Vec<storage::UserAuthenticationMethod>, errors::StorageError> {
let user_authentication_methods = self.user_authentication_methods.lock().await;
let user_authentication_methods_list: Vec<_> = user_authentication_methods
.iter()
.filter(|auth_method_inner| auth_method_inner.email_domain == email_domain)
.cloned()
.collect();
Ok(user_authentication_methods_list)
}
} }

View File

@ -138,6 +138,15 @@ impl UserEmail {
pub fn get_secret(self) -> Secret<String, pii::EmailStrategy> { pub fn get_secret(self) -> Secret<String, pii::EmailStrategy> {
(*self.0).clone() (*self.0).clone()
} }
pub fn extract_domain(&self) -> UserResult<&str> {
let (_username, domain) = self
.peek()
.split_once('@')
.ok_or(UserErrors::InternalServerError)?;
Ok(domain)
}
} }
impl TryFrom<pii::Email> for UserEmail { impl TryFrom<pii::Email> for UserEmail {

View File

@ -14,4 +14,5 @@ pub static DEFAULT_USER_AUTH_METHOD: Lazy<UserAuthenticationMethod> =
allow_signup: true, allow_signup: true,
created_at: common_utils::date_time::now(), created_at: common_utils::date_time::now(),
last_modified_at: common_utils::date_time::now(), last_modified_at: common_utils::date_time::now(),
email_domain: String::from("hyperswitch"),
}); });

View File

@ -5,7 +5,7 @@ use common_enums::UserAuthType;
use common_utils::{ use common_utils::{
encryption::Encryption, errors::CustomResult, id_type, type_name, types::keymanager::Identifier, encryption::Encryption, errors::CustomResult, id_type, type_name, types::keymanager::Identifier,
}; };
use diesel_models::{organization, organization::OrganizationBridge}; use diesel_models::organization::{self, OrganizationBridge};
use error_stack::ResultExt; use error_stack::ResultExt;
use masking::{ExposeInterface, Secret}; use masking::{ExposeInterface, Secret};
use redis_interface::RedisConnectionPool; use redis_interface::RedisConnectionPool;
@ -312,3 +312,23 @@ pub fn create_merchant_account_request_for_org(
pm_collect_link_config: None, pm_collect_link_config: None,
}) })
} }
pub async fn validate_email_domain_auth_type_using_db(
state: &SessionState,
email: &domain::UserEmail,
required_auth_type: UserAuthType,
) -> UserResult<()> {
let domain = email.extract_domain()?;
let user_auth_methods = state
.store
.list_user_authentication_methods_for_email_domain(domain)
.await
.change_context(UserErrors::InternalServerError)?;
(user_auth_methods.is_empty()
|| user_auth_methods
.iter()
.any(|auth_method| auth_method.auth_type == required_auth_type))
.then_some(())
.ok_or(UserErrors::InvalidUserAuthMethodOperation.into())
}

View File

@ -0,0 +1,3 @@
-- This file should undo anything in `up.sql`
DROP INDEX email_domain_index;
ALTER TABLE user_authentication_methods DROP COLUMN email_domain;

View File

@ -0,0 +1,6 @@
-- Your SQL goes here
ALTER TABLE user_authentication_methods ADD COLUMN email_domain VARCHAR(64);
UPDATE user_authentication_methods SET email_domain = auth_id WHERE email_domain IS NULL;
ALTER TABLE user_authentication_methods ALTER COLUMN email_domain SET NOT NULL;
CREATE INDEX email_domain_index ON user_authentication_methods (email_domain);