diff --git a/schema_salad/codegen.py b/schema_salad/codegen.py index f81dc5ca..dd005830 100644 --- a/schema_salad/codegen.py +++ b/schema_salad/codegen.py @@ -15,6 +15,7 @@ from .java_codegen import JavaCodeGen from .python_codegen import PythonCodeGen from .ref_resolver import Loader +from .rust_codegen import RustCodeGen from .schema import shortname from .typescript_codegen import TypeScriptCodeGen from .utils import aslist @@ -99,6 +100,10 @@ def codegen( gen = TypeScriptCodeGen(base, target=target, package=pkg, examples=examples) elif lang == "dotnet": gen = DotNetCodeGen(base, target=target, package=pkg, examples=examples) + elif lang == "rust": + gen = RustCodeGen(base_uri=base, package=pkg, salad_version=salad_version, target=target) + gen.parse(j) + return else: raise SchemaSaladException(f"Unsupported code generation language {lang!r}") diff --git a/schema_salad/rust/.gitignore b/schema_salad/rust/.gitignore new file mode 100644 index 00000000..2c96eb1b --- /dev/null +++ b/schema_salad/rust/.gitignore @@ -0,0 +1,2 @@ +target/ +Cargo.lock diff --git a/schema_salad/rust/Cargo.toml b/schema_salad/rust/Cargo.toml new file mode 100644 index 00000000..98a79e83 --- /dev/null +++ b/schema_salad/rust/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "{package_name}" +version = "{package_version}" +publish = false +edition = "2021" + +[features] +# TBD ... + +[dependencies] +# This is supposed to be imported from `crates.io` +salad-core = { path = "salad-core" } + diff --git a/schema_salad/rust/salad-core/.gitignore b/schema_salad/rust/salad-core/.gitignore new file mode 100644 index 00000000..4214654e --- /dev/null +++ b/schema_salad/rust/salad-core/.gitignore @@ -0,0 +1,14 @@ +.idea/ +*.iml +.vscode/ +.history/ +.zed/ +.DS_Store +[Dd]esktop.ini +*~ + +debug/ +target/ +Cargo.lock +**/*.rs.bk +*.pdb diff --git a/schema_salad/rust/salad-core/Cargo.toml b/schema_salad/rust/salad-core/Cargo.toml new file mode 100644 index 00000000..b9bfee9e --- /dev/null +++ b/schema_salad/rust/salad-core/Cargo.toml @@ -0,0 +1,35 @@ +[workspace] +members = ["crates/*"] +resolver = "2" + +[workspace.package] +license = "Apache-2.0" +authors = ["Giuseppe Eletto "] +edition = "2021" + +[workspace.dependencies] +compact_str = { version = "0.9" } +fxhash = { version = "0.2" } +proc-macro2 = { version = "1.0" } +quote = { version = "1.0" } +salad-serde = { path = "crates/serde" } +salad-types = { path = "crates/types" } +serde = { version = "1.0" } +serde_yaml_ng = { version = "0.10" } +syn = { version = "2.0" } + +[package] +name = "salad-core" +version = "0.1.0" +description = "Core block for Schema Salad generated parsers." +license.workspace = true +authors.workspace = true +edition.workspace = true + +[lib] +name = "salad_core" +path = "src/lib.rs" + +[dependencies] +salad-serde.workspace = true +salad-types.workspace = true diff --git a/schema_salad/rust/salad-core/LICENSE b/schema_salad/rust/salad-core/LICENSE new file mode 100644 index 00000000..c98d27d4 --- /dev/null +++ b/schema_salad/rust/salad-core/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + https://www.apache.org/licenses/LICENSE-2.0 + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + +Copyright [yyyy] [name of copyright owner] + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/schema_salad/rust/salad-core/crates/serde/Cargo.toml b/schema_salad/rust/salad-core/crates/serde/Cargo.toml new file mode 100644 index 00000000..fdedac0b --- /dev/null +++ b/schema_salad/rust/salad-core/crates/serde/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "salad-serde" +version = "0.1.0" +license.workspace = true +authors.workspace = true +edition.workspace = true +publish = false + +[dependencies] +salad-types.workspace = true +serde.workspace = true + +[dev-dependencies] +serde_yaml_ng.workspace = true diff --git a/schema_salad/rust/salad-core/crates/serde/LICENSE b/schema_salad/rust/salad-core/crates/serde/LICENSE new file mode 120000 index 00000000..30cff740 --- /dev/null +++ b/schema_salad/rust/salad-core/crates/serde/LICENSE @@ -0,0 +1 @@ +../../LICENSE \ No newline at end of file diff --git a/schema_salad/rust/salad-core/crates/serde/src/de/data.rs b/schema_salad/rust/salad-core/crates/serde/src/de/data.rs new file mode 100644 index 00000000..14fb05f6 --- /dev/null +++ b/schema_salad/rust/salad-core/crates/serde/src/de/data.rs @@ -0,0 +1,2 @@ +// WIP +pub struct SeedData; diff --git a/schema_salad/rust/salad-core/crates/serde/src/de/list.rs b/schema_salad/rust/salad-core/crates/serde/src/de/list.rs new file mode 100644 index 00000000..42ba11a0 --- /dev/null +++ b/schema_salad/rust/salad-core/crates/serde/src/de/list.rs @@ -0,0 +1,258 @@ +use std::{fmt, marker::PhantomData}; + +use salad_types::SaladType; +use serde::de; + +use super::{IntoDeserializeSeed, SeedData}; + +/// A list helper deserializer for handling both a single value or a list of values +/// of type `T`. +/// +/// This is useful for configurations where a field might accept either a single value +/// or a list of values with the same semantics. +pub struct SingleOrManySeed<'sd, T> { + pub(super) data: &'sd SeedData, + pub(super) _phant: PhantomData, +} + +impl<'de, 'sd, T> de::DeserializeSeed<'de> for SingleOrManySeed<'sd, T> +where + T: SaladType + IntoDeserializeSeed<'de, 'sd>, +{ + type Value = Box<[T]>; + + fn deserialize(self, deserializer: D) -> Result + where + D: de::Deserializer<'de>, + { + struct SingleOrManySeedVisitor<'sd, T> { + data: &'sd SeedData, + _phant: PhantomData, + } + + impl<'de, 'sd, T> SingleOrManySeedVisitor<'sd, T> + where + T: SaladType + IntoDeserializeSeed<'de, 'sd>, + { + // Private helper method to reduce duplication + #[inline] + fn visit_single_value(&self, value: V) -> Result, E> + where + E: de::Error, + V: de::IntoDeserializer<'de, E>, + { + let deserializer = de::IntoDeserializer::into_deserializer(value); + de::DeserializeSeed::deserialize(T::deserialize_seed(self.data), deserializer) + .map(|t| Box::from([t])) + } + } + + impl<'de, 'sd, T> de::Visitor<'de> for SingleOrManySeedVisitor<'sd, T> + where + T: SaladType + IntoDeserializeSeed<'de, 'sd>, + { + type Value = Box<[T]>; + + fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str("one or a list of values") + } + + fn visit_bool(self, v: bool) -> Result + where + E: de::Error, + { + self.visit_single_value(v) + } + + fn visit_i8(self, v: i8) -> Result + where + E: de::Error, + { + self.visit_i32(v as i32) + } + + fn visit_i16(self, v: i16) -> Result + where + E: de::Error, + { + self.visit_i32(v as i32) + } + + fn visit_i32(self, v: i32) -> Result + where + E: de::Error, + { + self.visit_single_value(v) + } + + fn visit_i64(self, v: i64) -> Result + where + E: de::Error, + { + self.visit_single_value(v) + } + + fn visit_u8(self, v: u8) -> Result + where + E: de::Error, + { + self.visit_i32(v as i32) + } + + fn visit_u16(self, v: u16) -> Result + where + E: de::Error, + { + self.visit_i32(v as i32) + } + + fn visit_u64(self, v: u64) -> Result + where + E: de::Error, + { + self.visit_single_value(v) + } + + fn visit_f32(self, v: f32) -> Result + where + E: de::Error, + { + self.visit_single_value(v) + } + + fn visit_f64(self, v: f64) -> Result + where + E: de::Error, + { + self.visit_single_value(v) + } + + fn visit_str(self, v: &str) -> Result + where + E: de::Error, + { + self.visit_single_value(v) + } + + fn visit_string(self, v: String) -> Result + where + E: de::Error, + { + self.visit_single_value(v) + } + + fn visit_bytes(self, v: &[u8]) -> Result + where + E: de::Error, + { + self.visit_single_value(v) + } + + fn visit_byte_buf(self, v: Vec) -> Result + where + E: de::Error, + { + self.visit_single_value(v) + } + + fn visit_map(self, map: A) -> Result + where + A: de::MapAccess<'de>, + { + let deserializer = de::value::MapAccessDeserializer::new(map); + de::DeserializeSeed::deserialize(T::deserialize_seed(self.data), deserializer) + .map(|t| Box::from([t])) + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: de::SeqAccess<'de>, + { + let capacity = seq.size_hint().unwrap_or(8); + let mut entries = Vec::with_capacity(capacity); + + while let Some(entry) = seq.next_element_seed(T::deserialize_seed(self.data))? { + entries.push(entry); + } + + Ok(entries.into_boxed_slice()) + } + } + + deserializer.deserialize_any(SingleOrManySeedVisitor { + data: self.data, + _phant: PhantomData, + }) + } +} + +#[cfg(test)] +mod tests { + use salad_types::SaladAny; + use serde::__private::de::{Content, ContentDeserializer}; + + use super::*; + + #[test] + fn single_object_entry() { + let input = r#" + type: object + key: value + "#; + + let deserializer: ContentDeserializer<'_, serde_yaml_ng::Error> = { + let content: Content<'static> = serde_yaml_ng::from_str::(input).unwrap(); + ContentDeserializer::new(content) + }; + + let object = de::DeserializeSeed::deserialize( + >::deserialize_seed(&SeedData), + deserializer, + ); + + assert!(object.is_ok_and(|r| matches!(r[0], SaladAny::Object(_)))) + } + + #[test] + fn single_primitive_entry() { + let input = r#"Hello, World!"#; + + let deserializer: ContentDeserializer<'_, serde_yaml_ng::Error> = { + let content: Content<'static> = serde_yaml_ng::from_str::(input).unwrap(); + ContentDeserializer::new(content) + }; + + let string = de::DeserializeSeed::deserialize( + >::deserialize_seed(&SeedData), + deserializer, + ); + + assert!(string.is_ok_and(|r| matches!(r[0], SaladAny::String(_)))) + } + + #[test] + fn multiple_entries() { + let input = r#" + - 1 + - 2.0 + - true + - Hello, World! + - type: object + key: value + "#; + + let deserializer: ContentDeserializer<'_, serde_yaml_ng::Error> = { + let content: Content<'static> = serde_yaml_ng::from_str::(input).unwrap(); + ContentDeserializer::new(content) + }; + + let string = de::DeserializeSeed::deserialize( + >::deserialize_seed(&SeedData), + deserializer, + ); + + assert!(string.is_ok_and( + |r| matches!(r[1], SaladAny::Float(_)) && matches!(r[3], SaladAny::String(_)) + )) + } +} diff --git a/schema_salad/rust/salad-core/crates/serde/src/de/map.rs b/schema_salad/rust/salad-core/crates/serde/src/de/map.rs new file mode 100644 index 00000000..aefd73ca --- /dev/null +++ b/schema_salad/rust/salad-core/crates/serde/src/de/map.rs @@ -0,0 +1,263 @@ +use std::{fmt, marker::PhantomData}; + +use salad_types::SaladType; +use serde::de::{self, DeserializeSeed}; + +use super::{IntoDeserializeSeed, SeedData}; + +/// A list helper deserializer, which allows flexible deserialization of data +/// represented either as maps or sequences of objects. +/// +/// Particularly useful when dealing with configurations or data formats +/// that might represent the same logical structure in different ways. +/// For example, in YAML: +/// +/// ```yaml +/// # Format 1: Sequence of maps with explicit keys +/// entries: +/// - key1: value1 +/// key2: value2 +/// +/// # Format 2: Nested map structure with the first value acting as a key +/// entries: +/// value1: +/// key2: value2 +/// +/// # Format 3: Map structure with key-predicate pairs +/// # Where: +/// # - The map key becomes the value for the specified `key` field +/// # - The map value becomes the value for the specified `predicate` field +/// entries: +/// value1: value2 +/// ``` +pub struct MapToListSeed<'sd, T> { + key: &'static str, + data: &'sd SeedData, + pred: Option<&'static str>, + _phant: PhantomData, +} + +impl<'de, 'sd, T> MapToListSeed<'sd, T> +where + T: SaladType + IntoDeserializeSeed<'de, 'sd>, +{ + /// Creates a new [`MapDeserializeSeed`] with the specified key and seed data. + /// + /// # Arguments + /// + /// * `key` - The field name that will be used as the key in the deserialized structure + /// * `data` - Additional seed data needed for deserialization + /// + /// # Examples + /// + /// ```no_run + /// # use serde::de; + /// # use crate::de::{MapDeserializeSeed, SeedData}; + /// # let data = SeedData; + /// // For mapping without a predicate + /// let seed = MapDeserializeSeed::new("class", &data); + /// ``` + pub fn new(key: &'static str, data: &'sd SeedData) -> Self { + Self { + key, + data, + pred: None, + _phant: PhantomData, + } + } + + /// Creates a new [`MapDeserializeSeed`] with the specified key, predicate, and seed data. + /// + /// # Arguments + /// + /// * `key` - The field name that will be used as the key in the deserialized structure + /// * `pred` - Predicate field name that enables the simpler key-value mapping format + /// * `data` - Additional seed data needed for deserialization + /// + /// # Examples + /// + /// ```no_run + /// # use serde::de; + /// # use crate::de::{MapDeserializeSeed, SeedData}; + /// # let data = SeedData; + /// // For mapping with a predicate + /// let seed = MapDeserializeSeed::with_predicate("class", "key", &data); + /// ``` + pub fn with_predicate(key: &'static str, pred: &'static str, data: &'sd SeedData) -> Self { + Self { + key, + data, + pred: Some(pred), + _phant: PhantomData, + } + } +} + +impl<'de, 'sd, T> de::DeserializeSeed<'de> for MapToListSeed<'sd, T> +where + T: SaladType + IntoDeserializeSeed<'de, 'sd>, +{ + type Value = Box<[T]>; + + fn deserialize(self, deserializer: D) -> Result + where + D: de::Deserializer<'de>, + { + struct MapVisitor<'sd, T> { + key: &'static str, + pred: Option<&'static str>, + data: &'sd SeedData, + _phant: PhantomData, + } + + impl<'de, 'sd, T> de::Visitor<'de> for MapVisitor<'sd, T> + where + T: SaladType + IntoDeserializeSeed<'de, 'sd>, + { + type Value = Box<[T]>; + + fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str("a map or sequence of objects") + } + + fn visit_map(self, mut map: A) -> Result + where + A: de::MapAccess<'de>, + { + use serde::__private::de::{Content, ContentDeserializer}; + + let capacity = map.size_hint().unwrap_or(1); + let mut entries = Vec::with_capacity(capacity); + + while let Some((key, value)) = map.next_entry::, Content<'de>>()? { + let value_map = match (value, self.pred) { + (Content::Map(mut value_map), _) => { + // Format 2: Add the key field to the existing map + let key_field = Content::Str(self.key); + value_map.reserve_exact(1); + value_map.push((key_field, key)); + value_map + } + (value, Some(pred)) => { + // Format 3: Build a map from key-value pair + let key_field = Content::Str(self.key); + let predicate_field = Content::Str(pred); + vec![(key_field, key), (predicate_field, value)] + } + (_, None) => { + return Err(de::Error::custom(format!( + "field `{}` requires a map or predicate value", + self.key + ))); + } + }; + + // Deserialize the created map into the target type + let deserializer = ContentDeserializer::new(Content::Map(value_map)); + let entry = T::deserialize_seed(self.data).deserialize(deserializer)?; + entries.push(entry); + } + + Ok(entries.into_boxed_slice()) + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: de::SeqAccess<'de>, + { + // Format 1: Sequence of objects + let capacity = seq.size_hint().unwrap_or(0); + let mut entries = Vec::with_capacity(capacity); + + while let Some(entry) = seq.next_element_seed(T::deserialize_seed(self.data))? { + entries.push(entry); + } + + Ok(entries.into_boxed_slice()) + } + } + + deserializer.deserialize_any(MapVisitor { + key: self.key, + pred: self.pred, + data: self.data, + _phant: PhantomData, + }) + } +} + +#[cfg(test)] +mod tests { + use salad_types::{SaladAny, SaladObject}; + use serde::__private::de::{Content, ContentDeserializer}; + + use super::*; + + fn setup_test_deserializer<'s>( + input: &'s str, + ) -> ContentDeserializer<'s, serde_yaml_ng::Error> { + let content: Content<'s> = serde_yaml_ng::from_str::(input).unwrap(); + ContentDeserializer::new(content) + } + + #[test] + fn list_entries() { + let input = r#" + - class: class_1 + key: value_1 + - class: class_2 + key: value_2 + - class: class_3 + key: value_3 + "#; + + let deserializer = setup_test_deserializer(input); + let to_match = SaladAny::String("value_2".into()); + let object_list = de::DeserializeSeed::deserialize( + MapToListSeed::<'_, SaladObject>::new("class", &SeedData), + deserializer, + ); + + assert!(object_list.is_ok_and(|r| matches!(r[1].get("key"), Some(s) if s == &to_match))) + } + + #[test] + fn map_entries() { + let input = r#" + class_1: + key: value_1 + class_2: + key: value_2 + class_3: + key: value_3 + "#; + + let deserializer = setup_test_deserializer(input); + let to_match = SaladAny::String("value_2".into()); + let object_list = de::DeserializeSeed::deserialize( + MapToListSeed::<'_, SaladObject>::new("class", &SeedData), + deserializer, + ); + + assert!(object_list.is_ok_and(|r| matches!(r[1].get("key"), Some(s) if s == &to_match))) + } + + #[test] + fn map_entries_with_predicate() { + let input = r#" + class_1: value_1 + class_2: value_2 + class_3: + key: value_3 + "#; + + let deserializer = setup_test_deserializer(input); + let to_match = SaladAny::String("value_2".into()); + let object_list = de::DeserializeSeed::deserialize( + MapToListSeed::<'_, SaladObject>::with_predicate("class", "key", &SeedData), + deserializer, + ); + + assert!(object_list.is_ok_and(|r| matches!(r[1].get("key"), Some(s) if s == &to_match))) + } +} diff --git a/schema_salad/rust/salad-core/crates/serde/src/de/mod.rs b/schema_salad/rust/salad-core/crates/serde/src/de/mod.rs new file mode 100644 index 00000000..3936377c --- /dev/null +++ b/schema_salad/rust/salad-core/crates/serde/src/de/mod.rs @@ -0,0 +1,73 @@ +use std::marker::PhantomData; +use serde::de; + +use salad_types::SaladType; + +mod data; +mod list; +mod map; + +use self::list::SingleOrManySeed; +pub use self::{data::SeedData, map::MapToListSeed}; + +/// Represents a type that can be converted into a serde +/// [`DeserializeSeed`](serde::de::DeserializeSeed). +pub trait IntoDeserializeSeed<'de, 'sd> { + type DeserializeSeed: de::DeserializeSeed<'de, Value = Self>; + + /// Returns a + /// [`DeserializeSeed`](https://docs.rs/serde/latest/serde/de/trait.DeserializeSeed.html) + /// instance from a [`SeedData`] reference that's able to deserialize this type. + fn deserialize_seed(data: &'sd SeedData) -> Self::DeserializeSeed; +} + +impl<'de, 'sd, T> IntoDeserializeSeed<'de, 'sd> for Box<[T]> +where + T: SaladType + IntoDeserializeSeed<'de, 'sd>, +{ + type DeserializeSeed = SingleOrManySeed<'sd, T>; + + #[inline] + fn deserialize_seed(data: &'sd SeedData) -> Self::DeserializeSeed { + SingleOrManySeed { + data, + _phant: PhantomData, + } + } +} + +macro_rules! impl_default_intoseed { + ( $( $ty:path ),* $(,)? ) => { + $( + impl<'sd> IntoDeserializeSeed<'_, 'sd> for $ty { + type DeserializeSeed = std::marker::PhantomData; + + #[inline] + fn deserialize_seed(_: &'sd SeedData) -> Self::DeserializeSeed { + std::marker::PhantomData + } + } + )* + }; +} + +impl_default_intoseed! { + // Any & Object + salad_types::SaladAny, + salad_types::SaladObject, + + // Primitives + salad_types::SaladBool, + salad_types::SaladInt, + salad_types::SaladLong, + salad_types::SaladFloat, + salad_types::SaladDouble, + salad_types::SaladString, + salad_types::SaladPrimitive, + + // Common + salad_types::common::ArrayName, + salad_types::common::EnumName, + salad_types::common::RecordName, + salad_types::common::PrimitiveType, +} diff --git a/schema_salad/rust/salad-core/crates/serde/src/lib.rs b/schema_salad/rust/salad-core/crates/serde/src/lib.rs new file mode 100644 index 00000000..7bbc60fa --- /dev/null +++ b/schema_salad/rust/salad-core/crates/serde/src/lib.rs @@ -0,0 +1 @@ +pub mod de; diff --git a/schema_salad/rust/salad-core/crates/types/Cargo.toml b/schema_salad/rust/salad-core/crates/types/Cargo.toml new file mode 100644 index 00000000..08eb12df --- /dev/null +++ b/schema_salad/rust/salad-core/crates/types/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "salad-types" +version = "0.1.0" +license.workspace = true +authors.workspace = true +edition.workspace = true +publish = false + +[dependencies] +compact_str = { workspace = true, features = ["serde"] } +fxhash = { workspace = true } +serde = { workspace = true } + +[dev-dependencies] +serde_yaml_ng = { workspace = true } diff --git a/schema_salad/rust/salad-core/crates/types/LICENSE b/schema_salad/rust/salad-core/crates/types/LICENSE new file mode 120000 index 00000000..30cff740 --- /dev/null +++ b/schema_salad/rust/salad-core/crates/types/LICENSE @@ -0,0 +1 @@ +../../LICENSE \ No newline at end of file diff --git a/schema_salad/rust/salad-core/crates/types/src/any/de.rs b/schema_salad/rust/salad-core/crates/types/src/any/de.rs new file mode 100644 index 00000000..9e848b14 --- /dev/null +++ b/schema_salad/rust/salad-core/crates/types/src/any/de.rs @@ -0,0 +1,356 @@ +use std::{collections::hash_map, slice}; + +use compact_str::CompactString; +use serde::de; + +use super::{SaladAny, SaladObject}; + +/// Deserializer for converting SaladAny values into SaladTypes +pub(super) struct SaladAnyDeserializer<'de>(pub &'de SaladAny); + +impl<'de> de::Deserializer<'de> for SaladAnyDeserializer<'de> { + type Error = de::value::Error; + + fn deserialize_any>(self, visitor: V) -> Result { + match self.0 { + SaladAny::Bool(b) => visitor.visit_bool(*b), + SaladAny::Int(i) => visitor.visit_i32(*i), + SaladAny::Long(l) => { + if super::INT_RANGE.contains(l) { + visitor.visit_i32(*l as i32) + } else { + visitor.visit_i64(*l) + } + } + SaladAny::Float(f) => visitor.visit_f32(*f), + SaladAny::Double(d) => { + if super::FLOAT_RANGE.contains(d) { + visitor.visit_f32(*d as f32) + } else { + visitor.visit_f64(*d) + } + } + SaladAny::String(s) => visitor.visit_str(s), + SaladAny::Object(o) => visitor.visit_map(SaladObjectMapAccess::new(o)), + SaladAny::List(l) => visitor.visit_seq(SaladAnyListSeqAccess::new(l)), + } + } + + fn deserialize_bool>(self, visitor: V) -> Result { + static ERR_MSG: &str = "boolean"; + + match self.0 { + SaladAny::Bool(b) => visitor.visit_bool(*b), + SaladAny::Int(1) | SaladAny::Long(1) => visitor.visit_bool(true), + SaladAny::Int(0) | SaladAny::Long(0) => visitor.visit_bool(false), + + // Errors + SaladAny::Int(i) => Err(de::Error::invalid_type( + de::Unexpected::Signed(*i as i64), + &ERR_MSG, + )), + SaladAny::Long(l) => Err(de::Error::invalid_type(de::Unexpected::Signed(*l), &ERR_MSG)), + SaladAny::Float(f) => Err(de::Error::invalid_type( + de::Unexpected::Float(*f as f64), + &ERR_MSG, + )), + SaladAny::Double(d) => Err(de::Error::invalid_type(de::Unexpected::Float(*d), &ERR_MSG)), + SaladAny::String(s) => Err(de::Error::invalid_type(de::Unexpected::Str(s), &ERR_MSG)), + SaladAny::Object(_) => Err(de::Error::invalid_type(de::Unexpected::Map, &ERR_MSG)), + SaladAny::List(_) => Err(de::Error::invalid_type(de::Unexpected::Seq, &ERR_MSG)), + } + } + + fn deserialize_i32>(self, visitor: V) -> Result { + static ERR_MSG: &str = "signed integer"; + + match self.0 { + SaladAny::Int(i) => visitor.visit_i32(*i), + SaladAny::Long(l) if super::INT_RANGE.contains(l) => visitor.visit_i32(*l as i32), + + // Errors + SaladAny::Bool(b) => Err(de::Error::invalid_type(de::Unexpected::Bool(*b), &ERR_MSG)), + SaladAny::Long(l) => Err(de::Error::invalid_type(de::Unexpected::Signed(*l), &ERR_MSG)), + SaladAny::Float(f) => Err(de::Error::invalid_type( + de::Unexpected::Float(*f as f64), + &ERR_MSG, + )), + SaladAny::Double(d) => Err(de::Error::invalid_type(de::Unexpected::Float(*d), &ERR_MSG)), + SaladAny::String(s) => Err(de::Error::invalid_type(de::Unexpected::Str(s), &ERR_MSG)), + SaladAny::Object(_) => Err(de::Error::invalid_type(de::Unexpected::Map, &ERR_MSG)), + SaladAny::List(_) => Err(de::Error::invalid_type(de::Unexpected::Seq, &ERR_MSG)), + } + } + + fn deserialize_i64>(self, visitor: V) -> Result { + static ERR_MSG: &str = "signed long integer"; + + match self.0 { + SaladAny::Long(l) => visitor.visit_i64(*l), + SaladAny::Int(i) => visitor.visit_i64(*i as i64), + + // Errors + SaladAny::Bool(b) => Err(de::Error::invalid_type(de::Unexpected::Bool(*b), &ERR_MSG)), + SaladAny::Float(f) => Err(de::Error::invalid_type( + de::Unexpected::Float(*f as f64), + &ERR_MSG, + )), + SaladAny::Double(d) => Err(de::Error::invalid_type(de::Unexpected::Float(*d), &ERR_MSG)), + SaladAny::String(s) => Err(de::Error::invalid_type(de::Unexpected::Str(s), &ERR_MSG)), + SaladAny::Object(_) => Err(de::Error::invalid_type(de::Unexpected::Map, &ERR_MSG)), + SaladAny::List(_) => Err(de::Error::invalid_type(de::Unexpected::Seq, &ERR_MSG)), + } + } + + fn deserialize_f32>(self, visitor: V) -> Result { + static ERR_MSG: &str = "float"; + + match self.0 { + SaladAny::Float(f) => visitor.visit_f32(*f), + SaladAny::Double(d) if super::FLOAT_RANGE.contains(d) => visitor.visit_f32(*d as f32), + + // Errors + SaladAny::Bool(b) => Err(de::Error::invalid_type(de::Unexpected::Bool(*b), &ERR_MSG)), + SaladAny::Int(i) => Err(de::Error::invalid_type( + de::Unexpected::Signed(*i as i64), + &ERR_MSG, + )), + SaladAny::Long(l) => Err(de::Error::invalid_type(de::Unexpected::Signed(*l), &ERR_MSG)), + SaladAny::Double(d) => Err(de::Error::invalid_type(de::Unexpected::Float(*d), &ERR_MSG)), + SaladAny::String(s) => Err(de::Error::invalid_type(de::Unexpected::Str(s), &ERR_MSG)), + SaladAny::Object(_) => Err(de::Error::invalid_type(de::Unexpected::Map, &ERR_MSG)), + SaladAny::List(_) => Err(de::Error::invalid_type(de::Unexpected::Seq, &ERR_MSG)), + } + } + + fn deserialize_f64>(self, visitor: V) -> Result { + static ERR_MSG: &str = "double"; + + match self.0 { + SaladAny::Double(d) => visitor.visit_f64(*d), + SaladAny::Float(f) => visitor.visit_f64(*f as f64), + + // Errors + SaladAny::Bool(b) => Err(de::Error::invalid_type(de::Unexpected::Bool(*b), &ERR_MSG)), + SaladAny::Int(i) => Err(de::Error::invalid_type( + de::Unexpected::Signed(*i as i64), + &ERR_MSG, + )), + SaladAny::Long(l) => Err(de::Error::invalid_type(de::Unexpected::Signed(*l), &ERR_MSG)), + SaladAny::String(s) => Err(de::Error::invalid_type(de::Unexpected::Str(s), &ERR_MSG)), + SaladAny::Object(_) => Err(de::Error::invalid_type(de::Unexpected::Map, &ERR_MSG)), + SaladAny::List(_) => Err(de::Error::invalid_type(de::Unexpected::Seq, &ERR_MSG)), + } + } + + fn deserialize_str>(self, visitor: V) -> Result { + static ERR_MSG: &str = "UTF-8 string"; + + match self.0 { + SaladAny::String(s) => visitor.visit_str(s), + + // Errors + SaladAny::Bool(b) => Err(de::Error::invalid_type(de::Unexpected::Bool(*b), &ERR_MSG)), + SaladAny::Int(i) => Err(de::Error::invalid_type( + de::Unexpected::Signed(*i as i64), + &ERR_MSG, + )), + SaladAny::Long(l) => Err(de::Error::invalid_type(de::Unexpected::Signed(*l), &ERR_MSG)), + SaladAny::Float(f) => Err(de::Error::invalid_type( + de::Unexpected::Float(*f as f64), + &ERR_MSG, + )), + SaladAny::Double(d) => Err(de::Error::invalid_type(de::Unexpected::Float(*d), &ERR_MSG)), + SaladAny::Object(_) => Err(de::Error::invalid_type(de::Unexpected::Map, &ERR_MSG)), + SaladAny::List(_) => Err(de::Error::invalid_type(de::Unexpected::Seq, &ERR_MSG)), + } + } + + fn deserialize_map>(self, visitor: V) -> Result { + static ERR_MSG: &str = "key-value map object"; + + match self.0 { + SaladAny::Object(o) => visitor.visit_map(SaladObjectMapAccess::new(o)), + + // Errors + SaladAny::Bool(b) => Err(de::Error::invalid_type(de::Unexpected::Bool(*b), &ERR_MSG)), + SaladAny::Int(i) => Err(de::Error::invalid_type( + de::Unexpected::Signed(*i as i64), + &ERR_MSG, + )), + SaladAny::Long(l) => Err(de::Error::invalid_type(de::Unexpected::Signed(*l), &ERR_MSG)), + SaladAny::Float(f) => Err(de::Error::invalid_type( + de::Unexpected::Float(*f as f64), + &ERR_MSG, + )), + SaladAny::Double(d) => Err(de::Error::invalid_type(de::Unexpected::Float(*d), &ERR_MSG)), + SaladAny::String(s) => Err(de::Error::invalid_type(de::Unexpected::Str(s), &ERR_MSG)), + SaladAny::List(_) => Err(de::Error::invalid_type(de::Unexpected::Seq, &ERR_MSG)), + } + } + + fn deserialize_seq>(self, visitor: V) -> Result { + static ERR_MSG: &str = "list of primitives/objects"; + + match self.0 { + SaladAny::List(l) => visitor.visit_seq(SaladAnyListSeqAccess::new(l)), + + // Errors + SaladAny::Bool(b) => Err(de::Error::invalid_type(de::Unexpected::Bool(*b), &ERR_MSG)), + SaladAny::Int(i) => Err(de::Error::invalid_type( + de::Unexpected::Signed(*i as i64), + &ERR_MSG, + )), + SaladAny::Long(l) => Err(de::Error::invalid_type(de::Unexpected::Signed(*l), &ERR_MSG)), + SaladAny::Float(f) => Err(de::Error::invalid_type( + de::Unexpected::Float(*f as f64), + &ERR_MSG, + )), + SaladAny::Double(d) => Err(de::Error::invalid_type(de::Unexpected::Float(*d), &ERR_MSG)), + SaladAny::String(s) => Err(de::Error::invalid_type(de::Unexpected::Str(s), &ERR_MSG)), + SaladAny::Object(_) => Err(de::Error::invalid_type(de::Unexpected::Map, &ERR_MSG)), + } + } + + // Unimplemented methods with a default implementation + serde::forward_to_deserialize_any! { + i8 i16 u8 u16 u32 u64 char string bytes byte_buf option unit + unit_struct newtype_struct tuple tuple_struct struct enum identifier ignored_any + } +} + +/// Map access implementation for SaladObject deserialization +pub(super) struct SaladObjectMapAccess<'de> { + iter: hash_map::Iter<'de, CompactString, SaladAny>, + value: Option<&'de SaladAny>, +} + +impl<'de> SaladObjectMapAccess<'de> { + pub fn new(obj: &'de SaladObject) -> Self { + Self { + iter: obj.map.iter(), + value: None, + } + } +} + +impl<'de> de::Deserializer<'de> for SaladObjectMapAccess<'de> { + type Error = de::value::Error; + + fn deserialize_any(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + visitor.visit_map(self) + } + + fn deserialize_map(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + visitor.visit_map(self) + } + + // Forward all other methods to deserialize_any + serde::forward_to_deserialize_any! { + bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 char str string bytes + byte_buf option unit unit_struct newtype_struct seq tuple + tuple_struct struct enum identifier ignored_any + } +} + +impl<'de> de::MapAccess<'de> for SaladObjectMapAccess<'de> { + type Error = de::value::Error; + + fn next_key_seed(&mut self, seed: K) -> Result, Self::Error> + where + K: de::DeserializeSeed<'de>, + { + match self.iter.next() { + Some((k, v)) => { + self.value = Some(v); + seed.deserialize(CompactStringDeserializer(k)).map(Some) + } + None => { + self.value = None; + Ok(None) + } + } + } + + fn next_value_seed(&mut self, seed: V) -> Result + where + V: de::DeserializeSeed<'de>, + { + let value = self.value.ok_or_else(|| de::Error::custom("value is missing"))?; + seed.deserialize(SaladAnyDeserializer(value)) + } +} + +/// Deserializer for CompactString values +struct CompactStringDeserializer<'de>(&'de CompactString); + +impl<'de> de::Deserializer<'de> for CompactStringDeserializer<'de> { + type Error = de::value::Error; + + fn deserialize_any>(self, visitor: V) -> Result { + visitor.visit_borrowed_str(self.0.as_str()) + } + + fn deserialize_str>(self, visitor: V) -> Result { + visitor.visit_borrowed_str(self.0.as_str()) + } + + fn deserialize_string>(self, visitor: V) -> Result { + visitor.visit_string(self.0.to_string()) + } + + fn deserialize_bytes>(self, visitor: V) -> Result { + visitor.visit_borrowed_bytes(self.0.as_bytes()) + } + + fn deserialize_byte_buf(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + visitor.visit_byte_buf(self.0.as_bytes().to_vec()) + } + + fn deserialize_identifier(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + visitor.visit_borrowed_str(self.0.as_str()) + } + + // Forward all other methods to deserialize_any + serde::forward_to_deserialize_any! { + bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 char option unit unit_struct + newtype_struct seq tuple tuple_struct map struct enum ignored_any + } +} + +/// Sequence access implementation for SaladAny list deserialization +struct SaladAnyListSeqAccess<'de> { + iter: slice::Iter<'de, SaladAny>, +} + +impl<'de> SaladAnyListSeqAccess<'de> { + pub fn new(list: &'de [SaladAny]) -> Self { + Self { iter: list.iter() } + } +} + +impl<'de> de::SeqAccess<'de> for SaladAnyListSeqAccess<'de> { + type Error = de::value::Error; + + fn next_element_seed(&mut self, seed: T) -> Result, Self::Error> + where + T: de::DeserializeSeed<'de>, + { + self.iter + .next() + .map(|v| seed.deserialize(SaladAnyDeserializer(v))) + .transpose() + } +} diff --git a/schema_salad/rust/salad-core/crates/types/src/any/mod.rs b/schema_salad/rust/salad-core/crates/types/src/any/mod.rs new file mode 100644 index 00000000..c13487ae --- /dev/null +++ b/schema_salad/rust/salad-core/crates/types/src/any/mod.rs @@ -0,0 +1,270 @@ +use std::fmt; +use std::convert::TryFrom; + +use serde::{de as des, ser}; + +mod de; +mod object; + +pub use self::object::SaladObject; +use crate::{ + primitive::{SaladBool, SaladDouble, SaladFloat, SaladInt, SaladLong, SaladString}, + util::{FLOAT_RANGE, INT_RANGE}, + SaladType, +}; + +/// The `SaladAny` type validates for any non-null value. +#[derive(Debug, Clone, PartialEq)] +pub enum SaladAny { + /// A binary value. + Bool(SaladBool), + /// 32-bit signed integer. + Int(SaladInt), + /// 64-bit signed integer. + Long(SaladLong), + /// Single precision (32-bit) IEEE 754 floating-point number. + Float(SaladFloat), + /// Double precision (64-bit) IEEE 754 floating-point number. + Double(SaladDouble), + /// Unicode character sequence. + String(SaladString), + /// Unknown object. + Object(SaladObject), + /// List of any values. + List(Box<[SaladAny]>), +} + +impl SaladAny { + /// Attempts to downcast to type `T` from a borrowed `SaladAny`. + /// N.B. When downcasting to a primitive (or object) type, consider using + /// a `match` expression or the `TryFrom::try_from` method. + pub fn downcast<'de, T>(&'de self) -> Result + where + T: SaladType + des::Deserialize<'de>, + { + let deserializer = self::de::SaladAnyDeserializer(self); + T::deserialize(deserializer).map_err(|_| self) + } + + /// Attempts to downcast from a consumed `SaladAny` to type `T`. + /// N.B. When downcasting to a primitive (or object) type, consider using + /// a `match` expression or the `TryFrom::try_from` method. + #[inline] + pub fn downcast_into(self) -> Result + where + for<'de> T: SaladType + des::Deserialize<'de>, + { + // Avoid duplicating the deserialization logic + match Self::downcast(&self) { + Ok(t) => Ok(t), + Err(_) => Err(self), + } + } + + /// Returns true if this value is a boolean + #[inline] + pub fn is_bool(&self) -> bool { + matches!(self, Self::Bool(_)) + } + + /// Returns true if this value is an integer (Int or Long) + #[inline] + pub fn is_integer(&self) -> bool { + matches!(self, Self::Int(_) | Self::Long(_)) + } + + /// Returns true if this value is a floating point number (Float or Double) + #[inline] + pub fn is_float(&self) -> bool { + matches!(self, Self::Float(_) | Self::Double(_)) + } + + /// Returns true if this value is a string + #[inline] + pub fn is_string(&self) -> bool { + matches!(self, Self::String(_)) + } + + /// Returns true if this value is an object + #[inline] + pub fn is_object(&self) -> bool { + matches!(self, Self::Object(_)) + } + + /// Returns true if this value is a list + #[inline] + pub fn is_list(&self) -> bool { + matches!(self, Self::List(_)) + } +} + +impl SaladType for SaladAny {} + +crate::util::impl_from_traits! { + SaladAny { + Bool => SaladBool, + Int => SaladInt, + Long => SaladLong, + Float => SaladFloat, + Double => SaladDouble, + String => SaladString, + Object => SaladObject, + } +} + +impl From> for SaladAny +where + T: SaladType, + Self: From, +{ + fn from(value: Vec) -> Self { + let list = value.into_iter().map(Self::from).collect(); + Self::List(list) + } +} + +impl From> for SaladAny +where + T: SaladType, + Self: From>, +{ + #[inline] + fn from(value: Box<[T]>) -> Self { + Self::from(value.into_vec()) + } +} + +impl ser::Serialize for SaladAny { + fn serialize(&self, serializer: S) -> Result { + match self { + Self::Bool(b) => serializer.serialize_bool(*b), + Self::Int(i) => serializer.serialize_i32(*i), + Self::Long(l) => serializer.serialize_i64(*l), + Self::Float(f) => serializer.serialize_f32(*f), + Self::Double(d) => serializer.serialize_f64(*d), + Self::String(s) => s.serialize(serializer), + Self::Object(o) => o.serialize(serializer), + Self::List(l) => l.serialize(serializer), + } + } +} + +impl<'de> des::Deserialize<'de> for SaladAny { + fn deserialize>(deserializer: D) -> Result { + struct SaladAnyVisitor; + + impl<'de> des::Visitor<'de> for SaladAnyVisitor { + type Value = SaladAny; + + fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str("a salad primitive, a key-value object, or a list of them") + } + + fn visit_bool(self, v: bool) -> Result { + Ok(SaladAny::Bool(v)) + } + + fn visit_i8(self, v: i8) -> Result { + Ok(SaladAny::Int(v.into())) + } + + fn visit_i16(self, v: i16) -> Result { + Ok(SaladAny::Int(v.into())) + } + + fn visit_i32(self, v: i32) -> Result { + Ok(SaladAny::Int(v)) + } + + fn visit_i64(self, v: i64) -> Result { + if INT_RANGE.contains(&v) { + Ok(SaladAny::Int(v as i32)) + } else { + Ok(SaladAny::Long(v)) + } + } + + fn visit_u8(self, v: u8) -> Result { + Ok(SaladAny::Int(v.into())) + } + + fn visit_u16(self, v: u16) -> Result { + Ok(SaladAny::Int(v.into())) + } + + fn visit_u32(self, v: u32) -> Result { + if v <= i32::MAX as u32 { + Ok(SaladAny::Int(v as i32)) + } else { + Ok(SaladAny::Long(v.into())) + } + } + + fn visit_u64(self, v: u64) -> Result { + if v <= i32::MAX as u64 { + Ok(SaladAny::Int(v as i32)) + } else if v <= i64::MAX as u64 { + Ok(SaladAny::Long(v as i64)) + } else { + Err(des::Error::invalid_value( + des::Unexpected::Unsigned(v), + &self, + )) + } + } + + fn visit_f32(self, v: f32) -> Result { + Ok(SaladAny::Float(v)) + } + + fn visit_f64(self, v: f64) -> Result { + if FLOAT_RANGE.contains(&v) { + Ok(SaladAny::Float(v as f32)) + } else { + Ok(SaladAny::Double(v)) + } + } + + fn visit_str(self, v: &str) -> Result { + Ok(SaladAny::String(v.into())) + } + + fn visit_string(self, v: String) -> Result { + Ok(SaladAny::String(v.into())) + } + + fn visit_bytes(self, v: &[u8]) -> Result { + match core::str::from_utf8(v) { + Ok(s) => Ok(SaladAny::String(s.into())), + Err(_) => Err(des::Error::invalid_value(des::Unexpected::Bytes(v), &self)), + } + } + + fn visit_map(self, map: A) -> Result + where + A: des::MapAccess<'de>, + { + let deserializer = des::value::MapAccessDeserializer::new(map); + ::deserialize(deserializer).map(SaladAny::Object) + } + + fn visit_seq(self, seq: A) -> Result + where + A: des::SeqAccess<'de>, + { + let deserializer = des::value::SeqAccessDeserializer::new(seq); + as des::Deserialize>::deserialize(deserializer).map(SaladAny::List) + } + + fn visit_none(self) -> Result { + Err(des::Error::invalid_type(des::Unexpected::Option, &self)) + } + + fn visit_unit(self) -> Result { + Err(des::Error::invalid_type(des::Unexpected::Unit, &self)) + } + } + + deserializer.deserialize_any(SaladAnyVisitor) + } +} diff --git a/schema_salad/rust/salad-core/crates/types/src/any/object.rs b/schema_salad/rust/salad-core/crates/types/src/any/object.rs new file mode 100644 index 00000000..7f052f30 --- /dev/null +++ b/schema_salad/rust/salad-core/crates/types/src/any/object.rs @@ -0,0 +1,139 @@ +use std::{borrow::Borrow, collections::HashMap, fmt, hash::Hash}; + +use compact_str::CompactString; +use fxhash::FxBuildHasher; +use serde::{ + de, + ser::{self, SerializeMap as _}, +}; + +use super::SaladAny; +use crate::SaladType; + +/// A key-value map representing an untyped Schema Salad object. +/// +/// `SaladObject` is a container that maps string keys to heterogeneous +/// values of type [`SaladAny`]. +/// It provides a flexible way to represent arbitrary Schema Salad objects +/// before they are parsed into their specific types. +/// +/// # Examples +/// ```ignore +/// use salad_core::SaladAny; +/// use salad_core::any::SaladObject; +/// +/// let obj = SaladObject::default(); +/// // Given some entries in the object +/// obj.get(key); // Returns Option<&SaladAny> +/// +/// // Downcast to a specific type +/// let typed_obj: Result = obj.downcast(); +/// ``` +#[derive(Clone, Default, PartialEq)] +pub struct SaladObject { + pub(super) map: HashMap, +} + +impl SaladObject { + /// Retrieves a reference to a value in the object by its key. + /// + /// Returns an `Option` containing a reference to the value if found, + /// or `None` if the key does not exist. + pub fn get(&self, key: &Q) -> Option<&SaladAny> + where + CompactString: Borrow, + Q: Hash + Eq + ?Sized, + { + self.map.get(key) + } + + /// Attempts to downcast to type `T` from a borrowed `SaladObject`. + /// + /// Returns a `Result` containing the downcasted value of type `T` if successful, + /// or a `SaladTypeDowncastError` if the downcast fails. + pub fn downcast<'de, T>(&'de self) -> Result + where + T: SaladType + de::Deserialize<'de>, + { + let deserializer = super::de::SaladObjectMapAccess::new(self); + match T::deserialize(deserializer) { + Ok(t) => Ok(t), + Err(_) => Err(self), + } + } + + /// Attempts to downcast from a consumed `SaladObject` to type `T`. + /// + /// Returns a `Result` containing the downcasted value of type `T` if successful, + /// or a `SaladTypeDowncastError` if the downcast fails. + #[inline] + pub fn downcast_into(self) -> Result + where + for<'de> T: SaladType + de::Deserialize<'de>, + { + match Self::downcast(&self) { + Ok(t) => Ok(t), + Err(_) => Err(self), + } + } +} + +impl SaladType for SaladObject {} + +impl fmt::Debug for SaladObject { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut debug_struct = f.debug_struct("SaladObject"); + for (k, v) in self.map.iter() { + debug_struct.field(k.as_str(), v); + } + debug_struct.finish() + } +} + +impl ser::Serialize for SaladObject { + fn serialize(&self, serializer: S) -> Result { + let mut map_serializer = serializer.serialize_map(Some(self.map.len()))?; + self.map + .iter() + .try_for_each(|(k, v)| map_serializer.serialize_entry(k.as_str(), v))?; + map_serializer.end() + } +} + +impl<'de> de::Deserialize<'de> for SaladObject { + fn deserialize>(deserializer: D) -> Result { + struct SaladObjectVisitor; + + impl<'de> de::Visitor<'de> for SaladObjectVisitor { + type Value = SaladObject; + + fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str("a Schema Salad key-value object") + } + + fn visit_map(self, mut serde_map: A) -> Result + where + A: de::MapAccess<'de>, + { + // Use size hint to allocate appropriately + let capacity = serde_map.size_hint().unwrap_or_default(); + let mut map = HashMap::with_capacity_and_hasher(capacity, FxBuildHasher::default()); + + // Process all key-value pairs from the input + while let Some(key) = serde_map.next_key::()? { + // Check for duplicate keys + if map.contains_key(&key) { + return Err(de::Error::custom(format_args!("duplicate field `{key}`",))); + } + + let value = serde_map.next_value::()?; + map.insert(key, value); + } + + Ok(SaladObject { map }) + } + } + + deserializer.deserialize_map(SaladObjectVisitor) + } +} diff --git a/schema_salad/rust/salad-core/crates/types/src/common/mod.rs b/schema_salad/rust/salad-core/crates/types/src/common/mod.rs new file mode 100644 index 00000000..5c22535b --- /dev/null +++ b/schema_salad/rust/salad-core/crates/types/src/common/mod.rs @@ -0,0 +1,7 @@ +mod names; +mod primitive_type; + +pub use self::{ + names::{ArrayName, EnumName, RecordName}, + primitive_type::PrimitiveType, +}; diff --git a/schema_salad/rust/salad-core/crates/types/src/common/names.rs b/schema_salad/rust/salad-core/crates/types/src/common/names.rs new file mode 100644 index 00000000..1bf294f2 --- /dev/null +++ b/schema_salad/rust/salad-core/crates/types/src/common/names.rs @@ -0,0 +1,76 @@ +macro_rules! string_match_struct { + ( + $( + $( #[$attrs:meta] )* + $ident:ident($value:literal) + ),* $(,)? + ) => { + $( + $( #[$attrs] )* + #[derive(Debug, Clone, Copy, PartialEq, Eq)] + pub struct $ident; + + impl crate::SaladType for $ident {} + + impl core::fmt::Display for $ident { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.write_str($value) + } + } + + impl serde::ser::Serialize for $ident { + #[inline] + fn serialize(&self, s: S) -> Result + where + S: serde::ser::Serializer, + { + s.serialize_str($value) + } + } + + impl<'de> serde::de::Deserialize<'de> for $ident { + fn deserialize(d: D) -> Result + where + D: serde::de::Deserializer<'de>, + { + struct NameVisitor; + + impl serde::de::Visitor<'_> for NameVisitor { + type Value = $ident; + + fn expecting(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { + f.write_str(concat!("the string `", $value, '`')) + } + + fn visit_str(self, v: &str) -> Result + where + E: serde::de::Error, + { + match v { + $value => Ok($ident), + _ => Err(serde::de::Error::invalid_value( + serde::de::Unexpected::Str(v), + &self, + )), + } + } + } + + d.deserialize_str(NameVisitor) + } + } + + )* + }; +} + +string_match_struct! { + /// Matches constant value `array`. + ArrayName("array"), + + /// Matches constant value `enum`. + EnumName("enum"), + + /// Matches constant value `record`. + RecordName("record"), +} diff --git a/schema_salad/rust/salad-core/crates/types/src/common/primitive_type.rs b/schema_salad/rust/salad-core/crates/types/src/common/primitive_type.rs new file mode 100644 index 00000000..f934d579 --- /dev/null +++ b/schema_salad/rust/salad-core/crates/types/src/common/primitive_type.rs @@ -0,0 +1,100 @@ +use std::fmt; + +use serde::{de, ser}; + +use crate::SaladType; + +/// Names of salad data primitive types (based on Avro schema declarations). +/// +/// Refer to the [Avro schema declaration documentation](https://avro.apache.org/docs/++version++/specification/#primitive-types) +/// for detailed information. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum PrimitiveType { + /// No value. + /// + /// Matches constant value `null`. + Null, + /// A binary value. + /// + /// Matches constant value `boolean`. + Boolean, + /// 32-bit signed integer. + /// + /// Matches constant value `int`. + Int, + /// 64-bit signed integer. + /// + /// Matches constant value `long`. + Long, + /// Single precision (32-bit) IEEE 754 floating-point number. + /// + /// Matches constant value `float`. + Float, + /// Double precision (64-bit) IEEE 754 floating-point number. + /// + /// Matches constant value `double`. + Double, + /// Unicode character sequence. + /// + /// Matches constant value `string`. + String, +} + +impl SaladType for PrimitiveType {} + +impl fmt::Display for PrimitiveType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str({ + match self { + Self::Null => "null", + Self::Boolean => "boolean", + Self::Int => "int", + Self::Long => "long", + Self::Float => "float", + Self::Double => "double", + Self::String => "string", + } + }) + } +} + +impl ser::Serialize for PrimitiveType { + #[inline] + fn serialize(&self, serializer: S) -> Result { + serializer.collect_str(self) + } +} + +impl<'de> de::Deserialize<'de> for PrimitiveType { + fn deserialize>(deserializer: D) -> Result { + struct PrimitiveTypeVisitor; + + impl de::Visitor<'_> for PrimitiveTypeVisitor { + type Value = PrimitiveType; + + fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str( + "any of the following strings: `null`, `boolean`, `int`, `long`, `float`, `double`, `string`" + ) + } + + fn visit_str(self, v: &str) -> Result { + match v { + "null" => Ok(PrimitiveType::Null), + "boolean" => Ok(PrimitiveType::Boolean), + "int" => Ok(PrimitiveType::Int), + "long" => Ok(PrimitiveType::Long), + "float" => Ok(PrimitiveType::Float), + "double" => Ok(PrimitiveType::Double), + "string" => Ok(PrimitiveType::String), + _ => Err(de::Error::invalid_value( + de::Unexpected::Str(v), + &self, + )), + } + } + } + + deserializer.deserialize_str(PrimitiveTypeVisitor) + } +} diff --git a/schema_salad/rust/salad-core/crates/types/src/lib.rs b/schema_salad/rust/salad-core/crates/types/src/lib.rs new file mode 100644 index 00000000..e6955f5d --- /dev/null +++ b/schema_salad/rust/salad-core/crates/types/src/lib.rs @@ -0,0 +1,123 @@ +use std::sync::Arc; + +mod any; +pub mod common; +mod primitive; +mod util; + +pub use self::{ + any::{SaladAny, SaladObject}, + primitive::{ + SaladBool, SaladDouble, SaladFloat, SaladInt, SaladLong, SaladPrimitive, SaladString, + }, +}; + +/// A marker trait for Schema Salad data types. +/// +/// This trait is implemented by all types that represent valid Schema Salad data, +/// including primitives (boolean, int, float, string), objects, and collections. +pub trait SaladType: Sized {} + +impl SaladType for Arc {} +impl SaladType for Arc<[T]> {} + +impl SaladType for Vec {} +impl SaladType for Box<[T]> {} + +#[cfg(test)] +mod tests { + use crate::{primitive, SaladAny, SaladObject}; + + #[test] + fn test_deserialize_bool() { + let yaml = "true"; + let any = serde_yaml_ng::from_str::(yaml).unwrap(); + let b = any.downcast::().unwrap(); + assert!(b); + } + + #[test] + fn test_deserialize_int() { + let yaml = "42"; + let any = serde_yaml_ng::from_str::(yaml).unwrap(); + let i = any.downcast::().unwrap(); + assert_eq!(i, 42); + } + + #[test] + fn test_deserialize_float() { + let yaml = "3.14"; + let any = serde_yaml_ng::from_str::(yaml).unwrap(); + let f = any.downcast::().unwrap(); + assert_eq!((f * 100.0).round(), 314.0); + } + + #[test] + fn test_deserialize_string() { + let yaml = "hello world"; + let any = serde_yaml_ng::from_str::(yaml).unwrap(); + let s = any.downcast::().unwrap(); + assert_eq!(s, "hello world"); + } + + #[test] + fn test_deserialize_primitive() { + let yaml = "42"; + let any = serde_yaml_ng::from_str::(yaml).unwrap(); + let primitive = any.downcast::().unwrap(); + assert_eq!(primitive.to_string(), "42"); + + let yaml = "true"; + let any = serde_yaml_ng::from_str::(yaml).unwrap(); + let primitive = any.downcast::().unwrap(); + assert_eq!(primitive.to_string(), "true"); + + let yaml = "3.14"; + let any = serde_yaml_ng::from_str::(yaml).unwrap(); + let primitive = any.downcast::().unwrap(); + assert_eq!(primitive.to_string(), "3.14"); + + let yaml = "hello"; + let any = serde_yaml_ng::from_str::(yaml).unwrap(); + let primitive = any.downcast::().unwrap(); + assert_eq!(primitive.to_string(), "hello"); + } + + #[test] + fn test_deserialize_object() { + let yaml = r#" + name: John + age: 30 + likes_pizza: true + "#; + let any = serde_yaml_ng::from_str::(yaml).unwrap(); + let obj: SaladObject = any.downcast().unwrap(); + + assert_eq!( + obj.get("name") + .unwrap() + .downcast::() + .unwrap(), + "John" + ); + assert_eq!( + obj.get("age") + .unwrap() + .downcast::() + .unwrap(), + 30 + ); + assert!(obj + .get("likes_pizza") + .unwrap() + .downcast::() + .unwrap()); + } + + #[test] + fn test_failed_downcast() { + let yaml = "42"; + let any = serde_yaml_ng::from_str::(yaml).unwrap(); + assert!(any.downcast::().is_err()); + } +} diff --git a/schema_salad/rust/salad-core/crates/types/src/primitive/mod.rs b/schema_salad/rust/salad-core/crates/types/src/primitive/mod.rs new file mode 100644 index 00000000..36cf60a8 --- /dev/null +++ b/schema_salad/rust/salad-core/crates/types/src/primitive/mod.rs @@ -0,0 +1,167 @@ +use std::fmt; + +use serde::{de, ser}; + +mod string; + +pub use self::string::SaladString; +use crate::{ + util::{FLOAT_RANGE, INT_RANGE}, + SaladType, +}; + +/// A binary value. +pub type SaladBool = bool; +impl SaladType for SaladBool {} + +/// 32-bit signed integer. +pub type SaladInt = i32; +impl SaladType for SaladInt {} + +/// 64-bit signed integer. +pub type SaladLong = i64; +impl SaladType for SaladLong {} + +/// Single precision (32-bit) IEEE 754 floating-point number. +pub type SaladFloat = f32; +impl SaladType for SaladFloat {} + +/// Double precision (64-bit) IEEE 754 floating-point number. +pub type SaladDouble = f64; +impl SaladType for SaladDouble {} + +/// Schema Salad primitives, except `null`. +#[derive(Debug, Clone, PartialEq, PartialOrd)] +pub enum SaladPrimitive { + /// A binary value. + Bool(SaladBool), + /// 32-bit signed integer. + Int(SaladInt), + /// 64-bit signed integer. + Long(SaladLong), + /// Single precision (32-bit) IEEE 754 floating-point number. + Float(SaladFloat), + /// Double precision (64-bit) IEEE 754 floating-point number. + Double(SaladDouble), + /// Unicode character sequence. + String(SaladString), +} + +impl SaladType for SaladPrimitive {} + +crate::util::impl_from_traits! { + SaladPrimitive { + Bool => SaladBool, + Int => SaladInt, + Long => SaladLong, + Float => SaladFloat, + Double => SaladDouble, + String => SaladString, + } +} + +impl fmt::Display for SaladPrimitive { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Bool(b) => fmt::Display::fmt(b, f), + Self::Int(i) => fmt::Display::fmt(i, f), + Self::Long(l) => fmt::Display::fmt(l, f), + Self::Float(fl) => fmt::Display::fmt(fl, f), + Self::Double(d) => fmt::Display::fmt(d, f), + Self::String(s) => fmt::Display::fmt(s, f), + } + } +} + +impl ser::Serialize for SaladPrimitive { + fn serialize(&self, serializer: S) -> Result { + match self { + Self::Bool(b) => serializer.serialize_bool(*b), + Self::Int(i) => serializer.serialize_i32(*i), + Self::Long(l) => serializer.serialize_i64(*l), + Self::Float(f) => serializer.serialize_f32(*f), + Self::Double(d) => serializer.serialize_f64(*d), + Self::String(s) => serializer.serialize_str(s), + } + } +} + +impl<'de> de::Deserialize<'de> for SaladPrimitive { + fn deserialize>(deserializer: D) -> Result { + struct SaladPrimitiveVisitor; + + impl de::Visitor<'_> for SaladPrimitiveVisitor { + type Value = SaladPrimitive; + + fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str("any of the salad primitives") + } + fn visit_bool(self, v: bool) -> Result { + Ok(SaladPrimitive::Bool(v)) + } + + fn visit_i8(self, v: i8) -> Result { + Ok(SaladPrimitive::Int(v as i32)) + } + + fn visit_i16(self, v: i16) -> Result { + Ok(SaladPrimitive::Int(v as i32)) + } + + fn visit_i32(self, v: i32) -> Result { + Ok(SaladPrimitive::Int(v)) + } + + fn visit_i64(self, v: i64) -> Result { + match v { + l if INT_RANGE.contains(&l) => Ok(SaladPrimitive::Int(v as i32)), + _ => Ok(SaladPrimitive::Long(v)), + } + } + + fn visit_u8(self, v: u8) -> Result { + Ok(SaladPrimitive::Int(v as i32)) + } + + fn visit_u16(self, v: u16) -> Result { + Ok(SaladPrimitive::Int(v as i32)) + } + + fn visit_u64(self, v: u64) -> Result { + match v { + u if u <= i32::MAX as u64 => Ok(SaladPrimitive::Int(v as i32)), + u if u <= i64::MAX as u64 => Ok(SaladPrimitive::Long(v as i64)), + _ => Err(de::Error::invalid_value(de::Unexpected::Unsigned(v), &self)), + } + } + + fn visit_f32(self, v: f32) -> Result { + Ok(SaladPrimitive::Float(v)) + } + + fn visit_f64(self, v: f64) -> Result { + match v { + d if FLOAT_RANGE.contains(&d) => Ok(SaladPrimitive::Float(v as f32)), + _ => Ok(SaladPrimitive::Double(v)), + } + } + + fn visit_str(self, v: &str) -> Result { + Ok(SaladPrimitive::String(v.into())) + } + + fn visit_string(self, v: String) -> Result { + Ok(SaladPrimitive::String(v.into())) + } + + fn visit_bytes(self, v: &[u8]) -> Result { + match std::str::from_utf8(v) { + Ok(s) => Ok(SaladPrimitive::String(s.into())), + Err(_) => Err(de::Error::invalid_value(de::Unexpected::Bytes(v), &self)), + } + } + } + + deserializer.deserialize_any(SaladPrimitiveVisitor) + } +} diff --git a/schema_salad/rust/salad-core/crates/types/src/primitive/string.rs b/schema_salad/rust/salad-core/crates/types/src/primitive/string.rs new file mode 100644 index 00000000..3a6cc477 --- /dev/null +++ b/schema_salad/rust/salad-core/crates/types/src/primitive/string.rs @@ -0,0 +1,249 @@ +use std::{ + borrow::Borrow, + cmp::Ordering, + fmt, + hash::{Hash, Hasher}, + ops::Deref, + str::FromStr, +}; + +use compact_str::CompactString; +use serde::{de, ser}; + +use crate::SaladType; + +/// Unicode character sequence. +#[repr(transparent)] +#[derive(Clone, Default)] +pub struct SaladString(CompactString); + +impl SaladString { + #[inline] + #[must_use] + pub fn as_str(&self) -> &str { + self.0.as_str() + } + + #[inline] + #[must_use] + pub fn len(&self) -> usize { + self.0.len() + } + + #[inline] + #[must_use] + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } +} + +impl SaladType for SaladString {} + +impl From for SaladString { + fn from(value: String) -> Self { + Self(CompactString::from(value)) + } +} + +impl<'a> From<&'a String> for SaladString { + fn from(value: &'a String) -> Self { + Self(CompactString::new(value)) + } +} + +impl From for String { + fn from(value: SaladString) -> Self { + String::from(value.0) + } +} + +impl<'a> From<&'a str> for SaladString { + fn from(value: &'a str) -> Self { + Self(CompactString::new(value)) + } +} + +impl From> for SaladString { + fn from(value: Box) -> Self { + Self(CompactString::from(value)) + } +} + +impl From for Box { + fn from(value: SaladString) -> Self { + Box::::from(value.0) + } +} + +impl FromStr for SaladString { + type Err = std::convert::Infallible; + + #[inline] + fn from_str(value: &str) -> Result { + Ok(Self::from(value)) + } +} + +impl FromIterator for SaladString { + fn from_iter>(iter: T) -> Self { + Self(CompactString::from_iter(iter)) + } +} + +impl<'a> FromIterator<&'a char> for SaladString { + fn from_iter>(iter: T) -> Self { + Self(CompactString::from_iter(iter)) + } +} + +impl Extend for SaladString { + fn extend>(&mut self, iter: T) { + self.0.extend(iter); + } +} + +impl<'a> Extend<&'a char> for SaladString { + fn extend>(&mut self, iter: T) { + self.0.extend(iter); + } +} + +impl AsRef for SaladString { + #[inline] + fn as_ref(&self) -> &str { + self.as_str() + } +} + +impl AsRef<[u8]> for SaladString { + #[inline] + fn as_ref(&self) -> &[u8] { + self.as_bytes() + } +} + +impl Borrow for SaladString { + #[inline] + fn borrow(&self) -> &str { + self.as_str() + } +} + +impl Eq for SaladString {} + +impl + ?Sized> PartialEq for SaladString { + fn eq(&self, other: &T) -> bool { + self.0.as_str() == other.as_ref() + } +} + +impl PartialEq for &SaladString { + fn eq(&self, other: &SaladString) -> bool { + self.0.as_str() == other.0.as_str() + } +} + +impl PartialEq for String { + fn eq(&self, other: &SaladString) -> bool { + self.as_str() == other.0.as_str() + } +} + +impl<'a> PartialEq<&'a SaladString> for String { + fn eq(&self, other: &&'a SaladString) -> bool { + self.as_str() == other.0.as_str() + } +} + +impl PartialEq for &String { + fn eq(&self, other: &SaladString) -> bool { + self.as_str() == other.0.as_str() + } +} + +impl PartialEq for &SaladString { + fn eq(&self, other: &String) -> bool { + self.0.as_str() == other.as_str() + } +} + +impl PartialEq for str { + fn eq(&self, other: &SaladString) -> bool { + self == other.0.as_str() + } +} + +impl<'a> PartialEq<&'a SaladString> for str { + fn eq(&self, other: &&'a SaladString) -> bool { + self == other.0.as_str() + } +} + +impl PartialEq for &str { + fn eq(&self, other: &SaladString) -> bool { + *self == other.0.as_str() + } +} + +impl PartialEq for &&str { + fn eq(&self, other: &SaladString) -> bool { + **self == other.0.as_str() + } +} + +impl Ord for SaladString { + fn cmp(&self, other: &Self) -> Ordering { + self.0.cmp(&other.0) + } +} + +impl PartialOrd for SaladString { + #[inline] + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Hash for SaladString { + #[inline] + fn hash(&self, state: &mut H) { + self.0.hash(state); + } +} + +impl Deref for SaladString { + type Target = str; + + #[inline] + fn deref(&self) -> &Self::Target { + self.as_str() + } +} + +impl fmt::Debug for SaladString { + #[inline] + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&self.0, f) + } +} + +impl fmt::Display for SaladString { + #[inline] + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&self.0, f) + } +} + +impl ser::Serialize for SaladString { + #[inline] + fn serialize(&self, serializer: S) -> Result { + CompactString::serialize(&self.0, serializer) + } +} + +impl<'de> de::Deserialize<'de> for SaladString { + #[inline] + fn deserialize>(deserializer: D) -> Result { + CompactString::deserialize(deserializer).map(Self) + } +} diff --git a/schema_salad/rust/salad-core/crates/types/src/util/macros.rs b/schema_salad/rust/salad-core/crates/types/src/util/macros.rs new file mode 100644 index 00000000..6e1a0c4b --- /dev/null +++ b/schema_salad/rust/salad-core/crates/types/src/util/macros.rs @@ -0,0 +1,40 @@ +macro_rules! impl_from_traits { + ( + $ty:ident { + $( $ident:ident => $subty:ident ),* $(,)? + } + ) => { + $( + impl From<$subty> for $ty { + #[inline] + fn from(value: $subty) -> Self { + Self::$ident(value) + } + } + + impl TryFrom<$ty> for $subty { + type Error = $ty; + + fn try_from(value: $ty) -> Result { + match value { + $ty::$ident(v) => Ok(v), + _ => Err(value), + } + } + } + + impl<'a> TryFrom<&'a $ty> for &'a $subty { + type Error = &'a $ty; + + fn try_from(value: &'a $ty) -> Result { + match value { + $ty::$ident(v) => Ok(v), + _ => Err(value), + } + } + } + )* + }; +} + +pub(crate) use impl_from_traits; diff --git a/schema_salad/rust/salad-core/crates/types/src/util/mod.rs b/schema_salad/rust/salad-core/crates/types/src/util/mod.rs new file mode 100644 index 00000000..15213e24 --- /dev/null +++ b/schema_salad/rust/salad-core/crates/types/src/util/mod.rs @@ -0,0 +1,16 @@ +use std::ops::RangeInclusive; + +mod macros; + +pub(crate) use self::macros::impl_from_traits; +use crate::primitive::{SaladDouble, SaladFloat, SaladInt, SaladLong}; + +/// Range representing the minimum and maximum [SaladLong] +/// values that can be stored in a [SaladInt]. +pub(crate) const INT_RANGE: RangeInclusive = + (SaladInt::MIN as SaladLong)..=(SaladInt::MAX as SaladLong); + +/// Range representing the minimum and maximum [SaladDouble] +/// values that can be stored in a [SaladFloat]. +pub(crate) const FLOAT_RANGE: RangeInclusive = + (SaladFloat::MIN as SaladDouble)..=(SaladFloat::MAX as SaladDouble); diff --git a/schema_salad/rust/salad-core/src/lib.rs b/schema_salad/rust/salad-core/src/lib.rs new file mode 100644 index 00000000..760400cb --- /dev/null +++ b/schema_salad/rust/salad-core/src/lib.rs @@ -0,0 +1,2 @@ +// WIP + diff --git a/schema_salad/rust/src/lib.rs b/schema_salad/rust/src/lib.rs new file mode 100644 index 00000000..30900056 --- /dev/null +++ b/schema_salad/rust/src/lib.rs @@ -0,0 +1,2 @@ +// Generated code + diff --git a/schema_salad/rust_codegen.py b/schema_salad/rust_codegen.py new file mode 100644 index 00000000..e782035f --- /dev/null +++ b/schema_salad/rust_codegen.py @@ -0,0 +1,1132 @@ +"""Rust code generator for schema salad definitions.""" + +__all__ = ["RustCodeGen"] + +import dataclasses +import functools +import itertools +import json +import re +import shutil +import sys +from abc import ABC, abstractmethod +from collections.abc import Iterator, MutableMapping, MutableSequence, Sequence +from importlib.resources import files as resource_files +from io import StringIO +from pathlib import Path +from time import sleep +from typing import ( + Any, + ClassVar, + Optional, + TextIO, + Union, +) + +from . import _logger +from .avro.schema import ( + ArraySchema, + EnumSchema, + JsonDataType, + NamedSchema, + NamedUnionSchema, + PrimitiveSchema, + RecordSchema, + Schema, + UnionSchema, + Field as SaladField, + Names as SaladNames, + make_avsc_object, +) +from .codegen_base import CodeGenBase +from .schema import make_valid_avro +from .validate import avro_shortname + + +def dataclass(*args, **kwargs): + """ + A wrapper around `@dataclass` attribute that automatically enables + `slots` if Python version >= 3.10. + """ + if sys.version_info >= (3, 10): + return dataclasses.dataclass(*args, slots=True, **kwargs) + return dataclasses.dataclass(*args, **kwargs) + + +# +# Util Functions +# + +__RUST_RESERVED_WORDS = [ + "type", "self", "let", "fn", "struct", "impl", "trait", "enum", "pub", + "mut", "true", "false", "return", "match", "if", "else", "for", "in", + "where", "ref", "use", "mod", "const", "static", "as", "move", "async", + "await", "dyn", "loop", "break", "continue", "super", "crate", "unsafe", + "extern", "box", "virtual", "override", "macro", "while", "yield", + "typeof", "sizeof", "final", "pure", "abstract", "become", "do", + "alignof", "offsetof", "priv", "proc", "unsized", +] # fmt: skip + +# __FIELD_NAME_REX_DICT = [ +# (re.compile(r"(?<=[a-z0-9])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])"), "_"), +# (re.compile(r"([\W_]$)|\W"), lambda m: "" if m.group(1) else "_"), +# (re.compile(r"^([0-9])"), lambda m: f"_{m.group(1)}"), +# ] +__TYPE_NAME_REX_DICT = [ + (re.compile(r"(?:^|[^a-zA-Z0-9.])(\w)"), lambda m: m.group(1).upper()), + (re.compile(r"\.([a-zA-Z])"), lambda m: m.group(1).upper()), + (re.compile(r"(?:^|\.)([0-9])"), lambda m: f"_{m.group(1)}"), +] +__MD_NON_HYPERLINK_REX = re.compile( + r"(?\"])" +) + + +# TODO Check strings for Unicode standard for `XID_Start` and `XID_Continue` +# @functools.cache +def rust_sanitize_field_ident(value: str) -> str: + """ + Checks whether the field name is a Rust reserved world, or escapes it. + """ + # value = functools.reduce(lambda s, r: re.sub(*r, s), __FIELD_NAME_REX_DICT, value) + # value = value.lower() + if value in __RUST_RESERVED_WORDS: + return f"r#{value}" + return value + + +# TODO Check strings for Unicode standard for `XID_Start` and `XID_Continue` +@functools.cache +def rust_sanitize_type_ident(value: str) -> str: + """ + Converts an input string into a valid Rust type name (PascalCase). + Results are cached for performance optimization. + """ + return functools.reduce(lambda s, r: re.sub(*r, s), __TYPE_NAME_REX_DICT, value) + + +def rust_sanitize_doc_iter(value: Union[list[str], str]) -> Iterator[str]: + """ + Sanitizes Markdown doc-strings by splitting lines and wrapping non-hyperlinked + URLs in angle brackets. + """ + return map( + lambda v: re.sub(__MD_NON_HYPERLINK_REX, lambda m: f"<{m.group()}>", v), + itertools.chain.from_iterable(map( # flat_map + lambda v: v.rstrip().split("\n"), + [value] if isinstance(value, str) else value, + )), + ) # fmt: skip + + +@functools.cache +def to_rust_literal(value: Any) -> str: + """ + Convert Python values to their equivalent Rust literal representation. + Results are cached for performance optimization. + """ + if isinstance(value, bool): + return str(value).lower() + if isinstance(value, (int, float)): + return str(value) + if isinstance(value, str): + if value.startswith('"') and value.endswith('"'): + value = value[1:-1] + return json.dumps(value, ensure_ascii=False) + if isinstance(value, (list, tuple)): + list_entries = ", ".join(map(to_rust_literal, value)) + return f"[{list_entries}]" + if value is None: + return "Option::None" + raise TypeError(f"Unsupported type for Rust literal conversion: {type(value).__name__}") + + +def make_avro(items: MutableSequence[JsonDataType]) -> MutableSequence[NamedSchema]: + """ + Processes a list of dictionaries to generate a list of Avro schemas. + """ + + # Same as `from .utils import convert_to_dict`, which, however, is not public + def convert_to_dict(j4: Any) -> Any: + """Convert generic Mapping objects to dicts recursively.""" + if isinstance(j4, MutableMapping): + return {k: convert_to_dict(v) for k, v in j4.items()} + if isinstance(j4, MutableSequence): + return list(map(convert_to_dict, j4)) + return j4 + + name_dict = {entry["name"]: entry for entry in items} + avro = make_valid_avro(items, name_dict, set()) + avro = [ + t + for t in avro + if isinstance(t, MutableMapping) + and not t.get("abstract") + and t.get("type") != "org.w3id.cwl.salad.documentation" + ] + + names = SaladNames() + make_avsc_object(convert_to_dict(avro), names) + return list(names.names.values()) + + +# +# Rust AST Nodes +# + + +# ASSERT: The string is a valid Rust identifier. +RustIdent = str # alias + + +@dataclass # ASSERT: Immutable class +class RustLifetime: + """ + Represents a Rust lifetime parameter (e.g., `'a`). + """ + + ident: RustIdent + + def __hash__(self) -> int: + return hash(self.ident) + + def __str__(self) -> str: + return f"'{str(self.ident)}" + + +class RustType(ABC): + """ + Abstract class for Rust types. + """ + + pass + + +class RustMeta(ABC): + """ + Abstract class for Rust attribute metas. + """ + + pass + + +@dataclass(unsafe_hash=True) # ASSERT: Immutable class +class RustAttribute: + """ + Represents a Rust attribute (e.g., `#[derive(Debug)]`). + """ + + meta: RustMeta + + def __str__(self) -> str: + return f"#[{str(self.meta)}]" + + +RustAttributes = Sequence[RustAttribute] # alias +RustAttributesMut = MutableSequence[RustAttribute] # alias + + +RustGenerics = Sequence[Union[RustLifetime, RustType]] # alias +RustGenericsMut = MutableSequence[Union[RustLifetime, RustType]] # alias + + +@dataclass(unsafe_hash=True) # ASSERT: Immutable class +class RustPathSegment: + """ + Represents a segment in a Rust path with optional generics. + """ + + ident: RustIdent + generics: RustGenerics = dataclasses.field(default_factory=tuple) + + REX: ClassVar[re.Pattern] = re.compile(r"^([a-zA-Z_]\w*)(?:<([ \w\t,'<>]+)>)?$") + + def __str__(self) -> str: + if not self.generics: + return self.ident + generics = sorted(self.generics, key=lambda r: 0 if isinstance(r, RustLifetime) else 1) + generics_str = ", ".join(map(str, generics)) + return f"{self.ident}<{generics_str}>" + + # noinspection PyArgumentList + @classmethod + @functools.cache + def from_str(cls, value: str) -> "RustPathSegment": + """ + Parses a string into RustPathSegment class. + Results are cached for performance optimization. + """ + + def parse_generics_string(value_generics: str) -> RustGenerics: + generics_sequence: Union[MutableSequence[str], RustGenerics] = [] + current, deep = [], 0 + for idx, char in enumerate(value_generics): + deep += (char == "<") - (char == ">") + if deep == 0 and char == ",": + generics_sequence.append("".join(current).strip()) + current = [] + elif deep < 0: + raise ValueError(f"Poorly formatted Rust path generics: '{value}'.") + else: + current.append(char) + if deep > 0: + raise ValueError(f"Poorly formatted Rust path generics: '{value}'.") + generics_sequence.append("".join(current).strip()) + return tuple([ + RustLifetime(g[1:]) if g[0] == "'" else RustPath.from_str(g) + for g in generics_sequence + ]) # fmt: skip + + # + # `from_str(...)` method + if match := re.match(RustPathSegment.REX, value): + ident, generics = match.groups() + return cls(ident, parse_generics_string(generics) if generics else tuple()) + raise ValueError(f"Poorly formatted Rust path segment: '{value}'.") + + +RustPathSegments = Sequence[RustPathSegment] # alias +RustPathSegmentsMut = MutableSequence[RustPathSegment] # alias + + +@dataclass(unsafe_hash=True) # ASSERT: Immutable class +class RustPath(RustType, RustMeta): + """ + Represents a complete Rust path (e.g., `::std::vec::Vec`). + """ + + # ASSERT: Never initialized with an empty sequence + segments: RustPathSegments + leading_colon: bool = False + + def __truediv__(self, other: Union["RustPath", RustPathSegment]) -> "RustPath": + if isinstance(other, RustPath): + if self.segments[-1].generics: + raise ValueError("Cannot chain to a RustPath with generics.") + if other.leading_colon: + raise ValueError("Cannot chain a RustPath with leading colon.") + return RustPath( + segments=tuple([*self.segments, *other.segments]), + leading_colon=self.leading_colon, + ) + if isinstance(other, RustPathSegment): + if self.segments[-1].generics: + raise ValueError("Cannot chain to a RustPath with generics.") + return RustPath( + segments=tuple([*self.segments, other]), + leading_colon=self.leading_colon, + ) + raise TypeError(f"RustPath chaining with type `{type(other).__name__}` not supported.") + + def __str__(self) -> str: + leading_colon = "::" if self.leading_colon else "" + path_str = "::".join(map(str, self.segments)) + return leading_colon + path_str + + # noinspection PyArgumentList + @classmethod + @functools.cache + def from_str(cls, value: str) -> "RustPath": + """ + Parses a string into RustPath class. + Results are cached for performance optimization. + """ + norm_value, leading_colon = (value[2:], True) if value.startswith("::") else (value, False) + segments, segment_with_generics = [], 0 + for segment in map(RustPathSegment.from_str, norm_value.split("::")): + if len(segment.generics): + segment_with_generics += 1 + segments.append(segment) + if segment_with_generics > 1: + raise ValueError(f"Poorly formatted Rust path: '{value}'") + return cls(tuple(segments), leading_colon) + + # def parent(self) -> "RustPath": + # """ + # Returns a new RustPath containing all but the last segment. + # """ + # return RustPath( + # segments=self.segments[:-1], + # leading_colon=self.leading_colon, + # ) + + +@dataclass(unsafe_hash=True) # ASSERT: Immutable class +class RustTypeTuple(RustType): + """ + Represents a Rust tuple type (e.g., `(T, U)`). + """ + + # ASSERT: Never initialized with an empty sequence + types: Sequence[RustType] + + def __str__(self) -> str: + types_str = ", ".join(str(ty) for ty in self.types) + return f"({types_str})" + + +@dataclass # ASSERT: Immutable class +class RustMetaList(RustMeta): + """ + Represents attribute meta list information (e.g., `derive(Debug, Clone)`) + """ + + path: RustPath + metas: Sequence[RustMeta] = tuple() + + def __hash__(self) -> int: + return hash(self.path) + + def __str__(self) -> str: + meta_str = ", ".join(str(meta) for meta in self.metas) + return f"{str(self.path)}(" + meta_str + ")" + + +@dataclass # ASSERT: Immutable class +class RustMetaNameValue(RustMeta): + """ + Represents attribute meta name-value information (e.g., `key = value`) + """ + + path: RustPath + value: Any = True + + def __hash__(self) -> int: + return hash(self.path) + + def __str__(self) -> str: + return f"{str(self.path)} = {to_rust_literal(self.value)}" + + +# +# Rust Type Representations +# + + +@dataclass +class RustNamedType(ABC): # ABC class + """ + Abstract class for Rust struct and enum types. + """ + + ident: RustIdent + attrs: RustAttributes = dataclasses.field(default_factory=list) + visibility: str = "pub" + + def __hash__(self) -> int: + return hash(self.ident) + + @abstractmethod + def write_to(self, writer: TextIO, depth: int = 0) -> None: + pass + + def __str__(self) -> str: + output = StringIO() + self.write_to(output, 0) + return output.getvalue() + + +@dataclass # ASSERT: Immutable class +class RustField: + """ + Represents a field in a Rust struct. + """ + + ident: RustIdent + type: RustType + attrs: RustAttributes = dataclasses.field(default_factory=list) + + def __hash__(self) -> int: + return hash(self.ident) + + def write_to(self, writer: TextIO, depth: int = 0) -> None: + indent = " " * depth + + if self.attrs: + writer.write("\n".join(f"{indent}{str(attr)}" for attr in self.attrs) + "\n") + writer.write(f"{indent}{self.ident}: {str(self.type)}") + + +RustFields = Union[Sequence[RustField], RustTypeTuple] # alias +RustFieldsMut = Union[MutableSequence[RustField], RustTypeTuple] # alias + + +@dataclass +class RustStruct(RustNamedType): + """ + Represents a Rust struct definition. + """ + + fields: Optional[RustFields] = None + + def write_to(self, writer: TextIO, depth: int = 0) -> None: + indent = " " * depth + + if self.attrs: + writer.write("\n".join(f"{indent}{str(attr)}" for attr in self.attrs) + "\n") + + writer.write(f"{indent}{self.visibility} struct {self.ident}") + if self.fields is None: + writer.write(";\n") + elif isinstance(self.fields, RustTypeTuple): + writer.write(f"{str(self.fields)};\n") + else: + writer.write(" {\n") + for field_ in self.fields: + field_.write_to(writer, depth + 1) + writer.write(",\n") + writer.write(f"{indent}}}\n") + + +@dataclass # ASSERT: Immutable class +class RustVariant: + """ + Represents a variant in a Rust enum. + """ + + ident: RustIdent + tuple: Optional[RustTypeTuple] = None + attrs: RustAttributes = dataclasses.field(default_factory=list) + + def __hash__(self) -> int: + return hash(self.ident) + + def write_to(self, writer: TextIO, depth: int = 0) -> None: + indent = " " * depth + + if self.attrs: + writer.write("\n".join(f"{indent}{str(attr)}" for attr in self.attrs) + "\n") + + writer.write(f"{indent}{self.ident}") + if self.tuple: + writer.write(str(self.tuple)) + + # noinspection PyArgumentList + @classmethod + def from_path(cls, path: RustPath) -> "RustVariant": + ident = "".join( + map( + lambda p: p.segments[-1].ident, + itertools.chain( + filter(lambda g: isinstance(g, RustPath), path.segments[-1].generics), + itertools.repeat(path, 1), + ), + ) + ) + ident = ident.replace("StrValue", "String", 1) # HACK + return cls(ident=ident, tuple=RustTypeTuple([path])) + + +RustVariants = Sequence[RustVariant] # alias +RustVariantsMut = MutableSequence[RustVariant] # alias + + +@dataclass +class RustEnum(RustNamedType): + """ + Represents a Rust enum definition. + """ + + variants: RustVariants = dataclasses.field(default_factory=tuple) + + def write_to(self, writer: TextIO, depth: int = 0) -> None: + indent = " " * depth + + if self.attrs: + writer.write("\n".join(f"{indent}{str(attr)}" for attr in self.attrs) + "\n") + + writer.write(f"{indent}{self.visibility} enum {self.ident} {{\n") + for variant in self.variants: + variant.write_to(writer, depth + 1) + writer.write(",\n") + writer.write(f"{indent}}}\n") + + +# Wrapper for the RustNamedType `write_to()` method call +def salad_macro_write_to(ty: RustNamedType, writer: TextIO, depth: int = 0) -> None: + """ + Writes a RustNamedType wrapping it in the Schema Salad macro + """ + indent = " " * depth + writer.write(indent + "salad_core::define_type! {\n") + ty.write_to(writer, 1) + writer.write(indent + "}\n\n") + + +# +# Rust Module Tree +# + + +@dataclass +class RustModuleTree: + """ + Represents a Rust module with submodules and named types + """ + + ident: RustIdent # ASSERT: Immutable field + parent: "RustModuleTree" # ASSERT: Immutable field + named_types: MutableMapping[RustIdent, RustNamedType] = dataclasses.field(default_factory=dict) + submodules: MutableMapping[RustIdent, "RustModuleTree"] = dataclasses.field( + default_factory=dict + ) + + def __hash__(self) -> int: + return hash((self.ident, self.parent)) + + def get_rust_path(self) -> RustPath: + """ + Returns the complete Rust path from root to this module. + """ + segments, current = [], self + while current: + segments.append(RustPathSegment(ident=current.ident)) + current = current.parent + return RustPath(segments=tuple(reversed(segments))) + + def add_submodule(self, path: Union[RustPath, str]) -> "RustModuleTree": + """ + Creates a new submodule or returns an existing one with the given path. + """ + if isinstance(path, str): + path = RustPath.from_str(path) + segments = iter(path.segments) + + # First segment, corner case + if (first := next(segments, None)) is None: + return self + + if first.ident == self.ident: + current = self + else: + current = self.submodules.setdefault( + first.ident, + RustModuleTree(ident=first.ident, parent=self), + ) + + # Subsequent segments + for segment in segments: + current = current.submodules.setdefault( + segment.ident, + RustModuleTree(ident=segment.ident, parent=current), + ) + return current + + # def get_submodule(self, path: Union[RustPath, str]) -> Optional["RustModuleTree"]: + # """ + # Returns a submodule from this module tree by its Rust path, if any. + # """ + # if isinstance(path, str): + # path = RustPath.from_str(path) + # current, last_segment_idx = self, len(path.segments) - 1 + # for idx, segment in enumerate(path.segments): + # if (idx == last_segment_idx) and (current.ident == segment.ident): + # return current + # current = current.submodules.get(segment.ident) + # if not current: + # return None + # return None + + def add_named_type(self, ty: RustNamedType) -> RustPath: + """ + Adds a named type to this module tree and returns its complete Rust path. + Raises `ValueError` if type with same name already exists + """ + module_rust_path = self.get_rust_path() + if ty.ident in self.named_types: + raise ValueError(f"Duplicate Rust type '{ty.ident}' in '{module_rust_path}'.") + self.named_types[ty.ident] = ty + return module_rust_path / RustPathSegment(ident=ty.ident) + + # def get_named_type(self, path: RustPath) -> Optional[RustNamedType]: + # if module := self.get_submodule(path.parent()): + # return module.named_types.get(path.segments[-1].ident) + # return None + + def write_to_fs(self, base_path: Path) -> None: + """ + Writes the module tree to the filesystem under the given base path. + """ + + # noinspection PyShadowingNames + def write_module_file(module: "RustModuleTree", path: Path, mode: str = "wt") -> None: + with open(path, mode=mode) as module_rs: + if module.submodules: + module_rs.write( + "\n".join([f"mod {mod.ident};" for mod in module.submodules.values()]) + + "\n\n" + ) + if module.named_types: + for ty in module.named_types.values(): + salad_macro_write_to(ty, module_rs, 0) + + # + # `write_to_fs(...)` method + path = base_path.resolve() + traversing_stack: MutableSequence[tuple[Path, RustModuleTree]] = [] + + # Write `lib.rs` module (corner case) + if not self.parent: + path.mkdir(mode=0o755, parents=True, exist_ok=True) + write_module_file(module=self, path=path / "lib.rs", mode="at") + traversing_stack.extend((path, sub_mod) for sub_mod in self.submodules.values()) + else: + traversing_stack.append((path, self)) + + # Generate module files + while traversing_stack: + path_parent, module = traversing_stack.pop() + + if not module.submodules: + path_parent.mkdir(mode=0o755, parents=True, exist_ok=True) + write_module_file(module=module, path=path_parent / f"{module.ident}.rs") + continue + + path_module = path_parent / module.ident + path_module.mkdir(mode=0o755, parents=True, exist_ok=True) + write_module_file(module=module, path=path_module / "mod.rs") + traversing_stack.extend( + (path_module, sub_mod) for sub_mod in module.submodules.values() + ) + + +# +# Salad Core Types +# + + +def rust_type_option(rust_ty: RustPath) -> RustPath: + # noinspection PyArgumentList + return RustPath([RustPathSegment(ident="Option", generics=[rust_ty])]) + + +def rust_type_list(rust_ty: RustPath) -> RustPath: + # noinspection PyArgumentList + return RustPath([ + RustPathSegment(ident="crate"), + RustPathSegment(ident="core"), + RustPathSegment(ident="List", generics=[rust_ty]), + ]) # fmt: skip + + +_AVRO_TO_RUST_PRESET = { + # Salad Types + "boolean": RustPath.from_str("crate::core::Bool"), + "int": RustPath.from_str("crate::core::Int"), + "long": RustPath.from_str("crate::core::Long"), + "float": RustPath.from_str("crate::core::Float"), + "double": RustPath.from_str("crate::core::Double"), + "string": RustPath.from_str("crate::core::StrValue"), + "org.w3id.cwl.salad.Any": RustPath.from_str("crate::core::Any"), + "org.w3id.cwl.salad.ArraySchema.type.Array_name": RustPath.from_str("crate::TypeArray"), + "org.w3id.cwl.salad.EnumSchema.type.Enum_name": RustPath.from_str("crate::TypeEnum"), + "org.w3id.cwl.salad.RecordSchema.type.Record_name": RustPath.from_str("crate::TypeRecord"), + # CWL Types + "org.w3id.cwl.cwl.Expression": RustPath.from_str("crate::core::StrValue"), +} + + +# +# Code generator +# + + +class RustCodeGen(CodeGenBase): + """ + Rust code generator for schema salad definitions. + """ + + # Static + PACKAGE_VERSION = "0.1.0" # Version of the generated crate + __TEMPLATE_DIR = Path(str(resource_files("schema_salad").joinpath("rust"))).resolve() + + # Parsing related + __avro_to_rust: MutableMapping[str, RustPath] + __document_root_paths: MutableSequence[RustPath] + __module_tree: RustModuleTree + __schema_stack: MutableSequence[NamedSchema] + + # noinspection PyMissingConstructor + def __init__( + self, + base_uri: str, + package: str, + salad_version: str, + target: Optional[str] = None, + ) -> None: + self.package = package + self.PACKAGE_VERSION = self.__generate_crate_version(salad_version) + self.output_dir = Path(target or ".").resolve() + self.document_root_attr = RustAttribute( + meta=RustMetaList( + path=RustPath.from_str("salad"), + metas=[ + RustPath.from_str("root"), + RustMetaNameValue( + path=RustPath.from_str("base_uri"), + value=base_uri, + ), + ], + ) + ) + + def parse(self, items: MutableSequence[JsonDataType]) -> None: + # Create output directory + self.__init_output_directory() + + # Generate Rust named types + self.__avro_to_rust = _AVRO_TO_RUST_PRESET.copy() + self.__document_root_paths = [] + self.__module_tree = RustModuleTree(ident="crate", parent=None) + self.__schema_stack = list(reversed(make_avro(items))) + + while self.__schema_stack: + schema = self.__schema_stack.pop() + + if not schema.name.startswith(self.package): + continue + if schema.name in self.__avro_to_rust: + _logger.warn(f"Skip parse step for schema: {schema.name}") + continue + + rust_path = self.__parse_named_schema(schema) + self.__avro_to_rust[schema.name] = rust_path + + # Generate `DocumentRoot` enum + self.__module_tree.add_named_type( + RustEnum( + ident="DocumentRoot", + attrs=[self.document_root_attr], + variants=list(map(RustVariant.from_path, self.__document_root_paths)), + ) + ) + + # Write named types to the "src" folder + self.__module_tree.write_to_fs(self.output_dir / "src") + + def __parse_named_schema(self, named: NamedSchema) -> RustPath: + if isinstance(named, RecordSchema): + return self.__parse_record_schema(named) + if isinstance(named, EnumSchema): + return self.__parse_enum_schema(named) + if isinstance(named, NamedUnionSchema): + return self.__parse_union_schema(named) + raise ValueError(f"Cannot parse schema of type {type(named).__name__}.") + + def __parse_record_schema(self, record: RecordSchema) -> RustPath: + ident = rust_sanitize_type_ident(avro_shortname(record.name)) + attrs, _ = self.__parse_named_schema_attrs(record) + fields = set(self.__parse_record_field(f, record) for f in record.fields) + + if record.get_prop("documentRoot"): + attrs = [*attrs, self.document_root_attr] + + rust_path = self.__module_tree \ + .add_submodule(self.__get_submodule_path(record)) \ + .add_named_type(RustStruct(ident=ident, attrs=attrs, fields=fields)) # fmt: skip + + if record.get_prop("documentRoot"): + self.__document_root_paths.append(rust_path) + return rust_path + + def __parse_record_field(self, field: SaladField, parent: RecordSchema) -> RustField: + def parse_field_type(schema: Schema) -> RustPath: + if isinstance(schema, UnionSchema): + filtered_schemas = [s for s in schema.schemas if s.type != "null"] + filtered_schemas_len = len(filtered_schemas) + + if filtered_schemas_len == 1: + rust_path = parse_field_type(filtered_schemas[0]) + if filtered_schemas_len < len(schema.schemas): + return rust_type_option(rust_path) + return rust_path + + union_name = f"{parent.name}.{field.name}" + if rust_path := self.__avro_to_rust.get(union_name): + if filtered_schemas_len < len(schema.schemas): + return rust_type_option(rust_path) + return rust_path + + named_union_schema = NamedUnionSchema.__new__(NamedUnionSchema) + setattr(named_union_schema, "_props", getattr(schema, "_props")) + setattr(named_union_schema, "_schemas", filtered_schemas) + named_union_schema.set_prop("name", union_name) + named_union_schema.set_prop("namespace", parent.name) + named_union_schema.set_prop("doc", field.get_prop("doc")) + + self.__schema_stack.append(named_union_schema) + type_path = self.__get_submodule_path(named_union_schema) / RustPathSegment( + rust_sanitize_type_ident(avro_shortname(union_name)) + ) + if filtered_schemas_len < len(schema.schemas): + return rust_type_option(type_path) + return type_path + + if isinstance(schema, (RecordSchema, EnumSchema)): + return self.__avro_to_rust.get( + schema.name, + self.__get_submodule_path(schema) + / RustPathSegment(ident=rust_sanitize_type_ident(avro_shortname(schema.name))), + ) + + if isinstance(schema, ArraySchema): + return rust_type_list(parse_field_type(schema.items)) + + if isinstance(schema, PrimitiveSchema): + return self.__avro_to_rust.get(schema.type) + + raise ValueError(f"Cannot parse schema with type: '{type(schema).__name__}'.") + + # + # `__parse_record_field(...)` method + ident = rust_sanitize_field_ident(field.name) + attrs, _ = self.__parse_field_schema_attrs(field) + ty = parse_field_type(field.type) + return RustField(ident=ident, attrs=attrs, type=ty) + + def __parse_union_schema(self, union: NamedUnionSchema) -> RustPath: + def parse_variant_array_subtype(schema: Schema) -> RustPath: + if isinstance(schema, UnionSchema): + filtered_schemas = [s for s in schema.schemas if s.type != "null"] + + item_name = f"{union.name}_item" + named_union_schema = NamedUnionSchema.__new__(NamedUnionSchema) + setattr(named_union_schema, "_props", getattr(schema, "_props")) + setattr(named_union_schema, "_schemas", filtered_schemas) + named_union_schema.set_prop("name", item_name) + named_union_schema.set_prop("namespace", union.name) + + self.__schema_stack.append(named_union_schema) + return self.__get_submodule_path(named_union_schema) / RustPathSegment( + rust_sanitize_type_ident(avro_shortname(item_name)) + ) + + if isinstance(schema, (RecordSchema, EnumSchema)): + return self.__avro_to_rust.get( + schema.name, + self.__get_submodule_path(schema) + / RustPathSegment(ident=rust_sanitize_type_ident(avro_shortname(schema.name))), + ) + + if isinstance(schema, PrimitiveSchema): + return self.__avro_to_rust.get(schema.type) + + def parse_variant_type(schema: Schema) -> RustVariant: + if isinstance(schema, (RecordSchema, EnumSchema)): + return RustVariant.from_path( + self.__avro_to_rust.get( + schema.name, + self.__get_submodule_path(schema) + / RustPathSegment( + ident=rust_sanitize_type_ident(avro_shortname(schema.name)) + ), + ) + ) + + if isinstance(schema, PrimitiveSchema): + return RustVariant.from_path(self.__avro_to_rust.get(schema.type)) + + if isinstance(schema, ArraySchema): + return RustVariant.from_path( + rust_type_list(parse_variant_array_subtype(schema.items)) + ) + + raise ValueError(f"Cannot parse schema with type: '{type(schema).__name__}'.") + + # + # `__parse_union_schema(...)` method + ident = rust_sanitize_type_ident(avro_shortname(union.name)) + attrs, _ = self.__parse_named_schema_attrs(union) + variants = set(map(parse_variant_type, union.schemas)) + + return self.__module_tree \ + .add_submodule(self.__get_submodule_path(union)) \ + .add_named_type(RustEnum(ident=ident, attrs=attrs, variants=variants)) # fmt: skip + + def __parse_enum_schema(self, enum: EnumSchema) -> RustPath: + ident = rust_sanitize_type_ident(avro_shortname(enum.name)) + attrs, docs_count = self.__parse_named_schema_attrs(enum) + attrs = [ + *attrs, + RustAttribute( + RustMetaList( + path=RustPath.from_str("derive"), + metas=[RustPath.from_str("Copy")], + ) + ), + ] + + if len(enum.symbols) == 1: + return self.__module_tree \ + .add_submodule(self.__get_submodule_path(enum)) \ + .add_named_type( + RustStruct( + ident=ident, + attrs=[ + *attrs[:docs_count], + RustAttribute( + RustMetaNameValue( + path=RustPath.from_str("doc"), + value=f"Matches constant value `{enum.symbols[0]}`.", + ) + ), + *attrs[docs_count:], + RustAttribute( + RustMetaList( + path=RustPath.from_str("salad"), + metas=[RustMetaNameValue( + path=RustPath.from_str("as_str"), + value=enum.symbols[0], + )], + ) + ), + ], + ) + ) # fmt: skip + else: + return self.__module_tree \ + .add_submodule(self.__get_submodule_path(enum)) \ + .add_named_type( + RustEnum( + ident=ident, + attrs=attrs, + variants=[ + RustVariant( + ident=rust_sanitize_type_ident(symbol), + attrs=[ + RustAttribute( + RustMetaNameValue( + path=RustPath.from_str("doc"), + value=f"Matches constant value `{symbol}`.", + ) + ), + RustAttribute( + RustMetaList( + path=RustPath.from_str("salad"), + metas=[RustMetaNameValue( + path=RustPath.from_str("as_str"), + value=symbol, + )], + ) + ), + ], + ) + for symbol in enum.symbols + ], + ) + ) # fmt: skip + + # End of named schemas parse block + # + @staticmethod + def __parse_named_schema_attrs(schema: NamedSchema) -> tuple[RustAttributes, int]: + attrs, docs_count = [], 0 + + if docs := schema.get_prop("doc"): + rust_path_doc = RustPath.from_str("doc") + attrs.extend( + RustAttribute(RustMetaNameValue(path=rust_path_doc, value=doc)) + for doc in rust_sanitize_doc_iter(docs) + ) + docs_count = len(attrs) + + attrs.append( + RustAttribute( + RustMetaList( + path=RustPath.from_str("derive"), + metas=[ + RustPath.from_str("Debug"), + RustPath.from_str("Clone"), + ], + ) + ) + ) + + return attrs, docs_count + + @staticmethod + def __parse_field_schema_attrs(schema: SaladField) -> tuple[RustAttributes, int]: + attrs, docs_count = [], 0 + + if docs := schema.get_prop("doc"): + rust_path_doc = RustPath.from_str("doc") + attrs.extend( + RustAttribute(RustMetaNameValue(path=rust_path_doc, value=doc)) + for doc in rust_sanitize_doc_iter(docs) + ) + docs_count = len(attrs) + + metas = [] + if default := schema.get_prop("default"): + metas.append(RustMetaNameValue(path=RustPath.from_str("default"), value=default)) + if jsonld_predicate := schema.get_prop("jsonldPredicate"): + if isinstance(jsonld_predicate, str) and jsonld_predicate == "@id": + metas.append(RustPath.from_str("identifier")) + elif isinstance(jsonld_predicate, MutableMapping): + metas.extend( + RustMetaNameValue(path=RustPath.from_str(rust_path), value=value) + for key, rust_path in [ + ("mapSubject", "map_key"), + ("mapPredicate", "map_predicate"), + ("subscope", "subscope"), + ] + if (value := jsonld_predicate.get(key)) + ) + if metas: + attrs.append(RustAttribute(RustMetaList(path=RustPath.from_str("salad"), metas=metas))) + + return attrs, docs_count + + # End of attributes parse block + # + def __get_submodule_path(self, schema: NamedSchema) -> RustPath: + segments = [RustPathSegment(ident="crate")] + if namespace_prop := schema.get_prop("namespace"): + if (namespace := namespace_prop.removeprefix(self.package)) not in ("", "."): + namespace_segment = namespace.split(".")[1].lower() + module_ident = rust_sanitize_field_ident(namespace_segment) + segments.append(RustPathSegment(ident=module_ident)) + return RustPath(segments=segments) + + def __init_output_directory(self) -> None: + """ + Initialize the output directory structure. + """ + if self.output_dir.is_file(): + raise ValueError(f"Output directory cannot be a file: {self.output_dir}") + if not self.output_dir.exists(): + _logger.info(f"Creating output directory: {self.output_dir}") + self.output_dir.mkdir(mode=0o755, parents=True) + elif any(self.output_dir.iterdir()): + _logger.warning( + f"Output directory is not empty: {self.output_dir}.\n" + "Wait for 3 seconds before proceeding..." + ) + sleep(3) + + def copy2_wrapper(src: str, dst: str) -> object: + if not src.endswith("rust/Cargo.toml"): + return shutil.copy2(src, dst) + + replace_dict = [ + ("{package_name}", self.output_dir.name), + ("{package_version}", self.PACKAGE_VERSION), + ] + + with open(src, "r") as src, open(dst, "w") as dst: + content = src.read() + for placeholder, value in replace_dict: + content = content.replace(placeholder, value) + dst.write(content) + + shutil.copytree( + RustCodeGen.__TEMPLATE_DIR, + self.output_dir, + dirs_exist_ok=True, + copy_function=copy2_wrapper, + ) + + @staticmethod + def __generate_crate_version(salad_version: str) -> str: + salad_version = salad_version.removeprefix("v") + return f"{RustCodeGen.PACKAGE_VERSION}+salad{salad_version}" diff --git a/setup.py b/setup.py index 6f3b4c86..0711b559 100644 --- a/setup.py +++ b/setup.py @@ -44,6 +44,7 @@ "schema_salad/cpp_codegen.py", "schema_salad/dlang_codegen.py", "schema_salad/dotnet_codegen.py", + "schema_salad/rust_codegen.py", # "schema_salad/exceptions.py", # leads to memory leaks "schema_salad/java_codegen.py", "schema_salad/jsonld_context.py", @@ -131,6 +132,7 @@ "typescript/*/*", "typescript/*/*/*", "typescript/.*", + "rust/**", ], "schema_salad.tests": [ "*.json",