diff --git a/crates/router/src/db.rs b/crates/router/src/db.rs index b0d86839a5..4120edb957 100644 --- a/crates/router/src/db.rs +++ b/crates/router/src/db.rs @@ -155,8 +155,10 @@ pub trait GlobalStorageInterface: + user_role::UserRoleInterface + user_key_store::UserKeyStoreInterface + role::RoleInterface + + RedisConnInterface + 'static { + fn get_cache_store(&self) -> Box<(dyn RedisConnInterface + Send + Sync + 'static)>; } #[async_trait::async_trait] @@ -217,7 +219,11 @@ impl StorageInterface for Store { } #[async_trait::async_trait] -impl GlobalStorageInterface for Store {} +impl GlobalStorageInterface for Store { + fn get_cache_store(&self) -> Box<(dyn RedisConnInterface + Send + Sync + 'static)> { + Box::new(self.clone()) + } +} impl AccountsStorageInterface for Store {} @@ -233,7 +239,11 @@ impl StorageInterface for MockDb { } #[async_trait::async_trait] -impl GlobalStorageInterface for MockDb {} +impl GlobalStorageInterface for MockDb { + fn get_cache_store(&self) -> Box<(dyn RedisConnInterface + Send + Sync + 'static)> { + Box::new(self.clone()) + } +} impl AccountsStorageInterface for MockDb {} diff --git a/crates/router/src/db/kafka_store.rs b/crates/router/src/db/kafka_store.rs index 0e39cbf841..6e59e0c95d 100644 --- a/crates/router/src/db/kafka_store.rs +++ b/crates/router/src/db/kafka_store.rs @@ -3269,7 +3269,11 @@ impl StorageInterface for KafkaStore { } } -impl GlobalStorageInterface for KafkaStore {} +impl GlobalStorageInterface for KafkaStore { + fn get_cache_store(&self) -> Box<(dyn RedisConnInterface + Send + Sync + 'static)> { + Box::new(self.clone()) + } +} impl AccountsStorageInterface for KafkaStore {} impl PaymentMethodsStorageInterface for KafkaStore {} diff --git a/crates/router/src/routes/app.rs b/crates/router/src/routes/app.rs index 9dc67be1d2..87307441d7 100644 --- a/crates/router/src/routes/app.rs +++ b/crates/router/src/routes/app.rs @@ -153,6 +153,7 @@ pub trait SessionStateInfo { #[cfg(feature = "partial-auth")] fn get_detached_auth(&self) -> RouterResult<(Blake3, &[u8])>; fn session_state(&self) -> SessionState; + fn global_store(&self) -> Box; } impl SessionStateInfo for SessionState { @@ -209,6 +210,9 @@ impl SessionStateInfo for SessionState { fn session_state(&self) -> SessionState { self.clone() } + fn global_store(&self) -> Box<(dyn GlobalStorageInterface)> { + self.global_store.to_owned() + } } #[derive(Clone)] pub struct AppState { diff --git a/crates/router/src/services/authentication/blacklist.rs b/crates/router/src/services/authentication/blacklist.rs index 6be6cc467b..77dd9861a0 100644 --- a/crates/router/src/services/authentication/blacklist.rs +++ b/crates/router/src/services/authentication/blacklist.rs @@ -27,7 +27,8 @@ pub async fn insert_user_in_blacklist(state: &SessionState, user_id: &str) -> Us let user_blacklist_key = format!("{}{}", USER_BLACKLIST_PREFIX, user_id); let expiry = expiry_to_i64(JWT_TOKEN_TIME_IN_SECS).change_context(UserErrors::InternalServerError)?; - let redis_conn = get_redis_connection(state).change_context(UserErrors::InternalServerError)?; + let redis_conn = get_redis_connection_for_global_tenant(state) + .change_context(UserErrors::InternalServerError)?; redis_conn .set_key_with_expiry( &user_blacklist_key.as_str().into(), @@ -43,7 +44,8 @@ pub async fn insert_role_in_blacklist(state: &SessionState, role_id: &str) -> Us let role_blacklist_key = format!("{}{}", ROLE_BLACKLIST_PREFIX, role_id); let expiry = expiry_to_i64(JWT_TOKEN_TIME_IN_SECS).change_context(UserErrors::InternalServerError)?; - let redis_conn = get_redis_connection(state).change_context(UserErrors::InternalServerError)?; + let redis_conn = get_redis_connection_for_global_tenant(state) + .change_context(UserErrors::InternalServerError)?; redis_conn .set_key_with_expiry( &role_blacklist_key.as_str().into(), @@ -59,7 +61,7 @@ pub async fn insert_role_in_blacklist(state: &SessionState, role_id: &str) -> Us #[cfg(feature = "olap")] async fn invalidate_role_cache(state: &SessionState, role_id: &str) -> RouterResult<()> { - let redis_conn = get_redis_connection(state)?; + let redis_conn = get_redis_connection_for_global_tenant(state)?; redis_conn .delete_key(&authz::get_cache_key_from_role_id(role_id).as_str().into()) .await @@ -74,7 +76,7 @@ pub async fn check_user_in_blacklist( ) -> RouterResult { let token = format!("{}{}", USER_BLACKLIST_PREFIX, user_id); let token_issued_at = expiry_to_i64(token_expiry - JWT_TOKEN_TIME_IN_SECS)?; - let redis_conn = get_redis_connection(state)?; + let redis_conn = get_redis_connection_for_global_tenant(state)?; redis_conn .get_key::>(&token.as_str().into()) .await @@ -89,7 +91,7 @@ pub async fn check_role_in_blacklist( ) -> RouterResult { let token = format!("{}{}", ROLE_BLACKLIST_PREFIX, role_id); let token_issued_at = expiry_to_i64(token_expiry - JWT_TOKEN_TIME_IN_SECS)?; - let redis_conn = get_redis_connection(state)?; + let redis_conn = get_redis_connection_for_global_tenant(state)?; redis_conn .get_key::>(&token.as_str().into()) .await @@ -99,7 +101,8 @@ pub async fn check_role_in_blacklist( #[cfg(feature = "email")] pub async fn insert_email_token_in_blacklist(state: &SessionState, token: &str) -> UserResult<()> { - let redis_conn = get_redis_connection(state).change_context(UserErrors::InternalServerError)?; + let redis_conn = get_redis_connection_for_global_tenant(state) + .change_context(UserErrors::InternalServerError)?; let blacklist_key = format!("{}{token}", EMAIL_TOKEN_BLACKLIST_PREFIX); let expiry = expiry_to_i64(EMAIL_TOKEN_TIME_IN_SECS).change_context(UserErrors::InternalServerError)?; @@ -111,7 +114,8 @@ pub async fn insert_email_token_in_blacklist(state: &SessionState, token: &str) #[cfg(feature = "email")] pub async fn check_email_token_in_blacklist(state: &SessionState, token: &str) -> UserResult<()> { - let redis_conn = get_redis_connection(state).change_context(UserErrors::InternalServerError)?; + let redis_conn = get_redis_connection_for_global_tenant(state) + .change_context(UserErrors::InternalServerError)?; let blacklist_key = format!("{}{token}", EMAIL_TOKEN_BLACKLIST_PREFIX); let key_exists = redis_conn .exists::<()>(&blacklist_key.as_str().into()) @@ -124,9 +128,11 @@ pub async fn check_email_token_in_blacklist(state: &SessionState, token: &str) - Ok(()) } -fn get_redis_connection(state: &A) -> RouterResult> { +fn get_redis_connection_for_global_tenant( + state: &A, +) -> RouterResult> { state - .store() + .global_store() .get_redis_conn() .change_context(ApiErrorResponse::InternalServerError) .attach_printable("Failed to get redis connection") diff --git a/crates/router/src/services/authorization.rs b/crates/router/src/services/authorization.rs index 0e66654b7b..947e0860ae 100644 --- a/crates/router/src/services/authorization.rs +++ b/crates/router/src/services/authorization.rs @@ -61,7 +61,7 @@ async fn get_role_info_from_cache(state: &A, role_id: &str) -> RouterResult( where A: SessionStateInfo + Sync, { - let redis_conn = get_redis_connection(state)?; + let redis_conn = get_redis_connection_for_global_tenant(state)?; redis_conn .serialize_and_set_key_with_expiry( @@ -143,9 +143,11 @@ pub fn check_tenant( Ok(()) } -fn get_redis_connection(state: &A) -> RouterResult> { +fn get_redis_connection_for_global_tenant( + state: &A, +) -> RouterResult> { state - .store() + .global_store() .get_redis_conn() .change_context(ApiErrorResponse::InternalServerError) .attach_printable("Failed to get redis connection") diff --git a/crates/router/src/services/openidconnect.rs b/crates/router/src/services/openidconnect.rs index 44f95b4260..c4dd0df6e1 100644 --- a/crates/router/src/services/openidconnect.rs +++ b/crates/router/src/services/openidconnect.rs @@ -34,7 +34,7 @@ pub async fn get_authorization_url( // Save csrf & nonce as key value respectively let key = get_oidc_redis_key(csrf_token.secret()); - get_redis_connection(&state)? + get_redis_connection_for_global_tenant(&state)? .set_key_with_expiry(&key.into(), nonce.secret(), consts::user::REDIS_SSO_TTL) .await .change_context(UserErrors::InternalServerError) @@ -138,7 +138,7 @@ async fn get_nonce_from_redis( state: &SessionState, redirect_state: &Secret, ) -> UserResult { - let redis_connection = get_redis_connection(state)?; + let redis_connection = get_redis_connection_for_global_tenant(state)?; let redirect_state = redirect_state.clone().expose(); let key = get_oidc_redis_key(&redirect_state); redis_connection @@ -188,9 +188,11 @@ fn get_oidc_redis_key(csrf: &str) -> String { format!("{}OIDC_{}", consts::user::REDIS_SSO_PREFIX, csrf) } -fn get_redis_connection(state: &SessionState) -> UserResult> { +fn get_redis_connection_for_global_tenant( + state: &SessionState, +) -> UserResult> { state - .store + .global_store .get_redis_conn() .change_context(UserErrors::InternalServerError) .attach_printable("Failed to get redis connection") diff --git a/crates/router/src/utils/user.rs b/crates/router/src/utils/user.rs index 13190125d7..bd5237c38c 100644 --- a/crates/router/src/utils/user.rs +++ b/crates/router/src/utils/user.rs @@ -135,33 +135,14 @@ pub async fn get_user_from_db_by_email( .map(UserFromStorage::from) } -pub fn get_redis_connection(state: &SessionState) -> UserResult> { - state - .store - .get_redis_conn() - .change_context(UserErrors::InternalServerError) - .attach_printable("Failed to get redis connection") -} - pub fn get_redis_connection_for_global_tenant( state: &SessionState, ) -> UserResult> { - let redis_connection_pool = state - .store + state + .global_store .get_redis_conn() .change_context(UserErrors::InternalServerError) - .attach_printable("Failed to get redis connection")?; - - let global_tenant_prefix = &state.conf.multitenancy.global_tenant.redis_key_prefix; - - Ok(Arc::new(RedisConnectionPool { - pool: Arc::clone(&redis_connection_pool.pool), - key_prefix: global_tenant_prefix.to_string(), - config: Arc::clone(&redis_connection_pool.config), - subscriber: Arc::clone(&redis_connection_pool.subscriber), - publisher: Arc::clone(&redis_connection_pool.publisher), - is_redis_available: Arc::clone(&redis_connection_pool.is_redis_available), - })) + .attach_printable("Failed to get redis connection") } impl ForeignFrom<&user_api::AuthConfig> for UserAuthType { @@ -267,7 +248,7 @@ pub async fn set_sso_id_in_redis( oidc_state: Secret, sso_id: String, ) -> UserResult<()> { - let connection = get_redis_connection(state)?; + let connection = get_redis_connection_for_global_tenant(state)?; let key = get_oidc_key(&oidc_state.expose()); connection .set_key_with_expiry(&key.into(), sso_id, REDIS_SSO_TTL) @@ -280,7 +261,7 @@ pub async fn get_sso_id_from_redis( state: &SessionState, oidc_state: Secret, ) -> UserResult { - let connection = get_redis_connection(state)?; + let connection = get_redis_connection_for_global_tenant(state)?; let key = get_oidc_key(&oidc_state.expose()); connection .get_key::>(&key.into()) diff --git a/crates/router/src/utils/user/two_factor_auth.rs b/crates/router/src/utils/user/two_factor_auth.rs index b4ced70d73..f539562614 100644 --- a/crates/router/src/utils/user/two_factor_auth.rs +++ b/crates/router/src/utils/user/two_factor_auth.rs @@ -33,7 +33,7 @@ pub fn generate_default_totp( } pub async fn check_totp_in_redis(state: &SessionState, user_id: &str) -> UserResult { - let redis_conn = super::get_redis_connection(state)?; + let redis_conn = super::get_redis_connection_for_global_tenant(state)?; let key = format!("{}{}", consts::user::REDIS_TOTP_PREFIX, user_id); redis_conn .exists::<()>(&key.into()) @@ -42,7 +42,7 @@ pub async fn check_totp_in_redis(state: &SessionState, user_id: &str) -> UserRes } pub async fn check_recovery_code_in_redis(state: &SessionState, user_id: &str) -> UserResult { - let redis_conn = super::get_redis_connection(state)?; + let redis_conn = super::get_redis_connection_for_global_tenant(state)?; let key = format!("{}{}", consts::user::REDIS_RECOVERY_CODE_PREFIX, user_id); redis_conn .exists::<()>(&key.into()) @@ -51,7 +51,7 @@ pub async fn check_recovery_code_in_redis(state: &SessionState, user_id: &str) - } pub async fn insert_totp_in_redis(state: &SessionState, user_id: &str) -> UserResult<()> { - let redis_conn = super::get_redis_connection(state)?; + let redis_conn = super::get_redis_connection_for_global_tenant(state)?; let key = format!("{}{}", consts::user::REDIS_TOTP_PREFIX, user_id); redis_conn .set_key_with_expiry( @@ -68,7 +68,7 @@ pub async fn insert_totp_secret_in_redis( user_id: &str, secret: &masking::Secret, ) -> UserResult<()> { - let redis_conn = super::get_redis_connection(state)?; + let redis_conn = super::get_redis_connection_for_global_tenant(state)?; redis_conn .set_key_with_expiry( &get_totp_secret_key(user_id).into(), @@ -83,7 +83,7 @@ pub async fn get_totp_secret_from_redis( state: &SessionState, user_id: &str, ) -> UserResult>> { - let redis_conn = super::get_redis_connection(state)?; + let redis_conn = super::get_redis_connection_for_global_tenant(state)?; redis_conn .get_key::>(&get_totp_secret_key(user_id).into()) .await @@ -92,7 +92,7 @@ pub async fn get_totp_secret_from_redis( } pub async fn delete_totp_secret_from_redis(state: &SessionState, user_id: &str) -> UserResult<()> { - let redis_conn = super::get_redis_connection(state)?; + let redis_conn = super::get_redis_connection_for_global_tenant(state)?; redis_conn .delete_key(&get_totp_secret_key(user_id).into()) .await @@ -105,7 +105,7 @@ fn get_totp_secret_key(user_id: &str) -> String { } pub async fn insert_recovery_code_in_redis(state: &SessionState, user_id: &str) -> UserResult<()> { - let redis_conn = super::get_redis_connection(state)?; + let redis_conn = super::get_redis_connection_for_global_tenant(state)?; let key = format!("{}{}", consts::user::REDIS_RECOVERY_CODE_PREFIX, user_id); redis_conn .set_key_with_expiry( @@ -118,7 +118,7 @@ pub async fn insert_recovery_code_in_redis(state: &SessionState, user_id: &str) } pub async fn delete_totp_from_redis(state: &SessionState, user_id: &str) -> UserResult<()> { - let redis_conn = super::get_redis_connection(state)?; + let redis_conn = super::get_redis_connection_for_global_tenant(state)?; let key = format!("{}{}", consts::user::REDIS_TOTP_PREFIX, user_id); redis_conn .delete_key(&key.into()) @@ -131,7 +131,7 @@ pub async fn delete_recovery_code_from_redis( state: &SessionState, user_id: &str, ) -> UserResult<()> { - let redis_conn = super::get_redis_connection(state)?; + let redis_conn = super::get_redis_connection_for_global_tenant(state)?; let key = format!("{}{}", consts::user::REDIS_RECOVERY_CODE_PREFIX, user_id); redis_conn .delete_key(&key.into()) @@ -156,7 +156,7 @@ pub async fn insert_totp_attempts_in_redis( user_id: &str, user_totp_attempts: u8, ) -> UserResult<()> { - let redis_conn = super::get_redis_connection(state)?; + let redis_conn = super::get_redis_connection_for_global_tenant(state)?; redis_conn .set_key_with_expiry( &get_totp_attempts_key(user_id).into(), @@ -167,7 +167,7 @@ pub async fn insert_totp_attempts_in_redis( .change_context(UserErrors::InternalServerError) } pub async fn get_totp_attempts_from_redis(state: &SessionState, user_id: &str) -> UserResult { - let redis_conn = super::get_redis_connection(state)?; + let redis_conn = super::get_redis_connection_for_global_tenant(state)?; redis_conn .get_key::>(&get_totp_attempts_key(user_id).into()) .await @@ -180,7 +180,7 @@ pub async fn insert_recovery_code_attempts_in_redis( user_id: &str, user_recovery_code_attempts: u8, ) -> UserResult<()> { - let redis_conn = super::get_redis_connection(state)?; + let redis_conn = super::get_redis_connection_for_global_tenant(state)?; redis_conn .set_key_with_expiry( &get_recovery_code_attempts_key(user_id).into(), @@ -195,7 +195,7 @@ pub async fn get_recovery_code_attempts_from_redis( state: &SessionState, user_id: &str, ) -> UserResult { - let redis_conn = super::get_redis_connection(state)?; + let redis_conn = super::get_redis_connection_for_global_tenant(state)?; redis_conn .get_key::>(&get_recovery_code_attempts_key(user_id).into()) .await @@ -207,7 +207,7 @@ pub async fn delete_totp_attempts_from_redis( state: &SessionState, user_id: &str, ) -> UserResult<()> { - let redis_conn = super::get_redis_connection(state)?; + let redis_conn = super::get_redis_connection_for_global_tenant(state)?; redis_conn .delete_key(&get_totp_attempts_key(user_id).into()) .await @@ -219,7 +219,7 @@ pub async fn delete_recovery_code_attempts_from_redis( state: &SessionState, user_id: &str, ) -> UserResult<()> { - let redis_conn = super::get_redis_connection(state)?; + let redis_conn = super::get_redis_connection_for_global_tenant(state)?; redis_conn .delete_key(&get_recovery_code_attempts_key(user_id).into()) .await