1use std::{collections::BTreeSet, fmt::Display, hash::Hash};
2
3use indexmap::{IndexMap, IndexSet};
4
5use crate::ast::{
6 Clause, Constraint, Formula, Key, Op, PersistentStr, Property, RuleTreeNode, Specificity,
7};
8
9#[derive(Debug)]
10pub struct Dag {
11 pub children: IndexMap<PersistentStr, LiteralMatcher>,
12 pub prop_node: Node,
13 pub next_node_id: usize,
14 pub node_data: IndexMap<Node, NodeData>,
15}
16impl Default for Dag {
17 fn default() -> Self {
18 let mut dag = Self {
19 children: IndexMap::default(),
20 prop_node: Node(0),
21 next_node_id: 0, node_data: IndexMap::default(),
23 };
24
25 dag.prop_node = dag.new_node(NodeData::or());
26 dag
27 }
28}
29impl Dag {
30 pub fn stats(&self) -> Stats {
31 let mut stats = Stats::default();
32 let mut visited = Default::default();
33
34 self.prop_node
35 .accumulate_stats(self, &mut stats, &mut visited);
36
37 for (_, matcher) in &self.children {
38 stats.literals += 1;
39 if let Some(wildcard) = matcher.wildcard {
40 wildcard.accumulate_stats(self, &mut stats, &mut visited);
41 }
42 for (_, nodes) in &matcher.positive_values {
43 for node in nodes {
44 node.accumulate_stats(self, &mut stats, &mut visited);
45 }
46 }
47 }
49 stats
50 }
51
52 fn new_node(&mut self, data: NodeData) -> Node {
53 let node = Node(self.next_id());
54 debug_assert!(!self.node_data.contains_key(&node));
55 self.node_data.insert(node, data);
56 node
57 }
58
59 fn next_id(&mut self) -> usize {
60 let id = self.next_node_id;
61 self.next_node_id += 1;
62 id
63 }
64
65 pub fn get_data(&self, node: Node) -> &NodeData {
66 &self.node_data[&node]
67 }
68
69 fn get_data_mut(&mut self, node: Node) -> &mut NodeData {
70 &mut self.node_data[&node]
71 }
72
73 pub fn build(rule_tree_node: RuleTreeNode) -> Self {
74 let mut dag = Self::default();
75 let mut lit_nodes = IndexMap::<Key, Node>::default();
76
77 let mut sorted_formulae: Vec<_> = rule_tree_node.iter().cloned().collect();
80 sorted_formulae.sort_by(|lhs, rhs| lhs.formula.cmp(&rhs.formula));
81
82 let mut all_clauses: Vec<_> = sorted_formulae
83 .iter()
84 .flat_map(|f| f.formula.elements().union(f.formula.shared()))
85 .collect();
86
87 all_clauses.sort();
88
89 let all_elements = all_clauses.iter().flat_map(|c| c.elements()).cloned();
90 for lit in IndexSet::<Key>::from_iter(all_elements) {
92 lit_nodes.insert(lit.clone(), dag.add_literal(&lit));
93 }
94
95 let mut clause_nodes = IndexMap::<Clause, Node>::default();
96 for clause in all_clauses.into_iter() {
97 if !clause.is_empty() {
98 let specificity = clause.specificity();
99 let expr = dag.build_expr(
100 clause.clone(),
101 || NodeData::and(specificity),
102 &mut lit_nodes,
103 &mut clause_nodes,
104 );
105 clause_nodes.insert(clause.clone(), expr);
106 }
107 }
108
109 let mut form_nodes = IndexMap::<Formula, Node>::default();
110 for rule in sorted_formulae {
111 let node = if rule.formula.is_empty() {
112 dag.prop_node
113 } else {
114 let node = dag.build_expr(
115 rule.formula.clone(),
116 NodeData::or,
117 &mut clause_nodes,
118 &mut form_nodes,
119 );
120 form_nodes.insert(rule.formula, node);
121 node
122 };
123 let data = dag.get_data_mut(node);
124 data.props.extend(rule.props.iter().cloned());
125 data.constraints
126 .extend(rule.constraints.iter().cloned().map(Into::into));
127 }
128
129 dag
130 }
131
132 fn add_literal(&mut self, lit: &Key) -> Node {
133 let mut node_data = NodeData::and(lit.specificity);
134 node_data.add_link();
135 let node = self.new_node(node_data);
136 self.children
137 .entry(lit.name.clone())
138 .or_default()
139 .add_values(lit.values.clone(), node);
140 node
141 }
142
143 fn build_expr<E: NodeCreatorCollection>(
144 &mut self,
145 expr: E,
146 constructor: impl Fn() -> NodeData,
147 base_nodes: &mut IndexMap<E::Item, Node>, these_nodes: &mut IndexMap<E, Node>, ) -> Node {
150 assert!(!expr.is_empty());
151
152 if expr.len() == 1 {
153 return base_nodes[&expr.first().unwrap()];
154 } else if let Some(existing) = these_nodes.get(&expr) {
155 return *existing;
156 } else if expr.len() == 2 {
157 let mut node_data = constructor();
158 node_data.add_links(expr.len());
159 let node: Node = self.new_node(node_data);
160
161 for el in expr.elements().iter() {
162 self.get_data_mut(base_nodes[el]).children.push(node);
163 }
164 return node;
165 }
166
167 let mut item_collection_indices = IndexMap::<E::Item, Vec<usize>>::default();
168 for (i, c) in these_nodes.keys().enumerate() {
169 if c.is_subset(&expr) {
170 assert!(
171 c.len() < expr.len(),
172 "Exact equality should be handled above"
173 );
174 for el in c.elements() {
175 item_collection_indices
176 .entry(el.clone())
177 .or_default()
178 .push(i);
179 }
180 }
181 }
182
183 let mut covered = BTreeSet::new();
184 let node: Node = self.new_node(constructor());
185 let mut collections: Vec<_> = these_nodes.keys().cloned().collect();
186 collections.sort_by(|l, r| r.cmp(l)); let biggest_collection = |covered_elements: &BTreeSet<E::Item>| {
192 let collection_rank = |collection: &E| {
193 collection
194 .elements()
195 .iter()
196 .filter(|e| !covered_elements.contains(*e))
197 .count()
198 };
199 collections
200 .iter()
201 .filter(|collection| collection.is_subset(&expr))
202 .map(|collection| (collection_rank(collection), collection))
203 .filter(|(rank, _)| *rank != 0)
204 .max_by_key(|(rank, _)| *rank) .map(|(_, collection)| collection)
206 };
207
208 while let Some(best) = biggest_collection(&covered) {
209 self.get_data_mut(these_nodes[best]).children.push(node);
210 self.get_data_mut(node).add_link();
211 for el in best.elements() {
212 if !covered.contains(el) {
213 covered.insert(el.clone());
214 }
215 }
216 }
217
218 let remaining = expr.elements() - &covered;
219 self.get_data_mut(node).add_links(remaining.len());
220 for el in remaining {
221 self.get_data_mut(base_nodes[&el]).children.push(node);
222 }
223
224 node
225 }
226}
227
228trait NodeCreator: Hash + Eq + Ord + Clone + std::fmt::Debug + std::fmt::Display {}
229impl NodeCreator for Key {}
230impl NodeCreator for Clause {}
231impl NodeCreator for Formula {}
232
233trait NodeCreatorCollection: NodeCreator {
234 type Item: NodeCreator;
235
236 fn len(&self) -> usize;
237 fn is_empty(&self) -> bool {
238 self.len() == 0
239 }
240 fn elements(&self) -> &BTreeSet<Self::Item>;
241 fn first(&self) -> Option<Self::Item> {
242 self.elements().iter().next().cloned()
243 }
244 fn is_subset(&self, other: &Self) -> bool;
245}
246impl NodeCreatorCollection for Clause {
247 type Item = Key;
248 fn len(&self) -> usize {
249 self.len()
250 }
251 fn elements(&self) -> &BTreeSet<Self::Item> {
252 self.elements()
253 }
254 fn is_subset(&self, other: &Self) -> bool {
255 self.is_subset(other)
256 }
257}
258impl NodeCreatorCollection for Formula {
259 type Item = Clause;
260 fn len(&self) -> usize {
261 self.len()
262 }
263 fn elements(&self) -> &BTreeSet<Self::Item> {
264 self.elements()
265 }
266 fn is_subset(&self, other: &Self) -> bool {
267 self.is_subset(other)
268 }
269}
270
271macro_rules! write_stat {
272 ($f:expr, $self:ident. $stat:ident) => {
273 write!($f, "{}: {}", stringify!($stat), $self.$stat)
274 };
275}
276
277#[derive(Default, Debug)]
278pub struct Stats {
279 literals: usize,
280 nodes: usize,
281 props: usize,
282 edges: usize,
283 tally_max: usize,
284 fanout_max: usize,
285 tally_total: usize,
286 fanout_total: usize,
287 nodes_with_fanout: usize,
288}
289impl Display for Stats {
290 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
291 write_stat!(f, self.literals)?;
292 write_stat!(f, self.nodes)?;
293 write_stat!(f, self.props)?;
294 write_stat!(f, self.edges)?;
295 write_stat!(f, self.tally_max)?;
296 write_stat!(f, self.fanout_max)?;
297 write_stat!(f, self.tally_total)?;
298 write_stat!(f, self.fanout_total)?;
299 write_stat!(f, self.nodes_with_fanout)?;
300
301 let tally_avg = self.tally_total as f64 / self.nodes as f64;
302 let fanout_avg = self.fanout_total as f64 / self.nodes as f64;
303 write!(f, "tally_avg: {tally_avg:.2}")?;
304 write!(f, "fanout_avg: {fanout_avg:.2}")?;
305
306 Ok(())
307 }
308}
309
310#[derive(Debug, Default)]
311pub struct LiteralMatcher {
312 pub wildcard: Option<Node>,
313 pub positive_values: IndexMap<PersistentStr, Vec<Node>>,
314 pub negative_values: Option<()>, }
316impl LiteralMatcher {
317 fn add_values(&mut self, values: Vec<PersistentStr>, node: Node) {
318 if values.is_empty() {
329 assert!(self.wildcard.is_none());
330 self.wildcard = Some(node);
331 }
332
333 for value in values {
334 self.positive_values.entry(value).or_default().push(node);
335 }
336 }
337}
338
339#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
340pub struct Node(usize);
341impl Node {
342 fn accumulate_stats(self, dag: &Dag, stats: &mut Stats, visited: &mut IndexSet<Node>) {
343 if !visited.contains(&self) {
344 visited.insert(self);
345 dag.get_data(self).accumulate_stats(dag, stats, visited);
346 }
347 }
348}
349
350#[derive(Debug, Clone)]
351pub enum NodeType {
352 Or,
353 And(Specificity),
354}
355impl From<&NodeType> for Op {
356 fn from(value: &NodeType) -> Self {
357 match value {
358 NodeType::Or => Op::Or,
359 NodeType::And(..) => Op::And,
360 }
361 }
362}
363
364#[derive(Debug, Clone)]
365pub struct NodeData {
366 pub children: Vec<Node>,
367 pub props: Vec<Property>,
368 pub constraints: Vec<Constraint>,
369 pub tally_count: usize,
371 pub op: NodeType,
372}
373impl NodeData {
374 fn and(specificity: Specificity) -> Self {
375 Self::with_op(NodeType::And(specificity))
376 }
377
378 fn or() -> Self {
379 Self::with_op(NodeType::Or)
380 }
381
382 fn with_op(op: NodeType) -> Self {
383 Self {
384 children: Default::default(),
385 props: Default::default(),
386 constraints: Default::default(),
387 tally_count: 0,
388 op,
389 }
390 }
391
392 fn add_link(&mut self) {
393 self.add_links(1)
394 }
395
396 fn add_links(&mut self, num: usize) {
397 self.tally_count += num
398 }
399
400 fn accumulate_subclass_stats(&self, stats: &mut Stats) {
401 if matches!(self.op, NodeType::And(..)) {
402 stats.tally_max = stats.tally_max.max(self.tally_count);
403 stats.tally_total += self.tally_count;
404 }
405 }
406
407 fn accumulate_stats(&self, dag: &Dag, stats: &mut Stats, visited: &mut IndexSet<Node>) {
408 stats.nodes += 1;
409 stats.props += self.props.len();
410 stats.edges += self.children.len();
411 stats.fanout_max = stats.fanout_max.max(self.children.len());
412 stats.fanout_total += self.children.len();
413 self.accumulate_subclass_stats(stats);
414 if !self.children.is_empty() {
415 stats.nodes_with_fanout += 1;
416 }
417 for node in &self.children {
418 node.accumulate_stats(dag, stats, visited);
419 }
420 }
421}
422
423#[cfg(feature = "dot")]
424pub mod dot {
425 use std::{fmt::Display, ops::AddAssign};
426
427 use crate::{ast::JoinedBy, search};
428
429 use super::*;
430 use petgraph::{
431 dot::{Config, Dot},
432 graph::NodeIndex,
433 };
434
435 pub type DiGraph = petgraph::graph::DiGraph<StyledNode, ()>;
436
437 #[allow(dead_code)]
438 pub struct StyledNode {
439 id: Node, label: String,
441 fillcolor: String,
442 style: String,
443 shape: String,
444 }
445 impl StyledNode {
446 pub fn new(id: Node, name: impl ToString) -> Self {
447 Self::styled(id, name, "", "", "")
448 }
449
450 pub fn styled(
451 id: Node,
452 name: impl ToString,
453 fillcolor: impl ToString,
454 style: impl ToString,
455 shape: impl ToString,
456 ) -> Self {
457 Self {
458 id,
459 label: name.to_string(),
460 fillcolor: fillcolor.to_string(),
461 style: style.to_string(),
462 shape: shape.to_string(),
463 }
464 }
465
466 pub fn to_dot(&self) -> String {
467 [
468 Self::attribute("label", &self.label),
469 Self::attribute("fillcolor", &self.fillcolor),
470 Self::attribute("style", &self.style),
471 Self::attribute("shape", &self.shape),
472 ]
473 .into_iter()
474 .joined_by(" ")
475 }
476
477 fn attribute(name: &str, value: &str) -> String {
478 if !value.is_empty() {
479 format!("{name}=\"{value}\"")
480 } else {
481 "".to_string()
482 }
483 }
484 }
485 impl Display for StyledNode {
486 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
487 write!(f, "{}", self.label)
488 }
489 }
490 impl std::fmt::Debug for StyledNode {
491 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
492 Display::fmt(self, f)
493 }
494 }
495
496 impl From<&RuleTreeNode> for DiGraph {
497 fn from(value: &RuleTreeNode) -> Self {
498 let mut g = Self::new();
499
500 let mut uid = 0;
501 fn add_node(g: &mut DiGraph, n: &RuleTreeNode, uid: &mut usize) -> NodeIndex {
502 let nodeid = g.add_node(StyledNode::styled(
503 Node(*uid),
504 n.label(),
505 n.color(),
506 "filled",
507 "box",
508 ));
509 uid.add_assign(1);
510 for c in &n.children {
511 let child = add_node(g, c, uid);
512 g.add_edge(nodeid, child, ());
513 }
514
515 nodeid
516 }
517 add_node(&mut g, value, &mut uid);
518
519 g
520 }
521 }
522
523 pub fn dag_to_digraph(dag: &Dag, tallies: &search::Tallies) -> DiGraph {
524 let mut g = DiGraph::new();
525
526 let mut node_mapping = IndexMap::default();
527 fn add_nodes(
528 dag: &Dag,
529 g: &mut DiGraph,
530 node_mapping: &mut IndexMap<Node, NodeIndex>,
531 t: &search::Tallies,
532 p: NodeIndex,
533 nodes: &[Node],
534 ) {
535 for n in nodes {
537 let n_data = dag.get_data(*n);
538 let (mut label, color) = if matches!(n_data.op, NodeType::Or) {
539 let label = "V".to_string();
540 if t.contains_key(n) {
541 (label, "palegreen")
542 } else {
543 (label, "lightblue")
544 }
545 } else {
546 let count = *t.get(n).unwrap_or(&n_data.tally_count);
547 let mut label = format!("{}", n_data.tally_count);
548 let color = if count == 0 {
549 "palegreen"
550 } else if count != n_data.tally_count {
551 label = format!("{} / {}", n_data.tally_count - count, label);
552 "mistyrose"
553 } else {
554 "pink2"
555 };
556 (label, color)
557 };
558 let mut style = "filled".to_string();
559 if !n_data.props.is_empty() {
560 label += &format!(" [{}]", n_data.props.iter().joined_by(","));
561 style += ", bold";
562 }
563 let node_id = if let Some(existing) = node_mapping.get(n) {
564 *existing
565 } else {
566 let node_id = g.add_node(StyledNode::styled(*n, label, color, style, ""));
567 node_mapping.insert(*n, node_id);
568 node_id
569 };
570
571 if !g.contains_edge(p, node_id) {
572 g.add_edge(p, node_id, ());
573 }
574 add_nodes(dag, g, node_mapping, t, node_id, &n_data.children);
575 }
576 }
577
578 let mut lit_id = 1000000;
580
581 for (l, matcher) in &dag.children {
582 let lit_node = Node(lit_id);
583 lit_id += 1;
584 let node_id = g.add_node(StyledNode::new(lit_node, l));
585 node_mapping.insert(lit_node, node_id);
586 if let Some(wildcard) = matcher.wildcard {
587 add_nodes(
588 dag,
589 &mut g,
590 &mut node_mapping,
591 tallies,
592 node_id,
593 &[wildcard],
594 );
595 }
596 for (v, nodes) in &matcher.positive_values {
597 let lit_node = Node(lit_id);
598 lit_id += 1;
599 let node_id_2 = g.add_node(StyledNode::styled(
600 lit_node,
601 v,
602 "lightyellow",
603 "filled",
604 "box",
605 ));
606 node_mapping.insert(lit_node, node_id_2);
607
608 if !g.contains_edge(node_id, node_id_2) {
609 g.add_edge(node_id, node_id_2, ());
610 }
611 add_nodes(dag, &mut g, &mut node_mapping, tallies, node_id_2, nodes);
612 }
613 }
614
615 g
616 }
617
618 impl From<&Dag> for DiGraph {
619 fn from(value: &Dag) -> Self {
620 dag_to_digraph(value, &Default::default())
621 }
622 }
623
624 pub fn to_dot<'a>(graph: &'a DiGraph) -> Dot<'a, &'a DiGraph> {
625 Dot::with_attr_getters(
626 graph,
627 &[Config::EdgeNoLabel, Config::NodeNoLabel],
628 &|_, _| "".to_string(),
629 &|_, (_, style)| style.to_dot(),
630 )
631 }
632
633 pub fn to_dot_str<G: Into<DiGraph>>(graph: G) -> String {
634 format!("{:?}", to_dot(&graph.into()))
635 }
636}
637
638#[cfg(test)]
639#[cfg(feature = "dot")]
640mod dot_tests {
641 use crate::{
642 ast::{NullResolver, RuleTreeNode},
643 dag::{Dag, dot::to_dot_str},
644 };
645
646 const MULTILEVEL_EXAMPLE: &str = r#"
647 a, f b e, c {
648 c d {
649 x = y
650 }
651 e f {
652 foobar = abc
653 }
654 }
655 a, c, b e f : baz = quux
656 "#;
657
658 #[test]
659 fn tree_to_dot_1() {
660 let n = crate::ast::parse(MULTILEVEL_EXAMPLE, NullResolver()).unwrap();
661
662 let mut tree = RuleTreeNode::default();
663 n.add_to(&mut tree);
664 println!("{}", to_dot_str(&tree));
665
666 let dag = Dag::build(tree);
667 println!("{}", to_dot_str(&dag));
668 }
669
670 #[test]
671 fn tree_to_dot_2() {
672 let n = crate::ast::parse(
673 r#"
674 a b c d e (f, g, h, i, j) {
675 x = y
676 }
677 "#,
678 NullResolver(),
679 )
680 .unwrap();
681
682 let mut tree = RuleTreeNode::default();
683 n.add_to(&mut tree);
684 println!("{}", to_dot_str(&tree));
685
686 let dag = Dag::build(tree);
687 println!("{}", to_dot_str(&dag));
688 }
689}