1use std::collections::BTreeSet;
5
6use crate::ast::{
7 Selector,
8 formula::{Clause, Formula},
9};
10
11pub fn to_dnf(expr: Selector, limit: usize) -> Formula {
12 match expr {
13 Selector::Step(key) => Formula::with_clause(Clause::with_literal(key)),
14 Selector::Expr(expr) => match expr.op {
15 super::Op::Or => merge(expr.children.into_iter().map(|e| to_dnf(e, limit))),
16 super::Op::And => expand(expr.children.into_iter().map(|e| to_dnf(e, limit)), limit),
17 },
18 }
19}
20
21pub fn merge(forms: impl Iterator<Item = Formula>) -> Formula {
23 forms
24 .into_iter()
25 .reduce(|acc, f| acc.merge(f))
26 .unwrap_or(Formula::default())
27}
28
29pub fn expand(forms: impl Iterator<Item = Formula>, limit: usize) -> Formula {
34 let forms: Vec<_> = forms.into_iter().collect();
35
36 let mut nontrivial = 0u32;
40 let mut common = Clause::default();
41 let mut result_size = 1;
42 for f in forms.iter() {
43 result_size *= f.len();
44 if f.len() == 1 {
45 common = common.union(f.first())
46 } else {
47 nontrivial += 1;
48 }
49 }
50
51 if result_size > limit {
52 panic!(
53 "Expanded form would have {result_size} clauses, which is more than the limit of \
54 {limit}. Consider increasing the limit or stratifying this rule."
55 );
56 }
57
58 fn exprec(forms: &[Formula]) -> Formula {
60 if forms.is_empty() {
61 return Formula::default();
62 }
63 let first = forms.first().unwrap();
64 let rest = exprec(&forms[1..]);
65 let cs = first
66 .elements()
67 .iter()
68 .flat_map(|c1| rest.elements().iter().map(|c2| c1.union(c2)))
69 .collect();
70 Formula::new(cs, first.shared().union(rest.shared()).cloned().collect())
71 }
72
73 let res = exprec(&forms);
74
75 let mut all_shared = BTreeSet::<Clause>::default();
77 if nontrivial > 0 && common.len() > 1 {
78 all_shared.insert(common);
79 }
80 if nontrivial > 1 {
81 for f in forms {
82 if f.len() > 1 {
83 all_shared.extend(f.elements().iter().filter(|c| c.len() > 1).cloned())
84 }
85 }
86 }
87
88 Formula::new(
89 res.elements().clone(),
90 res.shared().union(&all_shared).cloned().collect(),
91 )
92 .normalize()
93}
94
95#[cfg(test)]
96mod tests {
97 use super::*;
98 use crate::ast::{flatten, macros::*};
99
100 #[test]
101 fn dnf() {
102 let selector = OR!(AND!("a", "b"), AND!("c", "d"));
103 let expected_str = "a b, c d";
104
105 assert_eq!(to_dnf(flatten(selector), 100).to_string(), expected_str);
106 }
107
108 #[test]
109 fn cnf() {
110 let selector = AND!(OR!("a", "b"), OR!("c", "d"));
111 let expected_str = "a c, a d, b c, b d";
112
113 assert_eq!(to_dnf(flatten(selector), 100).to_string(), expected_str);
114 }
115
116 #[test]
117 fn nested_and() {
118 let selector = AND!(AND!("a", "b"), AND!("c", "d"));
119 let expected_str = "a b c d";
120
121 assert_eq!(to_dnf(flatten(selector), 100).to_string(), expected_str);
122 }
123
124 #[test]
125 fn sharing() {
126 let selector = AND!(AND!("a", "f", OR!("b", "e")), AND!("c", "d"));
127 let expected_str = "a b c d f, a c d e f";
128
129 assert_eq!(to_dnf(flatten(selector), 100).to_string(), expected_str);
130 }
131
132 #[test]
133 fn flatten_single_key_leaf_disjunctions() {
134 let selector = AND!(OR!("a.x", "a.y", "a.z"), "b");
135 let expected_str = "(a.x, a.y, a.z) b";
136
137 assert_eq!(to_dnf(flatten(selector), 100).to_string(), expected_str);
138 }
139
140 #[test]
141 fn cartesian_product() {
142 let selector = AND!(OR!("a", "b", "c"), OR!("d", "e", "f"), OR!("g", "h", "i"));
143 #[rustfmt::skip]
144 let expected_str = "a d g, a d h, a d i, a e g, a e h, a e i, a f g, a f h, a f i, \
145 b d g, b d h, b d i, b e g, b e h, b e i, b f g, b f h, b f i, \
146 c d g, c d h, c d i, c e g, c e h, c e i, c f g, c f h, c f i";
147
148 assert_eq!(to_dnf(flatten(selector), 100).to_string(), expected_str);
149 }
150}