ccs2/ast/
mod.rs

1use std::{
2    collections::HashMap,
3    fmt::Display,
4    hash::Hash,
5    ops::Add,
6    path::{Path, PathBuf},
7};
8
9use indexmap::{IndexMap, IndexSet};
10use itertools::Itertools;
11
12mod dnf;
13mod formula;
14mod parser;
15mod property;
16mod rule_tree;
17
18// TODO: Another opt-in thread safety type
19/// A shared string reference, used internally for keys and properties
20pub type PersistentStr = std::sync::Arc<str>;
21
22pub use dnf::to_dnf;
23pub use formula::{Clause, Formula};
24pub use property::{Property, PropertyValue};
25pub use rule_tree::RuleTreeNode;
26
27pub fn parse(file_contents: impl AsRef<str>, resolver: impl ImportResolver) -> AstResult<Nested> {
28    parser::parse(file_contents, resolver, &mut vec![])
29}
30
31/// A common error type for all of the things that can go wrong while parsing and building the AST
32///
33/// See [`AstResult`]
34#[derive(thiserror::Error, Debug)]
35pub enum AstError {
36    #[error(transparent)]
37    ParseError(#[from] parser::ParseError),
38    #[error("Circular import detected from file {0}")]
39    CircularImport(PathBuf),
40    #[error("Failed to resolve import {0}")]
41    ImportFailed(PathBuf),
42}
43
44pub type AstResult<T> = Result<T, AstError>;
45
46#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
47pub struct Specificity {
48    override_level: u32,
49    positive: u32,
50    negative: u32,
51    wildcard: u32,
52}
53impl Specificity {
54    pub const fn zero() -> Self {
55        Self::new(0, 0, 0, 0)
56    }
57
58    pub const fn positive_lit() -> Self {
59        Self::new(0, 1, 0, 0)
60    }
61
62    pub const fn wildcard() -> Self {
63        Self::new(0, 0, 0, 1)
64    }
65
66    pub const fn new(override_level: u32, positive: u32, negative: u32, wildcard: u32) -> Self {
67        Self {
68            override_level,
69            positive,
70            negative,
71            wildcard,
72        }
73    }
74}
75impl Add for Specificity {
76    type Output = Self;
77
78    fn add(self, rhs: Self) -> Self::Output {
79        Self {
80            override_level: self.override_level + rhs.override_level,
81            positive: self.positive + rhs.positive,
82            negative: self.negative + rhs.negative,
83            wildcard: self.wildcard + rhs.wildcard,
84        }
85    }
86}
87
88#[derive(Debug, Clone, Eq)]
89pub struct Key {
90    pub name: PersistentStr,
91    pub values: Vec<PersistentStr>,
92    pub specificity: Specificity,
93}
94impl Key {
95    pub fn new_lit(name: impl ToString, values: impl IntoIterator<Item = PersistentStr>) -> Self {
96        Self::create(name, values, Specificity::positive_lit())
97    }
98
99    pub fn new(name: impl ToString, values: impl IntoIterator<Item = PersistentStr>) -> Self {
100        // TODO: This is really error-prone. Require specificity from everyone?
101        let values: Vec<_> = values.into_iter().collect();
102        let specificity = if !values.is_empty() {
103            Specificity::positive_lit()
104        } else {
105            Specificity::wildcard()
106        };
107        Self::create(name, values, specificity)
108    }
109
110    fn create(
111        name: impl ToString,
112        values: impl IntoIterator<Item = PersistentStr>,
113        specificity: Specificity,
114    ) -> Self {
115        let mut values: Vec<_> = values.into_iter().collect();
116        values.sort_unstable(); // Important: values are always sorted!
117        Self {
118            name: name.to_string().into(),
119            specificity,
120            values: values.into_iter().collect(),
121        }
122    }
123
124    /// Key-value parsing shouldn't be done here, really. It's done by the actual parser
125    /// implementation. However, this is handy for expressivity in tests, so we'll just allow it for
126    /// that.
127    #[cfg(test)]
128    pub fn parse(input: impl Into<PersistentStr>) -> Self {
129        let input = input.into();
130        let inputs: Vec<&str> = input.split(".").collect();
131        assert!(!inputs.is_empty());
132        let (key, values) = &inputs.split_at(1);
133        let values: Vec<_> = values.iter().map(|s| s.to_string().into()).collect();
134        Self::new(key[0], values)
135    }
136}
137impl Hash for Key {
138    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
139        // Hash implementation needs to parallel PartialEq implementation, so we can't hash the
140        // specificity here unless we include it there. Does that matter?
141        self.name.hash(state);
142        self.values.hash(state);
143        // self.specificity.hash(state);
144    }
145}
146impl PartialEq for Key {
147    fn eq(&self, other: &Self) -> bool {
148        self.name == other.name && self.values == other.values
149    }
150}
151impl PartialOrd for Key {
152    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
153        Some(self.cmp(other))
154    }
155}
156impl Ord for Key {
157    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
158        debug_assert!(self.values.is_sorted());
159        debug_assert!(other.values.is_sorted());
160
161        self.name
162            .cmp(&other.name)
163            .then_with(|| self.values.cmp(&other.values))
164    }
165}
166impl Display for Key {
167    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
168        // # TODO java code notices if key/val are actually not idents and quotes them
169        if self.values.len() > 1 {
170            let vals_str = self
171                .values
172                .iter()
173                .map(|val| format!("{}.{val}", self.name))
174                .joined_by(", ");
175            write!(f, "({})", vals_str)
176        } else if let Some(first) = self.values.first()
177            && !first.is_empty()
178        {
179            write!(f, "{}.{first}", self.name)
180        } else {
181            write!(f, "{}", self.name)
182        }
183    }
184}
185
186pub type Env = HashMap<String, String>;
187
188/// The original source location from which a rule/property was parsed
189#[derive(Debug, Clone, PartialEq, Eq, Hash)]
190pub struct Origin {
191    pub filename: PathBuf,
192    pub line_number: u32,
193}
194impl Display for Origin {
195    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
196        write!(
197            f,
198            "{}:{}",
199            self.filename.to_str().unwrap(),
200            self.line_number,
201        )
202    }
203}
204
205#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
206pub enum Op {
207    And,
208    Or,
209}
210impl Display for Op {
211    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
212        match self {
213            Op::And => write!(f, "AND"),
214            Op::Or => write!(f, "OR"),
215        }
216    }
217}
218
219/// AST nodes representing selector expressions
220#[derive(Clone, Debug, PartialEq, Eq, Hash)]
221pub enum Selector {
222    /// A conjunction or disjunction selector expression
223    Expr(Expr),
224    /// A single-step primitive selector expression
225    Step(Key),
226}
227impl Display for Selector {
228    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
229        match self {
230            Selector::Expr(expr) => write!(f, "{expr}"),
231            Selector::Step(key) => write!(f, "{key}"),
232        }
233    }
234}
235
236/// Provides a binding to paths which are used for resolving `@import` expressions
237pub trait ImportResolver {
238    fn current_file_name(&self) -> PathBuf;
239
240    fn new_context(&self, location: &Path) -> AstResult<Self>
241    where
242        Self: Sized;
243
244    fn load(&self) -> AstResult<String>;
245}
246
247/// Only used for testing, shouldn't be made public
248pub(crate) struct NullResolver();
249impl ImportResolver for NullResolver {
250    fn current_file_name(&self) -> PathBuf {
251        PathBuf::new()
252    }
253    fn new_context(&self, _: &Path) -> AstResult<Self> {
254        Ok(Self())
255    }
256    fn load(&self) -> AstResult<String> {
257        Ok("".to_string())
258    }
259}
260
261#[derive(Clone, Debug, PartialEq, Eq, Hash)]
262pub struct Expr {
263    pub op: Op,
264    pub children: Vec<Selector>,
265}
266impl Expr {
267    pub fn new(op: Op, children: impl IntoIterator<Item = Selector>) -> Self {
268        Self {
269            op,
270            children: children.into_iter().collect(),
271        }
272    }
273}
274impl Display for Expr {
275    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
276        let children_str = self.children.iter().joined_by(" ");
277        write!(f, "({} {children_str})", self.op)
278    }
279}
280
281pub fn conj(terms: Vec<Selector>) -> Expr {
282    Expr {
283        op: Op::And,
284        children: terms,
285    }
286}
287pub fn disj(terms: Vec<Selector>) -> Expr {
288    Expr {
289        op: Op::Or,
290        children: terms,
291    }
292}
293
294#[derive(Debug)]
295pub struct Step {
296    pub key: Key,
297}
298
299/// AST nodes representing rules
300#[derive(Debug)]
301pub enum AstNode {
302    /// AST node for @import
303    Import(Import),
304    /// AST node for a property setting
305    PropDef(PropDef),
306    /// AST node for @constrain
307    Constraint(Key),
308    /// AST node for a nested ruleset (single or multiple rules)
309    Nested(Nested),
310}
311impl AstNode {
312    pub fn add_to(&self, build_context: &mut RuleTreeNode) {
313        use AstNode::*;
314        match self {
315            Import(import) => {
316                if let Some(ast) = import.ast.as_ref() {
317                    ast.add_to(build_context)
318                } else {
319                    panic!("Attempted to add Import node without a resolved AST context");
320                }
321            }
322            PropDef(prop_def) => build_context.add_property(
323                &prop_def.name,
324                &prop_def.value,
325                prop_def.origin.clone(),
326                prop_def.should_override,
327            ),
328            Constraint(key) => {
329                build_context.add_constraint(key.clone());
330            }
331            Nested(nested) => {
332                nested.add_to(build_context);
333            }
334        }
335    }
336
337    pub fn resolve_imports<R: ImportResolver>(
338        &mut self,
339        resolver: &R,
340        in_progress: &mut Vec<PathBuf>,
341    ) -> AstResult<()> {
342        use AstNode::*;
343        match self {
344            Import(import) => {
345                if in_progress.contains(&import.location) {
346                    Err(AstError::CircularImport(import.location.clone()))
347                } else {
348                    in_progress.push(import.location.clone());
349
350                    let resolver = resolver.new_context(&import.location)?;
351                    let nested = parser::parse(resolver.load()?, resolver, in_progress)?;
352
353                    in_progress.pop(); // TODO: Be more careful?
354                    import.ast = Some(Box::new(AstNode::Nested(nested)));
355                    Ok(())
356                }
357            }
358            Nested(nested) => nested.resolve_imports_internal(resolver, in_progress),
359            PropDef(prop_def) => {
360                prop_def.origin.filename = resolver.current_file_name();
361                Ok(())
362            }
363            Constraint(..) => Ok(()),
364        }
365    }
366}
367impl Display for AstNode {
368    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
369        use AstNode::*;
370        match self {
371            Import(import) => import.fmt(f),
372            PropDef(prop_def) => prop_def.fmt(f),
373            Constraint(key) => write!(f, "@constrain {key}"),
374            Nested(nested) => nested.fmt(f),
375        }
376    }
377}
378
379#[derive(Debug)]
380pub struct Import {
381    location: PathBuf,
382    ast: Option<Box<AstNode>>,
383}
384impl Import {
385    pub fn new(location: impl AsRef<Path>) -> Self {
386        Self {
387            location: location.as_ref().to_path_buf(),
388            ast: None,
389        }
390    }
391}
392impl Display for Import {
393    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
394        write!(f, "@import '{}'", self.location.to_string_lossy())
395    }
396}
397
398#[derive(Clone, Debug, PartialEq, Eq, Hash)]
399pub struct PropDef {
400    pub name: String,
401    pub value: String,
402    pub origin: Origin,
403    pub should_override: bool,
404}
405impl Display for PropDef {
406    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
407        write!(f, "def {}", self.name)
408    }
409}
410
411#[derive(Clone, Debug)]
412pub struct Constraint {
413    pub key: Key,
414}
415impl From<Key> for Constraint {
416    fn from(key: Key) -> Self {
417        Self { key }
418    }
419}
420impl Display for Constraint {
421    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
422        self.key.fmt(f)
423    }
424}
425
426#[derive(Default, Debug)]
427pub struct Nested {
428    selector: Option<Selector>,
429    rules: Vec<AstNode>,
430}
431impl Nested {
432    pub fn set_selector(&mut self, selector: Selector) {
433        assert!(self.selector.is_none());
434        self.selector = Some(selector)
435    }
436
437    pub fn append(&mut self, rule: AstNode) {
438        self.rules.push(rule)
439    }
440
441    pub fn add_to(&self, build_context: &mut RuleTreeNode) {
442        let build_context = if let Some(selector) = self.selector.as_ref() {
443            build_context.traverse(selector.clone())
444        } else {
445            build_context
446        };
447        for rule in self.rules.iter() {
448            rule.add_to(build_context);
449        }
450    }
451
452    pub fn resolve_imports<R: ImportResolver>(&mut self, resolver: &R) -> AstResult<()> {
453        let mut in_progress = vec![];
454        self.resolve_imports_internal(resolver, &mut in_progress)
455    }
456
457    pub fn resolve_imports_internal<R: ImportResolver>(
458        &mut self,
459        resolver: &R,
460        in_progress: &mut Vec<PathBuf>,
461    ) -> AstResult<()> {
462        for rule in self.rules.iter_mut() {
463            rule.resolve_imports(resolver, in_progress)?;
464        }
465        Ok(())
466    }
467}
468impl Display for Nested {
469    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
470        let selector = self
471            .selector
472            .as_ref()
473            .map(ToString::to_string)
474            .unwrap_or("None".to_string());
475        write!(f, "{selector} {{ {} }}", self.rules.iter().joined_by("; "))
476    }
477}
478
479/// Flatten a selector expression.
480///
481/// A selector is flattened when we've inlined trivially nested expressions. In other
482/// words, a flat selector consists of strictly alternating levels of AND and OR.
483pub fn flatten(expr: Selector) -> Selector {
484    match expr {
485        Selector::Step(k) => Selector::Step(k),
486        Selector::Expr(expr) => {
487            let mut lit_children = IndexMap::<Selector, IndexSet<PersistentStr>>::default();
488            let mut new_children = Vec::<Selector>::default();
489
490            let mut add_child = |e: Selector| {
491                match (e, expr.op) {
492                    (Selector::Step(key), Op::Or) => {
493                        // in this case, we can group matching literals by key to avoid unnecessary dnf expansion.
494                        // it's not totally clear whether it's better to do this here or in to_dnf() (or possibly even in
495                        // normalize()??, so this is a bit of an arbitrary choice...
496                        // TODO negative matches will need to be handled here, probably adding as separate clusters,
497                        // depending on specificity rules?
498                        // TODO wildcard matches also need to be handled specially here, either as a flag on the key or
499                        // a special entry in values...
500                        // TODO if this is done prior to normalize(), that function needs to be changed to understand
501                        // set-valued pos/neg literals... and might need to be changed for negative literals either way?
502                        let key_without_values = Key::new(key.name, []);
503                        lit_children
504                            .entry(Selector::Step(key_without_values))
505                            .or_default()
506                            .extend(key.values.iter().cloned());
507                    }
508                    (e, _) => {
509                        new_children.push(e);
510                    }
511                }
512            };
513
514            for e in expr.children.into_iter().map(flatten) {
515                match e {
516                    Selector::Expr(e) => {
517                        if e.op == expr.op {
518                            for c in e.children {
519                                add_child(c)
520                            }
521                        } else {
522                            add_child(Selector::Expr(e))
523                        }
524                    }
525                    Selector::Step(..) => add_child(e),
526                }
527            }
528
529            for (child, values) in lit_children {
530                match child {
531                    Selector::Step(key) => {
532                        new_children.push(Selector::Step(Key::new(key.name, values)))
533                    }
534                    Selector::Expr(..) => panic!("Attempted to add literal expr!"),
535                }
536            }
537            if new_children.len() == 1 {
538                new_children.into_iter().next().unwrap()
539            } else {
540                Selector::Expr(Expr::new(expr.op, new_children))
541            }
542        }
543    }
544}
545
546pub trait JoinedBy {
547    fn joined_by(self, joiner: impl AsRef<str>) -> String;
548}
549impl<S: ToString, T: Iterator<Item = S>> JoinedBy for T {
550    fn joined_by(self, joiner: impl AsRef<str>) -> String {
551        Itertools::intersperse(self.map(|s| s.to_string()), joiner.as_ref().to_string()).collect()
552    }
553}
554
555pub mod macros {
556    #![allow(unused)] // Mostly just used by tests, but pretty handy for expressivity
557
558    macro_rules! selector {
559        ($item:literal) => {
560            crate::ast::Selector::Step(crate::ast::Key::parse($item))
561        };
562        ($item:expr) => {
563            $item
564        };
565    }
566
567    macro_rules! key {
568        ($item:literal $(: ($($val:literal)+))?) => {
569            crate::ast::Key::new($item, [$($($val.to_string(),)+)?])
570        };
571    }
572
573    macro_rules! kv_selector {
574        ($item:literal $(: ($($val:literal)+))?) => {
575            crate::ast::Selector::Step(crate::ast::Key::new($item, [$($($val.to_string().into(),)+)?]))
576        };
577    }
578
579    macro_rules! expr {
580        ($operator:ident, $op1:expr $(, $ops:expr)*) => {
581            crate::ast::Selector::Expr(
582                crate::ast::Expr::new(
583                    crate::ast::Op::$operator, [selector!($op1) $(, selector!($ops))*]
584                )
585            )
586        }
587    }
588
589    macro_rules! AND {
590        ($op1:expr $(, $ops:expr)*) => {
591            expr!(And, $op1 $(, $ops)*)
592        }
593    }
594
595    macro_rules! OR {
596        ($op1:expr $(, $ops:expr)*) => {
597            expr!(Or, $op1 $(, $ops)*)
598        }
599    }
600
601    pub(crate) use AND;
602    pub(crate) use OR;
603    pub(crate) use expr;
604    pub(crate) use key;
605    pub(crate) use kv_selector;
606    pub(crate) use selector;
607}
608
609#[cfg(test)]
610mod tests {
611    use super::{macros::*, *};
612    use pretty_assertions::assert_eq;
613
614    #[test]
615    fn flatten_already_flattened() {
616        let selector = AND!("a", "b", "c", "d");
617
618        assert_eq!(flatten(selector.clone()), selector);
619    }
620
621    #[test]
622    fn flatten_and() {
623        let selector = AND!(AND!("a", "b"), AND!("c", "d"));
624        let expected = AND!("a", "b", "c", "d");
625
626        let flattened = flatten(selector);
627        assert_eq!(flattened, expected);
628        assert_eq!(flattened.to_string(), "(AND a b c d)");
629    }
630
631    #[test]
632    fn flatten_or() {
633        let selector = OR!(OR!("a", "b"), OR!("c", "d"));
634        let expected = OR!("a", "b", "c", "d");
635
636        let flattened = flatten(selector);
637        assert_eq!(flattened, expected);
638        assert_eq!(flattened.to_string(), "(OR a b c d)");
639    }
640
641    #[test]
642    fn flatten_mixed() {
643        #[rustfmt::skip]
644        let selector = AND!(OR!("a", "b", "c"), AND!("c", "d"), OR!("d", OR!("e", AND!("f", "g"))));
645        let expected = AND!(OR!("a", "b", "c"), "c", "d", OR!(AND!("f", "g"), "d", "e"));
646
647        let flattened = flatten(selector);
648        assert_eq!(flattened, expected);
649        assert_eq!(
650            flattened.to_string(),
651            "(AND (OR a b c) c d (OR (AND f g) d e))"
652        );
653    }
654
655    #[test]
656    fn flatten_single_key_leaf_disjunctions() {
657        let selector = AND!(OR!("a.x", "a.y", "a.z"), "b");
658        let expected = AND!(kv_selector!("a": ("x" "y" "z")), "b");
659
660        let flattened = flatten(selector);
661        assert_eq!(flattened, expected);
662        assert_eq!(flattened.to_string(), "(AND (a.x, a.y, a.z) b)");
663    }
664
665    #[test]
666    fn nested_string() {
667        let ccs = r#"
668            a, f b e, c {
669                c d {
670                    x = y
671                }
672                e f {
673                    foobar = abc
674                }
675            }
676            a, c, b e f : baz = quux
677
678            x = outerx
679            baz = outerbaz
680            foobar = outerfoobar
681            noothers = val
682            
683            multi {
684                x = failure
685                level {
686                    x = success
687                }
688            }
689
690            z.underconstraint {
691                c = success
692            }
693            @constrain z.underconstraint
694            c = failure
695        "#;
696
697        // Note: This isn't 100% the same as Python, because we do an extra layer of flattening
698        // while parsing, which messes up the ordering of some of the nested selectors.
699        let expected = "None { (OR (AND f b e) a c) { (AND c d) { def x }; (AND e f) { def foobar \
700                        } }; (OR (AND b e f) a c) { def baz }; def x; def baz; def foobar; def \
701                        noothers; multi { def x; level { def x } }; z.underconstraint { def c }; \
702                        @constrain z.underconstraint; def c }";
703        let parsed = parse(ccs, NullResolver()).unwrap();
704
705        assert_eq!(parsed.to_string(), expected);
706    }
707}