ccs2/
dag.rs

1use std::{collections::BTreeSet, fmt::Display, hash::Hash};
2
3use indexmap::{IndexMap, IndexSet};
4
5use crate::ast::{
6    Clause, Constraint, Formula, Key, Op, PersistentStr, Property, RuleTreeNode, Specificity,
7};
8
9#[derive(Debug)]
10pub struct Dag {
11    pub children: IndexMap<PersistentStr, LiteralMatcher>,
12    pub prop_node: Node,
13    pub next_node_id: usize,
14    pub node_data: IndexMap<Node, NodeData>,
15}
16impl Default for Dag {
17    fn default() -> Self {
18        let mut dag = Self {
19            children: IndexMap::default(),
20            prop_node: Node(0),
21            next_node_id: 0, // Temporary
22            node_data: IndexMap::default(),
23        };
24
25        dag.prop_node = dag.new_node(NodeData::or());
26        dag
27    }
28}
29impl Dag {
30    pub fn stats(&self) -> Stats {
31        let mut stats = Stats::default();
32        let mut visited = Default::default();
33
34        self.prop_node
35            .accumulate_stats(self, &mut stats, &mut visited);
36
37        for (_, matcher) in &self.children {
38            stats.literals += 1;
39            if let Some(wildcard) = matcher.wildcard {
40                wildcard.accumulate_stats(self, &mut stats, &mut visited);
41            }
42            for (_, nodes) in &matcher.positive_values {
43                for node in nodes {
44                    node.accumulate_stats(self, &mut stats, &mut visited);
45                }
46            }
47            // TODO: handle negatives as well
48        }
49        stats
50    }
51
52    fn new_node(&mut self, data: NodeData) -> Node {
53        let node = Node(self.next_id());
54        debug_assert!(!self.node_data.contains_key(&node));
55        self.node_data.insert(node, data);
56        node
57    }
58
59    fn next_id(&mut self) -> usize {
60        let id = self.next_node_id;
61        self.next_node_id += 1;
62        id
63    }
64
65    pub fn get_data(&self, node: Node) -> &NodeData {
66        &self.node_data[&node]
67    }
68
69    fn get_data_mut(&mut self, node: Node) -> &mut NodeData {
70        &mut self.node_data[&node]
71    }
72
73    pub fn build(rule_tree_node: RuleTreeNode) -> Self {
74        let mut dag = Self::default();
75        let mut lit_nodes = IndexMap::<Key, Node>::default();
76
77        // obviously there are better ways of gathering the unique literals and unique clauses,
78        // if performance needs to be improved...
79        let mut sorted_formulae: Vec<_> = rule_tree_node.iter().cloned().collect();
80        sorted_formulae.sort_by(|lhs, rhs| lhs.formula.cmp(&rhs.formula));
81
82        let mut all_clauses: Vec<_> = sorted_formulae
83            .iter()
84            .flat_map(|f| f.formula.elements().union(f.formula.shared()))
85            .collect();
86
87        all_clauses.sort();
88
89        let all_elements = all_clauses.iter().flat_map(|c| c.elements()).cloned();
90        // This dedup is very important
91        for lit in IndexSet::<Key>::from_iter(all_elements) {
92            lit_nodes.insert(lit.clone(), dag.add_literal(&lit));
93        }
94
95        let mut clause_nodes = IndexMap::<Clause, Node>::default();
96        for clause in all_clauses.into_iter() {
97            if !clause.is_empty() {
98                let specificity = clause.specificity();
99                let expr = dag.build_expr(
100                    clause.clone(),
101                    || NodeData::and(specificity),
102                    &mut lit_nodes,
103                    &mut clause_nodes,
104                );
105                clause_nodes.insert(clause.clone(), expr);
106            }
107        }
108
109        let mut form_nodes = IndexMap::<Formula, Node>::default();
110        for rule in sorted_formulae {
111            let node = if rule.formula.is_empty() {
112                dag.prop_node
113            } else {
114                let node = dag.build_expr(
115                    rule.formula.clone(),
116                    NodeData::or,
117                    &mut clause_nodes,
118                    &mut form_nodes,
119                );
120                form_nodes.insert(rule.formula, node);
121                node
122            };
123            let data = dag.get_data_mut(node);
124            data.props.extend(rule.props.iter().cloned());
125            data.constraints
126                .extend(rule.constraints.iter().cloned().map(Into::into));
127        }
128
129        dag
130    }
131
132    fn add_literal(&mut self, lit: &Key) -> Node {
133        let mut node_data = NodeData::and(lit.specificity);
134        node_data.add_link();
135        let node = self.new_node(node_data);
136        self.children
137            .entry(lit.name.clone())
138            .or_default()
139            .add_values(lit.values.clone(), node);
140        node
141    }
142
143    fn build_expr<E: NodeCreatorCollection>(
144        &mut self,
145        expr: E,
146        constructor: impl Fn() -> NodeData,
147        base_nodes: &mut IndexMap<E::Item, Node>, // Existing graph state
148        these_nodes: &mut IndexMap<E, Node>,      // Accumulator for current "layer"
149    ) -> Node {
150        assert!(!expr.is_empty());
151
152        if expr.len() == 1 {
153            return base_nodes[&expr.first().unwrap()];
154        } else if let Some(existing) = these_nodes.get(&expr) {
155            return *existing;
156        } else if expr.len() == 2 {
157            let mut node_data = constructor();
158            node_data.add_links(expr.len());
159            let node: Node = self.new_node(node_data);
160
161            for el in expr.elements().iter() {
162                self.get_data_mut(base_nodes[el]).children.push(node);
163            }
164            return node;
165        }
166
167        let mut item_collection_indices = IndexMap::<E::Item, Vec<usize>>::default();
168        for (i, c) in these_nodes.keys().enumerate() {
169            if c.is_subset(&expr) {
170                assert!(
171                    c.len() < expr.len(),
172                    "Exact equality should be handled above"
173                );
174                for el in c.elements() {
175                    item_collection_indices
176                        .entry(el.clone())
177                        .or_default()
178                        .push(i);
179                }
180            }
181        }
182
183        let mut covered = BTreeSet::new();
184        let node: Node = self.new_node(constructor());
185        let mut collections: Vec<_> = these_nodes.keys().cloned().collect();
186        collections.sort_by(|l, r| r.cmp(l)); // TODO: Confusing and maybe inefficient?
187
188        // TODO: this constant re-ranking is not very nice, but it should be roughly the same
189        // algorithmic complexity as the Python implementation, which re-heapifies the ranking list
190        // on every iteration. Still, this should be rethought
191        let biggest_collection = |covered_elements: &BTreeSet<E::Item>| {
192            let collection_rank = |collection: &E| {
193                collection
194                    .elements()
195                    .iter()
196                    .filter(|e| !covered_elements.contains(*e))
197                    .count()
198            };
199            collections
200                .iter()
201                .filter(|collection| collection.is_subset(&expr))
202                .map(|collection| (collection_rank(collection), collection))
203                .filter(|(rank, _)| *rank != 0)
204                .max_by_key(|(rank, _)| *rank) // TODO: Tie?
205                .map(|(_, collection)| collection)
206        };
207
208        while let Some(best) = biggest_collection(&covered) {
209            self.get_data_mut(these_nodes[best]).children.push(node);
210            self.get_data_mut(node).add_link();
211            for el in best.elements() {
212                if !covered.contains(el) {
213                    covered.insert(el.clone());
214                }
215            }
216        }
217
218        let remaining = expr.elements() - &covered;
219        self.get_data_mut(node).add_links(remaining.len());
220        for el in remaining {
221            self.get_data_mut(base_nodes[&el]).children.push(node);
222        }
223
224        node
225    }
226}
227
228trait NodeCreator: Hash + Eq + Ord + Clone + std::fmt::Debug + std::fmt::Display {}
229impl NodeCreator for Key {}
230impl NodeCreator for Clause {}
231impl NodeCreator for Formula {}
232
233trait NodeCreatorCollection: NodeCreator {
234    type Item: NodeCreator;
235
236    fn len(&self) -> usize;
237    fn is_empty(&self) -> bool {
238        self.len() == 0
239    }
240    fn elements(&self) -> &BTreeSet<Self::Item>;
241    fn first(&self) -> Option<Self::Item> {
242        self.elements().iter().next().cloned()
243    }
244    fn is_subset(&self, other: &Self) -> bool;
245}
246impl NodeCreatorCollection for Clause {
247    type Item = Key;
248    fn len(&self) -> usize {
249        self.len()
250    }
251    fn elements(&self) -> &BTreeSet<Self::Item> {
252        self.elements()
253    }
254    fn is_subset(&self, other: &Self) -> bool {
255        self.is_subset(other)
256    }
257}
258impl NodeCreatorCollection for Formula {
259    type Item = Clause;
260    fn len(&self) -> usize {
261        self.len()
262    }
263    fn elements(&self) -> &BTreeSet<Self::Item> {
264        self.elements()
265    }
266    fn is_subset(&self, other: &Self) -> bool {
267        self.is_subset(other)
268    }
269}
270
271macro_rules! write_stat {
272    ($f:expr, $self:ident. $stat:ident) => {
273        write!($f, "{}: {}", stringify!($stat), $self.$stat)
274    };
275}
276
277#[derive(Default, Debug)]
278pub struct Stats {
279    literals: usize,
280    nodes: usize,
281    props: usize,
282    edges: usize,
283    tally_max: usize,
284    fanout_max: usize,
285    tally_total: usize,
286    fanout_total: usize,
287    nodes_with_fanout: usize,
288}
289impl Display for Stats {
290    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
291        write_stat!(f, self.literals)?;
292        write_stat!(f, self.nodes)?;
293        write_stat!(f, self.props)?;
294        write_stat!(f, self.edges)?;
295        write_stat!(f, self.tally_max)?;
296        write_stat!(f, self.fanout_max)?;
297        write_stat!(f, self.tally_total)?;
298        write_stat!(f, self.fanout_total)?;
299        write_stat!(f, self.nodes_with_fanout)?;
300
301        let tally_avg = self.tally_total as f64 / self.nodes as f64;
302        let fanout_avg = self.fanout_total as f64 / self.nodes as f64;
303        write!(f, "tally_avg: {tally_avg:.2}")?;
304        write!(f, "fanout_avg: {fanout_avg:.2}")?;
305
306        Ok(())
307    }
308}
309
310#[derive(Debug, Default)]
311pub struct LiteralMatcher {
312    pub wildcard: Option<Node>,
313    pub positive_values: IndexMap<PersistentStr, Vec<Node>>,
314    pub negative_values: Option<()>, // TODO: support this
315}
316impl LiteralMatcher {
317    fn add_values(&mut self, values: Vec<PersistentStr>, node: Node) {
318        // because we find the set of unique literals prior to creating these matchers, we
319        // don't currently need to worry about the added node representing being redundant.
320        // each node will definitely represent a unique set of values for this name. in the
321        // event that the node doesn't end up with any local property settings, building a
322        // separate node for every combination is overkill. it might be nice to detect this
323        // case and elide the subset node, replacing it with individual nodes for each member.
324        // on the other hand, whether this is an improvement depends on whether or not those
325        // individual nodes will actually end up existing either way, or alternatively on the
326        // number of different sets those values appear in. this isn't a tradeoff with an
327        // easy obvious answer.
328        if values.is_empty() {
329            assert!(self.wildcard.is_none());
330            self.wildcard = Some(node);
331        }
332
333        for value in values {
334            self.positive_values.entry(value).or_default().push(node);
335        }
336    }
337}
338
339#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
340pub struct Node(usize);
341impl Node {
342    fn accumulate_stats(self, dag: &Dag, stats: &mut Stats, visited: &mut IndexSet<Node>) {
343        if !visited.contains(&self) {
344            visited.insert(self);
345            dag.get_data(self).accumulate_stats(dag, stats, visited);
346        }
347    }
348}
349
350#[derive(Debug, Clone)]
351pub enum NodeType {
352    Or,
353    And(Specificity),
354}
355impl From<&NodeType> for Op {
356    fn from(value: &NodeType) -> Self {
357        match value {
358            NodeType::Or => Op::Or,
359            NodeType::And(..) => Op::And,
360        }
361    }
362}
363
364#[derive(Debug, Clone)]
365pub struct NodeData {
366    pub children: Vec<Node>,
367    pub props: Vec<Property>,
368    pub constraints: Vec<Constraint>,
369    /// Used for poisoning in case of OrNode
370    pub tally_count: usize,
371    pub op: NodeType,
372}
373impl NodeData {
374    fn and(specificity: Specificity) -> Self {
375        Self::with_op(NodeType::And(specificity))
376    }
377
378    fn or() -> Self {
379        Self::with_op(NodeType::Or)
380    }
381
382    fn with_op(op: NodeType) -> Self {
383        Self {
384            children: Default::default(),
385            props: Default::default(),
386            constraints: Default::default(),
387            tally_count: 0,
388            op,
389        }
390    }
391
392    fn add_link(&mut self) {
393        self.add_links(1)
394    }
395
396    fn add_links(&mut self, num: usize) {
397        self.tally_count += num
398    }
399
400    fn accumulate_subclass_stats(&self, stats: &mut Stats) {
401        if matches!(self.op, NodeType::And(..)) {
402            stats.tally_max = stats.tally_max.max(self.tally_count);
403            stats.tally_total += self.tally_count;
404        }
405    }
406
407    fn accumulate_stats(&self, dag: &Dag, stats: &mut Stats, visited: &mut IndexSet<Node>) {
408        stats.nodes += 1;
409        stats.props += self.props.len();
410        stats.edges += self.children.len();
411        stats.fanout_max = stats.fanout_max.max(self.children.len());
412        stats.fanout_total += self.children.len();
413        self.accumulate_subclass_stats(stats);
414        if !self.children.is_empty() {
415            stats.nodes_with_fanout += 1;
416        }
417        for node in &self.children {
418            node.accumulate_stats(dag, stats, visited);
419        }
420    }
421}
422
423#[cfg(feature = "dot")]
424pub mod dot {
425    use std::{fmt::Display, ops::AddAssign};
426
427    use crate::{ast::JoinedBy, search};
428
429    use super::*;
430    use petgraph::{
431        dot::{Config, Dot},
432        graph::NodeIndex,
433    };
434
435    pub type DiGraph = petgraph::graph::DiGraph<StyledNode, ()>;
436
437    #[allow(dead_code)]
438    pub struct StyledNode {
439        id: Node, // Unique identity in the dag
440        label: String,
441        fillcolor: String,
442        style: String,
443        shape: String,
444    }
445    impl StyledNode {
446        pub fn new(id: Node, name: impl ToString) -> Self {
447            Self::styled(id, name, "", "", "")
448        }
449
450        pub fn styled(
451            id: Node,
452            name: impl ToString,
453            fillcolor: impl ToString,
454            style: impl ToString,
455            shape: impl ToString,
456        ) -> Self {
457            Self {
458                id,
459                label: name.to_string(),
460                fillcolor: fillcolor.to_string(),
461                style: style.to_string(),
462                shape: shape.to_string(),
463            }
464        }
465
466        pub fn to_dot(&self) -> String {
467            [
468                Self::attribute("label", &self.label),
469                Self::attribute("fillcolor", &self.fillcolor),
470                Self::attribute("style", &self.style),
471                Self::attribute("shape", &self.shape),
472            ]
473            .into_iter()
474            .joined_by(" ")
475        }
476
477        fn attribute(name: &str, value: &str) -> String {
478            if !value.is_empty() {
479                format!("{name}=\"{value}\"")
480            } else {
481                "".to_string()
482            }
483        }
484    }
485    impl Display for StyledNode {
486        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
487            write!(f, "{}", self.label)
488        }
489    }
490    impl std::fmt::Debug for StyledNode {
491        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
492            Display::fmt(self, f)
493        }
494    }
495
496    impl From<&RuleTreeNode> for DiGraph {
497        fn from(value: &RuleTreeNode) -> Self {
498            let mut g = Self::new();
499
500            let mut uid = 0;
501            fn add_node(g: &mut DiGraph, n: &RuleTreeNode, uid: &mut usize) -> NodeIndex {
502                let nodeid = g.add_node(StyledNode::styled(
503                    Node(*uid),
504                    n.label(),
505                    n.color(),
506                    "filled",
507                    "box",
508                ));
509                uid.add_assign(1);
510                for c in &n.children {
511                    let child = add_node(g, c, uid);
512                    g.add_edge(nodeid, child, ());
513                }
514
515                nodeid
516            }
517            add_node(&mut g, value, &mut uid);
518
519            g
520        }
521    }
522
523    pub fn dag_to_digraph(dag: &Dag, tallies: &search::Tallies) -> DiGraph {
524        let mut g = DiGraph::new();
525
526        let mut node_mapping = IndexMap::default();
527        fn add_nodes(
528            dag: &Dag,
529            g: &mut DiGraph,
530            node_mapping: &mut IndexMap<Node, NodeIndex>,
531            t: &search::Tallies,
532            p: NodeIndex,
533            nodes: &[Node],
534        ) {
535            // TODO: active_only?
536            for n in nodes {
537                let n_data = dag.get_data(*n);
538                let (mut label, color) = if matches!(n_data.op, NodeType::Or) {
539                    let label = "V".to_string();
540                    if t.contains_key(n) {
541                        (label, "palegreen")
542                    } else {
543                        (label, "lightblue")
544                    }
545                } else {
546                    let count = *t.get(n).unwrap_or(&n_data.tally_count);
547                    let mut label = format!("{}", n_data.tally_count);
548                    let color = if count == 0 {
549                        "palegreen"
550                    } else if count != n_data.tally_count {
551                        label = format!("{} / {}", n_data.tally_count - count, label);
552                        "mistyrose"
553                    } else {
554                        "pink2"
555                    };
556                    (label, color)
557                };
558                let mut style = "filled".to_string();
559                if !n_data.props.is_empty() {
560                    label += &format!(" [{}]", n_data.props.iter().joined_by(","));
561                    style += ", bold";
562                }
563                let node_id = if let Some(existing) = node_mapping.get(n) {
564                    *existing
565                } else {
566                    let node_id = g.add_node(StyledNode::styled(*n, label, color, style, ""));
567                    node_mapping.insert(*n, node_id);
568                    node_id
569                };
570
571                if !g.contains_edge(p, node_id) {
572                    g.add_edge(p, node_id, ());
573                }
574                add_nodes(dag, g, node_mapping, t, node_id, &n_data.children);
575            }
576        }
577
578        // These aren't real nodes in the Dag, but we want to draw them anyway
579        let mut lit_id = 1000000;
580
581        for (l, matcher) in &dag.children {
582            let lit_node = Node(lit_id);
583            lit_id += 1;
584            let node_id = g.add_node(StyledNode::new(lit_node, l));
585            node_mapping.insert(lit_node, node_id);
586            if let Some(wildcard) = matcher.wildcard {
587                add_nodes(
588                    dag,
589                    &mut g,
590                    &mut node_mapping,
591                    tallies,
592                    node_id,
593                    &[wildcard],
594                );
595            }
596            for (v, nodes) in &matcher.positive_values {
597                let lit_node = Node(lit_id);
598                lit_id += 1;
599                let node_id_2 = g.add_node(StyledNode::styled(
600                    lit_node,
601                    v,
602                    "lightyellow",
603                    "filled",
604                    "box",
605                ));
606                node_mapping.insert(lit_node, node_id_2);
607
608                if !g.contains_edge(node_id, node_id_2) {
609                    g.add_edge(node_id, node_id_2, ());
610                }
611                add_nodes(dag, &mut g, &mut node_mapping, tallies, node_id_2, nodes);
612            }
613        }
614
615        g
616    }
617
618    impl From<&Dag> for DiGraph {
619        fn from(value: &Dag) -> Self {
620            dag_to_digraph(value, &Default::default())
621        }
622    }
623
624    pub fn to_dot<'a>(graph: &'a DiGraph) -> Dot<'a, &'a DiGraph> {
625        Dot::with_attr_getters(
626            graph,
627            &[Config::EdgeNoLabel, Config::NodeNoLabel],
628            &|_, _| "".to_string(),
629            &|_, (_, style)| style.to_dot(),
630        )
631    }
632
633    pub fn to_dot_str<G: Into<DiGraph>>(graph: G) -> String {
634        format!("{:?}", to_dot(&graph.into()))
635    }
636}
637
638#[cfg(test)]
639#[cfg(feature = "dot")]
640mod dot_tests {
641    use crate::{
642        ast::{NullResolver, RuleTreeNode},
643        dag::{Dag, dot::to_dot_str},
644    };
645
646    const MULTILEVEL_EXAMPLE: &str = r#"
647        a, f b e, c {
648            c d {
649                x = y
650            }
651            e f {
652                foobar = abc
653            }
654        }
655        a, c, b e f : baz = quux
656    "#;
657
658    #[test]
659    fn tree_to_dot_1() {
660        let n = crate::ast::parse(MULTILEVEL_EXAMPLE, NullResolver()).unwrap();
661
662        let mut tree = RuleTreeNode::default();
663        n.add_to(&mut tree);
664        println!("{}", to_dot_str(&tree));
665
666        let dag = Dag::build(tree);
667        println!("{}", to_dot_str(&dag));
668    }
669
670    #[test]
671    fn tree_to_dot_2() {
672        let n = crate::ast::parse(
673            r#"
674                a b c d e (f, g, h, i, j) {
675                    x = y
676                }
677            "#,
678            NullResolver(),
679        )
680        .unwrap();
681
682        let mut tree = RuleTreeNode::default();
683        n.add_to(&mut tree);
684        println!("{}", to_dot_str(&tree));
685
686        let dag = Dag::build(tree);
687        println!("{}", to_dot_str(&dag));
688    }
689}