diff --git a/Cargo.lock b/Cargo.lock index e9620bd60..9363a97e9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -241,6 +241,12 @@ version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" +[[package]] +name = "beef" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a8241f3ebb85c056b509d4327ad0358fbbba6ffb340bf388f26350aeda225b1" + [[package]] name = "bitflags" version = "1.3.2" @@ -1050,7 +1056,9 @@ dependencies = [ "indexmap 2.2.6", "indextree", "indicatif", + "itertools", "log", + "logos", "maplit", "nom", "num-traits", @@ -1077,6 +1085,8 @@ dependencies = [ "tar", "tempfile", "term_size", + "test-log", + "thiserror", "toml", "tonic", "tonic-build", @@ -1444,6 +1454,39 @@ version = "0.4.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" +[[package]] +name = "logos" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff1ceb190eb9bdeecdd8f1ad6a71d6d632a50905948771718741b5461fb01e13" +dependencies = [ + "logos-derive", +] + +[[package]] +name = "logos-codegen" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90be66cb7bd40cb5cc2e9cfaf2d1133b04a3d93b72344267715010a466e0915a" +dependencies = [ + "beef", + "fnv", + "lazy_static", + "proc-macro2", + "quote", + "regex-syntax 0.8.4", + "syn 2.0.74", +] + +[[package]] +name = "logos-derive" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45154231e8e96586b39494029e58f12f8ffcb5ecf80333a603a13aa205ea8cbd" +dependencies = [ + "logos-codegen", +] + [[package]] name = "lzma-rs" version = "0.3.0" @@ -1471,6 +1514,15 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3e2e65a1a2e43cfcb47a895c4c8b10d1f4a61097f9f254f183aee60cad9c651d" +[[package]] +name = "matchers" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558" +dependencies = [ + "regex-automata 0.1.10", +] + [[package]] name = "matchit" version = "0.7.3" @@ -1506,9 +1558,9 @@ dependencies = [ [[package]] name = "mio" -version = "1.0.1" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4569e456d394deccd22ce1c1913e6ea0e54519f577285001215d33557431afe4" +checksum = "80e04d1dcff3aae0704555fe5fee3bcfaf3d1fdf8a7e521d5b9d2b42acb52cec" dependencies = [ "hermit-abi", "libc", @@ -1532,6 +1584,16 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "nu-ansi-term" +version = "0.46.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" +dependencies = [ + "overload", + "winapi", +] + [[package]] name = "num-conv" version = "0.1.0" @@ -1629,6 +1691,12 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "overload" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" + [[package]] name = "packageurl" version = "0.4.0" @@ -1998,8 +2066,17 @@ checksum = "b91213439dad192326a0d7c6ee3955910425f441d7038e0d6933b0aec5c4517f" dependencies = [ "aho-corasick", "memchr", - "regex-automata", - "regex-syntax", + "regex-automata 0.4.7", + "regex-syntax 0.8.4", +] + +[[package]] +name = "regex-automata" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132" +dependencies = [ + "regex-syntax 0.6.29", ] [[package]] @@ -2010,9 +2087,15 @@ checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df" dependencies = [ "aho-corasick", "memchr", - "regex-syntax", + "regex-syntax 0.8.4", ] +[[package]] +name = "regex-syntax" +version = "0.6.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" + [[package]] name = "regex-syntax" version = "0.8.4" @@ -2297,6 +2380,15 @@ dependencies = [ "digest", ] +[[package]] +name = "sharded-slab" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" +dependencies = [ + "lazy_static", +] + [[package]] name = "shared_child" version = "1.0.0" @@ -2544,6 +2636,28 @@ dependencies = [ "winapi", ] +[[package]] +name = "test-log" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3dffced63c2b5c7be278154d76b479f9f9920ed34e7574201407f0b14e2bbb93" +dependencies = [ + "env_logger", + "test-log-macros", + "tracing-subscriber", +] + +[[package]] +name = "test-log-macros" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5999e24eaa32083191ba4e425deb75cdf25efefabe5aaccb7446dd0d4122a3f5" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.74", +] + [[package]] name = "thiserror" version = "1.0.63" @@ -2564,6 +2678,16 @@ dependencies = [ "syn 2.0.74", ] +[[package]] +name = "thread_local" +version = "1.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b9ef9bad013ada3808854ceac7b46812a6465ba368859a37e2100283d2d719c" +dependencies = [ + "cfg-if", + "once_cell", +] + [[package]] name = "time" version = "0.3.36" @@ -2799,6 +2923,35 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" dependencies = [ "once_cell", + "valuable", +] + +[[package]] +name = "tracing-log" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" +dependencies = [ + "log", + "once_cell", + "tracing-core", +] + +[[package]] +name = "tracing-subscriber" +version = "0.3.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b" +dependencies = [ + "matchers", + "nu-ansi-term", + "once_cell", + "regex", + "sharded-slab", + "thread_local", + "tracing", + "tracing-core", + "tracing-log", ] [[package]] @@ -2904,6 +3057,12 @@ dependencies = [ "getrandom", ] +[[package]] +name = "valuable" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" + [[package]] name = "vcpkg" version = "0.2.15" diff --git a/hipcheck/Cargo.toml b/hipcheck/Cargo.toml index 296995d04..16e7f5ffb 100644 --- a/hipcheck/Cargo.toml +++ b/hipcheck/Cargo.toml @@ -1,47 +1,95 @@ [package] name = "hipcheck" -description = "Automatically assess and score software repositories for supply chain risk" +description = """ +Automatically assess and score software packages for supply chain risk. +""" keywords = ["security", "sbom"] categories = ["command-line-utilities", "development-tools"] readme = "../README.md" version = "3.5.0" edition = "2021" license = "Apache-2.0" +homepage = "https://mitre.github.io/hipcheck" repository = "https://github.com/mitre/hipcheck" include = ["src/**/*", "../LICENSE", "../README.md"] -[features] -# Print timings feature is used to print timing information throughout hipchecks runtime. -print-timings = ["benchmarking"] -# Benchmarking enables the benchmarking module, containing special utilities for benchmarking. -benchmarking = [] - +# Rename the binary from the default "hipcheck" (based on the package name) +# to "hc". [[bin]] name = "hc" path = "src/main.rs" +[features] + +# Print timings feature is used to print timing information throughout +# Hipcheck's runtime. +print-timings = ["benchmarking"] + +# Benchmarking enables the benchmarking module, containing special utilities +# for benchmarking. +benchmarking = [] + [dependencies] +base64 = "0.22.1" content_inspector = "0.2.4" cyclonedx-bom = "0.7.0" dotenv = "0.15.0" chrono = { version = "0.4.19", features = ["alloc", "serde"] } clap = { version = "4.5.13", features = ["derive"] } +console = { version = "0.15.8", features = ["windows-console-colors"] } +dashmap = { version = "6.0.1", features = ["rayon", "inline"] } +dialoguer = "0.11.0" dirs = "5.0.1" duct = "0.13.5" env_logger = { version = "0.11.5" } +finl_unicode = { version = "1.2.0", default-features = false, features = [ + "grapheme_clusters", +] } +fs_extra = "1.3.0" +# Vendor libgit2 and openssl so that they will be statically included +# and not cause problems on certain systems that might not have one or +# the other. +git2 = { version = "0.19.0", features = [ + "vendored-libgit2", + "vendored-openssl", +] } graphql_client = "0.14.0" +# Include with both a `path` and `version` reference. +# Local builds will use the `path` dependency, which may be a newer +# version than the one published to Crates.io. +# People building from Crates.io will use the published version. +# +# See: https://doc.rust-lang.org/cargo/reference/specifying-dependencies.html#multiple-locations hipcheck-macros = { path = "../hipcheck-macros", version = "0.3.1" } +http = "1.1.0" +indexmap = "2.2.6" +indextree = "4.6.1" +indicatif = { version = "0.17.8", features = ["rayon"] } +itertools = "0.13.0" log = "0.4.22" +logos = "0.14.0" maplit = "1.0.2" nom = "7.1.3" +num-traits = "0.2.19" once_cell = "1.10.0" ordered-float = { version = "4.2.2", features = ["serde"] } packageurl = "0.4.0" paste = "1.0.7" pathbuf = "1.0.0" petgraph = { version = "0.6.0", features = ["serde-1"] } +prost = "0.13.1" +rayon = "1.10.0" regex = "1.10.5" +# Exactly matching the version of rustls used by ureq +# Get rid of default features since we don't use the AWS backed crypto +# provider (we use ring) and it breaks stuff on windows. +rustls = { version = "0.23.10", default-features = false, features = [ + "logging", + "std", + "tls12", + "ring", +] } rustls-native-certs = "0.7.1" salsa = "0.16.1" schemars = { version = "0.8.21", default-features = false, features = [ @@ -55,7 +103,12 @@ serde_derive = "1.0.137" serde_json = "1.0.122" smart-default = "0.7.1" spdx-rs = "0.5.0" +tabled = "0.15.0" +tar = "0.4.41" +term_size = "0.3.2" toml = "0.8.19" +tonic = "0.12.1" +thiserror = "1.0.63" unicode-normalization = "0.1.19" ureq = { version = "2.10.0", default-features = false, features = [ "json", @@ -65,51 +118,20 @@ url = "2.5.1" walkdir = "2.5.0" which = { version = "6.0.1", default-features = false } xml-rs = "0.8.20" -rayon = "1.10.0" -indexmap = "2.2.6" -dashmap = { version = "6.0.1", features = ["rayon", "inline"] } -# Vendor libgit2 and openssl so that they will be statically included and not cause problems on certain systems that might not have one or the other. -git2 = { version = "0.19.0", features = ["vendored-libgit2", "vendored-openssl"]} -indicatif = { version = "0.17.8", features = ["rayon"] } -finl_unicode = { version = "1.2.0", default-features = false, features = [ - "grapheme_clusters", -] } -tar = "0.4.41" -zip = "2.1.6" xz2 = "0.1.7" -indextree = "4.6.1" -num-traits = "0.2.19" -console = { version = "0.15.8", features = ["windows-console-colors"] } -term_size = "0.3.2" -base64 = "0.22.1" -http = "1.1.0" -dialoguer = "0.11.0" -tabled = "0.15.0" -fs_extra = "1.3.0" -tonic = "0.12.1" -prost = "0.13.1" - -# Exactly matching the version of rustls used by ureq -# Get rid of default features since we don't use the AWS backed crypto provider (we use ring). -# and it breaks stuff on windows. -[dependencies.rustls] -version = "0.23.10" -default-features = false -features = [ - "logging", - "std", - "tls12", - "ring" -] +zip = "2.1.6" [build-dependencies] + anyhow = "1.0.86" tonic-build = "0.12.1" which = { version = "6.0.1", default-features = false } [dev-dependencies] + dirs = "5.0.1" tempfile = "3.12.0" +test-log = "0.2.16" [package.metadata.dist] diff --git a/hipcheck/src/main.rs b/hipcheck/src/main.rs index 777d20afe..784a84670 100644 --- a/hipcheck/src/main.rs +++ b/hipcheck/src/main.rs @@ -15,6 +15,7 @@ mod git2_log_shim; mod git2_rustls_transport; mod log_bridge; mod metric; +mod policy_exprs; mod report; mod session; mod setup; @@ -81,14 +82,6 @@ use target::{RemoteGitRepo, TargetSeed, TargetSeedKind, ToTargetSeed}; use util::fs::create_dir_all; use which::which; -fn init_logging() -> std::result::Result<(), log::SetLoggerError> { - let env = Env::new().filter("HC_LOG").write_style("HC_LOG_STYLE"); - - let logger = env_logger::Builder::from_env(env).build(); - - log_bridge::LogWrapper(logger).try_init() -} - /// Entry point for Hipcheck. fn main() -> ExitCode { // Initialize the global shell with normal verbosity by default. @@ -156,6 +149,12 @@ fn main() -> ExitCode { ExitCode::SUCCESS } +fn init_logging() -> std::result::Result<(), log::SetLoggerError> { + let env = Env::new().filter("HC_LOG").write_style("HC_LOG_STYLE"); + let logger = env_logger::Builder::from_env(env).build(); + log_bridge::LogWrapper(logger).try_init() +} + /// Run the `check` command. fn cmd_check(args: &CheckArgs, config: &CliConfig) -> ExitCode { let target = match args.to_target_seed() { diff --git a/hipcheck/src/policy_exprs/bridge.rs b/hipcheck/src/policy_exprs/bridge.rs new file mode 100644 index 000000000..24adef9f6 --- /dev/null +++ b/hipcheck/src/policy_exprs/bridge.rs @@ -0,0 +1,347 @@ +// The following code is copied from the `logos-nom-bridge` crate, which uses +// an outdated version of `logos` and thus can't be used directly here. +// +// The original code which we have copied and modified is MIT licensed, and +// used under the terms of that license here. + +//! # logos-nom-bridge +//! +//! A [`logos::Lexer`] wrapper than can be used as an input for +//! [nom](https://docs.rs/nom/7.0.0/nom/index.html). +//! + +use core::fmt; +use logos::{Lexer, Logos, Span, SpannedIter}; +use nom::{InputIter, InputLength, InputTake}; + +/// A [`logos::Lexer`] wrapper than can be used as an input for +/// [nom](https://docs.rs/nom/7.0.0/nom/index.html). +/// +/// You can find an example in the [module-level docs](..). +pub struct Tokens<'i, T> +where + T: Logos<'i>, +{ + lexer: Lexer<'i, T>, +} + +impl<'i, T> Clone for Tokens<'i, T> +where + T: Logos<'i> + Clone, + T::Extras: Clone, +{ + fn clone(&self) -> Self { + Self { + lexer: self.lexer.clone(), + } + } +} + +// Helper type returned by the logos parser. +type ParseResult<'i, T> = Result>::Error>; + +impl<'i, T> Tokens<'i, T> +where + T: Logos<'i, Source = str> + Clone, + T::Extras: Default + Clone, +{ + /// Create a new token parser. + pub fn new(input: &'i str) -> Self { + Tokens { + lexer: Lexer::new(input), + } + } + + /// Get the length of the remaining source to parse. + pub fn len(&self) -> usize { + self.lexer.source().len() - self.lexer.span().end + } + + /// See if the remaining length to parse is empty. + #[allow(unused)] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Peek at the next token, possibly with a parsing error. + pub fn peek(&self) -> Option<(ParseResult<'i, T>, &'i str)> { + let mut iter = self.lexer.clone().spanned(); + iter.next().map(|(t, span)| (t, &self.lexer.source()[span])) + } + + /// Advance the parser one step. + pub fn advance(mut self) -> Self { + self.lexer.next(); + self + } + + /// Get the underlying lexer. + pub fn lexer(&self) -> &Lexer<'i, T> { + &self.lexer + } +} + +impl<'i, T> PartialEq for Tokens<'i, T> +where + T: PartialEq + Logos<'i> + Clone, + T::Extras: Clone, +{ + fn eq(&self, other: &Self) -> bool { + Iterator::eq(self.lexer.clone(), other.lexer.clone()) + } +} + +impl<'i, T> Eq for Tokens<'i, T> +where + T: Eq + Logos<'i> + Clone, + T::Extras: Clone, +{ +} + +impl<'i, T> fmt::Debug for Tokens<'i, T> +where + T: fmt::Debug + Logos<'i, Source = str>, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let source = self.lexer.source(); + let start = self.lexer.span().start; + f.debug_tuple("Tokens").field(&&source[start..]).finish() + } +} + +impl<'i, T> fmt::Display for Tokens<'i, T> +where + T: fmt::Debug + fmt::Display + Logos<'i, Source = str>, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + (self as &dyn fmt::Debug).fmt(f) + } +} + +impl<'i, T> Default for Tokens<'i, T> +where + T: Logos<'i, Source = str>, + T::Extras: Default, +{ + fn default() -> Self { + Tokens { + lexer: Lexer::new(""), + } + } +} + +/// An iterator, that (similarly to [`std::iter::Enumerate`]) produces byte offsets of the tokens. +pub struct IndexIterator<'i, T> +where + T: Logos<'i>, +{ + logos: Lexer<'i, T>, +} + +impl<'i, T> Iterator for IndexIterator<'i, T> +where + T: Logos<'i>, +{ + type Item = (usize, (ParseResult<'i, T>, Span)); + + fn next(&mut self) -> Option { + self.logos.next().map(|t| { + let span = self.logos.span(); + (span.start, (t, span)) + }) + } +} + +impl<'i, T> InputIter for Tokens<'i, T> +where + T: Logos<'i, Source = str> + Clone, + T::Extras: Default + Clone, +{ + type Item = (ParseResult<'i, T>, Span); + type Iter = IndexIterator<'i, T>; + type IterElem = SpannedIter<'i, T>; + + fn iter_indices(&self) -> Self::Iter { + IndexIterator { + logos: self.lexer.clone(), + } + } + + fn iter_elements(&self) -> Self::IterElem { + self.lexer.clone().spanned() + } + + fn position

(&self, predicate: P) -> Option + where + P: Fn(Self::Item) -> bool, + { + let mut iter = self.lexer.clone().spanned(); + iter.find(|t| predicate(t.clone())) + .map(|(_, span)| span.start) + } + + fn slice_index(&self, count: usize) -> Result { + let mut cnt = 0; + for (_, span) in self.lexer.clone().spanned() { + if cnt == count { + return Ok(span.start); + } + cnt += 1; + } + if cnt == count { + return Ok(self.len()); + } + Err(nom::Needed::Unknown) + } +} + +impl<'i, T> InputLength for Tokens<'i, T> +where + T: Logos<'i, Source = str> + Clone, + T::Extras: Default + Clone, +{ + fn input_len(&self) -> usize { + self.len() + } +} + +impl<'i, T> InputTake for Tokens<'i, T> +where + T: Logos<'i, Source = str>, + T::Extras: Default, +{ + fn take(&self, count: usize) -> Self { + Tokens { + lexer: Lexer::new(&self.lexer.source()[..count]), + } + } + + fn take_split(&self, count: usize) -> (Self, Self) { + let (a, b) = self.lexer.source().split_at(count); + ( + Tokens { + lexer: Lexer::new(a), + }, + Tokens { + lexer: Lexer::new(b), + }, + ) + } +} + +#[macro_export] +#[doc(hidden)] +macro_rules! token_parser { + ( + token: $token_ty:ty $(,)? + ) => { + $crate::token_parser!( + token: $token_ty, + error<'source>(input, token): ::nom::error::Error<$crate::policy_exprs::Tokens<'source, $token_ty>> = + nom::error::Error::new(input, nom::error::ErrorKind::IsA), + ); + }; + + ( + token: $token_ty:ty, + error: $error_ty:ty = $error:expr $(,)? + ) => { + $crate::token_parser!( + token: $token_ty, + error<'source>(input, token): $error_ty = $error, + ); + }; + + ( + token: $token_ty:ty, + error<$lt:lifetime>($input:ident, $token:ident): $error_ty:ty = $error:expr $(,)? + ) => { + #[allow(unused)] + impl<$lt> ::nom::Parser< + $crate::policy_exprs::Tokens<$lt, $token_ty>, + &$lt str, + $error_ty, + > for $token_ty { + fn parse( + &mut self, + $input: $crate::policy_exprs::Tokens<$lt, $token_ty>, + ) -> ::nom::IResult< + $crate::policy_exprs::Tokens<$lt, $token_ty>, + &$lt str, + $error_ty, + > { + match $input.peek() { + ::std::option::Option::Some((::std::result::Result::Ok(__token), __s)) if __token == *self => { + ::std::result::Result::Ok(($input.advance(), __s)) + } + ::std::option::Option::Some((::std::result::Result::Err(__err), __s)) => { + // Technically this could just be the subsequent case as well, but I am + // deciding to distinguish it here. + ::std::result::Result::Err(::nom::Err::Error($error)) + } + _ => { + // This was in the original code. It appears to be unused, but I am leaving it here + // as a sort of Chesterton's Fence situation. + let $token = self; + ::std::result::Result::Err(::nom::Err::Error($error)) + }, + } + } + } + }; +} + +/// Generates a nom parser function to parse an enum variant that contains data. +#[macro_export] +#[doc(hidden)] +macro_rules! data_variant_parser { + ( + fn $fn_name:ident($input:ident) -> Result<$ok_ty:ty>; + + pattern = $type:ident :: $variant:ident $data:tt => $res:expr; + ) => { + $crate::data_variant_parser! { + fn $fn_name<'src>($input) -> Result< + $ok_ty, + ::nom::error::Error<$crate::policy_exprs::Tokens<'src, $type>>, + >; + + pattern = $type :: $variant $data => $res; + error = ::nom::error::Error::new($input, ::nom::error::ErrorKind::IsA); + } + }; + + ( + fn $fn_name:ident($input:ident) -> Result<$ok_ty:ty, $error_ty:ty $(,)?>; + + pattern = $type:ident :: $variant:ident $data:tt => $res:expr; + error = $error:expr; + ) => { + $crate::data_variant_parser! { + fn $fn_name<'src>($input) -> Result<$ok_ty, $error_ty>; + + pattern = $type :: $variant $data => $res; + error = $error; + } + }; + + ( + fn $fn_name:ident<$lt:lifetime>($input:ident) -> Result<$ok_ty:ty, $error_ty:ty $(,)?>; + + pattern = $type:ident :: $variant:ident $data:tt => $res:expr; + error = $error:expr; + ) => { + fn $fn_name<$lt>($input: $crate::policy_exprs::Tokens<$lt, $type>) -> ::nom::IResult< + $crate::policy_exprs::Tokens<$lt, $type>, + $ok_ty, + $error_ty, + > { + match $input.peek() { + ::std::option::Option::Some((::std::result::Result::Ok($type::$variant $data), _)) => { + Ok(($input.advance(), $res)) + } + _ => ::std::result::Result::Err(::nom::Err::Error($error)), + } + } + }; +} diff --git a/hipcheck/src/policy_exprs/env.rs b/hipcheck/src/policy_exprs/env.rs new file mode 100644 index 000000000..bb22184ca --- /dev/null +++ b/hipcheck/src/policy_exprs/env.rs @@ -0,0 +1,774 @@ +use crate::policy_exprs::eval; +use crate::policy_exprs::Error; +use crate::policy_exprs::Expr; +use crate::policy_exprs::Ident; +use crate::policy_exprs::Primitive; +use crate::policy_exprs::Result; +use crate::policy_exprs::F64; +use itertools::Itertools as _; +use std::cmp::Ordering; +use std::collections::HashMap; +use std::ops::Not as _; +use Expr::*; +use Primitive::*; + +/// Environment, containing bindings of names to functions and variables. +pub struct Env<'parent> { + /// Map of bindings,. + bindings: HashMap, + + /// Possible pointer to parent, for lexical scope. + parent: Option<&'parent Env<'parent>>, +} + +/// A binding in the environment. +#[derive(Clone)] +pub enum Binding { + /// A function. + Fn(Op), + + /// A primitive value. + Var(Primitive), +} + +/// Helper type for operation function pointer. +type Op = fn(&Env, &[Expr]) -> Result; + +impl<'parent> Env<'parent> { + /// Create an empty environment. + fn empty() -> Self { + Env { + bindings: HashMap::new(), + parent: None, + } + } + + /// Create the standard environment. + pub fn std() -> Self { + let mut env = Env::empty(); + + // Comparison functions. + env.add_fn("gt", gt); + env.add_fn("lt", lt); + env.add_fn("gte", gte); + env.add_fn("lte", lte); + env.add_fn("eq", eq); + env.add_fn("neq", neq); + + // Math functions. + env.add_fn("add", add); + env.add_fn("sub", sub); + + // Logical functions. + env.add_fn("and", and); + env.add_fn("or", or); + env.add_fn("not", not); + + // Array math functions. + env.add_fn("max", max); + env.add_fn("min", min); + env.add_fn("avg", avg); + env.add_fn("median", median); + env.add_fn("count", count); + + // Array logic functions. + env.add_fn("all", all); + env.add_fn("nall", nall); + env.add_fn("some", some); + env.add_fn("none", none); + + // Array higher-order functions. + env.add_fn("filter", filter); + env.add_fn("foreach", foreach); + + // Debugging functions. + env.add_fn("dbg", dbg); + + env + } + + /// Create a child environment. + pub fn child(&self) -> Env<'_> { + Env { + bindings: HashMap::new(), + parent: Some(self), + } + } + + /// Add a variable to the environment. + pub fn add_var(&mut self, name: &str, value: Primitive) -> Option { + self.bindings.insert(name.to_owned(), Binding::Var(value)) + } + + /// Add a function to the environment. + pub fn add_fn(&mut self, name: &str, op: Op) -> Option { + self.bindings.insert(name.to_owned(), Binding::Fn(op)) + } + + /// Get a binding from the environment, walking up the scopes. + pub fn get(&self, name: &str) -> Option { + self.bindings + .get(name) + .cloned() + .or_else(|| self.parent.and_then(|parent| parent.get(name))) + } +} + +/// Check the number of args provided to the function. +fn check_num_args(name: &str, args: &[Expr], expected: usize) -> Result<()> { + let given = args.len(); + + match expected.cmp(&given) { + Ordering::Equal => Ok(()), + Ordering::Less => Err(Error::TooManyArgs { + name: name.to_string(), + expected, + given, + }), + Ordering::Greater => Err(Error::NotEnoughArgs { + name: name.to_string(), + expected, + given, + }), + } +} + +/// Partially evaluate a binary operation on primitives. +fn partially_evaluate(fn_name: &'static str, arg: Expr) -> Result { + let var_name = "x"; + let var = Ident(String::from(var_name)); + let func = Ident(String::from(fn_name)); + let op = Function(func, vec![Primitive(Identifier(var.clone())), arg]); + let lambda = Lambda(var, Box::new(op)); + Ok(lambda) +} + +/// Define binary operations on primitives. +fn binary_primitive_op(name: &'static str, env: &Env, args: &[Expr], op: F) -> Result +where + F: FnOnce(Primitive, Primitive) -> Result, +{ + if args.len() == 1 { + return partially_evaluate(name, args[0].clone()); + } + + check_num_args(name, args, 2)?; + + let arg_1 = match eval(env, &args[0])? { + Primitive(p) => p, + _ => return Err(Error::BadType(name)), + }; + + let arg_2 = match eval(env, &args[1])? { + Primitive(p) => p, + _ => return Err(Error::BadType(name)), + }; + + let primitive = match (&arg_1, &arg_2) { + (Int(_), Int(_)) | (Float(_), Float(_)) | (Bool(_), Bool(_)) => op(arg_1, arg_2)?, + _ => return Err(Error::BadType(name)), + }; + + Ok(Primitive(primitive)) +} + +/// Define unary operations on primitives. +fn unary_primitive_op(name: &'static str, env: &Env, args: &[Expr], op: F) -> Result +where + F: FnOnce(Primitive) -> Result, +{ + check_num_args(name, args, 1)?; + + let primitive = match eval(env, &args[0])? { + Primitive(arg) => arg, + _ => return Err(Error::BadType(name)), + }; + + Ok(Expr::Primitive(op(primitive)?)) +} + +/// Define unary operations on arrays. +fn unary_array_op(name: &'static str, env: &Env, args: &[Expr], op: F) -> Result +where + F: FnOnce(ArrayType) -> Result, +{ + check_num_args(name, args, 1)?; + + let arr = match eval(env, &args[0])? { + Array(arg) => array_type(&arg[..])?, + _ => return Err(Error::BadType(name)), + }; + + op(arr) +} + +/// Define a higher-order operation over arrays. +fn higher_order_array_op(name: &'static str, env: &Env, args: &[Expr], op: F) -> Result +where + F: FnOnce(ArrayType, Ident, Box) -> Result, +{ + check_num_args(name, args, 2)?; + + let (ident, body) = match eval(env, &args[0])? { + Lambda(ident, body) => (ident, body), + _ => return Err(Error::BadType(name)), + }; + + let arr = match eval(env, &args[1])? { + Array(arr) => array_type(&arr[..])?, + _ => return Err(Error::BadType(name)), + }; + + op(arr, ident, body) +} + +/// A fully-typed array. +enum ArrayType { + /// An array of ints. + Int(Vec), + + /// An array of floats. + Float(Vec), + + /// An array of bools. + Bool(Vec), + + /// An empty array (no type hints). + Empty, +} + +/// Process an array into a singular type, or error out. +fn array_type(arr: &[Primitive]) -> Result { + if arr.is_empty() { + return Ok(ArrayType::Empty); + } + + match &arr[0] { + Int(_) => { + let mut result: Vec = Vec::with_capacity(arr.len()); + for elem in arr { + if let Int(val) = elem { + result.push(*val); + } else { + return Err(Error::InconsistentArrayTypes); + } + } + Ok(ArrayType::Int(result)) + } + Float(_) => { + let mut result: Vec = Vec::with_capacity(arr.len()); + for elem in arr { + if let Float(val) = elem { + result.push(*val); + } else { + return Err(Error::InconsistentArrayTypes); + } + } + Ok(ArrayType::Float(result)) + } + Bool(_) => { + let mut result: Vec = Vec::with_capacity(arr.len()); + for elem in arr { + if let Bool(val) = elem { + result.push(*val); + } else { + return Err(Error::InconsistentArrayTypes); + } + } + Ok(ArrayType::Bool(result)) + } + Identifier(_) => unimplemented!("we don't currently support idents in arrays"), + } +} + +/// Evaluate the lambda, injecting into the environment. +fn eval_lambda(env: &Env, ident: &Ident, val: Primitive, body: Expr) -> Result { + let mut child = env.child(); + + if child.add_var(&ident.0, val).is_some() { + return Err(Error::AlreadyBound); + } + + eval(&child, &body) +} + +#[allow(clippy::bool_comparison)] +fn gt(env: &Env, args: &[Expr]) -> Result { + let name = "gt"; + + let op = |arg_1, arg_2| match (arg_1, arg_2) { + (Int(arg_1), Int(arg_2)) => Ok(Bool(arg_1 > arg_2)), + (Float(arg_1), Float(arg_2)) => Ok(Bool(arg_1 > arg_2)), + (Bool(arg_1), Bool(arg_2)) => Ok(Bool(arg_1 > arg_2)), + _ => unreachable!(), + }; + + binary_primitive_op(name, env, args, op) +} + +#[allow(clippy::bool_comparison)] +fn lt(env: &Env, args: &[Expr]) -> Result { + let name = "lt"; + + let op = |arg_1, arg_2| match (arg_1, arg_2) { + (Int(arg_1), Int(arg_2)) => Ok(Bool(arg_1 < arg_2)), + (Float(arg_1), Float(arg_2)) => Ok(Bool(arg_1 < arg_2)), + (Bool(arg_1), Bool(arg_2)) => Ok(Bool(arg_1 < arg_2)), + _ => unreachable!(), + }; + + binary_primitive_op(name, env, args, op) +} + +#[allow(clippy::bool_comparison)] +fn gte(env: &Env, args: &[Expr]) -> Result { + let name = "gte"; + + let op = |arg_1, arg_2| match (arg_1, arg_2) { + (Int(arg_1), Int(arg_2)) => Ok(Bool(arg_1 >= arg_2)), + (Float(arg_1), Float(arg_2)) => Ok(Bool(arg_1 >= arg_2)), + (Bool(arg_1), Bool(arg_2)) => Ok(Bool(arg_1 >= arg_2)), + _ => unreachable!(), + }; + + binary_primitive_op(name, env, args, op) +} + +#[allow(clippy::bool_comparison)] +fn lte(env: &Env, args: &[Expr]) -> Result { + let name = "lte"; + + let op = |arg_1, arg_2| match (arg_1, arg_2) { + (Int(arg_1), Int(arg_2)) => Ok(Bool(arg_1 <= arg_2)), + (Float(arg_1), Float(arg_2)) => Ok(Bool(arg_1 <= arg_2)), + (Bool(arg_1), Bool(arg_2)) => Ok(Bool(arg_1 <= arg_2)), + _ => unreachable!(), + }; + + binary_primitive_op(name, env, args, op) +} + +#[allow(clippy::bool_comparison)] +fn eq(env: &Env, args: &[Expr]) -> Result { + let name = "eq"; + + let op = |arg_1, arg_2| match (arg_1, arg_2) { + (Int(arg_1), Int(arg_2)) => Ok(Bool(arg_1 == arg_2)), + (Float(arg_1), Float(arg_2)) => Ok(Bool(arg_1 == arg_2)), + (Bool(arg_1), Bool(arg_2)) => Ok(Bool(arg_1 == arg_2)), + _ => unreachable!(), + }; + + binary_primitive_op(name, env, args, op) +} + +#[allow(clippy::bool_comparison)] +fn neq(env: &Env, args: &[Expr]) -> Result { + let name = "neq"; + + let op = |arg_1, arg_2| match (arg_1, arg_2) { + (Int(arg_1), Int(arg_2)) => Ok(Bool(arg_1 != arg_2)), + (Float(arg_1), Float(arg_2)) => Ok(Bool(arg_1 != arg_2)), + (Bool(arg_1), Bool(arg_2)) => Ok(Bool(arg_1 != arg_2)), + _ => unreachable!(), + }; + + binary_primitive_op(name, env, args, op) +} + +fn add(env: &Env, args: &[Expr]) -> Result { + let name = "add"; + + let op = |arg_1, arg_2| match (arg_1, arg_2) { + (Int(arg_1), Int(arg_2)) => Ok(Int(arg_1 + arg_2)), + (Float(arg_1), Float(arg_2)) => Ok(Float(arg_1 + arg_2)), + (Bool(_), Bool(_)) => Err(Error::BadType(name)), + _ => unreachable!(), + }; + + binary_primitive_op(name, env, args, op) +} + +fn sub(env: &Env, args: &[Expr]) -> Result { + let name = "sub"; + + let op = |arg_1, arg_2| match (arg_1, arg_2) { + (Int(arg_1), Int(arg_2)) => Ok(Int(arg_1 - arg_2)), + (Float(arg_1), Float(arg_2)) => Ok(Float(arg_1 - arg_2)), + (Bool(_), Bool(_)) => Err(Error::BadType(name)), + _ => unreachable!(), + }; + + binary_primitive_op(name, env, args, op) +} + +fn and(env: &Env, args: &[Expr]) -> Result { + let name = "and"; + + let op = |arg_1, arg_2| match (arg_1, arg_2) { + (Int(_), Int(_)) => Err(Error::BadType(name)), + (Float(_), Float(_)) => Err(Error::BadType(name)), + (Bool(arg_1), Bool(arg_2)) => Ok(Bool(arg_1 && arg_2)), + _ => unreachable!(), + }; + + binary_primitive_op(name, env, args, op) +} + +fn or(env: &Env, args: &[Expr]) -> Result { + let name = "or"; + + let op = |arg_1, arg_2| match (arg_1, arg_2) { + (Int(_), Int(_)) => Err(Error::BadType(name)), + (Float(_), Float(_)) => Err(Error::BadType(name)), + (Bool(arg_1), Bool(arg_2)) => Ok(Bool(arg_1 || arg_2)), + _ => unreachable!(), + }; + + binary_primitive_op(name, env, args, op) +} + +fn not(env: &Env, args: &[Expr]) -> Result { + let name = "not"; + + let op = |arg| match arg { + Int(_) => Err(Error::BadType(name)), + Float(_) => Err(Error::BadType(name)), + Bool(arg) => Ok(Primitive::Bool(arg.not())), + Identifier(_) => unreachable!("no idents should be here"), + }; + + unary_primitive_op(name, env, args, op) +} + +fn max(env: &Env, args: &[Expr]) -> Result { + let name = "max"; + + let op = |arg| match arg { + ArrayType::Int(ints) => ints + .iter() + .copied() + .max() + .ok_or(Error::NoMax) + .map(|m| Primitive(Int(m))), + + ArrayType::Float(floats) => floats + .iter() + .copied() + .max() + .ok_or(Error::NoMax) + .map(|m| Primitive(Float(m))), + + ArrayType::Bool(_) => Err(Error::BadType(name)), + ArrayType::Empty => Err(Error::NoMax), + }; + + unary_array_op(name, env, args, op) +} + +fn min(env: &Env, args: &[Expr]) -> Result { + let name = "min"; + + let op = |arg| match arg { + ArrayType::Int(ints) => ints + .iter() + .copied() + .min() + .ok_or(Error::NoMin) + .map(|m| Primitive(Int(m))), + + ArrayType::Float(floats) => floats + .iter() + .copied() + .min() + .ok_or(Error::NoMin) + .map(|m| Primitive(Float(m))), + + ArrayType::Bool(_) => Err(Error::BadType(name)), + ArrayType::Empty => Err(Error::NoMin), + }; + + unary_array_op(name, env, args, op) +} + +fn avg(env: &Env, args: &[Expr]) -> Result { + let name = "avg"; + + let op = |arg| match arg { + ArrayType::Int(ints) => { + let count = ints.len() as i64; + let sum = ints.iter().copied().sum::(); + Ok(Primitive(Float(F64::new(sum as f64 / count as f64)?))) + } + + ArrayType::Float(floats) => { + let count = floats.len() as i64; + let sum = floats.iter().copied().sum::(); + Ok(Primitive(Float(F64::new(sum.into_inner() / count as f64)?))) + } + + ArrayType::Bool(_) => Err(Error::BadType(name)), + ArrayType::Empty => Err(Error::NoAvg), + }; + + unary_array_op(name, env, args, op) +} + +fn median(env: &Env, args: &[Expr]) -> Result { + let name = "median"; + + let op = |arg| match arg { + ArrayType::Int(mut ints) => { + ints.sort(); + let mid = ints.len() / 2; + Ok(Primitive(Int(ints[mid]))) + } + ArrayType::Float(mut floats) => { + floats.sort(); + let mid = floats.len() / 2; + Ok(Primitive(Float(floats[mid]))) + } + ArrayType::Bool(_) => Err(Error::BadType(name)), + ArrayType::Empty => Err(Error::NoMedian), + }; + + unary_array_op(name, env, args, op) +} + +fn count(env: &Env, args: &[Expr]) -> Result { + let name = "count"; + + let op = |arg| match arg { + ArrayType::Int(ints) => Ok(Primitive(Int(ints.len() as i64))), + ArrayType::Float(floats) => Ok(Primitive(Int(floats.len() as i64))), + ArrayType::Bool(bools) => Ok(Primitive(Int(bools.len() as i64))), + ArrayType::Empty => Ok(Primitive(Int(0))), + }; + + unary_array_op(name, env, args, op) +} + +fn all(env: &Env, args: &[Expr]) -> Result { + let name = "all"; + + let op = |arr, ident: Ident, body: Box| { + let result = match arr { + ArrayType::Int(ints) => ints + .iter() + .map(|val| eval_lambda(env, &ident, Int(*val), (*body).clone())) + .process_results(|mut iter| { + iter.all(|expr| matches!(expr, Primitive(Bool(true)))) + })?, + ArrayType::Float(floats) => floats + .iter() + .map(|val| eval_lambda(env, &ident, Float(*val), (*body).clone())) + .process_results(|mut iter| { + iter.all(|expr| matches!(expr, Primitive(Bool(true)))) + })?, + ArrayType::Bool(bools) => bools + .iter() + .map(|val| eval_lambda(env, &ident, Bool(*val), (*body).clone())) + .process_results(|mut iter| { + iter.all(|expr| matches!(expr, Primitive(Bool(true)))) + })?, + ArrayType::Empty => true, + }; + + Ok(Primitive(Bool(result))) + }; + + higher_order_array_op(name, env, args, op) +} + +fn nall(env: &Env, args: &[Expr]) -> Result { + let name = "nall"; + + let op = |arr, ident: Ident, body: Box| { + let result = match arr { + ArrayType::Int(ints) => ints + .iter() + .map(|val| eval_lambda(env, &ident, Int(*val), (*body).clone())) + .process_results(|mut iter| { + iter.all(|expr| matches!(expr, Primitive(Bool(true)))).not() + })?, + ArrayType::Float(floats) => floats + .iter() + .map(|val| eval_lambda(env, &ident, Float(*val), (*body).clone())) + .process_results(|mut iter| { + iter.all(|expr| matches!(expr, Primitive(Bool(true)))).not() + })?, + ArrayType::Bool(bools) => bools + .iter() + .map(|val| eval_lambda(env, &ident, Bool(*val), (*body).clone())) + .process_results(|mut iter| { + iter.all(|expr| matches!(expr, Primitive(Bool(true)))).not() + })?, + ArrayType::Empty => false, + }; + + Ok(Primitive(Bool(result))) + }; + + higher_order_array_op(name, env, args, op) +} + +fn some(env: &Env, args: &[Expr]) -> Result { + let name = "some"; + + let op = |arr, ident: Ident, body: Box| { + let result = match arr { + ArrayType::Int(ints) => ints + .iter() + .map(|val| eval_lambda(env, &ident, Int(*val), (*body).clone())) + .process_results(|mut iter| { + iter.any(|expr| matches!(expr, Primitive(Bool(true)))) + })?, + ArrayType::Float(floats) => floats + .iter() + .map(|val| eval_lambda(env, &ident, Float(*val), (*body).clone())) + .process_results(|mut iter| { + iter.any(|expr| matches!(expr, Primitive(Bool(true)))) + })?, + ArrayType::Bool(bools) => bools + .iter() + .map(|val| eval_lambda(env, &ident, Bool(*val), (*body).clone())) + .process_results(|mut iter| { + iter.any(|expr| matches!(expr, Primitive(Bool(true)))) + })?, + ArrayType::Empty => false, + }; + + Ok(Primitive(Bool(result))) + }; + + higher_order_array_op(name, env, args, op) +} + +fn none(env: &Env, args: &[Expr]) -> Result { + let name = "none"; + + let op = |arr, ident: Ident, body: Box| { + let result = match arr { + ArrayType::Int(ints) => ints + .iter() + .map(|val| eval_lambda(env, &ident, Int(*val), (*body).clone())) + .process_results(|mut iter| { + iter.any(|expr| matches!(expr, Primitive(Bool(true)))).not() + })?, + ArrayType::Float(floats) => floats + .iter() + .map(|val| eval_lambda(env, &ident, Float(*val), (*body).clone())) + .process_results(|mut iter| { + iter.any(|expr| matches!(expr, Primitive(Bool(true)))).not() + })?, + ArrayType::Bool(bools) => bools + .iter() + .map(|val| eval_lambda(env, &ident, Bool(*val), (*body).clone())) + .process_results(|mut iter| { + iter.any(|expr| matches!(expr, Primitive(Bool(true)))).not() + })?, + ArrayType::Empty => true, + }; + + Ok(Primitive(Bool(result))) + }; + + higher_order_array_op(name, env, args, op) +} + +fn filter(env: &Env, args: &[Expr]) -> Result { + let name = "filter"; + + let op = |arr, ident: Ident, body: Box| { + let arr = match arr { + ArrayType::Int(ints) => ints + .iter() + .map(|val| Ok((val, eval_lambda(env, &ident, Int(*val), (*body).clone())))) + .filter_map_ok(|(val, expr)| { + if let Ok(Primitive(Bool(true))) = expr { + Some(Primitive::Int(*val)) + } else { + None + } + }) + .collect::>>()?, + ArrayType::Float(floats) => floats + .iter() + .map(|val| Ok((val, eval_lambda(env, &ident, Float(*val), (*body).clone())))) + .filter_map_ok(|(val, expr)| { + if let Ok(Primitive(Bool(true))) = expr { + Some(Primitive::Float(*val)) + } else { + None + } + }) + .collect::>>()?, + ArrayType::Bool(bools) => bools + .iter() + .map(|val| Ok((val, eval_lambda(env, &ident, Bool(*val), (*body).clone())))) + .filter_map_ok(|(val, expr)| { + if let Ok(Primitive(Bool(true))) = expr { + Some(Primitive::Bool(*val)) + } else { + None + } + }) + .collect::>>()?, + ArrayType::Empty => Vec::new(), + }; + + Ok(Array(arr)) + }; + + higher_order_array_op(name, env, args, op) +} + +fn foreach(env: &Env, args: &[Expr]) -> Result { + let name = "foreach"; + + let op = |arr, ident: Ident, body: Box| { + let arr = match arr { + ArrayType::Int(ints) => ints + .iter() + .map(|val| eval_lambda(env, &ident, Int(*val), (*body).clone())) + .map(|expr| match expr { + Ok(Primitive(inner)) => Ok(inner), + Ok(_) => Err(Error::BadType(name)), + Err(err) => Err(err), + }) + .collect::>>()?, + ArrayType::Float(floats) => floats + .iter() + .map(|val| eval_lambda(env, &ident, Float(*val), (*body).clone())) + .map(|expr| match expr { + Ok(Primitive(inner)) => Ok(inner), + Ok(_) => Err(Error::BadType(name)), + Err(err) => Err(err), + }) + .collect::>>()?, + ArrayType::Bool(bools) => bools + .iter() + .map(|val| eval_lambda(env, &ident, Bool(*val), (*body).clone())) + .map(|expr| match expr { + Ok(Primitive(inner)) => Ok(inner), + Ok(_) => Err(Error::BadType(name)), + Err(err) => Err(err), + }) + .collect::>>()?, + ArrayType::Empty => Vec::new(), + }; + + Ok(Array(arr)) + }; + + higher_order_array_op(name, env, args, op) +} + +fn dbg(env: &Env, args: &[Expr]) -> Result { + let name = "dbg"; + check_num_args(name, args, 1)?; + let arg = &args[0]; + let result = eval(env, arg)?; + log::debug!("{arg} = {result}"); + Ok(result) +} diff --git a/hipcheck/src/policy_exprs/error.rs b/hipcheck/src/policy_exprs/error.rs new file mode 100644 index 000000000..586fb3b7c --- /dev/null +++ b/hipcheck/src/policy_exprs/error.rs @@ -0,0 +1,97 @@ +use crate::policy_exprs::{Expr, Ident, LexingError}; +use nom::{error::ErrorKind, Needed}; +use ordered_float::FloatIsNan; + +/// `Result` which uses [`Error`]. +pub type Result = std::result::Result; + +/// An error arising during program execution. +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("missing close paren")] + MissingOpenParen, + + #[error("missing open paren")] + MissingCloseParen, + + #[error("missing ident")] + MissingIdent, + + #[error("wrong type in ident spot")] + WrongTypeInIdentSpot, + + #[error("missing args")] + MissingArgs, + + #[error(transparent)] + Lex(#[from] LexingError), + + #[error("expression returned '{0:?}', not a boolean")] + DidNotReturnBool(Expr), + + #[error("tried to call unknown function '{0}'")] + UnknownFunction(String), + + #[error("ident '{0}' resolved to a variable, not a function")] + FoundVarExpectedFunc(String), + + #[error("parsing did not consume the entire input {}", needed_str(.0))] + IncompleteParse(Needed), + + #[error("parse failed with kind '{kind:?}', with '{remaining}' remaining")] + Parse { remaining: String, kind: ErrorKind }, + + #[error(transparent)] + FloatIsNan(#[from] FloatIsNan), + + #[error("too many args to '{name}'; expected {expected}, got {given}")] + TooManyArgs { + name: String, + expected: usize, + given: usize, + }, + + #[error("not enough args to '{name}'; expected {expected}, got {given}")] + NotEnoughArgs { + name: String, + expected: usize, + given: usize, + }, + + #[error("called '{0}' with mismatched types")] + BadType(&'static str), + + #[error("no max value found in array")] + NoMax, + + #[error("no min value found in array")] + NoMin, + + #[error("no avg value found for array")] + NoAvg, + + #[error("no median value found for array")] + NoMedian, + + #[error("array mixing multiple primitive types")] + InconsistentArrayTypes, + + #[error("variable '{0}' is not bound")] + UnboundVar(Ident), + + #[error("variable '{0}' conflicts with function")] + VarConflictsWithFunc(Ident), + + #[error("variable '{checked}' resolves to another variable '{found}'")] + VarResolvesToVar { checked: Ident, found: Ident }, + + #[error("variable is already bound")] + AlreadyBound, +} + +fn needed_str(needed: &Needed) -> String { + match needed { + Needed::Unknown => String::from(""), + Needed::Size(bytes) => format!(", needed {} more bytes", bytes), + } +} diff --git a/hipcheck/src/policy_exprs/expr.rs b/hipcheck/src/policy_exprs/expr.rs new file mode 100644 index 000000000..a04c3d982 --- /dev/null +++ b/hipcheck/src/policy_exprs/expr.rs @@ -0,0 +1,280 @@ +use crate::policy_exprs::env::Binding; +use crate::policy_exprs::env::Env; +use crate::policy_exprs::token::Token; +use crate::policy_exprs::Error; +use crate::policy_exprs::Result; +use crate::policy_exprs::Tokens; +use itertools::Itertools; +use nom::branch::alt; +use nom::combinator::all_consuming; +use nom::combinator::map; +use nom::multi::many0; +use nom::sequence::tuple; +use nom::Finish as _; +use nom::IResult; +use ordered_float::NotNan; +use std::fmt::Display; +use std::ops::Deref; + +/// A `deke` expression to evaluate. +#[derive(Debug, PartialEq, Eq, Clone)] +pub enum Expr { + /// Primitive data (ints, floats, bool). + Primitive(Primitive), + + /// An array of primitive data. + Array(Vec), + + /// Stores the name of the function, followed by the args. + Function(Ident, Vec), + + /// Stores the name of the input variable, followed by the lambda body. + Lambda(Ident, Box), +} + +/// Primitive data. +#[derive(Debug, PartialEq, Eq, Clone)] +pub enum Primitive { + /// Identifier in a lambda, to be substituted. + Identifier(Ident), + + /// Signed 64-bit integer. + Int(i64), + + /// 64-bit float, not allowed to be NaN. + Float(F64), + + /// Boolean. + Bool(bool), +} + +/// A variable or function identifier. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Ident(pub String); + +/// A non-NaN 64-bit floating point number. +pub type F64 = NotNan; + +impl Display for Expr { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Expr::Primitive(primitive) => write!(f, "{}", primitive), + Expr::Array(array) => { + write!(f, "[{}]", array.iter().map(ToString::to_string).join(" ")) + } + Expr::Function(ident, args) => { + let args = args.iter().map(ToString::to_string).join(" "); + write!(f, "({} {})", ident, args) + } + Expr::Lambda(arg, body) => write!(f, "(lambda ({}) {}", arg, body), + } + } +} + +impl Display for Primitive { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Primitive::Identifier(ident) => write!(f, "{}", ident), + Primitive::Int(i) => write!(f, "{}", i), + Primitive::Float(fl) => write!(f, "{}", fl), + Primitive::Bool(b) => write!(f, "{}", if *b { "#t" } else { "#f" }), + } + } +} + +impl Primitive { + pub fn resolve(&self, env: &Env) -> Result { + match self { + Primitive::Identifier(ident) => match env.get(ident) { + Some(Binding::Var(Primitive::Identifier(found))) => Err(Error::VarResolvesToVar { + checked: ident.clone(), + found, + }), + Some(Binding::Var(var)) => Ok(var), + Some(Binding::Fn(_)) => Err(Error::VarConflictsWithFunc(ident.clone())), + None => Err(Error::UnboundVar(ident.clone())), + }, + _ => Ok(self.clone()), + } + } +} + +impl Display for Ident { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +impl Deref for Ident { + type Target = str; + + fn deref(&self) -> &Self::Target { + self.0.deref() + } +} + +crate::token_parser!(token: Token); + +crate::data_variant_parser! { + fn parse_integer(input) -> Result; + pattern = Token::Integer(n) => Primitive::Int(n); +} + +crate::data_variant_parser! { + fn parse_float(input) -> Result; + pattern = Token::Float(f) => Primitive::Float(f); +} + +crate::data_variant_parser! { + fn parse_bool(input) -> Result; + pattern = Token::Bool(b) => Primitive::Bool(b); +} + +crate::data_variant_parser! { + fn parse_ident(input) -> Result; + pattern = Token::Ident(s) => s.to_owned(); +} + +// Helper type for token parsing. +pub type Input<'source> = Tokens<'source, Token>; + +/// Parse a single piece of primitive data. +fn parse_primitive(input: Input<'_>) -> IResult, Primitive> { + alt((parse_integer, parse_float, parse_bool))(input) +} + +/// Parse an array. +fn parse_array(input: Input<'_>) -> IResult, Expr> { + let parser = tuple((Token::OpenBrace, many0(parse_primitive), Token::CloseBrace)); + let mut parser = map(parser, |(_, inner, _)| Expr::Array(inner)); + parser(input) +} + +/// Parse an expression. +fn parse_expr(input: Input<'_>) -> IResult, Expr> { + let primitive = map(parse_primitive, Expr::Primitive); + alt((primitive, parse_array, parse_function))(input) +} + +/// Parse a function call. +fn parse_function(input: Input<'_>) -> IResult, Expr> { + let parser = tuple(( + Token::OpenParen, + parse_ident, + many0(parse_expr), + Token::CloseParen, + )); + let mut parser = map(parser, |(_, ident, args, _)| { + Expr::Function(Ident(ident), args) + }); + parser(input) +} + +pub fn parse(input: &str) -> Result { + let tokens = Tokens::new(input); + let mut parser = all_consuming(parse_function); + + match parser(tokens).finish() { + Ok((rest, expr)) if rest.is_empty() => Ok(expr), + Ok(_) => Err(Error::IncompleteParse(nom::Needed::Unknown)), + Err(err) => { + let remaining = err.input.lexer().slice().to_string(); + let kind = err.code; + Err(Error::Parse { remaining, kind }) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use test_log::test; + + trait IntoExpr { + fn into_expr(self) -> Expr; + } + + impl IntoExpr for Expr { + fn into_expr(self) -> Expr { + self + } + } + + impl IntoExpr for Primitive { + fn into_expr(self) -> Expr { + Expr::Primitive(self) + } + } + + fn func(name: &str, args: Vec) -> Expr { + let args = args.into_iter().map(|arg| arg.into_expr()).collect(); + Expr::Function(Ident(String::from(name)), args) + } + + fn int(val: i64) -> Primitive { + Primitive::Int(val) + } + + fn float(val: f64) -> Primitive { + Primitive::Float(F64::new(val).unwrap()) + } + + #[allow(unused)] + fn boolean(val: bool) -> Primitive { + Primitive::Bool(val) + } + + fn array(vals: Vec) -> Expr { + Expr::Array(vals) + } + + #[test] + fn parse_function() { + let input = "(add 2 3)"; + let expected = func("add", vec![int(2), int(3)]); + let result = parse(input).unwrap(); + assert_eq!(result, expected); + } + + #[test] + fn parse_nested_function() { + let input = "(add (add 1 2) 3)"; + let expected = func( + "add", + vec![func("add", vec![int(1), int(2)]), int(3).into_expr()], + ); + let result = parse(input).unwrap(); + assert_eq!(result, expected); + } + + #[test] + fn parse_array() { + let input = "(eq 0 (count (filter (gt 8.0) [1.0 2.0 10.0 20.0 30.0])))"; + + let expected = func( + "eq", + vec![ + int(0).into_expr(), + func( + "count", + vec![func( + "filter", + vec![ + func("gt", vec![float(8.0)]), + array(vec![ + float(1.0), + float(2.0), + float(10.0), + float(20.0), + float(30.0), + ]), + ], + )], + ), + ], + ); + + let result = parse(input).unwrap(); + assert_eq!(result, expected); + } +} diff --git a/hipcheck/src/policy_exprs/mod.rs b/hipcheck/src/policy_exprs/mod.rs new file mode 100644 index 000000000..57c6b91a5 --- /dev/null +++ b/hipcheck/src/policy_exprs/mod.rs @@ -0,0 +1,141 @@ +#![allow(unused)] + +mod bridge; +mod env; +mod error; +mod expr; +mod token; + +pub(crate) use crate::policy_exprs::bridge::Tokens; +use crate::policy_exprs::env::Env; +pub use crate::policy_exprs::error::Error; +pub use crate::policy_exprs::error::Result; +pub use crate::policy_exprs::expr::Expr; +pub use crate::policy_exprs::expr::Ident; +pub(crate) use crate::policy_exprs::expr::F64; +pub use crate::policy_exprs::token::LexingError; +use env::Binding; +use expr::parse; +pub use expr::Primitive; +use std::ops::Deref; + +/// Evaluates `deke` expressions. +pub struct Executor { + env: Env<'static>, +} + +impl Executor { + /// Create an `Executor` with the standard set of functions defined. + pub fn std() -> Self { + Executor { env: Env::std() } + } + + /// Run a `deke` program. + pub fn run(&self, raw_program: &str) -> Result { + match self.parse_and_eval(raw_program)? { + Expr::Primitive(Primitive::Bool(b)) => Ok(b), + result => Err(Error::DidNotReturnBool(result)), + } + } + + /// Run a `deke` program, but don't try to convert the result to a `bool`. + pub fn parse_and_eval(&self, raw_program: &str) -> Result { + let program = parse(raw_program)?; + let expr = eval(&self.env, &program)?; + Ok(expr) + } +} + +/// Evaluate the `Expr`, returning a boolean. +pub(crate) fn eval(env: &Env, program: &Expr) -> Result { + let output = match program { + Expr::Primitive(primitive) => Ok(Expr::Primitive(primitive.resolve(env)?)), + Expr::Array(_) => Ok(program.clone()), + Expr::Function(name, args) => { + let binding = env + .get(name) + .ok_or_else(|| Error::UnknownFunction(name.deref().to_owned()))?; + + if let Binding::Fn(op) = binding { + op(env, args) + } else { + Err(Error::FoundVarExpectedFunc(name.deref().to_owned())) + } + } + Expr::Lambda(_, body) => Ok((**body).clone()), + }; + + log::debug!("input: {program:?}, output: {output:?}"); + + output +} + +#[cfg(test)] +mod tests { + use super::*; + use test_log::test; + + #[test] + fn run_basic() { + let program = "(eq (add 1 2) 3)"; + let is_true = Executor::std().run(program).unwrap(); + assert!(is_true); + } + + #[test] + fn eval_basic() { + let program = "(add 1 2)"; + let result = Executor::std().parse_and_eval(program).unwrap(); + assert_eq!(result, Expr::Primitive(Primitive::Int(3))); + } + + #[test] + fn eval_bools() { + let program = "(neq 1 2)"; + let result = Executor::std().parse_and_eval(program).unwrap(); + assert_eq!(result, Expr::Primitive(Primitive::Bool(true))); + } + + #[test] + fn eval_array() { + let program = "(max [1 4 6 10 2 3 0])"; + let result = Executor::std().parse_and_eval(program).unwrap(); + assert_eq!(result, Expr::Primitive(Primitive::Int(10))); + } + + #[test] + fn run_array() { + let program = "(eq 7 (count [1 4 6 10 2 3 0]))"; + let is_true = Executor::std().run(program).unwrap(); + assert!(is_true); + } + + #[test] + fn eval_higher_order_func() { + let program = "(eq 3 (count (filter (gt 8.0) [1.0 2.0 10.0 20.0 30.0])))"; + let result = Executor::std().parse_and_eval(program).unwrap(); + assert_eq!(result, Expr::Primitive(Primitive::Bool(true))); + } + + #[test] + fn eval_foreach() { + let program = + "(eq 3 (count (filter (gt 8.0) (foreach (sub 1.0) [1.0 2.0 10.0 20.0 30.0]))))"; + let result = Executor::std().parse_and_eval(program).unwrap(); + assert_eq!(result, Expr::Primitive(Primitive::Bool(true))); + } + + #[test] + fn eval_basic_filter() { + let program = "(filter (eq 0) [1 0 1 0 0 1 2])"; + let result = Executor::std().parse_and_eval(program).unwrap(); + assert_eq!( + result, + Expr::Array(vec![ + Primitive::Int(0), + Primitive::Int(0), + Primitive::Int(0) + ]) + ); + } +} diff --git a/hipcheck/src/policy_exprs/token.rs b/hipcheck/src/policy_exprs/token.rs new file mode 100644 index 000000000..0d8b60caf --- /dev/null +++ b/hipcheck/src/policy_exprs/token.rs @@ -0,0 +1,164 @@ +use crate::policy_exprs::F64; +use logos::Lexer; +use logos::Logos; +use ordered_float::FloatIsNan; +use std::fmt::Display; +use std::num::ParseFloatError; +use std::num::ParseIntError; + +type Result = std::result::Result; + +#[derive(Logos, Clone, Debug, PartialEq)] +#[logos(skip r"[ \t\n\f]+", error = LexingError)] +pub enum Token { + #[token("(")] + OpenParen, + + #[token(")")] + CloseParen, + + #[token("[")] + OpenBrace, + + #[token("]")] + CloseBrace, + + #[regex(r"\#[tf]", lex_bool)] + Bool(bool), + + #[regex(r"-?(?:0|[1-9]\d*)(?:\.\d+)?(?:[eE][+-]?\d+)?", lex_float)] + Float(F64), + + #[regex(r"([1-9]?[0-9]*)", lex_integer, priority = 20)] + Integer(i64), + + #[regex("([a-zA-Z]+)", lex_ident)] + Ident(String), +} + +/// Lex a single boolean. +fn lex_bool(input: &mut Lexer<'_, Token>) -> Result { + match input.slice() { + "#t" => Ok(true), + "#f" => Ok(false), + value => Err(LexingError::InvalidBool(String::from(value))), + } +} + +/// Lex a single integer. +fn lex_integer(input: &mut Lexer<'_, Token>) -> Result { + let s = input.slice(); + let i = s + .parse::() + .map_err(|err| LexingError::InvalidInteger(s.to_string(), err))?; + Ok(i) +} + +/// Lex a single float. +fn lex_float(input: &mut Lexer<'_, Token>) -> Result { + let s = input.slice(); + let f = s + .parse::() + .map_err(|err| LexingError::InvalidFloat(s.to_string(), err))?; + Ok(F64::new(f)?) +} + +/// Lex a single identifier. +fn lex_ident(input: &mut Lexer<'_, Token>) -> Result { + Ok(input.slice().to_owned()) +} + +impl Display for Token { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Token::OpenParen => write!(f, "("), + Token::CloseParen => write!(f, ")"), + Token::OpenBrace => write!(f, "["), + Token::CloseBrace => write!(f, "]"), + Token::Bool(true) => write!(f, "#t"), + Token::Bool(false) => write!(f, "#f"), + Token::Integer(i) => write!(f, "{i}"), + Token::Float(fl) => write!(f, "{fl}"), + Token::Ident(i) => write!(f, "{i}"), + } + } +} + +/// Error arising during lexing. +#[derive(Default, Debug, Clone, PartialEq, thiserror::Error)] +pub enum LexingError { + #[error("an unknown lexing error occured")] + #[default] + UnknownError, + + #[error("failed to parse integer")] + InvalidInteger(String, ParseIntError), + + #[error("failed to parse float")] + InvalidFloat(String, ParseFloatError), + + #[error("float is not a number")] + FloatIsNan(#[from] FloatIsNan), + + #[error("invalid boolean, found '{0}'")] + InvalidBool(String), +} + +#[cfg(test)] +mod tests { + use crate::policy_exprs::token::Token; + use crate::policy_exprs::Result; + use crate::policy_exprs::F64; + use logos::Logos as _; + use test_log::test; + + // Helper function for running the lexer to get all tokens. + fn lex(input: &str) -> Result> { + let tokens = Token::lexer(input) + .map(|res| res.map_err(Into::into)) + .collect::>>()?; + Ok(tokens) + } + + #[test] + fn basic_lexing() { + let raw_program = "(add 1 2)"; + let expected = vec![ + Token::OpenParen, + Token::Ident(String::from("add")), + Token::Integer(1), + Token::Integer(2), + Token::CloseParen, + ]; + let tokens = lex(raw_program).unwrap(); + assert_eq!(tokens, expected); + } + + #[test] + fn basic_lexing_with_floats() { + let raw_program = "(add 1.0 2.0)"; + let expected = vec![ + Token::OpenParen, + Token::Ident(String::from("add")), + Token::Float(F64::new(1.0).unwrap()), + Token::Float(F64::new(2.0).unwrap()), + Token::CloseParen, + ]; + let tokens = lex(raw_program).unwrap(); + assert_eq!(tokens, expected); + } + + #[test] + fn basic_lexing_with_bools() { + let raw_program = "(eq #t #f)"; + let expected = vec![ + Token::OpenParen, + Token::Ident(String::from("eq")), + Token::Bool(true), + Token::Bool(false), + Token::CloseParen, + ]; + let tokens = lex(raw_program).unwrap(); + assert_eq!(tokens, expected); + } +}