use std::sync::{Arc, Weak}; use rustc_hash::{FxHashMap, FxHashSet}; use crate::{ builder, dense_map::DenseMap, error::{self, AnalysisTrace, GraphError}, types::{ CheckingContext, CycleCheck, DomainId, DomainIdentifier, DomainInfo, Edge, EdgeId, Memoization, Metadata, Node, NodeId, NodeType, NodeValue, Relation, RelationResolution, Strength, ValueNode, }, }; #[derive(Debug)] struct CheckNodeContext<'a, V: ValueNode, C: CheckingContext> { ctx: &'a C, node: &'a Node, node_id: NodeId, relation: Relation, strength: Strength, memo: &'a mut Memoization, cycle_map: &'a mut CycleCheck, domains: Option<&'a [DomainId]>, } #[derive(Debug)] pub struct ConstraintGraph { pub domain: DenseMap, pub domain_identifier_map: FxHashMap, pub nodes: DenseMap>, pub edges: DenseMap, pub value_map: FxHashMap, NodeId>, pub node_info: DenseMap>, pub node_metadata: DenseMap>>, } impl ConstraintGraph where V: ValueNode, { fn get_predecessor_edges_by_domain( &self, node_id: NodeId, domains: Option<&[DomainId]>, ) -> Result, GraphError> { let node = self.nodes.get(node_id).ok_or(GraphError::NodeNotFound)?; let mut final_list = Vec::new(); for &pred in &node.preds { let edge = self.edges.get(pred).ok_or(GraphError::EdgeNotFound)?; if let Some((domain_id, domains)) = edge.domain.zip(domains) { if domains.contains(&domain_id) { final_list.push(edge); } } else if edge.domain.is_none() { final_list.push(edge); } } Ok(final_list) } #[allow(clippy::too_many_arguments)] pub fn check_node( &self, ctx: &C, node_id: NodeId, relation: Relation, strength: Strength, memo: &mut Memoization, cycle_map: &mut CycleCheck, domains: Option<&[String]>, ) -> Result<(), GraphError> where C: CheckingContext, { let domains = domains .map(|domain_idents| { domain_idents .iter() .map(|domain_ident| { self.domain_identifier_map .get(&DomainIdentifier::new(domain_ident.to_string())) .copied() .ok_or(GraphError::DomainNotFound) }) .collect::, _>>() }) .transpose()?; self.check_node_inner( ctx, node_id, relation, strength, memo, cycle_map, domains.as_deref(), ) } #[allow(clippy::too_many_arguments)] pub fn check_node_inner( &self, ctx: &C, node_id: NodeId, relation: Relation, strength: Strength, memo: &mut Memoization, cycle_map: &mut CycleCheck, domains: Option<&[DomainId]>, ) -> Result<(), GraphError> where C: CheckingContext, { let node = self.nodes.get(node_id).ok_or(GraphError::NodeNotFound)?; if let Some(already_memo) = memo.get(&(node_id, relation, strength)) { already_memo .clone() .map_err(|err| GraphError::AnalysisError(Arc::downgrade(&err))) } else if let Some((initial_strength, initial_relation)) = cycle_map.get(&node_id).copied() { let strength_relation = Strength::get_resolved_strength(initial_strength, strength); let relation_resolve = RelationResolution::get_resolved_relation(initial_relation, relation.into()); cycle_map.entry(node_id).and_modify(|value| { value.0 = strength_relation; value.1 = relation_resolve }); Ok(()) } else { let check_node_context = CheckNodeContext { node, node_id, relation, strength, memo, cycle_map, ctx, domains, }; match &node.node_type { NodeType::AllAggregator => self.validate_all_aggregator(check_node_context), NodeType::AnyAggregator => self.validate_any_aggregator(check_node_context), NodeType::InAggregator(expected) => { self.validate_in_aggregator(check_node_context, expected) } NodeType::Value(val) => self.validate_value_node(check_node_context, val), } } } fn validate_all_aggregator( &self, vald: CheckNodeContext<'_, V, C>, ) -> Result<(), GraphError> where C: CheckingContext, { let mut unsatisfied = Vec::>>::new(); for edge in self.get_predecessor_edges_by_domain(vald.node_id, vald.domains)? { vald.cycle_map .insert(vald.node_id, (vald.strength, vald.relation.into())); if let Err(e) = self.check_node_inner( vald.ctx, edge.pred, edge.relation, edge.strength, vald.memo, vald.cycle_map, vald.domains, ) { unsatisfied.push(e.get_analysis_trace()?); } if let Some((_resolved_strength, resolved_relation)) = vald.cycle_map.remove(&vald.node_id) { if resolved_relation == RelationResolution::Contradiction { let err = Arc::new(AnalysisTrace::Contradiction { relation: resolved_relation, }); vald.memo.insert( (vald.node_id, vald.relation, vald.strength), Err(Arc::clone(&err)), ); return Err(GraphError::AnalysisError(Arc::downgrade(&err))); } } } if !unsatisfied.is_empty() { let err = Arc::new(AnalysisTrace::AllAggregation { unsatisfied, info: self.node_info.get(vald.node_id).copied().flatten(), metadata: self.node_metadata.get(vald.node_id).cloned().flatten(), }); vald.memo.insert( (vald.node_id, vald.relation, vald.strength), Err(Arc::clone(&err)), ); Err(GraphError::AnalysisError(Arc::downgrade(&err))) } else { vald.memo .insert((vald.node_id, vald.relation, vald.strength), Ok(())); Ok(()) } } fn validate_any_aggregator( &self, vald: CheckNodeContext<'_, V, C>, ) -> Result<(), GraphError> where C: CheckingContext, { let mut unsatisfied = Vec::>>::new(); let mut matched_one = false; for edge in self.get_predecessor_edges_by_domain(vald.node_id, vald.domains)? { vald.cycle_map .insert(vald.node_id, (vald.strength, vald.relation.into())); if let Err(e) = self.check_node_inner( vald.ctx, edge.pred, edge.relation, edge.strength, vald.memo, vald.cycle_map, vald.domains, ) { unsatisfied.push(e.get_analysis_trace()?); } else { matched_one = true; } if let Some((_resolved_strength, resolved_relation)) = vald.cycle_map.remove(&vald.node_id) { if resolved_relation == RelationResolution::Contradiction { let err = Arc::new(AnalysisTrace::Contradiction { relation: resolved_relation, }); vald.memo.insert( (vald.node_id, vald.relation, vald.strength), Err(Arc::clone(&err)), ); return Err(GraphError::AnalysisError(Arc::downgrade(&err))); } } } if matched_one || vald.node.preds.is_empty() { vald.memo .insert((vald.node_id, vald.relation, vald.strength), Ok(())); Ok(()) } else { let err = Arc::new(AnalysisTrace::AnyAggregation { unsatisfied: unsatisfied.clone(), info: self.node_info.get(vald.node_id).copied().flatten(), metadata: self.node_metadata.get(vald.node_id).cloned().flatten(), }); vald.memo.insert( (vald.node_id, vald.relation, vald.strength), Err(Arc::clone(&err)), ); Err(GraphError::AnalysisError(Arc::downgrade(&err))) } } fn validate_in_aggregator( &self, vald: CheckNodeContext<'_, V, C>, expected: &FxHashSet, ) -> Result<(), GraphError> where C: CheckingContext, { let the_key = expected .iter() .next() .ok_or_else(|| GraphError::MalformedGraph { reason: "An OnlyIn aggregator node must have at least one expected value" .to_string(), })? .get_key(); let ctx_vals = if let Some(vals) = vald.ctx.get_values_by_key(&the_key) { vals } else { return if let Strength::Weak = vald.strength { vald.memo .insert((vald.node_id, vald.relation, vald.strength), Ok(())); Ok(()) } else { let err = Arc::new(AnalysisTrace::InAggregation { expected: expected.iter().cloned().collect(), found: None, relation: vald.relation, info: self.node_info.get(vald.node_id).copied().flatten(), metadata: self.node_metadata.get(vald.node_id).cloned().flatten(), }); vald.memo.insert( (vald.node_id, vald.relation, vald.strength), Err(Arc::clone(&err)), ); Err(GraphError::AnalysisError(Arc::downgrade(&err))) }; }; let relation_bool: bool = vald.relation.into(); for ctx_value in ctx_vals { if expected.contains(&ctx_value) != relation_bool { let err = Arc::new(AnalysisTrace::InAggregation { expected: expected.iter().cloned().collect(), found: Some(ctx_value.clone()), relation: vald.relation, info: self.node_info.get(vald.node_id).copied().flatten(), metadata: self.node_metadata.get(vald.node_id).cloned().flatten(), }); vald.memo.insert( (vald.node_id, vald.relation, vald.strength), Err(Arc::clone(&err)), ); Err(GraphError::AnalysisError(Arc::downgrade(&err)))?; } } vald.memo .insert((vald.node_id, vald.relation, vald.strength), Ok(())); Ok(()) } fn validate_value_node( &self, vald: CheckNodeContext<'_, V, C>, val: &NodeValue, ) -> Result<(), GraphError> where C: CheckingContext, { let mut errors = Vec::>>::new(); let mut matched_one = false; self.context_analysis( vald.node_id, vald.relation, vald.strength, vald.ctx, val, vald.memo, )?; for edge in self.get_predecessor_edges_by_domain(vald.node_id, vald.domains)? { vald.cycle_map .insert(vald.node_id, (vald.strength, vald.relation.into())); let result = self.check_node_inner( vald.ctx, edge.pred, edge.relation, edge.strength, vald.memo, vald.cycle_map, vald.domains, ); if let Some((resolved_strength, resolved_relation)) = vald.cycle_map.remove(&vald.node_id) { if resolved_relation == RelationResolution::Contradiction { let err = Arc::new(AnalysisTrace::Contradiction { relation: resolved_relation, }); vald.memo.insert( (vald.node_id, vald.relation, vald.strength), Err(Arc::clone(&err)), ); return Err(GraphError::AnalysisError(Arc::downgrade(&err))); } else if resolved_strength != vald.strength { self.context_analysis( vald.node_id, vald.relation, resolved_strength, vald.ctx, val, vald.memo, )? } } match (edge.strength, result) { (Strength::Strong, Err(trace)) => { let err = Arc::new(AnalysisTrace::Value { value: val.clone(), relation: vald.relation, info: self.node_info.get(vald.node_id).copied().flatten(), metadata: self.node_metadata.get(vald.node_id).cloned().flatten(), predecessors: Some(error::ValueTracePredecessor::Mandatory(Box::new( trace.get_analysis_trace()?, ))), }); vald.memo.insert( (vald.node_id, vald.relation, vald.strength), Err(Arc::clone(&err)), ); Err(GraphError::AnalysisError(Arc::downgrade(&err)))?; } (Strength::Strong, Ok(_)) => { matched_one = true; } (Strength::Normal | Strength::Weak, Err(trace)) => { errors.push(trace.get_analysis_trace()?); } (Strength::Normal | Strength::Weak, Ok(_)) => { matched_one = true; } } } if matched_one || vald.node.preds.is_empty() { vald.memo .insert((vald.node_id, vald.relation, vald.strength), Ok(())); Ok(()) } else { let err = Arc::new(AnalysisTrace::Value { value: val.clone(), relation: vald.relation, info: self.node_info.get(vald.node_id).copied().flatten(), metadata: self.node_metadata.get(vald.node_id).cloned().flatten(), predecessors: Some(error::ValueTracePredecessor::OneOf(errors.clone())), }); vald.memo.insert( (vald.node_id, vald.relation, vald.strength), Err(Arc::clone(&err)), ); Err(GraphError::AnalysisError(Arc::downgrade(&err))) } } fn context_analysis( &self, node_id: NodeId, relation: Relation, strength: Strength, ctx: &C, val: &NodeValue, memo: &mut Memoization, ) -> Result<(), GraphError> where C: CheckingContext, { let in_context = ctx.check_presence(val, strength); let relation_bool: bool = relation.into(); if in_context != relation_bool { let err = Arc::new(AnalysisTrace::Value { value: val.clone(), relation, predecessors: None, info: self.node_info.get(node_id).copied().flatten(), metadata: self.node_metadata.get(node_id).cloned().flatten(), }); memo.insert((node_id, relation, strength), Err(Arc::clone(&err))); Err(GraphError::AnalysisError(Arc::downgrade(&err)))?; } if !relation_bool { memo.insert((node_id, relation, strength), Ok(())); return Ok(()); } Ok(()) } pub fn combine(g1: &Self, g2: &Self) -> Result> { let mut node_builder = builder::ConstraintGraphBuilder::new(); let mut g1_old2new_id = DenseMap::::new(); let mut g2_old2new_id = DenseMap::::new(); let mut g1_old2new_domain_id = DenseMap::::new(); let mut g2_old2new_domain_id = DenseMap::::new(); let add_domain = |node_builder: &mut builder::ConstraintGraphBuilder, domain: DomainInfo| -> Result> { node_builder.make_domain( domain.domain_identifier.into_inner(), &domain.domain_description, ) }; let add_node = |node_builder: &mut builder::ConstraintGraphBuilder, node: &Node| -> Result> { match &node.node_type { NodeType::Value(node_value) => { Ok(node_builder.make_value_node(node_value.clone(), None, None::<()>)) } NodeType::AllAggregator => { Ok(node_builder.make_all_aggregator(&[], None, None::<()>, None)?) } NodeType::AnyAggregator => { Ok(node_builder.make_any_aggregator(&[], None, None::<()>, None)?) } NodeType::InAggregator(expected) => Ok(node_builder.make_in_aggregator( expected.iter().cloned().collect(), None, None::<()>, )?), } }; for (_old_domain_id, domain) in g1.domain.iter() { let new_domain_id = add_domain(&mut node_builder, domain.clone())?; g1_old2new_domain_id.push(new_domain_id); } for (_old_domain_id, domain) in g2.domain.iter() { let new_domain_id = add_domain(&mut node_builder, domain.clone())?; g2_old2new_domain_id.push(new_domain_id); } for (_old_node_id, node) in g1.nodes.iter() { let new_node_id = add_node(&mut node_builder, node)?; g1_old2new_id.push(new_node_id); } for (_old_node_id, node) in g2.nodes.iter() { let new_node_id = add_node(&mut node_builder, node)?; g2_old2new_id.push(new_node_id); } for edge in g1.edges.values() { let new_pred_id = g1_old2new_id .get(edge.pred) .ok_or(GraphError::NodeNotFound)?; let new_succ_id = g1_old2new_id .get(edge.succ) .ok_or(GraphError::NodeNotFound)?; let domain_ident = edge .domain .map(|domain_id| g1.domain.get(domain_id).ok_or(GraphError::DomainNotFound)) .transpose()? .map(|domain| domain.domain_identifier.clone()); node_builder.make_edge( *new_pred_id, *new_succ_id, edge.strength, edge.relation, domain_ident, )?; } for edge in g2.edges.values() { let new_pred_id = g2_old2new_id .get(edge.pred) .ok_or(GraphError::NodeNotFound)?; let new_succ_id = g2_old2new_id .get(edge.succ) .ok_or(GraphError::NodeNotFound)?; let domain_ident = edge .domain .map(|domain_id| g2.domain.get(domain_id).ok_or(GraphError::DomainNotFound)) .transpose()? .map(|domain| domain.domain_identifier.clone()); node_builder.make_edge( *new_pred_id, *new_succ_id, edge.strength, edge.relation, domain_ident, )?; } Ok(node_builder.build()) } } #[cfg(feature = "viz")] mod viz { use graphviz_rust::{ dot_generator::*, dot_structures::*, printer::{DotPrinter, PrinterContext}, }; use crate::{dense_map::EntityId, types, ConstraintGraph, NodeViz, ValueNode}; fn get_node_id(node_id: types::NodeId) -> String { format!("N{}", node_id.get_id()) } impl ConstraintGraph where V: ValueNode + NodeViz, ::Key: NodeViz, { fn get_node_label(node: &types::Node) -> String { let label = match &node.node_type { types::NodeType::Value(types::NodeValue::Key(key)) => format!("any {}", key.viz()), types::NodeType::Value(types::NodeValue::Value(val)) => { format!("{} = {}", val.get_key().viz(), val.viz()) } types::NodeType::AllAggregator => "&&".to_string(), types::NodeType::AnyAggregator => "| |".to_string(), types::NodeType::InAggregator(agg) => { let key = if let Some(val) = agg.iter().next() { val.get_key().viz() } else { return "empty in".to_string(); }; let nodes = agg.iter().map(NodeViz::viz).collect::>(); format!("{key} in [{}]", nodes.join(", ")) } }; format!("\"{label}\"") } fn build_node(cg_node_id: types::NodeId, cg_node: &types::Node) -> Node { let viz_node_id = get_node_id(cg_node_id); let viz_node_label = Self::get_node_label(cg_node); node!(viz_node_id; attr!("label", viz_node_label)) } fn build_edge(cg_edge: &types::Edge) -> Edge { let pred_vertex = get_node_id(cg_edge.pred); let succ_vertex = get_node_id(cg_edge.succ); let arrowhead = match cg_edge.strength { types::Strength::Weak => "onormal", types::Strength::Normal => "normal", types::Strength::Strong => "normalnormal", }; let color = match cg_edge.relation { types::Relation::Positive => "blue", types::Relation::Negative => "red", }; edge!( node_id!(pred_vertex) => node_id!(succ_vertex); attr!("arrowhead", arrowhead), attr!("color", color) ) } pub fn get_viz_digraph(&self) -> Graph { graph!( strict di id!("constraint_graph"), self.nodes .iter() .map(|(node_id, node)| Self::build_node(node_id, node)) .map(Stmt::Node) .chain(self.edges.values().map(Self::build_edge).map(Stmt::Edge)) .collect::>() ) } pub fn get_viz_digraph_string(&self) -> String { let mut ctx = PrinterContext::default(); let digraph = self.get_viz_digraph(); digraph.print(&mut ctx) } } }