1use std::{
2 collections::HashMap,
3 fmt::Display,
4 hash::Hash,
5 ops::Add,
6 path::{Path, PathBuf},
7};
8
9use indexmap::{IndexMap, IndexSet};
10use itertools::Itertools;
11
12mod dnf;
13mod formula;
14mod parser;
15mod property;
16mod rule_tree;
17
18pub type PersistentStr = std::sync::Arc<str>;
21
22pub use dnf::to_dnf;
23pub use formula::{Clause, Formula};
24pub use property::{Property, PropertyValue};
25pub use rule_tree::RuleTreeNode;
26
27pub fn parse(file_contents: impl AsRef<str>, resolver: impl ImportResolver) -> AstResult<Nested> {
28 parser::parse(file_contents, resolver, &mut vec![])
29}
30
31#[derive(thiserror::Error, Debug)]
35pub enum AstError {
36 #[error(transparent)]
37 ParseError(#[from] parser::ParseError),
38 #[error("Circular import detected from file {0}")]
39 CircularImport(PathBuf),
40 #[error("Failed to resolve import {0}")]
41 ImportFailed(PathBuf),
42}
43
44pub type AstResult<T> = Result<T, AstError>;
45
46#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
47pub struct Specificity {
48 override_level: u32,
49 positive: u32,
50 negative: u32,
51 wildcard: u32,
52}
53impl Specificity {
54 pub const fn zero() -> Self {
55 Self::new(0, 0, 0, 0)
56 }
57
58 pub const fn positive_lit() -> Self {
59 Self::new(0, 1, 0, 0)
60 }
61
62 pub const fn wildcard() -> Self {
63 Self::new(0, 0, 0, 1)
64 }
65
66 pub const fn new(override_level: u32, positive: u32, negative: u32, wildcard: u32) -> Self {
67 Self {
68 override_level,
69 positive,
70 negative,
71 wildcard,
72 }
73 }
74}
75impl Add for Specificity {
76 type Output = Self;
77
78 fn add(self, rhs: Self) -> Self::Output {
79 Self {
80 override_level: self.override_level + rhs.override_level,
81 positive: self.positive + rhs.positive,
82 negative: self.negative + rhs.negative,
83 wildcard: self.wildcard + rhs.wildcard,
84 }
85 }
86}
87
88#[derive(Debug, Clone, Eq)]
89pub struct Key {
90 pub name: PersistentStr,
91 pub values: Vec<PersistentStr>,
92 pub specificity: Specificity,
93}
94impl Key {
95 pub fn new_lit(name: impl ToString, values: impl IntoIterator<Item = PersistentStr>) -> Self {
96 Self::create(name, values, Specificity::positive_lit())
97 }
98
99 pub fn new(name: impl ToString, values: impl IntoIterator<Item = PersistentStr>) -> Self {
100 let values: Vec<_> = values.into_iter().collect();
102 let specificity = if !values.is_empty() {
103 Specificity::positive_lit()
104 } else {
105 Specificity::wildcard()
106 };
107 Self::create(name, values, specificity)
108 }
109
110 fn create(
111 name: impl ToString,
112 values: impl IntoIterator<Item = PersistentStr>,
113 specificity: Specificity,
114 ) -> Self {
115 let mut values: Vec<_> = values.into_iter().collect();
116 values.sort_unstable(); Self {
118 name: name.to_string().into(),
119 specificity,
120 values: values.into_iter().collect(),
121 }
122 }
123
124 #[cfg(test)]
128 pub fn parse(input: impl Into<PersistentStr>) -> Self {
129 let input = input.into();
130 let inputs: Vec<&str> = input.split(".").collect();
131 assert!(!inputs.is_empty());
132 let (key, values) = &inputs.split_at(1);
133 let values: Vec<_> = values.iter().map(|s| s.to_string().into()).collect();
134 Self::new(key[0], values)
135 }
136}
137impl Hash for Key {
138 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
139 self.name.hash(state);
142 self.values.hash(state);
143 }
145}
146impl PartialEq for Key {
147 fn eq(&self, other: &Self) -> bool {
148 self.name == other.name && self.values == other.values
149 }
150}
151impl PartialOrd for Key {
152 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
153 Some(self.cmp(other))
154 }
155}
156impl Ord for Key {
157 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
158 debug_assert!(self.values.is_sorted());
159 debug_assert!(other.values.is_sorted());
160
161 self.name
162 .cmp(&other.name)
163 .then_with(|| self.values.cmp(&other.values))
164 }
165}
166impl Display for Key {
167 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
168 if self.values.len() > 1 {
170 let vals_str = self
171 .values
172 .iter()
173 .map(|val| format!("{}.{val}", self.name))
174 .joined_by(", ");
175 write!(f, "({})", vals_str)
176 } else if let Some(first) = self.values.first()
177 && !first.is_empty()
178 {
179 write!(f, "{}.{first}", self.name)
180 } else {
181 write!(f, "{}", self.name)
182 }
183 }
184}
185
186pub type Env = HashMap<String, String>;
187
188#[derive(Debug, Clone, PartialEq, Eq, Hash)]
190pub struct Origin {
191 pub filename: PathBuf,
192 pub line_number: u32,
193}
194impl Display for Origin {
195 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
196 write!(
197 f,
198 "{}:{}",
199 self.filename.to_str().unwrap(),
200 self.line_number,
201 )
202 }
203}
204
205#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
206pub enum Op {
207 And,
208 Or,
209}
210impl Display for Op {
211 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
212 match self {
213 Op::And => write!(f, "AND"),
214 Op::Or => write!(f, "OR"),
215 }
216 }
217}
218
219#[derive(Clone, Debug, PartialEq, Eq, Hash)]
221pub enum Selector {
222 Expr(Expr),
224 Step(Key),
226}
227impl Display for Selector {
228 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
229 match self {
230 Selector::Expr(expr) => write!(f, "{expr}"),
231 Selector::Step(key) => write!(f, "{key}"),
232 }
233 }
234}
235
236pub trait ImportResolver {
238 fn current_file_name(&self) -> PathBuf;
239
240 fn new_context(&self, location: &Path) -> AstResult<Self>
241 where
242 Self: Sized;
243
244 fn load(&self) -> AstResult<String>;
245}
246
247pub(crate) struct NullResolver();
249impl ImportResolver for NullResolver {
250 fn current_file_name(&self) -> PathBuf {
251 PathBuf::new()
252 }
253 fn new_context(&self, _: &Path) -> AstResult<Self> {
254 Ok(Self())
255 }
256 fn load(&self) -> AstResult<String> {
257 Ok("".to_string())
258 }
259}
260
261#[derive(Clone, Debug, PartialEq, Eq, Hash)]
262pub struct Expr {
263 pub op: Op,
264 pub children: Vec<Selector>,
265}
266impl Expr {
267 pub fn new(op: Op, children: impl IntoIterator<Item = Selector>) -> Self {
268 Self {
269 op,
270 children: children.into_iter().collect(),
271 }
272 }
273}
274impl Display for Expr {
275 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
276 let children_str = self.children.iter().joined_by(" ");
277 write!(f, "({} {children_str})", self.op)
278 }
279}
280
281pub fn conj(terms: Vec<Selector>) -> Expr {
282 Expr {
283 op: Op::And,
284 children: terms,
285 }
286}
287pub fn disj(terms: Vec<Selector>) -> Expr {
288 Expr {
289 op: Op::Or,
290 children: terms,
291 }
292}
293
294#[derive(Debug)]
295pub struct Step {
296 pub key: Key,
297}
298
299#[derive(Debug)]
301pub enum AstNode {
302 Import(Import),
304 PropDef(PropDef),
306 Constraint(Key),
308 Nested(Nested),
310}
311impl AstNode {
312 pub fn add_to(&self, build_context: &mut RuleTreeNode) {
313 use AstNode::*;
314 match self {
315 Import(import) => {
316 if let Some(ast) = import.ast.as_ref() {
317 ast.add_to(build_context)
318 } else {
319 panic!("Attempted to add Import node without a resolved AST context");
320 }
321 }
322 PropDef(prop_def) => build_context.add_property(
323 &prop_def.name,
324 &prop_def.value,
325 prop_def.origin.clone(),
326 prop_def.should_override,
327 ),
328 Constraint(key) => {
329 build_context.add_constraint(key.clone());
330 }
331 Nested(nested) => {
332 nested.add_to(build_context);
333 }
334 }
335 }
336
337 pub fn resolve_imports<R: ImportResolver>(
338 &mut self,
339 resolver: &R,
340 in_progress: &mut Vec<PathBuf>,
341 ) -> AstResult<()> {
342 use AstNode::*;
343 match self {
344 Import(import) => {
345 if in_progress.contains(&import.location) {
346 Err(AstError::CircularImport(import.location.clone()))
347 } else {
348 in_progress.push(import.location.clone());
349
350 let resolver = resolver.new_context(&import.location)?;
351 let nested = parser::parse(resolver.load()?, resolver, in_progress)?;
352
353 in_progress.pop(); import.ast = Some(Box::new(AstNode::Nested(nested)));
355 Ok(())
356 }
357 }
358 Nested(nested) => nested.resolve_imports_internal(resolver, in_progress),
359 PropDef(prop_def) => {
360 prop_def.origin.filename = resolver.current_file_name();
361 Ok(())
362 }
363 Constraint(..) => Ok(()),
364 }
365 }
366}
367impl Display for AstNode {
368 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
369 use AstNode::*;
370 match self {
371 Import(import) => import.fmt(f),
372 PropDef(prop_def) => prop_def.fmt(f),
373 Constraint(key) => write!(f, "@constrain {key}"),
374 Nested(nested) => nested.fmt(f),
375 }
376 }
377}
378
379#[derive(Debug)]
380pub struct Import {
381 location: PathBuf,
382 ast: Option<Box<AstNode>>,
383}
384impl Import {
385 pub fn new(location: impl AsRef<Path>) -> Self {
386 Self {
387 location: location.as_ref().to_path_buf(),
388 ast: None,
389 }
390 }
391}
392impl Display for Import {
393 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
394 write!(f, "@import '{}'", self.location.to_string_lossy())
395 }
396}
397
398#[derive(Clone, Debug, PartialEq, Eq, Hash)]
399pub struct PropDef {
400 pub name: String,
401 pub value: String,
402 pub origin: Origin,
403 pub should_override: bool,
404}
405impl Display for PropDef {
406 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
407 write!(f, "def {}", self.name)
408 }
409}
410
411#[derive(Clone, Debug)]
412pub struct Constraint {
413 pub key: Key,
414}
415impl From<Key> for Constraint {
416 fn from(key: Key) -> Self {
417 Self { key }
418 }
419}
420impl Display for Constraint {
421 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
422 self.key.fmt(f)
423 }
424}
425
426#[derive(Default, Debug)]
427pub struct Nested {
428 selector: Option<Selector>,
429 rules: Vec<AstNode>,
430}
431impl Nested {
432 pub fn set_selector(&mut self, selector: Selector) {
433 assert!(self.selector.is_none());
434 self.selector = Some(selector)
435 }
436
437 pub fn append(&mut self, rule: AstNode) {
438 self.rules.push(rule)
439 }
440
441 pub fn add_to(&self, build_context: &mut RuleTreeNode) {
442 let build_context = if let Some(selector) = self.selector.as_ref() {
443 build_context.traverse(selector.clone())
444 } else {
445 build_context
446 };
447 for rule in self.rules.iter() {
448 rule.add_to(build_context);
449 }
450 }
451
452 pub fn resolve_imports<R: ImportResolver>(&mut self, resolver: &R) -> AstResult<()> {
453 let mut in_progress = vec![];
454 self.resolve_imports_internal(resolver, &mut in_progress)
455 }
456
457 pub fn resolve_imports_internal<R: ImportResolver>(
458 &mut self,
459 resolver: &R,
460 in_progress: &mut Vec<PathBuf>,
461 ) -> AstResult<()> {
462 for rule in self.rules.iter_mut() {
463 rule.resolve_imports(resolver, in_progress)?;
464 }
465 Ok(())
466 }
467}
468impl Display for Nested {
469 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
470 let selector = self
471 .selector
472 .as_ref()
473 .map(ToString::to_string)
474 .unwrap_or("None".to_string());
475 write!(f, "{selector} {{ {} }}", self.rules.iter().joined_by("; "))
476 }
477}
478
479pub fn flatten(expr: Selector) -> Selector {
484 match expr {
485 Selector::Step(k) => Selector::Step(k),
486 Selector::Expr(expr) => {
487 let mut lit_children = IndexMap::<Selector, IndexSet<PersistentStr>>::default();
488 let mut new_children = Vec::<Selector>::default();
489
490 let mut add_child = |e: Selector| {
491 match (e, expr.op) {
492 (Selector::Step(key), Op::Or) => {
493 let key_without_values = Key::new(key.name, []);
503 lit_children
504 .entry(Selector::Step(key_without_values))
505 .or_default()
506 .extend(key.values.iter().cloned());
507 }
508 (e, _) => {
509 new_children.push(e);
510 }
511 }
512 };
513
514 for e in expr.children.into_iter().map(flatten) {
515 match e {
516 Selector::Expr(e) => {
517 if e.op == expr.op {
518 for c in e.children {
519 add_child(c)
520 }
521 } else {
522 add_child(Selector::Expr(e))
523 }
524 }
525 Selector::Step(..) => add_child(e),
526 }
527 }
528
529 for (child, values) in lit_children {
530 match child {
531 Selector::Step(key) => {
532 new_children.push(Selector::Step(Key::new(key.name, values)))
533 }
534 Selector::Expr(..) => panic!("Attempted to add literal expr!"),
535 }
536 }
537 if new_children.len() == 1 {
538 new_children.into_iter().next().unwrap()
539 } else {
540 Selector::Expr(Expr::new(expr.op, new_children))
541 }
542 }
543 }
544}
545
546pub trait JoinedBy {
547 fn joined_by(self, joiner: impl AsRef<str>) -> String;
548}
549impl<S: ToString, T: Iterator<Item = S>> JoinedBy for T {
550 fn joined_by(self, joiner: impl AsRef<str>) -> String {
551 Itertools::intersperse(self.map(|s| s.to_string()), joiner.as_ref().to_string()).collect()
552 }
553}
554
555pub mod macros {
556 #![allow(unused)] macro_rules! selector {
559 ($item:literal) => {
560 crate::ast::Selector::Step(crate::ast::Key::parse($item))
561 };
562 ($item:expr) => {
563 $item
564 };
565 }
566
567 macro_rules! key {
568 ($item:literal $(: ($($val:literal)+))?) => {
569 crate::ast::Key::new($item, [$($($val.to_string(),)+)?])
570 };
571 }
572
573 macro_rules! kv_selector {
574 ($item:literal $(: ($($val:literal)+))?) => {
575 crate::ast::Selector::Step(crate::ast::Key::new($item, [$($($val.to_string().into(),)+)?]))
576 };
577 }
578
579 macro_rules! expr {
580 ($operator:ident, $op1:expr $(, $ops:expr)*) => {
581 crate::ast::Selector::Expr(
582 crate::ast::Expr::new(
583 crate::ast::Op::$operator, [selector!($op1) $(, selector!($ops))*]
584 )
585 )
586 }
587 }
588
589 macro_rules! AND {
590 ($op1:expr $(, $ops:expr)*) => {
591 expr!(And, $op1 $(, $ops)*)
592 }
593 }
594
595 macro_rules! OR {
596 ($op1:expr $(, $ops:expr)*) => {
597 expr!(Or, $op1 $(, $ops)*)
598 }
599 }
600
601 pub(crate) use AND;
602 pub(crate) use OR;
603 pub(crate) use expr;
604 pub(crate) use key;
605 pub(crate) use kv_selector;
606 pub(crate) use selector;
607}
608
609#[cfg(test)]
610mod tests {
611 use super::{macros::*, *};
612 use pretty_assertions::assert_eq;
613
614 #[test]
615 fn flatten_already_flattened() {
616 let selector = AND!("a", "b", "c", "d");
617
618 assert_eq!(flatten(selector.clone()), selector);
619 }
620
621 #[test]
622 fn flatten_and() {
623 let selector = AND!(AND!("a", "b"), AND!("c", "d"));
624 let expected = AND!("a", "b", "c", "d");
625
626 let flattened = flatten(selector);
627 assert_eq!(flattened, expected);
628 assert_eq!(flattened.to_string(), "(AND a b c d)");
629 }
630
631 #[test]
632 fn flatten_or() {
633 let selector = OR!(OR!("a", "b"), OR!("c", "d"));
634 let expected = OR!("a", "b", "c", "d");
635
636 let flattened = flatten(selector);
637 assert_eq!(flattened, expected);
638 assert_eq!(flattened.to_string(), "(OR a b c d)");
639 }
640
641 #[test]
642 fn flatten_mixed() {
643 #[rustfmt::skip]
644 let selector = AND!(OR!("a", "b", "c"), AND!("c", "d"), OR!("d", OR!("e", AND!("f", "g"))));
645 let expected = AND!(OR!("a", "b", "c"), "c", "d", OR!(AND!("f", "g"), "d", "e"));
646
647 let flattened = flatten(selector);
648 assert_eq!(flattened, expected);
649 assert_eq!(
650 flattened.to_string(),
651 "(AND (OR a b c) c d (OR (AND f g) d e))"
652 );
653 }
654
655 #[test]
656 fn flatten_single_key_leaf_disjunctions() {
657 let selector = AND!(OR!("a.x", "a.y", "a.z"), "b");
658 let expected = AND!(kv_selector!("a": ("x" "y" "z")), "b");
659
660 let flattened = flatten(selector);
661 assert_eq!(flattened, expected);
662 assert_eq!(flattened.to_string(), "(AND (a.x, a.y, a.z) b)");
663 }
664
665 #[test]
666 fn nested_string() {
667 let ccs = r#"
668 a, f b e, c {
669 c d {
670 x = y
671 }
672 e f {
673 foobar = abc
674 }
675 }
676 a, c, b e f : baz = quux
677
678 x = outerx
679 baz = outerbaz
680 foobar = outerfoobar
681 noothers = val
682
683 multi {
684 x = failure
685 level {
686 x = success
687 }
688 }
689
690 z.underconstraint {
691 c = success
692 }
693 @constrain z.underconstraint
694 c = failure
695 "#;
696
697 let expected = "None { (OR (AND f b e) a c) { (AND c d) { def x }; (AND e f) { def foobar \
700 } }; (OR (AND b e f) a c) { def baz }; def x; def baz; def foobar; def \
701 noothers; multi { def x; level { def x } }; z.underconstraint { def c }; \
702 @constrain z.underconstraint; def c }";
703 let parsed = parse(ccs, NullResolver()).unwrap();
704
705 assert_eq!(parsed.to_string(), expected);
706 }
707}