From a23a365cdf3fc2a24f4e2a08996a5683dc4da89a Mon Sep 17 00:00:00 2001 From: Shanks Date: Mon, 6 May 2024 18:38:44 +0530 Subject: [PATCH] feat(constraint_graph): make the constraint graph framework generic and move it into a separate crate (#3071) --- .github/CODEOWNERS | 1 + Cargo.lock | 53 +- crates/euclid/Cargo.toml | 1 + crates/euclid/src/dssa/analyzer.rs | 37 +- crates/euclid/src/dssa/graph.rs | 1499 ++++++----------- crates/euclid/src/dssa/truth.rs | 41 +- crates/euclid/src/dssa/types.rs | 5 +- crates/euclid/src/lib.rs | 1 - crates/euclid/src/utils.rs | 3 - crates/euclid_macros/src/inner/knowledge.rs | 38 +- crates/euclid_wasm/Cargo.toml | 1 + crates/euclid_wasm/src/lib.rs | 31 +- .../hyperswitch_constraint_graph/Cargo.toml | 16 + .../src/builder.rs | 283 ++++ .../src}/dense_map.rs | 18 + .../hyperswitch_constraint_graph/src/error.rs | 77 + .../hyperswitch_constraint_graph/src/graph.rs | 587 +++++++ .../hyperswitch_constraint_graph/src/lib.rs | 13 + .../hyperswitch_constraint_graph/src/types.rs | 249 +++ crates/kgraph_utils/Cargo.toml | 1 + crates/kgraph_utils/benches/evaluation.rs | 12 +- crates/kgraph_utils/src/error.rs | 4 +- crates/kgraph_utils/src/mca.rs | 220 +-- crates/router/Cargo.toml | 1 + crates/router/src/core/payments/routing.rs | 18 +- 25 files changed, 2060 insertions(+), 1150 deletions(-) delete mode 100644 crates/euclid/src/utils.rs create mode 100644 crates/hyperswitch_constraint_graph/Cargo.toml create mode 100644 crates/hyperswitch_constraint_graph/src/builder.rs rename crates/{euclid/src/utils => hyperswitch_constraint_graph/src}/dense_map.rs (92%) create mode 100644 crates/hyperswitch_constraint_graph/src/error.rs create mode 100644 crates/hyperswitch_constraint_graph/src/graph.rs create mode 100644 crates/hyperswitch_constraint_graph/src/lib.rs create mode 100644 crates/hyperswitch_constraint_graph/src/types.rs diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 8f1326625c..933fe96526 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -45,6 +45,7 @@ crates/router/src/compatibility/ @juspay/hyperswitch-compatibility crates/router/src/core/ @juspay/hyperswitch-core crates/api_models/src/routing.rs @juspay/hyperswitch-routing +crates/hyperswitch_constraint_graph @juspay/hyperswitch-routing crates/euclid @juspay/hyperswitch-routing crates/euclid_macros @juspay/hyperswitch-routing crates/euclid_wasm @juspay/hyperswitch-routing diff --git a/Cargo.lock b/Cargo.lock index 00c491d710..7fce1e7f53 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2677,6 +2677,15 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" +[[package]] +name = "erased-serde" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c138974f9d5e7fe373eb04df7cae98833802ae4b11c24ac7039a21d5af4b26c" +dependencies = [ + "serde", +] + [[package]] name = "erased-serde" version = "0.4.4" @@ -2733,10 +2742,11 @@ version = "0.1.0" dependencies = [ "common_enums", "criterion", - "erased-serde", + "erased-serde 0.4.4", "euclid_macros", "frunk", "frunk_core", + "hyperswitch_constraint_graph", "nom", "once_cell", "rustc-hash", @@ -2768,6 +2778,7 @@ dependencies = [ "currency_conversion", "euclid", "getrandom", + "hyperswitch_constraint_graph", "kgraph_utils", "once_cell", "ron-parser", @@ -3602,6 +3613,18 @@ dependencies = [ "tokio 1.37.0", ] +[[package]] +name = "hyperswitch_constraint_graph" +version = "0.1.0" +dependencies = [ + "erased-serde 0.3.31", + "rustc-hash", + "serde", + "serde_json", + "strum 0.25.0", + "thiserror", +] + [[package]] name = "hyperswitch_domain_models" version = "0.1.0" @@ -3911,6 +3934,7 @@ dependencies = [ "common_enums", "criterion", "euclid", + "hyperswitch_constraint_graph", "masking", "serde", "serde_json", @@ -4067,7 +4091,7 @@ version = "0.1.0" dependencies = [ "bytes 1.6.0", "diesel", - "erased-serde", + "erased-serde 0.4.4", "serde", "serde_json", "subtle", @@ -5599,7 +5623,7 @@ dependencies = [ "digest", "dyn-clone", "encoding_rs", - "erased-serde", + "erased-serde 0.4.4", "error-stack", "euclid", "events", @@ -5608,6 +5632,7 @@ dependencies = [ "hex", "http 0.2.12", "hyper 0.14.28", + "hyperswitch_constraint_graph", "hyperswitch_domain_models", "hyperswitch_interfaces", "image", @@ -6698,6 +6723,15 @@ dependencies = [ "strum_macros 0.24.3", ] +[[package]] +name = "strum" +version = "0.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "290d54ea6f91c969195bdbcd7442c8c2a2ba87da8bf60a7ee86a235d4bc1e125" +dependencies = [ + "strum_macros 0.25.3", +] + [[package]] name = "strum" version = "0.26.2" @@ -6720,6 +6754,19 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "strum_macros" +version = "0.25.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23dc1fa9ac9c169a78ba62f0b841814b7abae11bdd047b9c58f893439e309ea0" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "rustversion", + "syn 2.0.57", +] + [[package]] name = "strum_macros" version = "0.26.2" diff --git a/crates/euclid/Cargo.toml b/crates/euclid/Cargo.toml index 7de2764552..3341746ab7 100644 --- a/crates/euclid/Cargo.toml +++ b/crates/euclid/Cargo.toml @@ -21,6 +21,7 @@ utoipa = { version = "4.2.0", features = ["preserve_order", "preserve_path_order # First party dependencies common_enums = { version = "0.1.0", path = "../common_enums" } +hyperswitch_constraint_graph = { version = "0.1.0", path = "../hyperswitch_constraint_graph" } euclid_macros = { version = "0.1.0", path = "../euclid_macros" } [features] diff --git a/crates/euclid/src/dssa/analyzer.rs b/crates/euclid/src/dssa/analyzer.rs index 4c615e7849..dc1da99a83 100644 --- a/crates/euclid/src/dssa/analyzer.rs +++ b/crates/euclid/src/dssa/analyzer.rs @@ -4,11 +4,15 @@ //! in the Euclid Rule DSL. These include standard control flow analyses like testing //! conflicting assertions, to Domain Specific Analyses making use of the //! [`Knowledge Graph Framework`](crate::dssa::graph). +use hyperswitch_constraint_graph::{ConstraintGraph, Memoization}; use rustc_hash::{FxHashMap, FxHashSet}; -use super::{graph::Memoization, types::EuclidAnalysable}; use crate::{ - dssa::{graph, state_machine, truth, types}, + dssa::{ + graph::CgraphExt, + state_machine, truth, + types::{self, EuclidAnalysable}, + }, frontend::{ ast, dir::{self, EuclidDirFilter}, @@ -203,12 +207,12 @@ fn perform_condition_analyses( fn perform_context_analyses( context: &types::ConjunctiveContext<'_>, - knowledge_graph: &graph::KnowledgeGraph<'_>, + knowledge_graph: &ConstraintGraph<'_, dir::DirValue>, ) -> Result<(), types::AnalysisError> { perform_condition_analyses(context)?; let mut memo = Memoization::new(); knowledge_graph - .perform_context_analysis(context, &mut memo) + .perform_context_analysis(context, &mut memo, None) .map_err(|err| types::AnalysisError { error_type: types::AnalysisErrorType::GraphAnalysis(err, memo), metadata: Default::default(), @@ -218,7 +222,7 @@ fn perform_context_analyses( pub fn analyze( program: ast::Program, - knowledge_graph: Option<&graph::KnowledgeGraph<'_>>, + knowledge_graph: Option<&ConstraintGraph<'_, dir::DirValue>>, ) -> Result, types::AnalysisError> { let dir_program = ast::lowering::lower_program(program)?; @@ -241,9 +245,14 @@ mod tests { use std::{ops::Deref, sync::Weak}; use euclid_macros::knowledge; + use hyperswitch_constraint_graph as cgraph; use super::*; - use crate::{dirval, types::DummyOutput}; + use crate::{ + dirval, + dssa::graph::{self, euclid_graph_prelude}, + types::DummyOutput, + }; #[test] fn test_conflicting_assertion_detection() { @@ -368,7 +377,7 @@ mod tests { #[test] fn test_negation_graph_analysis() { - let graph = knowledge! {crate + let graph = knowledge! { CaptureMethod(Automatic) ->> PaymentMethod(Card); }; @@ -410,18 +419,18 @@ mod tests { .deref() .clone() { - graph::AnalysisTrace::Value { predecessors, .. } => { - let _value = graph::NodeValue::Value(dir::DirValue::PaymentMethod( + cgraph::AnalysisTrace::Value { predecessors, .. } => { + let _value = cgraph::NodeValue::Value(dir::DirValue::PaymentMethod( dir::enums::PaymentMethod::Card, )); - let _relation = graph::Relation::Positive; + let _relation = cgraph::Relation::Positive; predecessors } _ => panic!("Expected Negation Trace for payment method = card"), }; let pred = match predecessor { - Some(graph::ValueTracePredecessor::Mandatory(predecessor)) => predecessor, + Some(cgraph::error::ValueTracePredecessor::Mandatory(predecessor)) => predecessor, _ => panic!("No predecessor found"), }; assert_eq!( @@ -433,11 +442,11 @@ mod tests { *Weak::upgrade(&pred) .expect("Expected Arc not found") .deref(), - graph::AnalysisTrace::Value { - value: graph::NodeValue::Value(dir::DirValue::CaptureMethod( + cgraph::AnalysisTrace::Value { + value: cgraph::NodeValue::Value(dir::DirValue::CaptureMethod( dir::enums::CaptureMethod::Automatic )), - relation: graph::Relation::Positive, + relation: cgraph::Relation::Positive, info: None, metadata: None, predecessors: None, diff --git a/crates/euclid/src/dssa/graph.rs b/crates/euclid/src/dssa/graph.rs index cb72cca044..526248a381 100644 --- a/crates/euclid/src/dssa/graph.rs +++ b/crates/euclid/src/dssa/graph.rs @@ -1,272 +1,53 @@ -use std::{ - fmt::Debug, - hash::Hash, - ops::{Deref, DerefMut}, - sync::{Arc, Weak}, -}; +use std::{fmt::Debug, sync::Weak}; -use erased_serde::{self, Serialize as ErasedSerialize}; +use hyperswitch_constraint_graph as cgraph; use rustc_hash::{FxHashMap, FxHashSet}; -use serde::Serialize; use crate::{ dssa::types, frontend::dir, types::{DataType, Metadata}, - utils, }; -#[derive(Debug, Clone, Copy, Serialize, PartialEq, Eq, Hash, strum::Display)] -pub enum Strength { - Weak, - Normal, - Strong, +pub mod euclid_graph_prelude { + pub use hyperswitch_constraint_graph as cgraph; + pub use rustc_hash::{FxHashMap, FxHashSet}; + + pub use crate::{ + dssa::graph::*, + frontend::dir::{enums::*, DirKey, DirKeyKind, DirValue}, + types::*, + }; } -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, strum::Display, Serialize)] -#[serde(rename_all = "snake_case")] -pub enum Relation { - Positive, - Negative, -} +impl cgraph::KeyNode for dir::DirKey {} -impl From for bool { - fn from(value: Relation) -> Self { - matches!(value, Relation::Positive) +impl cgraph::ValueNode for dir::DirValue { + type Key = dir::DirKey; + + fn get_key(&self) -> Self::Key { + Self::get_key(self) } } -#[derive(Debug, Clone, Copy, Serialize, PartialEq, Eq, Hash)] -pub struct NodeId(usize); - -impl utils::EntityId for NodeId { - #[inline] - fn get_id(&self) -> usize { - self.0 - } - - #[inline] - fn with_id(id: usize) -> Self { - Self(id) - } -} - -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct DomainInfo<'a> { - pub domain_identifier: DomainIdentifier<'a>, - pub domain_description: String, -} - -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct DomainIdentifier<'a>(&'a str); - -impl<'a> DomainIdentifier<'a> { - pub fn new(domain_identifier: &'a str) -> Self { - Self(domain_identifier) - } -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct DomainId(usize); - -impl utils::EntityId for DomainId { - #[inline] - fn get_id(&self) -> usize { - self.0 - } - - #[inline] - fn with_id(id: usize) -> Self { - Self(id) - } -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct EdgeId(usize); - -impl utils::EntityId for EdgeId { - #[inline] - fn get_id(&self) -> usize { - self.0 - } - - #[inline] - fn with_id(id: usize) -> Self { - Self(id) - } -} - -#[derive(Debug, Clone, Serialize)] -pub struct Memoization(FxHashMap<(NodeId, Relation, Strength), Result<(), Arc>>); - -impl Memoization { - pub fn new() -> Self { - Self(FxHashMap::default()) - } -} - -impl Default for Memoization { - #[inline] - fn default() -> Self { - Self::new() - } -} - -impl Deref for Memoization { - type Target = FxHashMap<(NodeId, Relation, Strength), Result<(), Arc>>; - fn deref(&self) -> &Self::Target { - &self.0 - } -} -impl DerefMut for Memoization { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } -} -#[derive(Debug, Clone)] -pub struct Edge { - pub strength: Strength, - pub relation: Relation, - pub pred: NodeId, - pub succ: NodeId, -} - -#[derive(Debug)] -pub struct Node { - pub node_type: NodeType, - pub preds: Vec, - pub succs: Vec, - pub domain_ids: Vec, -} - -impl Node { - fn new(node_type: NodeType, domain_ids: Vec) -> Self { - Self { - node_type, - preds: Vec::new(), - succs: Vec::new(), - domain_ids, - } - } -} - -pub trait KgraphMetadata: ErasedSerialize + std::any::Any + Sync + Send + Debug {} -erased_serde::serialize_trait_object!(KgraphMetadata); - -impl KgraphMetadata for M where M: ErasedSerialize + std::any::Any + Sync + Send + Debug {} - -#[derive(Debug)] -pub struct KnowledgeGraph<'a> { - domain: utils::DenseMap>, - nodes: utils::DenseMap, - edges: utils::DenseMap, - value_map: FxHashMap, - node_info: utils::DenseMap>, - node_metadata: utils::DenseMap>>, -} - -pub struct KnowledgeGraphBuilder<'a> { - domain: utils::DenseMap>, - nodes: utils::DenseMap, - edges: utils::DenseMap, - domain_identifier_map: FxHashMap, DomainId>, - value_map: FxHashMap, - edges_map: FxHashMap<(NodeId, NodeId), EdgeId>, - node_info: utils::DenseMap>, - node_metadata: utils::DenseMap>>, -} - -impl<'a> Default for KnowledgeGraphBuilder<'a> { - #[inline] - fn default() -> Self { - Self::new() - } -} - -#[derive(Debug, PartialEq, Eq)] -pub enum NodeType { - AllAggregator, - AnyAggregator, - InAggregator(FxHashSet), - Value(NodeValue), -} - -#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize)] -#[serde(tag = "type", content = "value", rename_all = "snake_case")] -pub enum NodeValue { - Key(dir::DirKey), - Value(dir::DirValue), -} - -impl From for NodeValue { - fn from(value: dir::DirValue) -> Self { - Self::Value(value) - } -} - -impl From for NodeValue { - fn from(key: dir::DirKey) -> Self { - Self::Key(key) - } -} - -#[derive(Debug, Clone, serde::Serialize)] -#[serde(tag = "type", content = "predecessor", rename_all = "snake_case")] -pub enum ValueTracePredecessor { - Mandatory(Box>), - OneOf(Vec>), -} - -#[derive(Debug, Clone, serde::Serialize)] -#[serde(tag = "type", content = "trace", rename_all = "snake_case")] -pub enum AnalysisTrace { - Value { - value: NodeValue, - relation: Relation, - predecessors: Option, - info: Option<&'static str>, - metadata: Option>, - }, - - AllAggregation { - unsatisfied: Vec>, - info: Option<&'static str>, - metadata: Option>, - }, - - AnyAggregation { - unsatisfied: Vec>, - info: Option<&'static str>, - metadata: Option>, - }, - - InAggregation { - expected: Vec, - found: Option, - relation: Relation, - info: Option<&'static str>, - metadata: Option>, - }, -} - #[derive(Debug, Clone, serde::Serialize)] #[serde(tag = "type", content = "details", rename_all = "snake_case")] -pub enum AnalysisError { - Graph(GraphError), +pub enum AnalysisError { + Graph(cgraph::GraphError), AssertionTrace { - trace: Weak, + trace: Weak>, metadata: Metadata, }, NegationTrace { - trace: Weak, + trace: Weak>, metadata: Vec, }, } -impl AnalysisError { - fn assertion_from_graph_error(metadata: &Metadata, graph_error: GraphError) -> Self { +impl AnalysisError { + fn assertion_from_graph_error(metadata: &Metadata, graph_error: cgraph::GraphError) -> Self { match graph_error { - GraphError::AnalysisError(trace) => Self::AssertionTrace { + cgraph::GraphError::AnalysisError(trace) => Self::AssertionTrace { trace, metadata: metadata.clone(), }, @@ -275,9 +56,12 @@ impl AnalysisError { } } - fn negation_from_graph_error(metadata: Vec<&Metadata>, graph_error: GraphError) -> Self { + fn negation_from_graph_error( + metadata: Vec<&Metadata>, + graph_error: cgraph::GraphError, + ) -> Self { match graph_error { - GraphError::AnalysisError(trace) => Self::NegationTrace { + cgraph::GraphError::AnalysisError(trace) => Self::NegationTrace { trace, metadata: metadata.iter().map(|m| (*m).clone()).collect(), }, @@ -287,56 +71,6 @@ impl AnalysisError { } } -#[derive(Debug, Clone, serde::Serialize, thiserror::Error)] -#[serde(tag = "type", content = "info", rename_all = "snake_case")] -pub enum GraphError { - #[error("An edge was not found in the graph")] - EdgeNotFound, - #[error("Attempted to create a conflicting edge between two nodes")] - ConflictingEdgeCreated, - #[error("Cycle detected in graph")] - CycleDetected, - #[error("Domain wasn't found in the Graph")] - DomainNotFound, - #[error("Malformed Graph: {reason}")] - MalformedGraph { reason: String }, - #[error("A node was not found in the graph")] - NodeNotFound, - #[error("A value node was not found: {0:#?}")] - ValueNodeNotFound(dir::DirValue), - #[error("No values provided for an 'in' aggregator node")] - NoInAggregatorValues, - #[error("Error during analysis: {0:#?}")] - AnalysisError(Weak), -} - -impl GraphError { - fn get_analysis_trace(self) -> Result, Self> { - match self { - Self::AnalysisError(trace) => Ok(trace), - _ => Err(self), - } - } -} - -impl PartialEq for NodeValue { - fn eq(&self, other: &dir::DirValue) -> bool { - match self { - Self::Key(dir_key) => *dir_key == other.get_key(), - Self::Value(dir_value) if dir_value.get_key() == other.get_key() => { - if let (Some(left), Some(right)) = - (dir_value.get_num_value(), other.get_num_value()) - { - left.fits(&right) - } else { - dir::DirValue::check_equality(dir_value, other) - } - } - Self::Value(_) => false, - } - } -} - pub struct AnalysisContext { keywise_values: FxHashMap>, } @@ -355,33 +89,6 @@ impl AnalysisContext { Self { keywise_values } } - fn check_presence(&self, value: &NodeValue, weak: bool) -> bool { - match value { - NodeValue::Key(k) => self.keywise_values.contains_key(k) || weak, - NodeValue::Value(val) => { - let key = val.get_key(); - let value_set = if let Some(set) = self.keywise_values.get(&key) { - set - } else { - return weak; - }; - - match key.kind.get_type() { - DataType::EnumVariant | DataType::StrValue | DataType::MetadataValue => { - value_set.contains(val) - } - DataType::Number => val.get_num_value().map_or(false, |num_val| { - value_set.iter().any(|ctx_val| { - ctx_val - .get_num_value() - .map_or(false, |ctx_num_val| num_val.fits(&ctx_num_val)) - }) - }), - } - } - } - } - pub fn insert(&mut self, value: dir::DirValue) { self.keywise_values .entry(value.get_key()) @@ -400,477 +107,153 @@ impl AnalysisContext { } } -impl<'a> KnowledgeGraphBuilder<'a> { - pub fn new() -> Self { - Self { - domain: utils::DenseMap::new(), - nodes: utils::DenseMap::new(), - edges: utils::DenseMap::new(), - domain_identifier_map: FxHashMap::default(), - value_map: FxHashMap::default(), - edges_map: FxHashMap::default(), - node_info: utils::DenseMap::new(), - node_metadata: utils::DenseMap::new(), +impl cgraph::CheckingContext for AnalysisContext { + type Value = dir::DirValue; + + fn from_node_values(vals: impl IntoIterator) -> Self + where + L: Into, + { + let mut keywise_values: FxHashMap> = + FxHashMap::default(); + + for dir_val in vals.into_iter().map(L::into) { + let key = dir_val.get_key(); + let set = keywise_values.entry(key).or_default(); + set.insert(dir_val); } + + Self { keywise_values } } - pub fn build(self) -> KnowledgeGraph<'a> { - KnowledgeGraph { - domain: self.domain, - nodes: self.nodes, - edges: self.edges, - value_map: self.value_map, - node_info: self.node_info, - node_metadata: self.node_metadata, - } - } - - pub fn make_domain( - &mut self, - domain_identifier: DomainIdentifier<'a>, - domain_description: String, - ) -> Result { - Ok(self - .domain_identifier_map - .clone() - .get(&domain_identifier) - .map_or_else( - || { - let domain_id = self.domain.push(DomainInfo { - domain_identifier: domain_identifier.clone(), - domain_description, - }); - self.domain_identifier_map - .insert(domain_identifier.clone(), domain_id); - domain_id - }, - |domain_id| *domain_id, - )) - } - - pub fn make_value_node( - &mut self, - value: NodeValue, - info: Option<&'static str>, - domain_identifiers: Vec>, - metadata: Option, - ) -> Result { - match self.value_map.get(&value).copied() { - Some(node_id) => Ok(node_id), - None => { - let mut domain_ids: Vec = Vec::new(); - domain_identifiers - .iter() - .try_for_each(|ident| { - self.domain_identifier_map - .get(ident) - .map(|id| domain_ids.push(*id)) - }) - .ok_or(GraphError::DomainNotFound)?; - - let node_id = self - .nodes - .push(Node::new(NodeType::Value(value.clone()), domain_ids)); - let _node_info_id = self.node_info.push(info); - - let _node_metadata_id = self - .node_metadata - .push(metadata.map(|meta| -> Arc { Arc::new(meta) })); - - self.value_map.insert(value, node_id); - Ok(node_id) + fn check_presence( + &self, + value: &cgraph::NodeValue, + strength: cgraph::Strength, + ) -> bool { + match value { + cgraph::NodeValue::Key(k) => { + self.keywise_values.contains_key(k) || matches!(strength, cgraph::Strength::Weak) } - } - } - pub fn make_edge( - &mut self, - pred_id: NodeId, - succ_id: NodeId, - strength: Strength, - relation: Relation, - ) -> Result { - self.ensure_node_exists(pred_id)?; - self.ensure_node_exists(succ_id)?; - self.edges_map - .get(&(pred_id, succ_id)) - .copied() - .and_then(|edge_id| self.edges.get(edge_id).cloned().map(|edge| (edge_id, edge))) - .map_or_else( - || { - let edge_id = self.edges.push(Edge { - strength, - relation, - pred: pred_id, - succ: succ_id, - }); - self.edges_map.insert((pred_id, succ_id), edge_id); + cgraph::NodeValue::Value(val) => { + let key = val.get_key(); + let value_set = if let Some(set) = self.keywise_values.get(&key) { + set + } else { + return matches!(strength, cgraph::Strength::Weak); + }; - let pred = self - .nodes - .get_mut(pred_id) - .ok_or(GraphError::NodeNotFound)?; - pred.succs.push(edge_id); - - let succ = self - .nodes - .get_mut(succ_id) - .ok_or(GraphError::NodeNotFound)?; - succ.preds.push(edge_id); - - Ok(edge_id) - }, - |(edge_id, edge)| { - if edge.strength == strength && edge.relation == relation { - Ok(edge_id) - } else { - Err(GraphError::ConflictingEdgeCreated) + match key.kind.get_type() { + DataType::EnumVariant | DataType::StrValue | DataType::MetadataValue => { + value_set.contains(val) } - }, - ) - } - - pub fn make_all_aggregator( - &mut self, - nodes: &[(NodeId, Relation, Strength)], - info: Option<&'static str>, - metadata: Option, - domain: Vec>, - ) -> Result { - nodes - .iter() - .try_for_each(|(node_id, _, _)| self.ensure_node_exists(*node_id))?; - - let mut domain_ids: Vec = Vec::new(); - domain - .iter() - .try_for_each(|ident| { - self.domain_identifier_map - .get(ident) - .map(|id| domain_ids.push(*id)) - }) - .ok_or(GraphError::DomainNotFound)?; - - let aggregator_id = self - .nodes - .push(Node::new(NodeType::AllAggregator, domain_ids)); - let _aggregator_info_id = self.node_info.push(info); - - let _node_metadata_id = self - .node_metadata - .push(metadata.map(|meta| -> Arc { Arc::new(meta) })); - - for (node_id, relation, strength) in nodes { - self.make_edge(*node_id, aggregator_id, *strength, *relation)?; - } - - Ok(aggregator_id) - } - - pub fn make_any_aggregator( - &mut self, - nodes: &[(NodeId, Relation)], - info: Option<&'static str>, - metadata: Option, - domain: Vec>, - ) -> Result { - nodes - .iter() - .try_for_each(|(node_id, _)| self.ensure_node_exists(*node_id))?; - - let mut domain_ids: Vec = Vec::new(); - domain - .iter() - .try_for_each(|ident| { - self.domain_identifier_map - .get(ident) - .map(|id| domain_ids.push(*id)) - }) - .ok_or(GraphError::DomainNotFound)?; - - let aggregator_id = self - .nodes - .push(Node::new(NodeType::AnyAggregator, domain_ids)); - let _aggregator_info_id = self.node_info.push(info); - - let _node_metadata_id = self - .node_metadata - .push(metadata.map(|meta| -> Arc { Arc::new(meta) })); - - for (node_id, relation) in nodes { - self.make_edge(*node_id, aggregator_id, Strength::Strong, *relation)?; - } - - Ok(aggregator_id) - } - - pub fn make_in_aggregator( - &mut self, - values: Vec, - info: Option<&'static str>, - metadata: Option, - domain: Vec>, - ) -> Result { - let key = values - .first() - .ok_or(GraphError::NoInAggregatorValues)? - .get_key(); - - for val in &values { - if val.get_key() != key { - Err(GraphError::MalformedGraph { - reason: "Values for 'In' aggregator not of same key".to_string(), - })?; + DataType::Number => val.get_num_value().map_or(false, |num_val| { + value_set.iter().any(|ctx_val| { + ctx_val + .get_num_value() + .map_or(false, |ctx_num_val| num_val.fits(&ctx_num_val)) + }) + }), + } } } - - let mut domain_ids: Vec = Vec::new(); - domain - .iter() - .try_for_each(|ident| { - self.domain_identifier_map - .get(ident) - .map(|id| domain_ids.push(*id)) - }) - .ok_or(GraphError::DomainNotFound)?; - - let node_id = self.nodes.push(Node::new( - NodeType::InAggregator(FxHashSet::from_iter(values)), - domain_ids, - )); - let _aggregator_info_id = self.node_info.push(info); - - let _node_metadata_id = self - .node_metadata - .push(metadata.map(|meta| -> Arc { Arc::new(meta) })); - - Ok(node_id) } - fn ensure_node_exists(&self, id: NodeId) -> Result<(), GraphError> { - if self.nodes.contains_key(id) { - Ok(()) - } else { - Err(GraphError::NodeNotFound) - } + fn get_values_by_key( + &self, + key: &::Key, + ) -> Option> { + self.keywise_values + .get(key) + .map(|set| set.iter().cloned().collect()) } } -impl<'a> KnowledgeGraph<'a> { - fn check_node( - &self, - ctx: &AnalysisContext, - node_id: NodeId, - relation: Relation, - strength: Strength, - memo: &mut Memoization, - ) -> Result<(), GraphError> { - let node = self.nodes.get(node_id).ok_or(GraphError::NodeNotFound)?; - if let Some(already_memo) = memo.get(&(node_id, relation, strength)) { - already_memo - .clone() - .map_err(|err| GraphError::AnalysisError(Arc::downgrade(&err))) - } else { - match &node.node_type { - NodeType::AllAggregator => { - let mut unsatisfied = Vec::>::new(); - - for edge_id in node.preds.iter().copied() { - let edge = self.edges.get(edge_id).ok_or(GraphError::EdgeNotFound)?; - - if let Err(e) = - self.check_node(ctx, edge.pred, edge.relation, edge.strength, memo) - { - unsatisfied.push(e.get_analysis_trace()?); - } - } - - if !unsatisfied.is_empty() { - let err = Arc::new(AnalysisTrace::AllAggregation { - unsatisfied, - info: self.node_info.get(node_id).cloned().flatten(), - metadata: self.node_metadata.get(node_id).cloned().flatten(), - }); - - memo.insert((node_id, relation, strength), Err(Arc::clone(&err))); - Err(GraphError::AnalysisError(Arc::downgrade(&err))) - } else { - memo.insert((node_id, relation, strength), Ok(())); - Ok(()) - } - } - - NodeType::AnyAggregator => { - let mut unsatisfied = Vec::>::new(); - let mut matched_one = false; - - for edge_id in node.preds.iter().copied() { - let edge = self.edges.get(edge_id).ok_or(GraphError::EdgeNotFound)?; - - if let Err(e) = - self.check_node(ctx, edge.pred, edge.relation, edge.strength, memo) - { - unsatisfied.push(e.get_analysis_trace()?); - } else { - matched_one = true; - } - } - - if matched_one || node.preds.is_empty() { - memo.insert((node_id, relation, strength), Ok(())); - Ok(()) - } else { - let err = Arc::new(AnalysisTrace::AnyAggregation { - unsatisfied: unsatisfied.clone(), - info: self.node_info.get(node_id).cloned().flatten(), - metadata: self.node_metadata.get(node_id).cloned().flatten(), - }); - - memo.insert((node_id, relation, strength), Err(Arc::clone(&err))); - Err(GraphError::AnalysisError(Arc::downgrade(&err))) - } - } - - NodeType::InAggregator(expected) => { - let the_key = expected - .iter() - .next() - .ok_or_else(|| GraphError::MalformedGraph { - reason: - "An OnlyIn aggregator node must have at least one expected value" - .to_string(), - })? - .get_key(); - - let ctx_vals = if let Some(vals) = ctx.keywise_values.get(&the_key) { - vals - } else { - return if let Strength::Weak = strength { - memo.insert((node_id, relation, strength), Ok(())); - Ok(()) - } else { - let err = Arc::new(AnalysisTrace::InAggregation { - expected: expected.iter().cloned().collect(), - found: None, - relation, - info: self.node_info.get(node_id).cloned().flatten(), - metadata: self.node_metadata.get(node_id).cloned().flatten(), - }); - - memo.insert((node_id, relation, strength), Err(Arc::clone(&err))); - Err(GraphError::AnalysisError(Arc::downgrade(&err))) - }; - }; - - let relation_bool: bool = relation.into(); - for ctx_value in ctx_vals { - if expected.contains(ctx_value) != relation_bool { - let err = Arc::new(AnalysisTrace::InAggregation { - expected: expected.iter().cloned().collect(), - found: Some(ctx_value.clone()), - relation, - info: self.node_info.get(node_id).cloned().flatten(), - metadata: self.node_metadata.get(node_id).cloned().flatten(), - }); - - memo.insert((node_id, relation, strength), Err(Arc::clone(&err))); - Err(GraphError::AnalysisError(Arc::downgrade(&err)))?; - } - } - - memo.insert((node_id, relation, strength), Ok(())); - Ok(()) - } - - NodeType::Value(val) => { - let in_context = ctx.check_presence(val, matches!(strength, Strength::Weak)); - let relation_bool: bool = relation.into(); - - if in_context != relation_bool { - let err = Arc::new(AnalysisTrace::Value { - value: val.clone(), - relation, - predecessors: None, - info: self.node_info.get(node_id).cloned().flatten(), - metadata: self.node_metadata.get(node_id).cloned().flatten(), - }); - - memo.insert((node_id, relation, strength), Err(Arc::clone(&err))); - Err(GraphError::AnalysisError(Arc::downgrade(&err)))?; - } - - if !relation_bool { - memo.insert((node_id, relation, strength), Ok(())); - return Ok(()); - } - - let mut errors = Vec::>::new(); - let mut matched_one = false; - - for edge_id in node.preds.iter().copied() { - let edge = self.edges.get(edge_id).ok_or(GraphError::EdgeNotFound)?; - let result = - self.check_node(ctx, edge.pred, edge.relation, edge.strength, memo); - - match (edge.strength, result) { - (Strength::Strong, Err(trace)) => { - let err = Arc::new(AnalysisTrace::Value { - value: val.clone(), - relation, - info: self.node_info.get(node_id).cloned().flatten(), - metadata: self.node_metadata.get(node_id).cloned().flatten(), - predecessors: Some(ValueTracePredecessor::Mandatory(Box::new( - trace.get_analysis_trace()?, - ))), - }); - memo.insert((node_id, relation, strength), Err(Arc::clone(&err))); - Err(GraphError::AnalysisError(Arc::downgrade(&err)))?; - } - - (Strength::Strong, Ok(_)) => { - matched_one = true; - } - - (Strength::Normal | Strength::Weak, Err(trace)) => { - errors.push(trace.get_analysis_trace()?); - } - - (Strength::Normal | Strength::Weak, Ok(_)) => { - matched_one = true; - } - } - } - - if matched_one || node.preds.is_empty() { - memo.insert((node_id, relation, strength), Ok(())); - Ok(()) - } else { - let err = Arc::new(AnalysisTrace::Value { - value: val.clone(), - relation, - info: self.node_info.get(node_id).cloned().flatten(), - metadata: self.node_metadata.get(node_id).cloned().flatten(), - predecessors: Some(ValueTracePredecessor::OneOf(errors.clone())), - }); - - memo.insert((node_id, relation, strength), Err(Arc::clone(&err))); - Err(GraphError::AnalysisError(Arc::downgrade(&err))) - } - } - } - } - } - +pub trait CgraphExt { fn key_analysis( &self, key: dir::DirKey, ctx: &AnalysisContext, - memo: &mut Memoization, - ) -> Result<(), GraphError> { + memo: &mut cgraph::Memoization, + cycle_map: &mut cgraph::CycleCheck, + domains: Option<&[&str]>, + ) -> Result<(), cgraph::GraphError>; + + fn value_analysis( + &self, + val: dir::DirValue, + ctx: &AnalysisContext, + memo: &mut cgraph::Memoization, + cycle_map: &mut cgraph::CycleCheck, + domains: Option<&[&str]>, + ) -> Result<(), cgraph::GraphError>; + + fn check_value_validity( + &self, + val: dir::DirValue, + analysis_ctx: &AnalysisContext, + memo: &mut cgraph::Memoization, + cycle_map: &mut cgraph::CycleCheck, + domains: Option<&[&str]>, + ) -> Result>; + + fn key_value_analysis( + &self, + val: dir::DirValue, + ctx: &AnalysisContext, + memo: &mut cgraph::Memoization, + cycle_map: &mut cgraph::CycleCheck, + domains: Option<&[&str]>, + ) -> Result<(), cgraph::GraphError>; + + fn assertion_analysis( + &self, + positive_ctx: &[(&dir::DirValue, &Metadata)], + analysis_ctx: &AnalysisContext, + memo: &mut cgraph::Memoization, + cycle_map: &mut cgraph::CycleCheck, + domains: Option<&[&str]>, + ) -> Result<(), AnalysisError>; + + fn negation_analysis( + &self, + negative_ctx: &[(&[dir::DirValue], &Metadata)], + analysis_ctx: &mut AnalysisContext, + memo: &mut cgraph::Memoization, + cycle_map: &mut cgraph::CycleCheck, + domains: Option<&[&str]>, + ) -> Result<(), AnalysisError>; + + fn perform_context_analysis( + &self, + ctx: &types::ConjunctiveContext<'_>, + memo: &mut cgraph::Memoization, + domains: Option<&[&str]>, + ) -> Result<(), AnalysisError>; +} + +impl CgraphExt for cgraph::ConstraintGraph<'_, dir::DirValue> { + fn key_analysis( + &self, + key: dir::DirKey, + ctx: &AnalysisContext, + memo: &mut cgraph::Memoization, + cycle_map: &mut cgraph::CycleCheck, + domains: Option<&[&str]>, + ) -> Result<(), cgraph::GraphError> { self.value_map - .get(&NodeValue::Key(key)) + .get(&cgraph::NodeValue::Key(key)) .map_or(Ok(()), |node_id| { - self.check_node(ctx, *node_id, Relation::Positive, Strength::Strong, memo) + self.check_node( + ctx, + *node_id, + cgraph::Relation::Positive, + cgraph::Strength::Strong, + memo, + cycle_map, + domains, + ) }) } @@ -878,22 +261,34 @@ impl<'a> KnowledgeGraph<'a> { &self, val: dir::DirValue, ctx: &AnalysisContext, - memo: &mut Memoization, - ) -> Result<(), GraphError> { + memo: &mut cgraph::Memoization, + cycle_map: &mut cgraph::CycleCheck, + domains: Option<&[&str]>, + ) -> Result<(), cgraph::GraphError> { self.value_map - .get(&NodeValue::Value(val)) + .get(&cgraph::NodeValue::Value(val)) .map_or(Ok(()), |node_id| { - self.check_node(ctx, *node_id, Relation::Positive, Strength::Strong, memo) + self.check_node( + ctx, + *node_id, + cgraph::Relation::Positive, + cgraph::Strength::Strong, + memo, + cycle_map, + domains, + ) }) } - pub fn check_value_validity( + fn check_value_validity( &self, val: dir::DirValue, analysis_ctx: &AnalysisContext, - memo: &mut Memoization, - ) -> Result { - let maybe_node_id = self.value_map.get(&NodeValue::Value(val)); + memo: &mut cgraph::Memoization, + cycle_map: &mut cgraph::CycleCheck, + domains: Option<&[&str]>, + ) -> Result> { + let maybe_node_id = self.value_map.get(&cgraph::NodeValue::Value(val)); let node_id = if let Some(nid) = maybe_node_id { nid @@ -904,9 +299,11 @@ impl<'a> KnowledgeGraph<'a> { let result = self.check_node( analysis_ctx, *node_id, - Relation::Positive, - Strength::Weak, + cgraph::Relation::Positive, + cgraph::Strength::Weak, memo, + cycle_map, + domains, ); match result { @@ -918,24 +315,28 @@ impl<'a> KnowledgeGraph<'a> { } } - pub fn key_value_analysis( + fn key_value_analysis( &self, val: dir::DirValue, ctx: &AnalysisContext, - memo: &mut Memoization, - ) -> Result<(), GraphError> { - self.key_analysis(val.get_key(), ctx, memo) - .and_then(|_| self.value_analysis(val, ctx, memo)) + memo: &mut cgraph::Memoization, + cycle_map: &mut cgraph::CycleCheck, + domains: Option<&[&str]>, + ) -> Result<(), cgraph::GraphError> { + self.key_analysis(val.get_key(), ctx, memo, cycle_map, domains) + .and_then(|_| self.value_analysis(val, ctx, memo, cycle_map, domains)) } fn assertion_analysis( &self, positive_ctx: &[(&dir::DirValue, &Metadata)], analysis_ctx: &AnalysisContext, - memo: &mut Memoization, - ) -> Result<(), AnalysisError> { + memo: &mut cgraph::Memoization, + cycle_map: &mut cgraph::CycleCheck, + domains: Option<&[&str]>, + ) -> Result<(), AnalysisError> { positive_ctx.iter().try_for_each(|(value, metadata)| { - self.key_value_analysis((*value).clone(), analysis_ctx, memo) + self.key_value_analysis((*value).clone(), analysis_ctx, memo, cycle_map, domains) .map_err(|e| AnalysisError::assertion_from_graph_error(metadata, e)) }) } @@ -944,8 +345,10 @@ impl<'a> KnowledgeGraph<'a> { &self, negative_ctx: &[(&[dir::DirValue], &Metadata)], analysis_ctx: &mut AnalysisContext, - memo: &mut Memoization, - ) -> Result<(), AnalysisError> { + memo: &mut cgraph::Memoization, + cycle_map: &mut cgraph::CycleCheck, + domains: Option<&[&str]>, + ) -> Result<(), AnalysisError> { let mut keywise_metadata: FxHashMap> = FxHashMap::default(); let mut keywise_negation: FxHashMap> = FxHashMap::default(); @@ -974,7 +377,7 @@ impl<'a> KnowledgeGraph<'a> { let all_metadata = keywise_metadata.remove(&key).unwrap_or_default(); let first_metadata = all_metadata.first().cloned().cloned().unwrap_or_default(); - self.key_analysis(key.clone(), analysis_ctx, memo) + self.key_analysis(key.clone(), analysis_ctx, memo, cycle_map, domains) .map_err(|e| AnalysisError::assertion_from_graph_error(&first_metadata, e))?; let mut value_set = if let Some(set) = key.kind.get_value_set() { @@ -987,7 +390,7 @@ impl<'a> KnowledgeGraph<'a> { for value in value_set { analysis_ctx.insert(value.clone()); - self.value_analysis(value.clone(), analysis_ctx, memo) + self.value_analysis(value.clone(), analysis_ctx, memo, cycle_map, domains) .map_err(|e| { AnalysisError::negation_from_graph_error(all_metadata.clone(), e) })?; @@ -998,11 +401,12 @@ impl<'a> KnowledgeGraph<'a> { Ok(()) } - pub fn perform_context_analysis( + fn perform_context_analysis( &self, ctx: &types::ConjunctiveContext<'_>, - memo: &mut Memoization, - ) -> Result<(), AnalysisError> { + memo: &mut cgraph::Memoization, + domains: Option<&[&str]>, + ) -> Result<(), AnalysisError> { let mut analysis_ctx = AnalysisContext::from_dir_values( ctx.iter() .filter_map(|ctx_val| ctx_val.value.get_assertion().cloned()), @@ -1017,7 +421,13 @@ impl<'a> KnowledgeGraph<'a> { .map(|val| (val, ctx_val.metadata)) }) .collect::>(); - self.assertion_analysis(&positive_ctx, &analysis_ctx, memo)?; + self.assertion_analysis( + &positive_ctx, + &analysis_ctx, + memo, + &mut cgraph::CycleCheck::new(), + domains, + )?; let negative_ctx = ctx .iter() @@ -1028,127 +438,38 @@ impl<'a> KnowledgeGraph<'a> { .map(|vals| (vals, ctx_val.metadata)) }) .collect::>(); - self.negation_analysis(&negative_ctx, &mut analysis_ctx, memo)?; + self.negation_analysis( + &negative_ctx, + &mut analysis_ctx, + memo, + &mut cgraph::CycleCheck::new(), + domains, + )?; Ok(()) } - - pub fn combine<'b>(g1: &'b Self, g2: &'b Self) -> Result { - let mut node_builder = KnowledgeGraphBuilder::new(); - let mut g1_old2new_id = utils::DenseMap::::new(); - let mut g2_old2new_id = utils::DenseMap::::new(); - let mut g1_old2new_domain_id = utils::DenseMap::::new(); - let mut g2_old2new_domain_id = utils::DenseMap::::new(); - - let add_domain = |node_builder: &mut KnowledgeGraphBuilder<'a>, - domain: DomainInfo<'a>| - -> Result { - node_builder.make_domain(domain.domain_identifier, domain.domain_description) - }; - - let add_node = |node_builder: &mut KnowledgeGraphBuilder<'a>, - node: &Node, - domains: Vec>| - -> Result { - match &node.node_type { - NodeType::Value(node_value) => { - node_builder.make_value_node(node_value.clone(), None, domains, None::<()>) - } - - NodeType::AllAggregator => { - Ok(node_builder.make_all_aggregator(&[], None, None::<()>, domains)?) - } - - NodeType::AnyAggregator => { - Ok(node_builder.make_any_aggregator(&[], None, None::<()>, Vec::new())?) - } - - NodeType::InAggregator(expected) => Ok(node_builder.make_in_aggregator( - expected.iter().cloned().collect(), - None, - None::<()>, - Vec::new(), - )?), - } - }; - - for (_old_domain_id, domain) in g1.domain.iter() { - let new_domain_id = add_domain(&mut node_builder, domain.clone())?; - g1_old2new_domain_id.push(new_domain_id); - } - - for (_old_domain_id, domain) in g2.domain.iter() { - let new_domain_id = add_domain(&mut node_builder, domain.clone())?; - g2_old2new_domain_id.push(new_domain_id); - } - - for (_old_node_id, node) in g1.nodes.iter() { - let mut domain_identifiers: Vec> = Vec::new(); - for domain_id in &node.domain_ids { - match g1.domain.get(*domain_id) { - Some(domain) => domain_identifiers.push(domain.domain_identifier.clone()), - None => return Err(GraphError::DomainNotFound), - } - } - let new_node_id = add_node(&mut node_builder, node, domain_identifiers.clone())?; - g1_old2new_id.push(new_node_id); - } - - for (_old_node_id, node) in g2.nodes.iter() { - let mut domain_identifiers: Vec> = Vec::new(); - for domain_id in &node.domain_ids { - match g2.domain.get(*domain_id) { - Some(domain) => domain_identifiers.push(domain.domain_identifier.clone()), - None => return Err(GraphError::DomainNotFound), - } - } - let new_node_id = add_node(&mut node_builder, node, domain_identifiers.clone())?; - g2_old2new_id.push(new_node_id); - } - - for edge in g1.edges.values() { - let new_pred_id = g1_old2new_id - .get(edge.pred) - .ok_or(GraphError::NodeNotFound)?; - let new_succ_id = g1_old2new_id - .get(edge.succ) - .ok_or(GraphError::NodeNotFound)?; - - node_builder.make_edge(*new_pred_id, *new_succ_id, edge.strength, edge.relation)?; - } - - for edge in g2.edges.values() { - let new_pred_id = g2_old2new_id - .get(edge.pred) - .ok_or(GraphError::NodeNotFound)?; - let new_succ_id = g2_old2new_id - .get(edge.succ) - .ok_or(GraphError::NodeNotFound)?; - - node_builder.make_edge(*new_pred_id, *new_succ_id, edge.strength, edge.relation)?; - } - - Ok(node_builder.build()) - } } #[cfg(test)] mod test { #![allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)] + use std::ops::Deref; + use euclid_macros::knowledge; + use hyperswitch_constraint_graph::CycleCheck; use super::*; use crate::{dirval, frontend::dir::enums}; #[test] fn test_strong_positive_relation_success() { - let graph = knowledge! {crate + let graph = knowledge! { PaymentMethod(Card) ->> CaptureMethod(Automatic); PaymentMethod(not Wallet) & PaymentMethod(not PayLater) -> CaptureMethod(Automatic); }; - let memo = &mut Memoization::new(); + let memo = &mut cgraph::Memoization::new(); let result = graph.key_value_analysis( dirval!(CaptureMethod = Automatic), &AnalysisContext::from_dir_values([ @@ -1156,6 +477,8 @@ mod test { dirval!(PaymentMethod = Card), ]), memo, + &mut CycleCheck::new(), + None, ); assert!(result.is_ok()); @@ -1163,15 +486,17 @@ mod test { #[test] fn test_strong_positive_relation_failure() { - let graph = knowledge! {crate + let graph = knowledge! { PaymentMethod(Card) ->> CaptureMethod(Automatic); PaymentMethod(not Wallet) -> CaptureMethod(Automatic); }; - let memo = &mut Memoization::new(); + let memo = &mut cgraph::Memoization::new(); let result = graph.key_value_analysis( dirval!(CaptureMethod = Automatic), &AnalysisContext::from_dir_values([dirval!(CaptureMethod = Automatic)]), memo, + &mut CycleCheck::new(), + None, ); assert!(result.is_err()); @@ -1179,11 +504,11 @@ mod test { #[test] fn test_strong_negative_relation_success() { - let graph = knowledge! {crate + let graph = knowledge! { PaymentMethod(Card) -> CaptureMethod(Automatic); PaymentMethod(not Wallet) ->> CaptureMethod(Automatic); }; - let memo = &mut Memoization::new(); + let memo = &mut cgraph::Memoization::new(); let result = graph.key_value_analysis( dirval!(CaptureMethod = Automatic), &AnalysisContext::from_dir_values([ @@ -1191,6 +516,8 @@ mod test { dirval!(PaymentMethod = Card), ]), memo, + &mut CycleCheck::new(), + None, ); assert!(result.is_ok()); @@ -1198,11 +525,11 @@ mod test { #[test] fn test_strong_negative_relation_failure() { - let graph = knowledge! {crate + let graph = knowledge! { PaymentMethod(Card) -> CaptureMethod(Automatic); PaymentMethod(not Wallet) ->> CaptureMethod(Automatic); }; - let memo = &mut Memoization::new(); + let memo = &mut cgraph::Memoization::new(); let result = graph.key_value_analysis( dirval!(CaptureMethod = Automatic), &AnalysisContext::from_dir_values([ @@ -1210,6 +537,8 @@ mod test { dirval!(PaymentMethod = Wallet), ]), memo, + &mut CycleCheck::new(), + None, ); assert!(result.is_err()); @@ -1217,11 +546,11 @@ mod test { #[test] fn test_normal_one_of_failure() { - let graph = knowledge! {crate + let graph = knowledge! { PaymentMethod(Card) -> CaptureMethod(Automatic); PaymentMethod(Wallet) -> CaptureMethod(Automatic); }; - let memo = &mut Memoization::new(); + let memo = &mut cgraph::Memoization::new(); let result = graph.key_value_analysis( dirval!(CaptureMethod = Automatic), &AnalysisContext::from_dir_values([ @@ -1229,12 +558,14 @@ mod test { dirval!(PaymentMethod = PayLater), ]), memo, + &mut CycleCheck::new(), + None, ); assert!(matches!( *Weak::upgrade(&result.unwrap_err().get_analysis_trace().unwrap()) .expect("Expected Arc"), - AnalysisTrace::Value { - predecessors: Some(ValueTracePredecessor::OneOf(_)), + cgraph::AnalysisTrace::Value { + predecessors: Some(cgraph::error::ValueTracePredecessor::OneOf(_)), .. } )); @@ -1242,10 +573,10 @@ mod test { #[test] fn test_all_aggregator_success() { - let graph = knowledge! {crate + let graph = knowledge! { PaymentMethod(Card) & PaymentMethod(not Wallet) -> CaptureMethod(Automatic); }; - let memo = &mut Memoization::new(); + let memo = &mut cgraph::Memoization::new(); let result = graph.key_value_analysis( dirval!(CaptureMethod = Automatic), &AnalysisContext::from_dir_values([ @@ -1253,6 +584,8 @@ mod test { dirval!(CaptureMethod = Automatic), ]), memo, + &mut CycleCheck::new(), + None, ); assert!(result.is_ok()); @@ -1260,10 +593,10 @@ mod test { #[test] fn test_all_aggregator_failure() { - let graph = knowledge! {crate + let graph = knowledge! { PaymentMethod(Card) & PaymentMethod(not Wallet) -> CaptureMethod(Automatic); }; - let memo = &mut Memoization::new(); + let memo = &mut cgraph::Memoization::new(); let result = graph.key_value_analysis( dirval!(CaptureMethod = Automatic), &AnalysisContext::from_dir_values([ @@ -1271,6 +604,8 @@ mod test { dirval!(PaymentMethod = PayLater), ]), memo, + &mut CycleCheck::new(), + None, ); assert!(result.is_err()); @@ -1278,10 +613,10 @@ mod test { #[test] fn test_all_aggregator_mandatory_failure() { - let graph = knowledge! {crate + let graph = knowledge! { PaymentMethod(Card) & PaymentMethod(not Wallet) ->> CaptureMethod(Automatic); }; - let mut memo = Memoization::new(); + let mut memo = cgraph::Memoization::new(); let result = graph.key_value_analysis( dirval!(CaptureMethod = Automatic), &AnalysisContext::from_dir_values([ @@ -1289,13 +624,15 @@ mod test { dirval!(PaymentMethod = PayLater), ]), &mut memo, + &mut CycleCheck::new(), + None, ); assert!(matches!( *Weak::upgrade(&result.unwrap_err().get_analysis_trace().unwrap()) .expect("Expected Arc"), - AnalysisTrace::Value { - predecessors: Some(ValueTracePredecessor::Mandatory(_)), + cgraph::AnalysisTrace::Value { + predecessors: Some(cgraph::error::ValueTracePredecessor::Mandatory(_)), .. } )); @@ -1303,10 +640,10 @@ mod test { #[test] fn test_in_aggregator_success() { - let graph = knowledge! {crate + let graph = knowledge! { PaymentMethod(in [Card, Wallet]) -> CaptureMethod(Automatic); }; - let memo = &mut Memoization::new(); + let memo = &mut cgraph::Memoization::new(); let result = graph.key_value_analysis( dirval!(CaptureMethod = Automatic), &AnalysisContext::from_dir_values([ @@ -1315,6 +652,8 @@ mod test { dirval!(PaymentMethod = Wallet), ]), memo, + &mut CycleCheck::new(), + None, ); assert!(result.is_ok()); @@ -1322,10 +661,10 @@ mod test { #[test] fn test_in_aggregator_failure() { - let graph = knowledge! {crate + let graph = knowledge! { PaymentMethod(in [Card, Wallet]) -> CaptureMethod(Automatic); }; - let memo = &mut Memoization::new(); + let memo = &mut cgraph::Memoization::new(); let result = graph.key_value_analysis( dirval!(CaptureMethod = Automatic), &AnalysisContext::from_dir_values([ @@ -1335,6 +674,8 @@ mod test { dirval!(PaymentMethod = PayLater), ]), memo, + &mut CycleCheck::new(), + None, ); assert!(result.is_err()); @@ -1342,10 +683,10 @@ mod test { #[test] fn test_not_in_aggregator_success() { - let graph = knowledge! {crate + let graph = knowledge! { PaymentMethod(not in [Card, Wallet]) ->> CaptureMethod(Automatic); }; - let memo = &mut Memoization::new(); + let memo = &mut cgraph::Memoization::new(); let result = graph.key_value_analysis( dirval!(CaptureMethod = Automatic), &AnalysisContext::from_dir_values([ @@ -1354,6 +695,8 @@ mod test { dirval!(PaymentMethod = BankRedirect), ]), memo, + &mut CycleCheck::new(), + None, ); assert!(result.is_ok()); @@ -1361,10 +704,10 @@ mod test { #[test] fn test_not_in_aggregator_failure() { - let graph = knowledge! {crate + let graph = knowledge! { PaymentMethod(not in [Card, Wallet]) ->> CaptureMethod(Automatic); }; - let memo = &mut Memoization::new(); + let memo = &mut cgraph::Memoization::new(); let result = graph.key_value_analysis( dirval!(CaptureMethod = Automatic), &AnalysisContext::from_dir_values([ @@ -1374,6 +717,8 @@ mod test { dirval!(PaymentMethod = Card), ]), memo, + &mut CycleCheck::new(), + None, ); assert!(result.is_err()); @@ -1381,10 +726,10 @@ mod test { #[test] fn test_in_aggregator_failure_trace() { - let graph = knowledge! {crate + let graph = knowledge! { PaymentMethod(in [Card, Wallet]) ->> CaptureMethod(Automatic); }; - let memo = &mut Memoization::new(); + let memo = &mut cgraph::Memoization::new(); let result = graph.key_value_analysis( dirval!(CaptureMethod = Automatic), &AnalysisContext::from_dir_values([ @@ -1394,10 +739,12 @@ mod test { dirval!(PaymentMethod = PayLater), ]), memo, + &mut CycleCheck::new(), + None, ); - if let AnalysisTrace::Value { - predecessors: Some(ValueTracePredecessor::Mandatory(agg_error)), + if let cgraph::AnalysisTrace::Value { + predecessors: Some(cgraph::error::ValueTracePredecessor::Mandatory(agg_error)), .. } = Weak::upgrade(&result.unwrap_err().get_analysis_trace().unwrap()) .expect("Expected arc") @@ -1405,7 +752,7 @@ mod test { { assert!(matches!( *Weak::upgrade(agg_error.deref()).expect("Expected Arc"), - AnalysisTrace::InAggregation { + cgraph::AnalysisTrace::InAggregation { found: Some(dir::DirValue::PaymentMethod(enums::PaymentMethod::PayLater)), .. } @@ -1416,43 +763,43 @@ mod test { } #[test] - fn _test_memoization_in_kgraph() { - let mut builder = KnowledgeGraphBuilder::new(); + fn test_memoization_in_kgraph() { + let mut builder = cgraph::ConstraintGraphBuilder::new(); let _node_1 = builder.make_value_node( - NodeValue::Value(dir::DirValue::PaymentMethod(enums::PaymentMethod::Wallet)), + cgraph::NodeValue::Value(dir::DirValue::PaymentMethod(enums::PaymentMethod::Wallet)), None, - Vec::new(), None::<()>, ); let _node_2 = builder.make_value_node( - NodeValue::Value(dir::DirValue::BillingCountry(enums::BillingCountry::India)), + cgraph::NodeValue::Value(dir::DirValue::BillingCountry(enums::BillingCountry::India)), None, - Vec::new(), None::<()>, ); let _node_3 = builder.make_value_node( - NodeValue::Value(dir::DirValue::BusinessCountry( + cgraph::NodeValue::Value(dir::DirValue::BusinessCountry( enums::BusinessCountry::UnitedStatesOfAmerica, )), None, - Vec::new(), None::<()>, ); - let mut memo = Memoization::new(); + let mut memo = cgraph::Memoization::new(); + let mut cycle_map = CycleCheck::new(); let _edge_1 = builder .make_edge( - _node_1.expect("node1 constructtion failed"), - _node_2.clone().expect("node2 construction failed"), - Strength::Strong, - Relation::Positive, + _node_1, + _node_2, + cgraph::Strength::Strong, + cgraph::Relation::Positive, + None::, ) .expect("Failed to make an edge"); let _edge_2 = builder .make_edge( - _node_2.expect("node2 construction failed"), - _node_3.clone().expect("node3 construction failed"), - Strength::Strong, - Relation::Positive, + _node_2, + _node_3, + cgraph::Strength::Strong, + cgraph::Relation::Positive, + None::, ) .expect("Failed to an edge"); let graph = builder.build(); @@ -1464,15 +811,263 @@ mod test { dirval!(BusinessCountry = UnitedStatesOfAmerica), ]), &mut memo, + &mut cycle_map, + None, ); let _answer = memo - .0 .get(&( - _node_3.expect("node3 construction failed"), - Relation::Positive, - Strength::Strong, + _node_3, + cgraph::Relation::Positive, + cgraph::Strength::Strong, )) .expect("Memoization not workng"); matches!(_answer, Ok(())); } + + #[test] + fn test_cycle_resolution_in_graph() { + let mut builder = cgraph::ConstraintGraphBuilder::new(); + let _node_1 = builder.make_value_node( + cgraph::NodeValue::Value(dir::DirValue::PaymentMethod(enums::PaymentMethod::Wallet)), + None, + None::<()>, + ); + let _node_2 = builder.make_value_node( + cgraph::NodeValue::Value(dir::DirValue::PaymentMethod(enums::PaymentMethod::Card)), + None, + None::<()>, + ); + let mut memo = cgraph::Memoization::new(); + let mut cycle_map = cgraph::CycleCheck::new(); + let _edge_1 = builder + .make_edge( + _node_1, + _node_2, + cgraph::Strength::Weak, + cgraph::Relation::Positive, + None::, + ) + .expect("Failed to make an edge"); + let _edge_2 = builder + .make_edge( + _node_2, + _node_1, + cgraph::Strength::Weak, + cgraph::Relation::Positive, + None::, + ) + .expect("Failed to an edge"); + let graph = builder.build(); + let _result = graph.key_value_analysis( + dirval!(PaymentMethod = Wallet), + &AnalysisContext::from_dir_values([ + dirval!(PaymentMethod = Wallet), + dirval!(PaymentMethod = Card), + ]), + &mut memo, + &mut cycle_map, + None, + ); + + assert!(_result.is_ok()); + } + + #[test] + fn test_cycle_resolution_in_graph1() { + let mut builder = cgraph::ConstraintGraphBuilder::new(); + let _node_1 = builder.make_value_node( + cgraph::NodeValue::Value(dir::DirValue::CaptureMethod( + enums::CaptureMethod::Automatic, + )), + None, + None::<()>, + ); + + let _node_2 = builder.make_value_node( + cgraph::NodeValue::Value(dir::DirValue::PaymentMethod(enums::PaymentMethod::Card)), + None, + None::<()>, + ); + let _node_3 = builder.make_value_node( + cgraph::NodeValue::Value(dir::DirValue::PaymentMethod(enums::PaymentMethod::Wallet)), + None, + None::<()>, + ); + let mut memo = cgraph::Memoization::new(); + let mut cycle_map = cgraph::CycleCheck::new(); + + let _edge_1 = builder + .make_edge( + _node_1, + _node_2, + cgraph::Strength::Weak, + cgraph::Relation::Positive, + None::, + ) + .expect("Failed to make an edge"); + let _edge_2 = builder + .make_edge( + _node_1, + _node_3, + cgraph::Strength::Weak, + cgraph::Relation::Positive, + None::, + ) + .expect("Failed to make an edge"); + let _edge_3 = builder + .make_edge( + _node_2, + _node_1, + cgraph::Strength::Weak, + cgraph::Relation::Positive, + None::, + ) + .expect("Failed to make an edge"); + let _edge_4 = builder + .make_edge( + _node_3, + _node_1, + cgraph::Strength::Strong, + cgraph::Relation::Positive, + None::, + ) + .expect("Failed to make an edge"); + + let graph = builder.build(); + let _result = graph.key_value_analysis( + dirval!(CaptureMethod = Automatic), + &AnalysisContext::from_dir_values([ + dirval!(PaymentMethod = Card), + dirval!(PaymentMethod = Wallet), + dirval!(CaptureMethod = Automatic), + ]), + &mut memo, + &mut cycle_map, + None, + ); + + assert!(_result.is_ok()); + } + + #[test] + fn test_cycle_resolution_in_graph2() { + let mut builder = cgraph::ConstraintGraphBuilder::new(); + let _node_0 = builder.make_value_node( + cgraph::NodeValue::Value(dir::DirValue::BillingCountry( + enums::BillingCountry::Afghanistan, + )), + None, + None::<()>, + ); + + let _node_1 = builder.make_value_node( + cgraph::NodeValue::Value(dir::DirValue::CaptureMethod( + enums::CaptureMethod::Automatic, + )), + None, + None::<()>, + ); + + let _node_2 = builder.make_value_node( + cgraph::NodeValue::Value(dir::DirValue::PaymentMethod(enums::PaymentMethod::Card)), + None, + None::<()>, + ); + let _node_3 = builder.make_value_node( + cgraph::NodeValue::Value(dir::DirValue::PaymentMethod(enums::PaymentMethod::Wallet)), + None, + None::<()>, + ); + + let _node_4 = builder.make_value_node( + cgraph::NodeValue::Value(dir::DirValue::PaymentCurrency(enums::PaymentCurrency::USD)), + None, + None::<()>, + ); + + let mut memo = cgraph::Memoization::new(); + let mut cycle_map = cgraph::CycleCheck::new(); + + let _edge_1 = builder + .make_edge( + _node_0, + _node_1, + cgraph::Strength::Weak, + cgraph::Relation::Positive, + None::, + ) + .expect("Failed to make an edge"); + let _edge_2 = builder + .make_edge( + _node_1, + _node_2, + cgraph::Strength::Normal, + cgraph::Relation::Positive, + None::, + ) + .expect("Failed to make an edge"); + let _edge_3 = builder + .make_edge( + _node_1, + _node_3, + cgraph::Strength::Weak, + cgraph::Relation::Positive, + None::, + ) + .expect("Failed to make an edge"); + let _edge_4 = builder + .make_edge( + _node_3, + _node_4, + cgraph::Strength::Normal, + cgraph::Relation::Positive, + None::, + ) + .expect("Failed to make an edge"); + let _edge_5 = builder + .make_edge( + _node_2, + _node_4, + cgraph::Strength::Normal, + cgraph::Relation::Positive, + None::, + ) + .expect("Failed to make an edge"); + + let _edge_6 = builder + .make_edge( + _node_4, + _node_1, + cgraph::Strength::Normal, + cgraph::Relation::Positive, + None::, + ) + .expect("Failed to make an edge"); + let _edge_7 = builder + .make_edge( + _node_4, + _node_0, + cgraph::Strength::Normal, + cgraph::Relation::Positive, + None::, + ) + .expect("Failed to make an edge"); + + let graph = builder.build(); + let _result = graph.key_value_analysis( + dirval!(BillingCountry = Afghanistan), + &AnalysisContext::from_dir_values([ + dirval!(PaymentCurrency = USD), + dirval!(PaymentMethod = Card), + dirval!(PaymentMethod = Wallet), + dirval!(CaptureMethod = Automatic), + dirval!(BillingCountry = Afghanistan), + ]), + &mut memo, + &mut cycle_map, + None, + ); + + assert!(_result.is_ok()); + } } diff --git a/crates/euclid/src/dssa/truth.rs b/crates/euclid/src/dssa/truth.rs index 17e6e728e6..388180eed5 100644 --- a/crates/euclid/src/dssa/truth.rs +++ b/crates/euclid/src/dssa/truth.rs @@ -1,29 +1,30 @@ use euclid_macros::knowledge; use once_cell::sync::Lazy; -use crate::dssa::graph; +use crate::{dssa::graph::euclid_graph_prelude, frontend::dir}; -pub static ANALYSIS_GRAPH: Lazy> = Lazy::new(|| { - knowledge! {crate - // Payment Method should be `Card` for a CardType to be present - PaymentMethod(Card) ->> CardType(any); +pub static ANALYSIS_GRAPH: Lazy> = + Lazy::new(|| { + knowledge! { + // Payment Method should be `Card` for a CardType to be present + PaymentMethod(Card) ->> CardType(any); - // Payment Method should be `PayLater` for a PayLaterType to be present - PaymentMethod(PayLater) ->> PayLaterType(any); + // Payment Method should be `PayLater` for a PayLaterType to be present + PaymentMethod(PayLater) ->> PayLaterType(any); - // Payment Method should be `Wallet` for a WalletType to be present - PaymentMethod(Wallet) ->> WalletType(any); + // Payment Method should be `Wallet` for a WalletType to be present + PaymentMethod(Wallet) ->> WalletType(any); - // Payment Method should be `BankRedirect` for a BankRedirectType to - // be present - PaymentMethod(BankRedirect) ->> BankRedirectType(any); + // Payment Method should be `BankRedirect` for a BankRedirectType to + // be present + PaymentMethod(BankRedirect) ->> BankRedirectType(any); - // Payment Method should be `BankTransfer` for a BankTransferType to - // be present - PaymentMethod(BankTransfer) ->> BankTransferType(any); + // Payment Method should be `BankTransfer` for a BankTransferType to + // be present + PaymentMethod(BankTransfer) ->> BankTransferType(any); - // Payment Method should be `GiftCard` for a GiftCardType to - // be present - PaymentMethod(GiftCard) ->> GiftCardType(any); - } -}); + // Payment Method should be `GiftCard` for a GiftCardType to + // be present + PaymentMethod(GiftCard) ->> GiftCardType(any); + } + }); diff --git a/crates/euclid/src/dssa/types.rs b/crates/euclid/src/dssa/types.rs index 4070e0825e..df54de2dd9 100644 --- a/crates/euclid/src/dssa/types.rs +++ b/crates/euclid/src/dssa/types.rs @@ -140,7 +140,10 @@ pub enum AnalysisErrorType { negation_metadata: Metadata, }, #[error("Graph analysis error: {0:#?}")] - GraphAnalysis(graph::AnalysisError, graph::Memoization), + GraphAnalysis( + graph::AnalysisError, + hyperswitch_constraint_graph::Memoization, + ), #[error("State machine error")] StateMachine(dssa::state_machine::StateMachineError), #[error("Unsupported program key '{0}'")] diff --git a/crates/euclid/src/lib.rs b/crates/euclid/src/lib.rs index d64297437a..261b3dc02d 100644 --- a/crates/euclid/src/lib.rs +++ b/crates/euclid/src/lib.rs @@ -4,4 +4,3 @@ pub mod dssa; pub mod enums; pub mod frontend; pub mod types; -pub mod utils; diff --git a/crates/euclid/src/utils.rs b/crates/euclid/src/utils.rs deleted file mode 100644 index e8cb7901f0..0000000000 --- a/crates/euclid/src/utils.rs +++ /dev/null @@ -1,3 +0,0 @@ -pub mod dense_map; - -pub use dense_map::{DenseMap, EntityId}; diff --git a/crates/euclid_macros/src/inner/knowledge.rs b/crates/euclid_macros/src/inner/knowledge.rs index 9f33a6871c..a9c453b42c 100644 --- a/crates/euclid_macros/src/inner/knowledge.rs +++ b/crates/euclid_macros/src/inner/knowledge.rs @@ -329,19 +329,17 @@ impl ToString for Scope { #[derive(Clone)] struct Program { rules: Vec>, - scope: Scope, } impl Parse for Program { fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result { - let scope: Scope = input.parse()?; let mut rules: Vec> = Vec::new(); while !input.is_empty() { rules.push(Rc::new(input.parse::()?)); } - Ok(Self { rules, scope }) + Ok(Self { rules }) } } @@ -502,12 +500,12 @@ impl GenContext { let key = format_ident!("{}", &atom.key); let the_value = match &atom.value { ValueType::Any => quote! { - NodeValue::Key(DirKey::new(DirKeyKind::#key,None)) + cgraph::NodeValue::Key(DirKey::new(DirKeyKind::#key,None)) }, ValueType::EnumVariant(variant) => { let variant = format_ident!("{}", variant); quote! { - NodeValue::Value(DirValue::#key(#key::#variant)) + cgraph::NodeValue::Value(DirValue::#key(#key::#variant)) } } ValueType::Number { number, comparison } => { @@ -530,7 +528,7 @@ impl GenContext { }; quote! { - NodeValue::Value(DirValue::#key(NumValue { + cgraph::NodeValue::Value(DirValue::#key(NumValue { number: #number, refinement: #comp_type, })) @@ -539,7 +537,7 @@ impl GenContext { }; let compiled = quote! { - let #identifier = graph.make_value_node(#the_value, None, Vec::new(), None::<()>).expect("NodeId derivation failed"); + let #identifier = graph.make_value_node(#the_value, None, None::<()>); }; tokens.extend(compiled); @@ -581,7 +579,6 @@ impl GenContext { Vec::from_iter([#(#values_tokens),*]), None, None::<()>, - Vec::new(), ).expect("Failed to make In aggregator"); }; @@ -606,7 +603,7 @@ impl GenContext { for (from_node, relation) in &node_details { let relation = format_ident!("{}", relation.to_string()); tokens.extend(quote! { - graph.make_edge(#from_node, #rhs_ident, Strength::#strength, Relation::#relation) + graph.make_edge(#from_node, #rhs_ident, cgraph::Strength::#strength, cgraph::Relation::#relation, None::) .expect("Failed to make edge"); }); } @@ -614,16 +611,18 @@ impl GenContext { let mut all_agg_nodes: Vec = Vec::with_capacity(node_details.len()); for (from_node, relation) in &node_details { let relation = format_ident!("{}", relation.to_string()); - all_agg_nodes.push(quote! { (#from_node, Relation::#relation, Strength::Strong) }); + all_agg_nodes.push( + quote! { (#from_node, cgraph::Relation::#relation, cgraph::Strength::Strong) }, + ); } let strength = format_ident!("{}", rule.strength.to_string()); let (agg_node_ident, _) = self.next_node_ident(); tokens.extend(quote! { - let #agg_node_ident = graph.make_all_aggregator(&[#(#all_agg_nodes),*], None, None::<()>, Vec::new()) + let #agg_node_ident = graph.make_all_aggregator(&[#(#all_agg_nodes),*], None, None::<()>, None) .expect("Failed to make all aggregator node"); - graph.make_edge(#agg_node_ident, #rhs_ident, Strength::#strength, Relation::Positive) + graph.make_edge(#agg_node_ident, #rhs_ident, cgraph::Strength::#strength, cgraph::Relation::Positive, None::) .expect("Failed to create all aggregator edge"); }); @@ -638,21 +637,10 @@ impl GenContext { self.compile_rule(rule, &mut tokens)?; } - let scope = match &program.scope { - Scope::Crate => quote! { crate }, - Scope::Extern => quote! { euclid }, - }; - let compiled = quote! {{ - use #scope::{ - dssa::graph::*, - types::*, - frontend::dir::{*, enums::*}, - }; + use euclid_graph_prelude::*; - use rustc_hash::{FxHashMap, FxHashSet}; - - let mut graph = KnowledgeGraphBuilder::new(); + let mut graph = cgraph::ConstraintGraphBuilder::new(); #tokens diff --git a/crates/euclid_wasm/Cargo.toml b/crates/euclid_wasm/Cargo.toml index 6f5d3ec9cc..293940f27c 100644 --- a/crates/euclid_wasm/Cargo.toml +++ b/crates/euclid_wasm/Cargo.toml @@ -23,6 +23,7 @@ payouts = ["api_models/payouts", "euclid/payouts"] [dependencies] api_models = { version = "0.1.0", path = "../api_models", package = "api_models" } +hyperswitch_constraint_graph = { version = "0.1.0", path = "../hyperswitch_constraint_graph" } currency_conversion = { version = "0.1.0", path = "../currency_conversion" } connector_configs = { version = "0.1.0", path = "../connector_configs" } euclid = { version = "0.1.0", path = "../euclid", features = [] } diff --git a/crates/euclid_wasm/src/lib.rs b/crates/euclid_wasm/src/lib.rs index 36af0dc2d2..4920243bcc 100644 --- a/crates/euclid_wasm/src/lib.rs +++ b/crates/euclid_wasm/src/lib.rs @@ -20,11 +20,7 @@ use currency_conversion::{ }; use euclid::{ backend::{inputs, interpreter::InterpreterBackend, EuclidBackend}, - dssa::{ - self, analyzer, - graph::{self, Memoization}, - state_machine, truth, - }, + dssa::{self, analyzer, graph::CgraphExt, state_machine, truth}, frontend::{ ast, dir::{self, enums as dir_enums, EuclidDirFilter}, @@ -38,7 +34,7 @@ use crate::utils::JsResultExt; type JsResult = Result; struct SeedData<'a> { - kgraph: graph::KnowledgeGraph<'a>, + cgraph: hyperswitch_constraint_graph::ConstraintGraph<'a, dir::DirValue>, connectors: Vec, } @@ -98,11 +94,12 @@ pub fn seed_knowledge_graph(mcas: JsValue) -> JsResult { let mca_graph = kgraph_utils::mca::make_mca_graph(mcas).err_to_js()?; let analysis_graph = - graph::KnowledgeGraph::combine(&mca_graph, &truth::ANALYSIS_GRAPH).err_to_js()?; + hyperswitch_constraint_graph::ConstraintGraph::combine(&mca_graph, &truth::ANALYSIS_GRAPH) + .err_to_js()?; SEED_DATA .set(SeedData { - kgraph: analysis_graph, + cgraph: analysis_graph, connectors, }) .map_err(|_| "Knowledge Graph has been already seeded".to_string()) @@ -138,8 +135,12 @@ pub fn get_valid_connectors_for_rule(rule: JsValue) -> JsResult { // Standalone conjunctive context analysis to ensure the context itself is valid before // checking it against merchant's connectors seed_data - .kgraph - .perform_context_analysis(ctx, &mut Memoization::new()) + .cgraph + .perform_context_analysis( + ctx, + &mut hyperswitch_constraint_graph::Memoization::new(), + None, + ) .err_to_js()?; // Update conjunctive context and run analysis on all of merchant's connectors. @@ -150,9 +151,11 @@ pub fn get_valid_connectors_for_rule(rule: JsValue) -> JsResult { let ctx_val = dssa::types::ContextValue::assertion(choice, &dummy_meta); ctx.push(ctx_val); - let analysis_result = seed_data - .kgraph - .perform_context_analysis(ctx, &mut Memoization::new()); + let analysis_result = seed_data.cgraph.perform_context_analysis( + ctx, + &mut hyperswitch_constraint_graph::Memoization::new(), + None, + ); if analysis_result.is_err() { invalid_connectors.insert(conn.clone()); } @@ -171,7 +174,7 @@ pub fn get_valid_connectors_for_rule(rule: JsValue) -> JsResult { #[wasm_bindgen(js_name = analyzeProgram)] pub fn analyze_program(js_program: JsValue) -> JsResult { let program: ast::Program = serde_wasm_bindgen::from_value(js_program)?; - analyzer::analyze(program, SEED_DATA.get().map(|sd| &sd.kgraph)).err_to_js()?; + analyzer::analyze(program, SEED_DATA.get().map(|sd| &sd.cgraph)).err_to_js()?; Ok(JsValue::NULL) } diff --git a/crates/hyperswitch_constraint_graph/Cargo.toml b/crates/hyperswitch_constraint_graph/Cargo.toml new file mode 100644 index 0000000000..425855a05b --- /dev/null +++ b/crates/hyperswitch_constraint_graph/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "hyperswitch_constraint_graph" +description = "Constraint Graph Framework for modeling Domain-Specific Constraints" +version = "0.1.0" +edition.workspace = true +rust-version.workspace = true + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +erased-serde = "0.3.28" +rustc-hash = "1.1.0" +serde = { version = "1.0.163", features = ["derive", "rc"] } +serde_json = "1.0.96" +strum = { version = "0.25", features = ["derive"] } +thiserror = "1.0.43" diff --git a/crates/hyperswitch_constraint_graph/src/builder.rs b/crates/hyperswitch_constraint_graph/src/builder.rs new file mode 100644 index 0000000000..c1343eff88 --- /dev/null +++ b/crates/hyperswitch_constraint_graph/src/builder.rs @@ -0,0 +1,283 @@ +use std::sync::Arc; + +use rustc_hash::{FxHashMap, FxHashSet}; + +use crate::{ + dense_map::DenseMap, + error::GraphError, + graph::ConstraintGraph, + types::{ + DomainId, DomainIdentifier, DomainInfo, Edge, EdgeId, Metadata, Node, NodeId, NodeType, + NodeValue, Relation, Strength, ValueNode, + }, +}; + +pub enum DomainIdOrIdentifier<'a> { + DomainId(DomainId), + DomainIdentifier(DomainIdentifier<'a>), +} + +impl<'a> From<&'a str> for DomainIdOrIdentifier<'a> { + fn from(value: &'a str) -> Self { + Self::DomainIdentifier(DomainIdentifier::new(value)) + } +} + +impl From for DomainIdOrIdentifier<'_> { + fn from(value: DomainId) -> Self { + Self::DomainId(value) + } +} + +pub struct ConstraintGraphBuilder<'a, V: ValueNode> { + domain: DenseMap>, + nodes: DenseMap>, + edges: DenseMap, + domain_identifier_map: FxHashMap, DomainId>, + value_map: FxHashMap, NodeId>, + edges_map: FxHashMap<(NodeId, NodeId, Option), EdgeId>, + node_info: DenseMap>, + node_metadata: DenseMap>>, +} + +#[allow(clippy::new_without_default)] +impl<'a, V> ConstraintGraphBuilder<'a, V> +where + V: ValueNode, +{ + pub fn new() -> Self { + Self { + domain: DenseMap::new(), + nodes: DenseMap::new(), + edges: DenseMap::new(), + domain_identifier_map: FxHashMap::default(), + value_map: FxHashMap::default(), + edges_map: FxHashMap::default(), + node_info: DenseMap::new(), + node_metadata: DenseMap::new(), + } + } + + pub fn build(self) -> ConstraintGraph<'a, V> { + ConstraintGraph { + domain: self.domain, + domain_identifier_map: self.domain_identifier_map, + nodes: self.nodes, + edges: self.edges, + value_map: self.value_map, + node_info: self.node_info, + node_metadata: self.node_metadata, + } + } + + fn retrieve_domain_from_identifier( + &self, + domain_ident: DomainIdentifier<'_>, + ) -> Result> { + self.domain_identifier_map + .get(&domain_ident) + .copied() + .ok_or(GraphError::DomainNotFound) + } + + pub fn make_domain( + &mut self, + domain_identifier: &'a str, + domain_description: &str, + ) -> Result> { + let domain_identifier = DomainIdentifier::new(domain_identifier); + Ok(self + .domain_identifier_map + .clone() + .get(&domain_identifier) + .map_or_else( + || { + let domain_id = self.domain.push(DomainInfo { + domain_identifier, + domain_description: domain_description.to_string(), + }); + self.domain_identifier_map + .insert(domain_identifier, domain_id); + domain_id + }, + |domain_id| *domain_id, + )) + } + + pub fn make_value_node( + &mut self, + value: NodeValue, + info: Option<&'static str>, + metadata: Option, + ) -> NodeId { + self.value_map.get(&value).copied().unwrap_or_else(|| { + let node_id = self.nodes.push(Node::new(NodeType::Value(value.clone()))); + let _node_info_id = self.node_info.push(info); + + let _node_metadata_id = self + .node_metadata + .push(metadata.map(|meta| -> Arc { Arc::new(meta) })); + + self.value_map.insert(value, node_id); + node_id + }) + } + + pub fn make_edge<'short, T: Into>>( + &mut self, + pred_id: NodeId, + succ_id: NodeId, + strength: Strength, + relation: Relation, + domain: Option, + ) -> Result> { + self.ensure_node_exists(pred_id)?; + self.ensure_node_exists(succ_id)?; + let domain_id = domain + .map(|d| match d.into() { + DomainIdOrIdentifier::DomainIdentifier(ident) => { + self.retrieve_domain_from_identifier(ident) + } + DomainIdOrIdentifier::DomainId(domain_id) => { + self.ensure_domain_exists(domain_id).map(|_| domain_id) + } + }) + .transpose()?; + self.edges_map + .get(&(pred_id, succ_id, domain_id)) + .copied() + .and_then(|edge_id| self.edges.get(edge_id).cloned().map(|edge| (edge_id, edge))) + .map_or_else( + || { + let edge_id = self.edges.push(Edge { + strength, + relation, + pred: pred_id, + succ: succ_id, + domain: domain_id, + }); + self.edges_map + .insert((pred_id, succ_id, domain_id), edge_id); + + let pred = self + .nodes + .get_mut(pred_id) + .ok_or(GraphError::NodeNotFound)?; + pred.succs.push(edge_id); + + let succ = self + .nodes + .get_mut(succ_id) + .ok_or(GraphError::NodeNotFound)?; + succ.preds.push(edge_id); + + Ok(edge_id) + }, + |(edge_id, edge)| { + if edge.strength == strength && edge.relation == relation { + Ok(edge_id) + } else { + Err(GraphError::ConflictingEdgeCreated) + } + }, + ) + } + + pub fn make_all_aggregator( + &mut self, + nodes: &[(NodeId, Relation, Strength)], + info: Option<&'static str>, + metadata: Option, + domain: Option<&str>, + ) -> Result> { + nodes + .iter() + .try_for_each(|(node_id, _, _)| self.ensure_node_exists(*node_id))?; + + let aggregator_id = self.nodes.push(Node::new(NodeType::AllAggregator)); + let _aggregator_info_id = self.node_info.push(info); + + let _node_metadata_id = self + .node_metadata + .push(metadata.map(|meta| -> Arc { Arc::new(meta) })); + + for (node_id, relation, strength) in nodes { + self.make_edge(*node_id, aggregator_id, *strength, *relation, domain)?; + } + + Ok(aggregator_id) + } + + pub fn make_any_aggregator( + &mut self, + nodes: &[(NodeId, Relation, Strength)], + info: Option<&'static str>, + metadata: Option, + domain: Option<&str>, + ) -> Result> { + nodes + .iter() + .try_for_each(|(node_id, _, _)| self.ensure_node_exists(*node_id))?; + + let aggregator_id = self.nodes.push(Node::new(NodeType::AnyAggregator)); + let _aggregator_info_id = self.node_info.push(info); + + let _node_metadata_id = self + .node_metadata + .push(metadata.map(|meta| -> Arc { Arc::new(meta) })); + + for (node_id, relation, strength) in nodes { + self.make_edge(*node_id, aggregator_id, *strength, *relation, domain)?; + } + + Ok(aggregator_id) + } + + pub fn make_in_aggregator( + &mut self, + values: Vec, + info: Option<&'static str>, + metadata: Option, + ) -> Result> { + let key = values + .first() + .ok_or(GraphError::NoInAggregatorValues)? + .get_key(); + + for val in &values { + if val.get_key() != key { + Err(GraphError::MalformedGraph { + reason: "Values for 'In' aggregator not of same key".to_string(), + })?; + } + } + let node_id = self + .nodes + .push(Node::new(NodeType::InAggregator(FxHashSet::from_iter( + values, + )))); + let _aggregator_info_id = self.node_info.push(info); + + let _node_metadata_id = self + .node_metadata + .push(metadata.map(|meta| -> Arc { Arc::new(meta) })); + + Ok(node_id) + } + + fn ensure_node_exists(&self, id: NodeId) -> Result<(), GraphError> { + if self.nodes.contains_key(id) { + Ok(()) + } else { + Err(GraphError::NodeNotFound) + } + } + + fn ensure_domain_exists(&self, id: DomainId) -> Result<(), GraphError> { + if self.domain.contains_key(id) { + Ok(()) + } else { + Err(GraphError::DomainNotFound) + } + } +} diff --git a/crates/euclid/src/utils/dense_map.rs b/crates/hyperswitch_constraint_graph/src/dense_map.rs similarity index 92% rename from crates/euclid/src/utils/dense_map.rs rename to crates/hyperswitch_constraint_graph/src/dense_map.rs index 8bd4487c77..682833d65e 100644 --- a/crates/euclid/src/utils/dense_map.rs +++ b/crates/hyperswitch_constraint_graph/src/dense_map.rs @@ -5,6 +5,24 @@ pub trait EntityId { fn with_id(id: usize) -> Self; } +macro_rules! impl_entity { + ($name:ident) => { + impl $crate::dense_map::EntityId for $name { + #[inline] + fn get_id(&self) -> usize { + self.0 + } + + #[inline] + fn with_id(id: usize) -> Self { + Self(id) + } + } + }; +} + +pub(crate) use impl_entity; + pub struct DenseMap { data: Vec, _marker: PhantomData, diff --git a/crates/hyperswitch_constraint_graph/src/error.rs b/crates/hyperswitch_constraint_graph/src/error.rs new file mode 100644 index 0000000000..cd2269de26 --- /dev/null +++ b/crates/hyperswitch_constraint_graph/src/error.rs @@ -0,0 +1,77 @@ +use std::sync::{Arc, Weak}; + +use crate::types::{Metadata, NodeValue, Relation, RelationResolution, ValueNode}; + +#[derive(Debug, Clone, serde::Serialize)] +#[serde(tag = "type", content = "predecessor", rename_all = "snake_case")] +pub enum ValueTracePredecessor { + Mandatory(Box>>), + OneOf(Vec>>), +} + +#[derive(Debug, Clone, serde::Serialize)] +#[serde(tag = "type", content = "trace", rename_all = "snake_case")] +pub enum AnalysisTrace { + Value { + value: NodeValue, + relation: Relation, + predecessors: Option>, + info: Option<&'static str>, + metadata: Option>, + }, + + AllAggregation { + unsatisfied: Vec>>, + info: Option<&'static str>, + metadata: Option>, + }, + + AnyAggregation { + unsatisfied: Vec>>, + info: Option<&'static str>, + metadata: Option>, + }, + + InAggregation { + expected: Vec, + found: Option, + relation: Relation, + info: Option<&'static str>, + metadata: Option>, + }, + Contradiction { + relation: RelationResolution, + }, +} + +#[derive(Debug, Clone, serde::Serialize, thiserror::Error)] +#[serde(tag = "type", content = "info", rename_all = "snake_case")] +pub enum GraphError { + #[error("An edge was not found in the graph")] + EdgeNotFound, + #[error("Attempted to create a conflicting edge between two nodes")] + ConflictingEdgeCreated, + #[error("Cycle detected in graph")] + CycleDetected, + #[error("Domain wasn't found in the Graph")] + DomainNotFound, + #[error("Malformed Graph: {reason}")] + MalformedGraph { reason: String }, + #[error("A node was not found in the graph")] + NodeNotFound, + #[error("A value node was not found: {0:#?}")] + ValueNodeNotFound(V), + #[error("No values provided for an 'in' aggregator node")] + NoInAggregatorValues, + #[error("Error during analysis: {0:#?}")] + AnalysisError(Weak>), +} + +impl GraphError { + pub fn get_analysis_trace(self) -> Result>, Self> { + match self { + Self::AnalysisError(trace) => Ok(trace), + _ => Err(self), + } + } +} diff --git a/crates/hyperswitch_constraint_graph/src/graph.rs b/crates/hyperswitch_constraint_graph/src/graph.rs new file mode 100644 index 0000000000..d0a98e1952 --- /dev/null +++ b/crates/hyperswitch_constraint_graph/src/graph.rs @@ -0,0 +1,587 @@ +use std::sync::{Arc, Weak}; + +use rustc_hash::{FxHashMap, FxHashSet}; + +use crate::{ + builder, + dense_map::DenseMap, + error::{self, AnalysisTrace, GraphError}, + types::{ + CheckingContext, CycleCheck, DomainId, DomainIdentifier, DomainInfo, Edge, EdgeId, + Memoization, Metadata, Node, NodeId, NodeType, NodeValue, Relation, RelationResolution, + Strength, ValueNode, + }, +}; + +struct CheckNodeContext<'a, V: ValueNode, C: CheckingContext> { + ctx: &'a C, + node: &'a Node, + node_id: NodeId, + relation: Relation, + strength: Strength, + memo: &'a mut Memoization, + cycle_map: &'a mut CycleCheck, + domains: Option<&'a [DomainId]>, +} + +pub struct ConstraintGraph<'a, V: ValueNode> { + pub domain: DenseMap>, + pub domain_identifier_map: FxHashMap, DomainId>, + pub nodes: DenseMap>, + pub edges: DenseMap, + pub value_map: FxHashMap, NodeId>, + pub node_info: DenseMap>, + pub node_metadata: DenseMap>>, +} + +impl<'a, V> ConstraintGraph<'a, V> +where + V: ValueNode, +{ + fn get_predecessor_edges_by_domain( + &self, + node_id: NodeId, + domains: Option<&[DomainId]>, + ) -> Result, GraphError> { + let node = self.nodes.get(node_id).ok_or(GraphError::NodeNotFound)?; + let mut final_list = Vec::new(); + for &pred in &node.preds { + let edge = self.edges.get(pred).ok_or(GraphError::EdgeNotFound)?; + if let Some((domain_id, domains)) = edge.domain.zip(domains) { + if domains.contains(&domain_id) { + final_list.push(edge); + } + } else if edge.domain.is_none() { + final_list.push(edge); + } + } + + Ok(final_list) + } + + #[allow(clippy::too_many_arguments)] + pub fn check_node( + &self, + ctx: &C, + node_id: NodeId, + relation: Relation, + strength: Strength, + memo: &mut Memoization, + cycle_map: &mut CycleCheck, + domains: Option<&[&str]>, + ) -> Result<(), GraphError> + where + C: CheckingContext, + { + let domains = domains + .map(|domain_idents| { + domain_idents + .iter() + .map(|domain_ident| { + self.domain_identifier_map + .get(&DomainIdentifier::new(domain_ident)) + .copied() + .ok_or(GraphError::DomainNotFound) + }) + .collect::, _>>() + }) + .transpose()?; + + self.check_node_inner( + ctx, + node_id, + relation, + strength, + memo, + cycle_map, + domains.as_deref(), + ) + } + + #[allow(clippy::too_many_arguments)] + pub fn check_node_inner( + &self, + ctx: &C, + node_id: NodeId, + relation: Relation, + strength: Strength, + memo: &mut Memoization, + cycle_map: &mut CycleCheck, + domains: Option<&[DomainId]>, + ) -> Result<(), GraphError> + where + C: CheckingContext, + { + let node = self.nodes.get(node_id).ok_or(GraphError::NodeNotFound)?; + + if let Some(already_memo) = memo.get(&(node_id, relation, strength)) { + already_memo + .clone() + .map_err(|err| GraphError::AnalysisError(Arc::downgrade(&err))) + } else if let Some((initial_strength, initial_relation)) = cycle_map.get(&node_id).cloned() + { + let strength_relation = Strength::get_resolved_strength(initial_strength, strength); + let relation_resolve = + RelationResolution::get_resolved_relation(initial_relation, relation.into()); + cycle_map.entry(node_id).and_modify(|value| { + value.0 = strength_relation; + value.1 = relation_resolve + }); + Ok(()) + } else { + let check_node_context = CheckNodeContext { + node, + node_id, + relation, + strength, + memo, + cycle_map, + ctx, + domains, + }; + match &node.node_type { + NodeType::AllAggregator => self.validate_all_aggregator(check_node_context), + + NodeType::AnyAggregator => self.validate_any_aggregator(check_node_context), + + NodeType::InAggregator(expected) => { + self.validate_in_aggregator(check_node_context, expected) + } + NodeType::Value(val) => self.validate_value_node(check_node_context, val), + } + } + } + + fn validate_all_aggregator( + &self, + vald: CheckNodeContext<'_, V, C>, + ) -> Result<(), GraphError> + where + C: CheckingContext, + { + let mut unsatisfied = Vec::>>::new(); + + for edge in self.get_predecessor_edges_by_domain(vald.node_id, vald.domains)? { + vald.cycle_map + .insert(vald.node_id, (vald.strength, vald.relation.into())); + if let Err(e) = self.check_node_inner( + vald.ctx, + edge.pred, + edge.relation, + edge.strength, + vald.memo, + vald.cycle_map, + vald.domains, + ) { + unsatisfied.push(e.get_analysis_trace()?); + } + if let Some((_resolved_strength, resolved_relation)) = + vald.cycle_map.remove(&vald.node_id) + { + if resolved_relation == RelationResolution::Contradiction { + let err = Arc::new(AnalysisTrace::Contradiction { + relation: resolved_relation, + }); + vald.memo.insert( + (vald.node_id, vald.relation, vald.strength), + Err(Arc::clone(&err)), + ); + return Err(GraphError::AnalysisError(Arc::downgrade(&err))); + } + } + } + + if !unsatisfied.is_empty() { + let err = Arc::new(AnalysisTrace::AllAggregation { + unsatisfied, + info: self.node_info.get(vald.node_id).cloned().flatten(), + metadata: self.node_metadata.get(vald.node_id).cloned().flatten(), + }); + + vald.memo.insert( + (vald.node_id, vald.relation, vald.strength), + Err(Arc::clone(&err)), + ); + Err(GraphError::AnalysisError(Arc::downgrade(&err))) + } else { + vald.memo + .insert((vald.node_id, vald.relation, vald.strength), Ok(())); + Ok(()) + } + } + + fn validate_any_aggregator( + &self, + vald: CheckNodeContext<'_, V, C>, + ) -> Result<(), GraphError> + where + C: CheckingContext, + { + let mut unsatisfied = Vec::>>::new(); + let mut matched_one = false; + + for edge in self.get_predecessor_edges_by_domain(vald.node_id, vald.domains)? { + vald.cycle_map + .insert(vald.node_id, (vald.strength, vald.relation.into())); + if let Err(e) = self.check_node_inner( + vald.ctx, + edge.pred, + edge.relation, + edge.strength, + vald.memo, + vald.cycle_map, + vald.domains, + ) { + unsatisfied.push(e.get_analysis_trace()?); + } else { + matched_one = true; + } + if let Some((_resolved_strength, resolved_relation)) = + vald.cycle_map.remove(&vald.node_id) + { + if resolved_relation == RelationResolution::Contradiction { + let err = Arc::new(AnalysisTrace::Contradiction { + relation: resolved_relation, + }); + vald.memo.insert( + (vald.node_id, vald.relation, vald.strength), + Err(Arc::clone(&err)), + ); + + return Err(GraphError::AnalysisError(Arc::downgrade(&err))); + } + } + } + + if matched_one || vald.node.preds.is_empty() { + vald.memo + .insert((vald.node_id, vald.relation, vald.strength), Ok(())); + Ok(()) + } else { + let err = Arc::new(AnalysisTrace::AnyAggregation { + unsatisfied: unsatisfied.clone(), + info: self.node_info.get(vald.node_id).cloned().flatten(), + metadata: self.node_metadata.get(vald.node_id).cloned().flatten(), + }); + + vald.memo.insert( + (vald.node_id, vald.relation, vald.strength), + Err(Arc::clone(&err)), + ); + Err(GraphError::AnalysisError(Arc::downgrade(&err))) + } + } + + fn validate_in_aggregator( + &self, + vald: CheckNodeContext<'_, V, C>, + expected: &FxHashSet, + ) -> Result<(), GraphError> + where + C: CheckingContext, + { + let the_key = expected + .iter() + .next() + .ok_or_else(|| GraphError::MalformedGraph { + reason: "An OnlyIn aggregator node must have at least one expected value" + .to_string(), + })? + .get_key(); + + let ctx_vals = if let Some(vals) = vald.ctx.get_values_by_key(&the_key) { + vals + } else { + return if let Strength::Weak = vald.strength { + vald.memo + .insert((vald.node_id, vald.relation, vald.strength), Ok(())); + Ok(()) + } else { + let err = Arc::new(AnalysisTrace::InAggregation { + expected: expected.iter().cloned().collect(), + found: None, + relation: vald.relation, + info: self.node_info.get(vald.node_id).cloned().flatten(), + metadata: self.node_metadata.get(vald.node_id).cloned().flatten(), + }); + + vald.memo.insert( + (vald.node_id, vald.relation, vald.strength), + Err(Arc::clone(&err)), + ); + Err(GraphError::AnalysisError(Arc::downgrade(&err))) + }; + }; + + let relation_bool: bool = vald.relation.into(); + for ctx_value in ctx_vals { + if expected.contains(&ctx_value) != relation_bool { + let err = Arc::new(AnalysisTrace::InAggregation { + expected: expected.iter().cloned().collect(), + found: Some(ctx_value.clone()), + relation: vald.relation, + info: self.node_info.get(vald.node_id).cloned().flatten(), + metadata: self.node_metadata.get(vald.node_id).cloned().flatten(), + }); + + vald.memo.insert( + (vald.node_id, vald.relation, vald.strength), + Err(Arc::clone(&err)), + ); + Err(GraphError::AnalysisError(Arc::downgrade(&err)))?; + } + } + + vald.memo + .insert((vald.node_id, vald.relation, vald.strength), Ok(())); + Ok(()) + } + + fn validate_value_node( + &self, + vald: CheckNodeContext<'_, V, C>, + val: &NodeValue, + ) -> Result<(), GraphError> + where + C: CheckingContext, + { + let mut errors = Vec::>>::new(); + let mut matched_one = false; + + self.context_analysis( + vald.node_id, + vald.relation, + vald.strength, + vald.ctx, + val, + vald.memo, + )?; + + for edge in self.get_predecessor_edges_by_domain(vald.node_id, vald.domains)? { + vald.cycle_map + .insert(vald.node_id, (vald.strength, vald.relation.into())); + let result = self.check_node_inner( + vald.ctx, + edge.pred, + edge.relation, + edge.strength, + vald.memo, + vald.cycle_map, + vald.domains, + ); + + if let Some((resolved_strength, resolved_relation)) = + vald.cycle_map.remove(&vald.node_id) + { + if resolved_relation == RelationResolution::Contradiction { + let err = Arc::new(AnalysisTrace::Contradiction { + relation: resolved_relation, + }); + vald.memo.insert( + (vald.node_id, vald.relation, vald.strength), + Err(Arc::clone(&err)), + ); + return Err(GraphError::AnalysisError(Arc::downgrade(&err))); + } else if resolved_strength != vald.strength { + self.context_analysis( + vald.node_id, + vald.relation, + resolved_strength, + vald.ctx, + val, + vald.memo, + )? + } + } + match (edge.strength, result) { + (Strength::Strong, Err(trace)) => { + let err = Arc::new(AnalysisTrace::Value { + value: val.clone(), + relation: vald.relation, + info: self.node_info.get(vald.node_id).cloned().flatten(), + metadata: self.node_metadata.get(vald.node_id).cloned().flatten(), + predecessors: Some(error::ValueTracePredecessor::Mandatory(Box::new( + trace.get_analysis_trace()?, + ))), + }); + vald.memo.insert( + (vald.node_id, vald.relation, vald.strength), + Err(Arc::clone(&err)), + ); + Err(GraphError::AnalysisError(Arc::downgrade(&err)))?; + } + + (Strength::Strong, Ok(_)) => { + matched_one = true; + } + + (Strength::Normal | Strength::Weak, Err(trace)) => { + errors.push(trace.get_analysis_trace()?); + } + + (Strength::Normal | Strength::Weak, Ok(_)) => { + matched_one = true; + } + } + } + + if matched_one || vald.node.preds.is_empty() { + vald.memo + .insert((vald.node_id, vald.relation, vald.strength), Ok(())); + Ok(()) + } else { + let err = Arc::new(AnalysisTrace::Value { + value: val.clone(), + relation: vald.relation, + info: self.node_info.get(vald.node_id).cloned().flatten(), + metadata: self.node_metadata.get(vald.node_id).cloned().flatten(), + predecessors: Some(error::ValueTracePredecessor::OneOf(errors.clone())), + }); + + vald.memo.insert( + (vald.node_id, vald.relation, vald.strength), + Err(Arc::clone(&err)), + ); + Err(GraphError::AnalysisError(Arc::downgrade(&err))) + } + } + + fn context_analysis( + &self, + node_id: NodeId, + relation: Relation, + strength: Strength, + ctx: &C, + val: &NodeValue, + memo: &mut Memoization, + ) -> Result<(), GraphError> + where + C: CheckingContext, + { + let in_context = ctx.check_presence(val, strength); + let relation_bool: bool = relation.into(); + if in_context != relation_bool { + let err = Arc::new(AnalysisTrace::Value { + value: val.clone(), + relation, + predecessors: None, + info: self.node_info.get(node_id).cloned().flatten(), + metadata: self.node_metadata.get(node_id).cloned().flatten(), + }); + memo.insert((node_id, relation, strength), Err(Arc::clone(&err))); + Err(GraphError::AnalysisError(Arc::downgrade(&err)))?; + } + if !relation_bool { + memo.insert((node_id, relation, strength), Ok(())); + return Ok(()); + } + Ok(()) + } + + pub fn combine<'b>(g1: &'b Self, g2: &'b Self) -> Result> { + let mut node_builder = builder::ConstraintGraphBuilder::new(); + let mut g1_old2new_id = DenseMap::::new(); + let mut g2_old2new_id = DenseMap::::new(); + let mut g1_old2new_domain_id = DenseMap::::new(); + let mut g2_old2new_domain_id = DenseMap::::new(); + + let add_domain = |node_builder: &mut builder::ConstraintGraphBuilder<'a, V>, + domain: DomainInfo<'a>| + -> Result> { + node_builder.make_domain( + domain.domain_identifier.into_inner(), + &domain.domain_description, + ) + }; + + let add_node = |node_builder: &mut builder::ConstraintGraphBuilder<'a, V>, + node: &Node| + -> Result> { + match &node.node_type { + NodeType::Value(node_value) => { + Ok(node_builder.make_value_node(node_value.clone(), None, None::<()>)) + } + + NodeType::AllAggregator => { + Ok(node_builder.make_all_aggregator(&[], None, None::<()>, None)?) + } + + NodeType::AnyAggregator => { + Ok(node_builder.make_any_aggregator(&[], None, None::<()>, None)?) + } + + NodeType::InAggregator(expected) => Ok(node_builder.make_in_aggregator( + expected.iter().cloned().collect(), + None, + None::<()>, + )?), + } + }; + + for (_old_domain_id, domain) in g1.domain.iter() { + let new_domain_id = add_domain(&mut node_builder, domain.clone())?; + g1_old2new_domain_id.push(new_domain_id); + } + + for (_old_domain_id, domain) in g2.domain.iter() { + let new_domain_id = add_domain(&mut node_builder, domain.clone())?; + g2_old2new_domain_id.push(new_domain_id); + } + + for (_old_node_id, node) in g1.nodes.iter() { + let new_node_id = add_node(&mut node_builder, node)?; + g1_old2new_id.push(new_node_id); + } + + for (_old_node_id, node) in g2.nodes.iter() { + let new_node_id = add_node(&mut node_builder, node)?; + g2_old2new_id.push(new_node_id); + } + + for edge in g1.edges.values() { + let new_pred_id = g1_old2new_id + .get(edge.pred) + .ok_or(GraphError::NodeNotFound)?; + let new_succ_id = g1_old2new_id + .get(edge.succ) + .ok_or(GraphError::NodeNotFound)?; + let domain_ident = edge + .domain + .map(|domain_id| g1.domain.get(domain_id).ok_or(GraphError::DomainNotFound)) + .transpose()? + .map(|domain| domain.domain_identifier); + + node_builder.make_edge( + *new_pred_id, + *new_succ_id, + edge.strength, + edge.relation, + domain_ident.as_deref(), + )?; + } + + for edge in g2.edges.values() { + let new_pred_id = g2_old2new_id + .get(edge.pred) + .ok_or(GraphError::NodeNotFound)?; + let new_succ_id = g2_old2new_id + .get(edge.succ) + .ok_or(GraphError::NodeNotFound)?; + let domain_ident = edge + .domain + .map(|domain_id| g2.domain.get(domain_id).ok_or(GraphError::DomainNotFound)) + .transpose()? + .map(|domain| domain.domain_identifier); + + node_builder.make_edge( + *new_pred_id, + *new_succ_id, + edge.strength, + edge.relation, + domain_ident.as_deref(), + )?; + } + + Ok(node_builder.build()) + } +} diff --git a/crates/hyperswitch_constraint_graph/src/lib.rs b/crates/hyperswitch_constraint_graph/src/lib.rs new file mode 100644 index 0000000000..ade9a64272 --- /dev/null +++ b/crates/hyperswitch_constraint_graph/src/lib.rs @@ -0,0 +1,13 @@ +pub mod builder; +mod dense_map; +pub mod error; +pub mod graph; +pub mod types; + +pub use builder::ConstraintGraphBuilder; +pub use error::{AnalysisTrace, GraphError}; +pub use graph::ConstraintGraph; +pub use types::{ + CheckingContext, CycleCheck, DomainId, DomainIdentifier, Edge, EdgeId, KeyNode, Memoization, + Node, NodeId, NodeValue, Relation, Strength, ValueNode, +}; diff --git a/crates/hyperswitch_constraint_graph/src/types.rs b/crates/hyperswitch_constraint_graph/src/types.rs new file mode 100644 index 0000000000..d1d14bd7e5 --- /dev/null +++ b/crates/hyperswitch_constraint_graph/src/types.rs @@ -0,0 +1,249 @@ +use std::{ + any::Any, + fmt, hash, + ops::{Deref, DerefMut}, + sync::Arc, +}; + +use rustc_hash::{FxHashMap, FxHashSet}; + +use crate::{dense_map::impl_entity, error::AnalysisTrace}; + +pub trait KeyNode: fmt::Debug + Clone + hash::Hash + serde::Serialize + PartialEq + Eq {} + +pub trait ValueNode: fmt::Debug + Clone + hash::Hash + serde::Serialize + PartialEq + Eq { + type Key: KeyNode; + + fn get_key(&self) -> Self::Key; +} + +#[derive(Debug, Clone, Copy, serde::Serialize, PartialEq, Eq, Hash)] +#[serde(transparent)] +pub struct NodeId(usize); + +impl_entity!(NodeId); + +#[derive(Debug)] +pub struct Node { + pub node_type: NodeType, + pub preds: Vec, + pub succs: Vec, +} + +impl Node { + pub(crate) fn new(node_type: NodeType) -> Self { + Self { + node_type, + preds: Vec::new(), + succs: Vec::new(), + } + } +} + +#[derive(Debug, PartialEq, Eq)] +pub enum NodeType { + AllAggregator, + AnyAggregator, + InAggregator(FxHashSet), + Value(NodeValue), +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize)] +#[serde(tag = "type", content = "value", rename_all = "snake_case")] +pub enum NodeValue { + Key(::Key), + Value(V), +} + +impl From for NodeValue { + fn from(value: V) -> Self { + Self::Value(value) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct EdgeId(usize); + +impl_entity!(EdgeId); + +#[derive( + Debug, Clone, Copy, serde::Serialize, PartialEq, Eq, Hash, strum::Display, PartialOrd, Ord, +)] +pub enum Strength { + Weak, + Normal, + Strong, +} + +impl Strength { + pub fn get_resolved_strength(prev_strength: Self, curr_strength: Self) -> Self { + std::cmp::max(prev_strength, curr_strength) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, strum::Display, serde::Serialize)] +#[serde(rename_all = "snake_case")] +pub enum Relation { + Positive, + Negative, +} + +impl From for bool { + fn from(value: Relation) -> Self { + matches!(value, Relation::Positive) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, strum::Display, serde::Serialize)] +pub enum RelationResolution { + Positive, + Negative, + Contradiction, +} + +impl From for RelationResolution { + fn from(value: Relation) -> Self { + match value { + Relation::Positive => Self::Positive, + Relation::Negative => Self::Negative, + } + } +} + +impl RelationResolution { + pub fn get_resolved_relation(prev_relation: Self, curr_relation: Self) -> Self { + if prev_relation != curr_relation { + Self::Contradiction + } else { + curr_relation + } + } +} + +#[derive(Debug, Clone)] +pub struct Edge { + pub strength: Strength, + pub relation: Relation, + pub pred: NodeId, + pub succ: NodeId, + pub domain: Option, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct DomainId(usize); + +impl_entity!(DomainId); + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct DomainIdentifier<'a>(&'a str); + +impl<'a> DomainIdentifier<'a> { + pub fn new(identifier: &'a str) -> Self { + Self(identifier) + } + + pub fn into_inner(&self) -> &'a str { + self.0 + } +} + +impl<'a> From<&'a str> for DomainIdentifier<'a> { + fn from(value: &'a str) -> Self { + Self(value) + } +} + +impl<'a> Deref for DomainIdentifier<'a> { + type Target = str; + + fn deref(&self) -> &'a Self::Target { + self.0 + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct DomainInfo<'a> { + pub domain_identifier: DomainIdentifier<'a>, + pub domain_description: String, +} + +pub trait CheckingContext { + type Value: ValueNode; + + fn from_node_values(vals: impl IntoIterator) -> Self + where + L: Into; + + fn check_presence(&self, value: &NodeValue, strength: Strength) -> bool; + + fn get_values_by_key( + &self, + expected: &::Key, + ) -> Option>; +} + +#[derive(Debug, Clone, serde::Serialize)] +pub struct Memoization( + #[allow(clippy::type_complexity)] + FxHashMap<(NodeId, Relation, Strength), Result<(), Arc>>>, +); + +impl Memoization { + pub fn new() -> Self { + Self(FxHashMap::default()) + } +} + +impl Default for Memoization { + #[inline] + fn default() -> Self { + Self::new() + } +} + +impl Deref for Memoization { + type Target = FxHashMap<(NodeId, Relation, Strength), Result<(), Arc>>>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for Memoization { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +#[derive(Debug, Clone)] +pub struct CycleCheck(FxHashMap); +impl CycleCheck { + pub fn new() -> Self { + Self(FxHashMap::default()) + } +} + +impl Default for CycleCheck { + #[inline] + fn default() -> Self { + Self::new() + } +} + +impl Deref for CycleCheck { + type Target = FxHashMap; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for CycleCheck { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +pub trait Metadata: erased_serde::Serialize + Any + Send + Sync + fmt::Debug {} +erased_serde::serialize_trait_object!(Metadata); + +impl Metadata for M where M: erased_serde::Serialize + Any + Send + Sync + fmt::Debug {} diff --git a/crates/kgraph_utils/Cargo.toml b/crates/kgraph_utils/Cargo.toml index 4ad5ef04f4..86de6002c3 100644 --- a/crates/kgraph_utils/Cargo.toml +++ b/crates/kgraph_utils/Cargo.toml @@ -13,6 +13,7 @@ connector_choice_mca_id = ["api_models/connector_choice_mca_id", "euclid/connect [dependencies] api_models = { version = "0.1.0", path = "../api_models", package = "api_models" } common_enums = { version = "0.1.0", path = "../common_enums" } +hyperswitch_constraint_graph = { version = "0.1.0", path = "../hyperswitch_constraint_graph" } euclid = { version = "0.1.0", path = "../euclid" } masking = { version = "0.1.0", path = "../masking/" } diff --git a/crates/kgraph_utils/benches/evaluation.rs b/crates/kgraph_utils/benches/evaluation.rs index 6105dc85d7..9921ee7af3 100644 --- a/crates/kgraph_utils/benches/evaluation.rs +++ b/crates/kgraph_utils/benches/evaluation.rs @@ -8,13 +8,17 @@ use api_models::{ use criterion::{black_box, criterion_group, criterion_main, Criterion}; use euclid::{ dirval, - dssa::graph::{self, Memoization}, + dssa::graph::{self, CgraphExt}, frontend::dir, types::{NumValue, NumValueRefinement}, }; +use hyperswitch_constraint_graph::{CycleCheck, Memoization}; use kgraph_utils::{error::KgraphError, transformers::IntoDirValue}; -fn build_test_data<'a>(total_enabled: usize, total_pm_types: usize) -> graph::KnowledgeGraph<'a> { +fn build_test_data<'a>( + total_enabled: usize, + total_pm_types: usize, +) -> hyperswitch_constraint_graph::ConstraintGraph<'a, dir::DirValue> { use api_models::{admin::*, payment_methods::*}; let mut pms_enabled: Vec = Vec::new(); @@ -88,6 +92,8 @@ fn evaluation(c: &mut Criterion) { dirval!(PaymentAmount = 100), ]), &mut Memoization::new(), + &mut CycleCheck::new(), + None, ); }); }); @@ -105,6 +111,8 @@ fn evaluation(c: &mut Criterion) { dirval!(PaymentAmount = 100), ]), &mut Memoization::new(), + &mut CycleCheck::new(), + None, ); }); }); diff --git a/crates/kgraph_utils/src/error.rs b/crates/kgraph_utils/src/error.rs index 5a16c6375b..95450fbe35 100644 --- a/crates/kgraph_utils/src/error.rs +++ b/crates/kgraph_utils/src/error.rs @@ -1,4 +1,4 @@ -use euclid::dssa::{graph::GraphError, types::AnalysisErrorType}; +use euclid::{dssa::types::AnalysisErrorType, frontend::dir}; #[derive(Debug, thiserror::Error, serde::Serialize)] #[serde(tag = "type", content = "info", rename_all = "snake_case")] @@ -6,7 +6,7 @@ pub enum KgraphError { #[error("Invalid connector name encountered: '{0}'")] InvalidConnectorName(String), #[error("There was an error constructing the graph: {0}")] - GraphConstructionError(GraphError), + GraphConstructionError(hyperswitch_constraint_graph::GraphError), #[error("There was an error constructing the context")] ContextConstructionError(AnalysisErrorType), #[error("there was an unprecedented indexing error")] diff --git a/crates/kgraph_utils/src/mca.rs b/crates/kgraph_utils/src/mca.rs index 8542437a5a..14a88dd1c6 100644 --- a/crates/kgraph_utils/src/mca.rs +++ b/crates/kgraph_utils/src/mca.rs @@ -4,42 +4,39 @@ use api_models::{ admin as admin_api, enums as api_enums, payment_methods::RequestPaymentMethodTypes, }; use euclid::{ - dssa::graph::{self, DomainIdentifier}, frontend::{ast, dir}, types::{NumValue, NumValueRefinement}, }; +use hyperswitch_constraint_graph as cgraph; use crate::{error::KgraphError, transformers::IntoDirValue}; pub const DOMAIN_IDENTIFIER: &str = "payment_methods_enabled_for_merchantconnectoraccount"; fn compile_request_pm_types( - builder: &mut graph::KnowledgeGraphBuilder<'_>, + builder: &mut cgraph::ConstraintGraphBuilder<'_, dir::DirValue>, pm_types: RequestPaymentMethodTypes, pm: api_enums::PaymentMethod, -) -> Result { - let mut agg_nodes: Vec<(graph::NodeId, graph::Relation, graph::Strength)> = Vec::new(); +) -> Result { + let mut agg_nodes: Vec<(cgraph::NodeId, cgraph::Relation, cgraph::Strength)> = Vec::new(); let pmt_info = "PaymentMethodType"; - let pmt_id = builder - .make_value_node( - (pm_types.payment_method_type, pm) - .into_dir_value() - .map(Into::into)?, - Some(pmt_info), - vec![DomainIdentifier::new(DOMAIN_IDENTIFIER)], - None::<()>, - ) - .map_err(KgraphError::GraphConstructionError)?; + let pmt_id = builder.make_value_node( + (pm_types.payment_method_type, pm) + .into_dir_value() + .map(Into::into)?, + Some(pmt_info), + None::<()>, + ); agg_nodes.push(( pmt_id, - graph::Relation::Positive, + cgraph::Relation::Positive, match pm_types.payment_method_type { api_enums::PaymentMethodType::Credit | api_enums::PaymentMethodType::Debit => { - graph::Strength::Weak + cgraph::Strength::Weak } - _ => graph::Strength::Strong, + _ => cgraph::Strength::Strong, }, )); @@ -52,13 +49,13 @@ fn compile_request_pm_types( let card_network_info = "Card Networks"; let card_network_id = builder - .make_in_aggregator(dir_vals, Some(card_network_info), None::<()>, Vec::new()) + .make_in_aggregator(dir_vals, Some(card_network_info), None::<()>) .map_err(KgraphError::GraphConstructionError)?; agg_nodes.push(( card_network_id, - graph::Relation::Positive, - graph::Strength::Weak, + cgraph::Relation::Positive, + cgraph::Strength::Weak, )); } } @@ -71,7 +68,7 @@ fn compile_request_pm_types( .map(IntoDirValue::into_dir_value) .collect::>() .ok()?, - graph::Relation::Positive, + cgraph::Relation::Positive, )), admin_api::AcceptedCurrencies::DisableOnly(curr) if !curr.is_empty() => Some(( @@ -79,7 +76,7 @@ fn compile_request_pm_types( .map(IntoDirValue::into_dir_value) .collect::>() .ok()?, - graph::Relation::Negative, + cgraph::Relation::Negative, )), _ => None, @@ -88,15 +85,10 @@ fn compile_request_pm_types( if let Some((currencies, relation)) = currencies_data { let accepted_currencies_info = "Accepted Currencies"; let accepted_currencies_id = builder - .make_in_aggregator( - currencies, - Some(accepted_currencies_info), - None::<()>, - Vec::new(), - ) + .make_in_aggregator(currencies, Some(accepted_currencies_info), None::<()>) .map_err(KgraphError::GraphConstructionError)?; - agg_nodes.push((accepted_currencies_id, relation, graph::Strength::Strong)); + agg_nodes.push((accepted_currencies_id, relation, cgraph::Strength::Strong)); } let mut amount_nodes = Vec::with_capacity(2); @@ -108,14 +100,11 @@ fn compile_request_pm_types( }; let min_amt_info = "Minimum Amount"; - let min_amt_id = builder - .make_value_node( - dir::DirValue::PaymentAmount(num_val).into(), - Some(min_amt_info), - vec![DomainIdentifier::new(DOMAIN_IDENTIFIER)], - None::<()>, - ) - .map_err(KgraphError::GraphConstructionError)?; + let min_amt_id = builder.make_value_node( + dir::DirValue::PaymentAmount(num_val).into(), + Some(min_amt_info), + None::<()>, + ); amount_nodes.push(min_amt_id); } @@ -127,14 +116,11 @@ fn compile_request_pm_types( }; let max_amt_info = "Maximum Amount"; - let max_amt_id = builder - .make_value_node( - dir::DirValue::PaymentAmount(num_val).into(), - Some(max_amt_info), - vec![DomainIdentifier::new(DOMAIN_IDENTIFIER)], - None::<()>, - ) - .map_err(KgraphError::GraphConstructionError)?; + let max_amt_id = builder.make_value_node( + dir::DirValue::PaymentAmount(num_val).into(), + Some(max_amt_info), + None::<()>, + ); amount_nodes.push(max_amt_id); } @@ -145,14 +131,11 @@ fn compile_request_pm_types( refinement: None, }; - let zero_amt_id = builder - .make_value_node( - dir::DirValue::PaymentAmount(zero_num_val).into(), - Some("zero_amount"), - vec![DomainIdentifier::new(DOMAIN_IDENTIFIER)], - None::<()>, - ) - .map_err(KgraphError::GraphConstructionError)?; + let zero_amt_id = builder.make_value_node( + dir::DirValue::PaymentAmount(zero_num_val).into(), + Some("zero_amount"), + None::<()>, + ); let or_node_neighbor_id = if amount_nodes.len() == 1 { amount_nodes @@ -163,7 +146,13 @@ fn compile_request_pm_types( let nodes = amount_nodes .iter() .copied() - .map(|node_id| (node_id, graph::Relation::Positive, graph::Strength::Strong)) + .map(|node_id| { + ( + node_id, + cgraph::Relation::Positive, + cgraph::Strength::Strong, + ) + }) .collect::>(); builder @@ -171,7 +160,7 @@ fn compile_request_pm_types( &nodes, Some("amount_constraint_aggregator"), None::<()>, - vec![DomainIdentifier::new(DOMAIN_IDENTIFIER)], + None, ) .map_err(KgraphError::GraphConstructionError)? }; @@ -179,37 +168,40 @@ fn compile_request_pm_types( let any_aggregator = builder .make_any_aggregator( &[ - (zero_amt_id, graph::Relation::Positive), - (or_node_neighbor_id, graph::Relation::Positive), + ( + zero_amt_id, + cgraph::Relation::Positive, + cgraph::Strength::Strong, + ), + ( + or_node_neighbor_id, + cgraph::Relation::Positive, + cgraph::Strength::Strong, + ), ], Some("zero_plus_limits_amount_aggregator"), None::<()>, - vec![DomainIdentifier::new(DOMAIN_IDENTIFIER)], + None, ) .map_err(KgraphError::GraphConstructionError)?; agg_nodes.push(( any_aggregator, - graph::Relation::Positive, - graph::Strength::Strong, + cgraph::Relation::Positive, + cgraph::Strength::Strong, )); } let pmt_all_aggregator_info = "All Aggregator for PaymentMethodType"; builder - .make_all_aggregator( - &agg_nodes, - Some(pmt_all_aggregator_info), - None::<()>, - Vec::new(), - ) + .make_all_aggregator(&agg_nodes, Some(pmt_all_aggregator_info), None::<()>, None) .map_err(KgraphError::GraphConstructionError) } fn compile_payment_method_enabled( - builder: &mut graph::KnowledgeGraphBuilder<'_>, + builder: &mut cgraph::ConstraintGraphBuilder<'_, dir::DirValue>, enabled: admin_api::PaymentMethodsEnabled, -) -> Result, KgraphError> { +) -> Result, KgraphError> { let agg_id = if !enabled .payment_method_types .as_ref() @@ -217,48 +209,44 @@ fn compile_payment_method_enabled( .unwrap_or(true) { let pm_info = "PaymentMethod"; - let pm_id = builder - .make_value_node( - enabled.payment_method.into_dir_value().map(Into::into)?, - Some(pm_info), - vec![DomainIdentifier::new(DOMAIN_IDENTIFIER)], - None::<()>, - ) - .map_err(KgraphError::GraphConstructionError)?; + let pm_id = builder.make_value_node( + enabled.payment_method.into_dir_value().map(Into::into)?, + Some(pm_info), + None::<()>, + ); - let mut agg_nodes: Vec<(graph::NodeId, graph::Relation)> = Vec::new(); + let mut agg_nodes: Vec<(cgraph::NodeId, cgraph::Relation, cgraph::Strength)> = Vec::new(); if let Some(pm_types) = enabled.payment_method_types { for pm_type in pm_types { let node_id = compile_request_pm_types(builder, pm_type, enabled.payment_method)?; - agg_nodes.push((node_id, graph::Relation::Positive)); + agg_nodes.push(( + node_id, + cgraph::Relation::Positive, + cgraph::Strength::Strong, + )); } } let any_aggregator_info = "Any aggregation for PaymentMethodsType"; let pm_type_agg_id = builder - .make_any_aggregator( - &agg_nodes, - Some(any_aggregator_info), - None::<()>, - Vec::new(), - ) + .make_any_aggregator(&agg_nodes, Some(any_aggregator_info), None::<()>, None) .map_err(KgraphError::GraphConstructionError)?; let all_aggregator_info = "All aggregation for PaymentMethod"; let enabled_pm_agg_id = builder .make_all_aggregator( &[ - (pm_id, graph::Relation::Positive, graph::Strength::Strong), + (pm_id, cgraph::Relation::Positive, cgraph::Strength::Strong), ( pm_type_agg_id, - graph::Relation::Positive, - graph::Strength::Strong, + cgraph::Relation::Positive, + cgraph::Strength::Strong, ), ], Some(all_aggregator_info), None::<()>, - Vec::new(), + None, ) .map_err(KgraphError::GraphConstructionError)?; @@ -271,26 +259,30 @@ fn compile_payment_method_enabled( } fn compile_merchant_connector_graph( - builder: &mut graph::KnowledgeGraphBuilder<'_>, + builder: &mut cgraph::ConstraintGraphBuilder<'_, dir::DirValue>, mca: admin_api::MerchantConnectorResponse, ) -> Result<(), KgraphError> { let connector = common_enums::RoutableConnectors::from_str(&mca.connector_name) .map_err(|_| KgraphError::InvalidConnectorName(mca.connector_name.clone()))?; - let mut agg_nodes: Vec<(graph::NodeId, graph::Relation)> = Vec::new(); + let mut agg_nodes: Vec<(cgraph::NodeId, cgraph::Relation, cgraph::Strength)> = Vec::new(); if let Some(pms_enabled) = mca.payment_methods_enabled { for pm_enabled in pms_enabled { let maybe_pm_enabled_id = compile_payment_method_enabled(builder, pm_enabled)?; if let Some(pm_enabled_id) = maybe_pm_enabled_id { - agg_nodes.push((pm_enabled_id, graph::Relation::Positive)); + agg_nodes.push(( + pm_enabled_id, + cgraph::Relation::Positive, + cgraph::Strength::Strong, + )); } } } let aggregator_info = "Available Payment methods for connector"; let pms_enabled_agg_id = builder - .make_any_aggregator(&agg_nodes, Some(aggregator_info), None::<()>, Vec::new()) + .make_any_aggregator(&agg_nodes, Some(aggregator_info), None::<()>, None) .map_err(KgraphError::GraphConstructionError)?; let connector_dir_val = dir::DirValue::Connector(Box::new(ast::ConnectorChoice { @@ -300,21 +292,16 @@ fn compile_merchant_connector_graph( })); let connector_info = "Connector"; - let connector_node_id = builder - .make_value_node( - connector_dir_val.into(), - Some(connector_info), - vec![DomainIdentifier::new(DOMAIN_IDENTIFIER)], - None::<()>, - ) - .map_err(KgraphError::GraphConstructionError)?; + let connector_node_id = + builder.make_value_node(connector_dir_val.into(), Some(connector_info), None::<()>); builder .make_edge( pms_enabled_agg_id, connector_node_id, - graph::Strength::Normal, - graph::Relation::Positive, + cgraph::Strength::Normal, + cgraph::Relation::Positive, + None::, ) .map_err(KgraphError::GraphConstructionError)?; @@ -323,11 +310,11 @@ fn compile_merchant_connector_graph( pub fn make_mca_graph<'a>( accts: Vec, -) -> Result, KgraphError> { - let mut builder = graph::KnowledgeGraphBuilder::new(); +) -> Result, KgraphError> { + let mut builder = cgraph::ConstraintGraphBuilder::new(); let _domain = builder.make_domain( - DomainIdentifier::new(DOMAIN_IDENTIFIER), - "Payment methods enabled for MerchantConnectorAccount".to_string(), + DOMAIN_IDENTIFIER, + "Payment methods enabled for MerchantConnectorAccount", ); for acct in accts { compile_merchant_connector_graph(&mut builder, acct)?; @@ -343,12 +330,13 @@ mod tests { use api_models::enums as api_enums; use euclid::{ dirval, - dssa::graph::{AnalysisContext, Memoization}, + dssa::graph::{AnalysisContext, CgraphExt}, }; + use hyperswitch_constraint_graph::{ConstraintGraph, CycleCheck, Memoization}; use super::*; - fn build_test_data<'a>() -> graph::KnowledgeGraph<'a> { + fn build_test_data<'a>() -> ConstraintGraph<'a, dir::DirValue> { use api_models::{admin::*, payment_methods::*}; let stripe_account = MerchantConnectorResponse { @@ -428,6 +416,8 @@ mod tests { dirval!(PaymentAmount = 100), ]), &mut Memoization::new(), + &mut CycleCheck::new(), + None, ); assert!(result.is_ok()); @@ -448,6 +438,8 @@ mod tests { dirval!(PaymentAmount = 100), ]), &mut Memoization::new(), + &mut CycleCheck::new(), + None, ); assert!(result.is_ok()); @@ -468,6 +460,8 @@ mod tests { dirval!(PaymentAmount = 100), ]), &mut Memoization::new(), + &mut CycleCheck::new(), + None, ); assert!(result.is_err()); @@ -488,6 +482,8 @@ mod tests { dirval!(PaymentAmount = 7), ]), &mut Memoization::new(), + &mut CycleCheck::new(), + None, ); assert!(result.is_err()); @@ -507,6 +503,8 @@ mod tests { dirval!(PaymentAmount = 7), ]), &mut Memoization::new(), + &mut CycleCheck::new(), + None, ); //println!("{:#?}", result); @@ -529,6 +527,8 @@ mod tests { dirval!(PaymentAmount = 100), ]), &mut Memoization::new(), + &mut CycleCheck::new(), + None, ); //println!("{:#?}", result); @@ -725,6 +725,8 @@ mod tests { dirval!(Connector = Stripe), &context, &mut Memoization::new(), + &mut CycleCheck::new(), + None, ); assert!(result.is_ok(), "stripe validation failed"); @@ -733,6 +735,8 @@ mod tests { dirval!(Connector = Bluesnap), &context, &mut Memoization::new(), + &mut CycleCheck::new(), + None, ); assert!(result.is_err(), "bluesnap validation failed"); } diff --git a/crates/router/Cargo.toml b/crates/router/Cargo.toml index f8e2cfad12..7ff47b927d 100644 --- a/crates/router/Cargo.toml +++ b/crates/router/Cargo.toml @@ -103,6 +103,7 @@ analytics = { version = "0.1.0", path = "../analytics", optional = true } cards = { version = "0.1.0", path = "../cards" } common_enums = { version = "0.1.0", path = "../common_enums" } common_utils = { version = "0.1.0", path = "../common_utils", features = ["signals", "async_ext", "logs"] } +hyperswitch_constraint_graph = { version = "0.1.0", path = "../hyperswitch_constraint_graph" } currency_conversion = { version = "0.1.0", path = "../currency_conversion" } hyperswitch_domain_models = { version = "0.1.0", path = "../hyperswitch_domain_models", default-features = false } diesel_models = { version = "0.1.0", path = "../diesel_models", features = ["kv_store"] } diff --git a/crates/router/src/core/payments/routing.rs b/crates/router/src/core/payments/routing.rs index ff7303c900..6967c97775 100644 --- a/crates/router/src/core/payments/routing.rs +++ b/crates/router/src/core/payments/routing.rs @@ -17,9 +17,9 @@ use diesel_models::enums as storage_enums; use error_stack::ResultExt; use euclid::{ backend::{self, inputs as dsl_inputs, EuclidBackend}, - dssa::graph::{self as euclid_graph, Memoization}, + dssa::graph::{self as euclid_graph, CgraphExt}, enums as euclid_enums, - frontend::ast, + frontend::{ast, dir as euclid_dir}, }; use kgraph_utils::{ mca as mca_graph, @@ -82,7 +82,9 @@ pub struct SessionRoutingPmTypeInput<'a> { profile_id: Option, } static ROUTING_CACHE: StaticCache = StaticCache::new(); -static KGRAPH_CACHE: StaticCache> = StaticCache::new(); +static KGRAPH_CACHE: StaticCache< + hyperswitch_constraint_graph::ConstraintGraph<'_, euclid_dir::DirValue>, +> = StaticCache::new(); type RoutingResult = oss_errors::CustomResult; @@ -542,7 +544,7 @@ pub async fn get_merchant_kgraph<'a>( merchant_last_modified: i64, #[cfg(feature = "business_profile_routing")] profile_id: Option, transaction_type: &api_enums::TransactionType, -) -> RoutingResult>> { +) -> RoutingResult>> { let merchant_id = &key_store.merchant_id; #[cfg(feature = "business_profile_routing")] @@ -690,7 +692,13 @@ async fn perform_kgraph_filtering( .into_dir_value() .change_context(errors::RoutingError::KgraphAnalysisError)?; let kgraph_eligible = cached_kgraph - .check_value_validity(dir_val, &context, &mut Memoization::new()) + .check_value_validity( + dir_val, + &context, + &mut hyperswitch_constraint_graph::Memoization::new(), + &mut hyperswitch_constraint_graph::CycleCheck::new(), + None, + ) .change_context(errors::RoutingError::KgraphAnalysisError)?; let filter_eligible =