feat(core): use profile_id passed from auth layer within core functions (#5553)

This commit is contained in:
Hrithikesh
2024-08-09 12:28:47 +05:30
committed by GitHub
parent fc581e08ff
commit 9fa631d2b9
12 changed files with 233 additions and 32 deletions

View File

@ -1,4 +1,4 @@
use std::{marker::PhantomData, str::FromStr};
use std::{collections::HashSet, marker::PhantomData, str::FromStr};
use api_models::enums::{DisputeStage, DisputeStatus};
#[cfg(feature = "payouts")]
@ -8,7 +8,10 @@ use common_enums::{IntentStatus, RequestIncrementalAuthorization};
use common_utils::{crypto::Encryptable, pii::Email};
use common_utils::{errors::CustomResult, ext_traits::AsyncExt, types::MinorUnit};
use error_stack::{report, ResultExt};
use hyperswitch_domain_models::{payment_address::PaymentAddress, router_data::ErrorResponse};
use hyperswitch_domain_models::{
merchant_connector_account::MerchantConnectorAccount, payment_address::PaymentAddress,
router_data::ErrorResponse,
};
#[cfg(feature = "payouts")]
use masking::{ExposeInterface, PeekInterface};
use maud::{html, PreEscaped};
@ -450,6 +453,64 @@ mod tests {
let generated_id = generate_id(consts::ID_LENGTH, "ref");
assert_eq!(generated_id.len(), consts::ID_LENGTH + 4)
}
#[test]
fn test_filter_objects_based_on_profile_id_list() {
#[derive(PartialEq, Debug, Clone)]
struct Object {
profile_id: Option<String>,
}
impl Object {
pub fn new(profile_id: &str) -> Self {
Self {
profile_id: Some(profile_id.to_string()),
}
}
}
impl GetProfileId for Object {
fn get_profile_id(&self) -> Option<&String> {
self.profile_id.as_ref()
}
}
// non empty object_list and profile_id_list
let object_list = vec![
Object::new("p1"),
Object::new("p2"),
Object::new("p2"),
Object::new("p4"),
Object::new("p5"),
];
let profile_id_list = vec!["p1".to_string(), "p2".to_string(), "p3".to_string()];
let filtered_list =
filter_objects_based_on_profile_id_list(Some(profile_id_list), object_list.clone());
let expected_result = vec![Object::new("p1"), Object::new("p2"), Object::new("p2")];
assert_eq!(filtered_list, expected_result);
// non empty object_list and empty profile_id_list
let empty_profile_id_list = vec![];
let filtered_list = filter_objects_based_on_profile_id_list(
Some(empty_profile_id_list),
object_list.clone(),
);
let expected_result = vec![];
assert_eq!(filtered_list, expected_result);
// non empty object_list and None profile_id_list
let profile_id_list_as_none = None;
let filtered_list =
filter_objects_based_on_profile_id_list(profile_id_list_as_none, object_list);
let expected_result = vec![
Object::new("p1"),
Object::new("p2"),
Object::new("p2"),
Object::new("p4"),
Object::new("p5"),
];
assert_eq!(filtered_list, expected_result);
}
}
// Dispute Stage can move linearly from PreDispute -> Dispute -> PreArbitration
@ -1267,3 +1328,100 @@ pub fn get_incremental_authorization_allowed_value(
incremental_authorization_allowed
}
}
pub(super) trait GetProfileId {
fn get_profile_id(&self) -> Option<&String>;
}
impl GetProfileId for MerchantConnectorAccount {
fn get_profile_id(&self) -> Option<&String> {
self.profile_id.as_ref()
}
}
impl GetProfileId for storage::PaymentIntent {
fn get_profile_id(&self) -> Option<&String> {
self.profile_id.as_ref()
}
}
impl<A> GetProfileId for (storage::PaymentIntent, A) {
fn get_profile_id(&self) -> Option<&String> {
self.0.get_profile_id()
}
}
impl GetProfileId for diesel_models::Dispute {
fn get_profile_id(&self) -> Option<&String> {
self.profile_id.as_ref()
}
}
impl GetProfileId for diesel_models::Refund {
fn get_profile_id(&self) -> Option<&String> {
self.profile_id.as_ref()
}
}
#[cfg(feature = "payouts")]
impl GetProfileId for storage::Payouts {
fn get_profile_id(&self) -> Option<&String> {
Some(&self.profile_id)
}
}
#[cfg(feature = "payouts")]
impl<T, F> GetProfileId for (storage::Payouts, T, F) {
fn get_profile_id(&self) -> Option<&String> {
self.0.get_profile_id()
}
}
/// Filter Objects based on profile ids
pub(super) fn filter_objects_based_on_profile_id_list<T: GetProfileId>(
profile_id_list_auth_layer: Option<Vec<String>>,
object_list: Vec<T>,
) -> Vec<T> {
if let Some(profile_id_list) = profile_id_list_auth_layer {
let profile_ids_to_filter: HashSet<_> = profile_id_list.iter().collect();
object_list
.into_iter()
.filter_map(|item| {
if item
.get_profile_id()
.is_some_and(|profile_id| profile_ids_to_filter.contains(profile_id))
{
Some(item)
} else {
None
}
})
.collect()
} else {
object_list
}
}
pub(super) fn validate_profile_id_from_auth_layer<T: GetProfileId + std::fmt::Debug>(
profile_id_auth_layer: Option<String>,
object: &T,
) -> RouterResult<()> {
match (profile_id_auth_layer, object.get_profile_id()) {
(Some(auth_profile_id), Some(object_profile_id)) => {
auth_profile_id.eq(object_profile_id).then_some(()).ok_or(
errors::ApiErrorResponse::PreconditionFailed {
message: "Profile id authentication failed. Please use the correct JWT token"
.to_string(),
}
.into(),
)
}
(Some(_auth_profile_id), None) => RouterResult::Err(
errors::ApiErrorResponse::PreconditionFailed {
message: "Couldn't find profile_id in record for authentication".to_string(),
}
.into(),
)
.attach_printable(format!("Couldn't find profile_id in entity {:?}", object)),
(None, None) | (None, Some(_)) => Ok(()),
}
}