From 57aeff2a87eb988994be8aaa5997c5365dbd0ea4 Mon Sep 17 00:00:00 2001 From: Grey Date: Fri, 13 Sep 2024 14:33:51 -0700 Subject: [PATCH 01/20] Add CI status badges to README (#25) Signed-off-by: Michael X. Grey --- README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/README.md b/README.md index de067452..b9393cff 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,8 @@ +[![style](https://github.com/open-rmf/bevy_impulse/actions/workflows/style.yaml/badge.svg)](https://github.com/open-rmf/bevy_impulse/actions/workflows/style.yaml) +[![ci_linux](https://github.com/open-rmf/bevy_impulse/actions/workflows/ci_linux.yaml/badge.svg)](https://github.com/open-rmf/bevy_impulse/actions/workflows/ci_linux.yaml) +[![ci_windows](https://github.com/open-rmf/bevy_impulse/actions/workflows/ci_windows.yaml/badge.svg)](https://github.com/open-rmf/bevy_impulse/actions/workflows/ci_windows.yaml) +[![ci_web](https://github.com/open-rmf/bevy_impulse/actions/workflows/ci_web.yaml/badge.svg)](https://github.com/open-rmf/bevy_impulse/actions/workflows/ci_web.yaml) + # Reactive Programming for Bevy This library provides sophisticated [reactive programming](https://en.wikipedia.org/wiki/Reactive_programming) for the [bevy](https://bevyengine.org/) ECS. In addition to supporting one-shot chains of async operations, it can support reusable workflows with parallel branches, synchronization, races, and cycles. These workflows can be hierarchical, so a workflow can be used as a building block by other workflows. From 72fcd726b9885971caba1b0d6d901bb01c5a86ed Mon Sep 17 00:00:00 2001 From: Grey Date: Thu, 14 Nov 2024 16:49:38 +0800 Subject: [PATCH 02/20] Add derive macro for DeliveryLabel (#30) Signed-off-by: Michael X. Grey --- Cargo.toml | 4 +- macros/Cargo.toml | 2 +- macros/src/lib.rs | 29 ++++++++++- src/impulse.rs | 126 +++++++++++++++++++++++++++++++++++----------- src/service.rs | 6 +++ 5 files changed, 133 insertions(+), 34 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 9a7a3b46..261c8742 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "bevy_impulse" -version = "0.0.1" +version = "0.0.2" edition = "2021" authors = ["Grey "] license = "Apache-2.0" @@ -11,7 +11,7 @@ keywords = ["reactive", "workflow", "behavior", "agent", "bevy"] categories = ["science::robotics", "asynchronous", "concurrency", "game-development"] [dependencies] -bevy_impulse_derive = { path = "macros", version = "0.0.1" } +bevy_impulse_derive = { path = "macros", version = "0.0.2" } bevy_ecs = "0.12" bevy_utils = "0.12" bevy_hierarchy = "0.12" diff --git a/macros/Cargo.toml b/macros/Cargo.toml index 6073aab1..12fa2bb3 100644 --- a/macros/Cargo.toml +++ b/macros/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "bevy_impulse_derive" -version = "0.0.1" +version = "0.0.2" edition = "2021" authors = ["Grey "] license = "Apache-2.0" diff --git a/macros/src/lib.rs b/macros/src/lib.rs index fa139b5e..d40c9309 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -24,12 +24,37 @@ pub fn simple_stream_macro(item: TokenStream) -> TokenStream { let ast: DeriveInput = syn::parse(item).unwrap(); let struct_name = &ast.ident; let (impl_generics, type_generics, where_clause) = &ast.generics.split_for_impl(); - // let bevy_impulse_path: Path = quote! { - impl #impl_generics Stream for #struct_name #type_generics #where_clause { + impl #impl_generics ::bevy_impulse::Stream for #struct_name #type_generics #where_clause { type Container = ::bevy_impulse::DefaultStreamContainer; } } .into() } + +#[proc_macro_derive(DeliveryLabel)] +pub fn delivery_label_macro(item: TokenStream) -> TokenStream { + let ast: DeriveInput = syn::parse(item).unwrap(); + let struct_name = &ast.ident; + let (impl_generics, type_generics, where_clause) = &ast.generics.split_for_impl(); + + quote! { + impl #impl_generics ::bevy_impulse::DeliveryLabel for #struct_name #type_generics #where_clause { + fn dyn_clone(&self) -> Box { + ::std::boxed::Box::new(::std::clone::Clone::clone(self)) + } + + fn as_dyn_eq(&self) -> &dyn ::bevy_impulse::utils::DynEq { + self + } + + fn dyn_hash(&self, mut state: &mut dyn ::std::hash::Hasher) { + let ty_id = ::std::any::TypeId::of::(); + ::std::hash::Hash::hash(&ty_id, &mut state); + ::std::hash::Hash::hash(self, &mut state); + } + } + } + .into() +} diff --git a/src/impulse.rs b/src/impulse.rs index 6ebf835b..875740e3 100644 --- a/src/impulse.rs +++ b/src/impulse.rs @@ -485,11 +485,30 @@ mod tests { } #[derive(Clone, Debug, PartialEq, Eq, Hash)] - struct TestLabel; + struct UnitLabel; - // TODO(luca) create a proc-macro for this, because library crates can't - // export proc macros we will need to create a new macros only crate - impl DeliveryLabel for TestLabel { + // TODO(@mxgrey) Figure out how to make the DeliveryLabel macro usable + // within the core bevy_impulse library + impl DeliveryLabel for UnitLabel { + fn dyn_clone(&self) -> Box { + Box::new(self.clone()) + } + + fn as_dyn_eq(&self) -> &dyn DynEq { + self + } + + fn dyn_hash(&self, mut state: &mut dyn std::hash::Hasher) { + let ty_id = std::any::TypeId::of::(); + std::hash::Hash::hash(&ty_id, &mut state); + std::hash::Hash::hash(self, &mut state); + } + } + + #[derive(Clone, Debug, PartialEq, Eq, Hash)] + struct StatefulLabel(u64); + + impl DeliveryLabel for StatefulLabel { fn dyn_clone(&self) -> Box { Box::new(self.clone()) } @@ -534,9 +553,38 @@ mod tests { service: Service>, ()>, context: &mut TestingContext, ) { - let queuing_service = service.instruct(TestLabel); - let preempting_service = service.instruct(TestLabel.preempt()); + // Test for a unit struct + verify_preemption_matrix( + service.instruct(UnitLabel), + service.instruct(UnitLabel.preempt()), + context, + ); + + // Test for a stateful struct + verify_preemption_matrix( + service.instruct(StatefulLabel(5)), + service.instruct(StatefulLabel(5).preempt()), + context, + ); + // Test for a unit struct + verify_queuing_matrix(service.instruct(UnitLabel), context); + + // Test for a stateful struct + verify_queuing_matrix(service.instruct(StatefulLabel(7)), context); + + // Test for a unit struct + verify_ensured_matrix(service, UnitLabel, context); + + // Test for a stateful struct + verify_ensured_matrix(service, StatefulLabel(2), context); + } + + fn verify_preemption_matrix( + queuing_service: Service>, ()>, + preempting_service: Service>, ()>, + context: &mut TestingContext, + ) { // Test by queuing up a bunch of requests before preempting them all at once. verify_preemption(1, queuing_service, preempting_service, context); verify_preemption(2, queuing_service, preempting_service, context); @@ -548,25 +596,6 @@ mod tests { verify_preemption(2, preempting_service, preempting_service, context); verify_preemption(3, preempting_service, preempting_service, context); verify_preemption(4, preempting_service, preempting_service, context); - - // Test by queuing up a bunch of requests and making sure they all get - // delivered. - verify_queuing(2, queuing_service, context); - verify_queuing(3, queuing_service, context); - verify_queuing(4, queuing_service, context); - verify_queuing(5, queuing_service, context); - - // Test by queuing up a mix of ensured and unensured requests, and then - // sending in one that preempts them all. The ensured requests should - // remain in the queue and execute despite the preempter. The unensured - // requests should all be cancelled. - verify_ensured([false, true, false, true], service, context); - verify_ensured([true, false, false, false], service, context); - verify_ensured([true, true, false, false], service, context); - verify_ensured([false, false, true, true], service, context); - verify_ensured([true, false, false, true], service, context); - verify_ensured([false, false, false, false], service, context); - verify_ensured([true, true, true, true], service, context); } fn verify_preemption( @@ -603,6 +632,18 @@ mod tests { assert!(context.no_unhandled_errors()); } + fn verify_queuing_matrix( + queuing_service: Service>, ()>, + context: &mut TestingContext, + ) { + // Test by queuing up a bunch of requests and making sure they all get + // delivered. + verify_queuing(2, queuing_service, context); + verify_queuing(3, queuing_service, context); + verify_queuing(4, queuing_service, context); + verify_queuing(5, queuing_service, context); + } + fn verify_queuing( queue_size: usize, queuing_service: Service>, ()>, @@ -628,9 +669,33 @@ mod tests { assert!(context.no_unhandled_errors()); } - fn verify_ensured( + fn verify_ensured_matrix( + service: Service>, ()>, + label: L, + context: &mut TestingContext, + ) { + // Test by queuing up a mix of ensured and unensured requests, and then + // sending in one that preempts them all. The ensured requests should + // remain in the queue and execute despite the preempter. The unensured + // requests should all be cancelled. + verify_ensured([false, true, false, true], service, label.clone(), context); + verify_ensured([true, false, false, false], service, label.clone(), context); + verify_ensured([true, true, false, false], service, label.clone(), context); + verify_ensured([false, false, true, true], service, label.clone(), context); + verify_ensured([true, false, false, true], service, label.clone(), context); + verify_ensured( + [false, false, false, false], + service, + label.clone(), + context, + ); + verify_ensured([true, true, true, true], service, label.clone(), context); + } + + fn verify_ensured( queued: impl IntoIterator, service: Service>, ()>, + label: L, context: &mut TestingContext, ) { let counter = Arc::new(Mutex::new(0_u64)); @@ -640,9 +705,9 @@ mod tests { for ensured in queued { let srv = if ensured { expected_count += 1; - service.instruct(TestLabel.ensure()) + service.instruct(label.clone().ensure()) } else { - service.instruct(TestLabel) + service.instruct(label.clone()) }; let promise = context @@ -653,7 +718,10 @@ mod tests { let mut preempter = context.command(|commands| { commands - .request(Arc::clone(&counter), service.instruct(TestLabel.preempt())) + .request( + Arc::clone(&counter), + service.instruct(label.clone().preempt()), + ) .take_response() }); diff --git a/src/service.rs b/src/service.rs index 8f20fea9..706a9412 100644 --- a/src/service.rs +++ b/src/service.rs @@ -26,6 +26,7 @@ use bevy_ecs::{ prelude::{Commands, Component, Entity, Event, World}, schedule::ScheduleLabel, }; +pub use bevy_impulse_derive::DeliveryLabel; use bevy_utils::{define_label, intern::Interned}; use std::{any::TypeId, collections::HashSet}; use thiserror::Error as ThisError; @@ -196,6 +197,11 @@ define_label!( DELIVERY_LABEL_INTERNER ); +pub mod utils { + /// Used by the procedural macro for DeliveryLabel + pub use bevy_utils::label::DynEq; +} + /// When using a service, you can bundle in delivery instructions that affect /// how multiple requests to the same service may interact with each other. /// From 85501df008e425ba0e19e5499702d5350b75c9b9 Mon Sep 17 00:00:00 2001 From: Grey Date: Wed, 27 Nov 2024 12:18:58 +0800 Subject: [PATCH 03/20] Introduce Split operation (#33) Signed-off-by: Michael X. Grey --- Cargo.toml | 1 - src/buffer/bufferable.rs | 33 +- src/builder.rs | 27 +- src/chain.rs | 126 ++++- src/chain/split.rs | 939 +++++++++++++++++++++++++++++++++ src/disposal.rs | 30 ++ src/node.rs | 3 + src/operation.rs | 3 + src/operation/operate_split.rs | 223 ++++++++ 9 files changed, 1370 insertions(+), 15 deletions(-) create mode 100644 src/chain/split.rs create mode 100644 src/operation/operate_split.rs diff --git a/Cargo.toml b/Cargo.toml index 261c8742..08bcde41 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,7 +25,6 @@ async-task = { version = "4.7.1", optional = true } # bevy_tasks::Task, so we're leaving it as a mandatory dependency for now. bevy_tasks = { version = "0.12", features = ["multi-threaded"] } -arrayvec = "0.7" itertools = "0.13" smallvec = "1.13" tokio = { version = "1.39", features = ["sync"]} diff --git a/src/buffer/bufferable.rs b/src/buffer/bufferable.rs index 06841be7..17f8b367 100644 --- a/src/buffer/bufferable.rs +++ b/src/buffer/bufferable.rs @@ -207,28 +207,28 @@ impl Bufferable for [T; N] { } pub trait IterBufferable { - type BufferType: Buffered; + type BufferElement: Buffered; /// Convert an iterable collection of bufferable workflow elements into /// buffers if they are not buffers already. fn into_buffer_vec( self, builder: &mut Builder, - ) -> SmallVec<[Self::BufferType; N]>; + ) -> SmallVec<[Self::BufferElement; N]>; /// Join an iterable collection of bufferable workflow elements. /// /// Performance is best if you can choose an `N` which is equal to the /// number of buffers inside the iterable, but this will work even if `N` /// does not match the number. - fn join_vec( + fn join_vec<'w, 's, 'a, 'b, const N: usize>( self, - builder: &mut Builder, - ) -> Output::Item; N]>> + builder: &'b mut Builder<'w, 's, 'a>, + ) -> Chain<'w, 's, 'a, 'b, SmallVec<[::Item; N]>> where Self: Sized, - Self::BufferType: 'static + Send + Sync, - ::Item: 'static + Send + Sync, + Self::BufferElement: 'static + Send + Sync, + ::Item: 'static + Send + Sync, { let buffers = self.into_buffer_vec::(builder); let join = builder.commands.spawn(()).id(); @@ -239,6 +239,23 @@ pub trait IterBufferable { Join::new(buffers, target), )); - Output::new(builder.scope, target) + Output::new(builder.scope, target).chain(builder) + } +} + +impl IterBufferable for T +where + T: IntoIterator, + T::Item: Bufferable, +{ + type BufferElement = ::BufferType; + + fn into_buffer_vec( + self, + builder: &mut Builder, + ) -> SmallVec<[Self::BufferElement; N]> { + SmallVec::<[Self::BufferElement; N]>::from_iter( + self.into_iter().map(|e| e.into_buffer(builder)), + ) } } diff --git a/src/builder.rs b/src/builder.rs index 702c6268..882b2a8a 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -26,9 +26,10 @@ use crate::{ AddOperation, AsMap, BeginCleanupWorkflow, Buffer, BufferItem, BufferKeys, BufferSettings, Bufferable, Buffered, Chain, Collect, ForkClone, ForkCloneOutput, ForkTargetStorage, Gate, GateRequest, Injection, InputSlot, IntoAsyncMap, IntoBlockingMap, Node, OperateBuffer, - OperateBufferAccess, OperateDynamicGate, OperateScope, OperateStaticGate, Output, Provider, - RequestOfMap, ResponseOfMap, Scope, ScopeEndpoints, ScopeSettings, ScopeSettingsStorage, - Sendish, Service, StreamPack, StreamTargetMap, StreamsOfMap, Trim, TrimBranch, UnusedTarget, + OperateBufferAccess, OperateDynamicGate, OperateScope, OperateSplit, OperateStaticGate, Output, + Provider, RequestOfMap, ResponseOfMap, Scope, ScopeEndpoints, ScopeSettings, + ScopeSettingsStorage, Sendish, Service, SplitOutputs, Splittable, StreamPack, StreamTargetMap, + StreamsOfMap, Trim, TrimBranch, UnusedTarget, }; pub(crate) mod connect; @@ -339,6 +340,26 @@ impl<'w, 's, 'a> Builder<'w, 's, 'a> { self.create_collect(n, Some(n)) } + /// Create a new split operation in the workflow. The [`InputSlot`] can take + /// in values that you want to split, and [`SplitOutputs::build`] will let + /// you build connections to the split value. + pub fn create_split(&mut self) -> (InputSlot, SplitOutputs) + where + T: 'static + Send + Sync + Splittable, + { + let source = self.commands.spawn(()).id(); + self.commands.add(AddOperation::new( + Some(self.scope), + source, + OperateSplit::::default(), + )); + + ( + InputSlot::new(self.scope, source), + SplitOutputs::new(self.scope, source), + ) + } + /// This method allows you to define a cleanup workflow that branches off of /// this scope that will activate during the scope's cleanup phase. The /// input to the cleanup workflow will be a key to access to one or more diff --git a/src/chain.rs b/src/chain.rs index 8c97447d..e81fcd7e 100644 --- a/src/chain.rs +++ b/src/chain.rs @@ -27,9 +27,9 @@ use crate::{ make_option_branching, make_result_branching, AddOperation, AsMap, Buffer, BufferKey, BufferKeys, Bufferable, Buffered, Builder, Collect, CreateCancelFilter, CreateDisposalFilter, ForkTargetStorage, Gate, GateRequest, InputSlot, IntoAsyncMap, IntoBlockingCallback, - IntoBlockingMap, Node, Noop, OperateBufferAccess, OperateDynamicGate, OperateStaticGate, - Output, ProvideOnce, Provider, Scope, ScopeSettings, Sendish, Service, Spread, StreamOf, - StreamPack, StreamTargetMap, Trim, TrimBranch, UnusedTarget, + IntoBlockingMap, Node, Noop, OperateBufferAccess, OperateDynamicGate, OperateSplit, + OperateStaticGate, Output, ProvideOnce, Provider, Scope, ScopeSettings, Sendish, Service, + Spread, StreamOf, StreamPack, StreamTargetMap, Trim, TrimBranch, UnusedTarget, }; pub mod fork_clone_builder; @@ -38,6 +38,9 @@ pub use fork_clone_builder::*; pub(crate) mod premade; use premade::*; +pub mod split; +pub use split::*; + pub mod unzip; pub use unzip::*; @@ -601,6 +604,71 @@ impl<'w, 's, 'a, 'b, T: 'static + Send + Sync> Chain<'w, 's, 'a, 'b, T> { self.map_async(|r| r) } + /// If the chain's response implements the [`Splittable`] trait, then this + /// will insert a split operation and provide your `build` function with the + /// [`SplitBuilder`] for it. This returns the return value of your build + /// function. + pub fn split(self, build: impl FnOnce(SplitBuilder) -> U) -> U + where + T: Splittable, + { + let source = self.target; + self.builder.commands.add(AddOperation::new( + Some(self.builder.scope), + source, + OperateSplit::::default(), + )); + + build(SplitBuilder::new(source, self.builder)) + } + + /// If the chain's response implements the [`Splittable`] trait, then this + /// will insert a split and provide a container for its available outputs. + /// To build connections to these outputs later, use [`SplitOutputs::build`]. + /// + /// This is equivalent to + /// ```text + /// .split(|split| split.outputs()) + /// ``` + pub fn split_outputs(self) -> SplitOutputs + where + T: Splittable, + { + self.split(|b| b.outputs()) + } + + /// If the chain's response can be turned into an iterator with an appropriate + /// item type, this will allow it to be split in a list-like way. + /// + /// This is equivalent to + /// ```text + /// .map_block(SplitAsList::new).split(build) + /// ``` + pub fn split_as_list(self, build: impl FnOnce(SplitBuilder>) -> U) -> U + where + T: IntoIterator, + T::Item: 'static + Send + Sync, + { + self.map_block(SplitAsList::new).split(build) + } + + /// If the chain's response can be turned into an iterator with an appropriate + /// item type, this will insert a split and provide a container for its + /// available outputs. To build connections to these outputs later, use + /// [`SplitOutputs::build`]. + /// + /// This is equivalent to + /// ```text + /// .split_as_list(|split| split.outputs()) + /// ``` + pub fn split_as_list_outputs(self) -> SplitOutputs> + where + T: IntoIterator, + T::Item: 'static + Send + Sync, + { + self.split_as_list(|b| b.outputs()) + } + /// Add a [no-op][1] to the current end of the chain. /// /// As the name suggests, a no-op will not actually do anything, but it adds @@ -633,10 +701,12 @@ impl<'w, 's, 'a, 'b, T: 'static + Send + Sync> Chain<'w, 's, 'a, 'b, T> { } } + /// The scope that the chain is building inside of. pub fn scope(&self) -> Entity { self.builder.scope } + /// The target where the chain will be sending its latest output. pub fn target(&self) -> Entity { self.target } @@ -942,6 +1012,38 @@ where } } +impl<'w, 's, 'a, 'b, K, V, T> Chain<'w, 's, 'a, 'b, T> +where + K: 'static + Send + Sync + Eq + std::hash::Hash + Clone + std::fmt::Debug, + V: 'static + Send + Sync, + T: 'static + Send + Sync + IntoIterator, +{ + /// If the chain's response type can be turned into an iterator that returns + /// `(key, value)` pairs, then this will split it in a map-like way, whether + /// or not it is a conventional map data structure. + /// + /// This is equivalent to + /// ```text + /// .map_block(SplitAsMap::new).split(build) + /// ``` + pub fn split_as_map(self, build: impl FnOnce(SplitBuilder>) -> U) -> U { + self.map_block(SplitAsMap::new).split(build) + } + + /// If the chain's response type can be turned into an iterator that returns + /// `(key, value)` pairs, then this will split it in a map-like way and + /// provide a container for its available outputs. To build connections to + /// these outputs later, use [`SplitOutputs::build`]. + /// + /// This is equivalent to + /// ```text + /// .split_as_map(|split| split.outputs()) + /// ``` + pub fn split_as_map_outputs(self) -> SplitOutputs> { + self.split_as_map(|b| b.outputs()) + } +} + impl<'w, 's, 'a, 'b, Request, Response, Streams> Chain<'w, 's, 'a, 'b, (Request, Service)> where @@ -1030,6 +1132,24 @@ impl<'w, 's, 'a, 'b, T: 'static + Send + Sync> Chain<'w, 's, 'a, 'b, T> { } } +impl<'w, 's, 'a, 'b, K, V> Chain<'w, 's, 'a, 'b, (K, V)> +where + K: 'static + Send + Sync, + V: 'static + Send + Sync, +{ + /// If the chain's response contains a `(key, value)` pair, get the `key` + /// component from it (the first element of the tuple). + pub fn key(self) -> Chain<'w, 's, 'a, 'b, K> { + self.map_block(|(key, _)| key) + } + + /// If the chain's response contains a `(key, value)` pair, get the `value` + /// component from it (the second element of the tuple). + pub fn value(self) -> Chain<'w, 's, 'a, 'b, V> { + self.map_block(|(_, value)| value) + } +} + #[cfg(test)] mod tests { use crate::{prelude::*, testing::*}; diff --git a/src/chain/split.rs b/src/chain/split.rs new file mode 100644 index 00000000..c22046d7 --- /dev/null +++ b/src/chain/split.rs @@ -0,0 +1,939 @@ +/* + * Copyright (C) 2024 Open Source Robotics Foundation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * +*/ + +use bevy_ecs::prelude::Entity; +use std::{ + collections::{BTreeMap, HashMap, HashSet}, + fmt::Debug, + hash::Hash, +}; +use thiserror::Error as ThisError; + +use crate::{Builder, Chain, ConnectToSplit, OperationResult, Output, UnusedTarget}; + +/// Implementing this trait on a struct will allow the [`Chain::split`] operation +/// to be performed on [outputs][crate::Output] of that type. +pub trait Splittable: Sized { + /// The key used to identify different elements in the split + type Key: 'static + Send + Sync + Eq + Hash + Clone + Debug; + + /// The type that the value gets split into + type Item: 'static + Send + Sync; + + /// Return true if the key is feasible for this type of split, otherwise + /// return false. Returning false will cause the user to receive a + /// [`SplitConnectionError::KeyOutOfBounds`]. This will also cause iterating + /// to cease. + fn validate(key: &Self::Key) -> bool; + + /// Get the next key value that would follow the provided one. If [`None`] + /// is passed in then return the first key. If you return [`None`] then + /// the connections will stop iterating. + fn next(key: &Option) -> Option; + + /// Split the value into its parts + fn split(self, dispatcher: SplitDispatcher<'_, Self::Key, Self::Item>) -> OperationResult; +} + +/// This is returned by [`Chain::split`] and allows you to connect to the +/// split pieces. +#[must_use] +pub struct SplitBuilder<'w, 's, 'a, 'b, T: Splittable> { + outputs: SplitOutputs, + builder: &'b mut Builder<'w, 's, 'a>, +} + +impl<'w, 's, 'a, 'b, T: Splittable> std::fmt::Debug for SplitBuilder<'w, 's, 'a, 'b, T> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SplitBuilder") + .field("outputs", &self.outputs) + .finish() + } +} + +impl<'w, 's, 'a, 'b, T: 'static + Splittable> SplitBuilder<'w, 's, 'a, 'b, T> { + /// Get the state of connections for this split. You can resume building + /// more connections later by calling [`SplitOutputs::build`] on the + /// connections. + pub fn outputs(self) -> SplitOutputs { + self.outputs + } + + /// Unpack the split outputs from the builder so the two can be used + /// independently. + pub fn unpack(self) -> (SplitOutputs, &'b mut Builder<'w, 's, 'a>) { + (self.outputs, self.builder) + } + + /// Build a branch for one of the keys in the split. + pub fn branch_for( + mut self, + key: T::Key, + f: impl FnOnce(Chain), + ) -> SplitChainResult<'w, 's, 'a, 'b, T> { + let output = match self.output_for(key) { + Ok(output) => output, + Err(err) => return Err((self, err)), + }; + f(output.chain(self.builder)); + Ok(self) + } + + /// This is a convenience function for splits whose keys implement + /// [`FromSpecific`] to build a branch for a specific key in the split. This + /// can be used by [`MapSplitKey`]. + pub fn specific_branch( + self, + specific_key: ::SpecificKey, + f: impl FnOnce(Chain), + ) -> SplitChainResult<'w, 's, 'a, 'b, T> + where + T::Key: FromSpecific, + { + self.branch_for(T::Key::from_specific(specific_key), f) + } + + /// This is a convenience function for splits whose keys implement + /// [`FromSequential`] to build a branch for an anonymous sequential key in + /// the split. This can be used by [`ListSplitKey`] and [`MapSplitKey`]. + pub fn sequential_branch( + self, + sequence_number: usize, + f: impl FnOnce(Chain), + ) -> SplitChainResult<'w, 's, 'a, 'b, T> + where + T::Key: FromSequential, + { + self.branch_for(T::Key::from_sequential(sequence_number), f) + } + + /// This is a convenience function for splits whose keys implement + /// [`ForRemaining`] to build a branch that takes in all items that did not + /// have a more specific connection available. This can be used by + /// [`ListSplitKey`] and [`MapSplitKey`]. + /// + /// This can only be set once, so subsequent attempts to set the remaining + /// branch will return an error. It will also return an error after + /// [`Self::remaining_output`] has been used. + pub fn remaining_branch( + self, + f: impl FnOnce(Chain), + ) -> SplitChainResult<'w, 's, 'a, 'b, T> + where + T::Key: ForRemaining, + { + self.branch_for(T::Key::for_remaining(), f) + } + + /// Build a branch for the next key in the split, if such a key may + /// be available. The first argument tells you what the key would be, but + /// you can safely ignore this if it doesn't matter to you. + /// + /// This changes what the next element will be if you later use this + /// [`SplitBuilder`] as an iterator. + pub fn next_branch( + mut self, + f: impl FnOnce(T::Key, Chain), + ) -> SplitChainResult<'w, 's, 'a, 'b, T> { + let Some((key, output)) = self.next() else { + return Err((self, SplitConnectionError::KeyOutOfBounds)); + }; + + f(key, output.chain(self.builder)); + Ok(self) + } + + /// Get the output slot for an element in the split. + pub fn output_for(&mut self, key: T::Key) -> Result, SplitConnectionError> { + if !T::validate(&key) { + return Err(SplitConnectionError::KeyOutOfBounds); + } + + if !self.outputs.used.insert(key.clone()) { + return Err(SplitConnectionError::KeyAlreadyUsed); + } + + let target = self.builder.commands.spawn(UnusedTarget).id(); + self.builder.commands.add(ConnectToSplit:: { + source: self.outputs.source, + target, + key, + }); + Ok(Output::new(self.outputs.scope, target)) + } + + /// This is a convenience function for splits whose keys implement + /// [`FromSpecific`] to get the output for a specific key in the split. This + /// can be used by [`MapSplitKey`]. + pub fn specific_output( + &mut self, + specific_key: ::SpecificKey, + ) -> Result, SplitConnectionError> + where + T::Key: FromSpecific, + { + self.output_for(T::Key::from_specific(specific_key)) + } + + /// This is a convenience function for splits whose keys implement + /// [`FromSequential`] to get the output for an anonymous sequential key in + /// the split. This can be used by [`ListSplitKey`] and [`MapSplitKey`]. + pub fn sequential_output( + &mut self, + sequence_number: usize, + ) -> Result, SplitConnectionError> + where + T::Key: FromSequential, + { + self.output_for(T::Key::from_sequential(sequence_number)) + } + + /// This is a convenience function for splits whose keys implement + /// [`ForRemaining`] to get the output for all keys remaining without a + /// connection after all the more specific connections have been considered. + /// This can be used by [`ListSplitKey`] and [`MapSplitKey`]. + /// + /// This can only be used once, after which it will return an error. It will + /// also return an error after [`Self::remaining_branch`] has been used. + pub fn remaining_output(&mut self) -> Result, SplitConnectionError> + where + T::Key: ForRemaining, + { + self.output_for(T::Key::for_remaining()) + } + + /// Explicitly stop building the split by indicating that you it to remain + /// unused. + pub fn unused(self) { + // Do nothing + } + + /// Used internally to create a new split connector + pub(crate) fn new(source: Entity, builder: &'b mut Builder<'w, 's, 'a>) -> Self { + Self { + outputs: SplitOutputs::new(builder.scope, source), + builder, + } + } +} + +impl<'w, 's, 'a, 'b, T: 'static + Splittable> Iterator for SplitBuilder<'w, 's, 'a, 'b, T> { + type Item = (T::Key, Output); + fn next(&mut self) -> Option { + loop { + let next_key = T::next(&self.outputs.last_key)?; + self.outputs.last_key = Some(next_key.clone()); + + match self.output_for(next_key.clone()) { + Ok(output) => { + return Some((next_key, output)); + } + Err(SplitConnectionError::KeyAlreadyUsed) => { + // Restart the loop and get the next key which might not be + // used yet. + continue; + } + Err(SplitConnectionError::KeyOutOfBounds) => { + // We have reached the end of the valid range so just quit + // iterating. + return None; + } + } + } + } +} + +/// This tracks the connections that have been made to a split. This can be +/// retrieved from [`SplitBuilder`] by calling [`SplitBuilder::outputs`]. +/// You can then continue building connections by calling [`SplitBuilder::build`]. +#[must_use] +pub struct SplitOutputs { + scope: Entity, + source: Entity, + last_key: Option, + used: HashSet, +} + +impl std::fmt::Debug for SplitOutputs { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct(&format!("SplitOutputs<{}>", std::any::type_name::())) + .field("scope", &self.scope) + .field("source", &self.source) + .field("last_key", &self.last_key) + .field("used", &self.used) + .finish() + } +} + +impl SplitOutputs { + /// Resume building connections for this split. + pub fn build<'w, 's, 'a, 'b>( + self, + builder: &'b mut Builder<'w, 's, 'a>, + ) -> SplitBuilder<'w, 's, 'a, 'b, T> { + assert_eq!(self.scope, builder.scope); + SplitBuilder { + outputs: self, + builder, + } + } + + pub(crate) fn new(scope: Entity, source: Entity) -> Self { + Self { + scope, + source, + last_key: None, + used: Default::default(), + } + } +} + +/// This is a type alias for the result returned by chainable [`SplitBuilder`] +/// functions. If the last connection succeeded, you will receive [`Ok`] with +/// the [`SplitBuilder`] which you can keep building off of. Otherwise if the +/// last connection failed, you will receive an [`Err`] with the [`SplitBuilder`] +/// bundled with [`SplitConnectionError`] to tell you what went wrong. You can +/// continue building with the [`SplitBuilder`] even if an error occurred. +/// +/// You can use `ignore_result` from the [`IgnoreSplitChainResult`] trait to +/// just keep chaining without checking whether the connection suceeded. +pub type SplitChainResult<'w, 's, 'a, 'b, T> = Result< + SplitBuilder<'w, 's, 'a, 'b, T>, + (SplitBuilder<'w, 's, 'a, 'b, T>, SplitConnectionError), +>; + +/// A helper trait that allows users to ignore any failures while chaining +/// connections to a split. +pub trait IgnoreSplitChainResult<'w, 's, 'a, 'b, T: Splittable> { + /// Ignore whether the result of the chaining connection was [`Ok`] or [`Err`] + /// and just keep chaining. + fn ignore_result(self) -> SplitBuilder<'w, 's, 'a, 'b, T>; +} + +impl<'w, 's, 'a, 'b, T: Splittable> IgnoreSplitChainResult<'w, 's, 'a, 'b, T> + for SplitChainResult<'w, 's, 'a, 'b, T> +{ + fn ignore_result(self) -> SplitBuilder<'w, 's, 'a, 'b, T> { + match self { + Ok(split) => split, + Err((split, _)) => split, + } + } +} + +/// Information about why a connection to a split failed +#[derive(ThisError, Debug, Clone)] +#[error("An error occurred while trying to connect to a split")] +pub enum SplitConnectionError { + /// The requested index was already connected to. + KeyAlreadyUsed, + /// The requested index is out of bounds. + KeyOutOfBounds, +} + +/// Used by implementers of the [`Splittable`] trait to help them send their +/// split values to the proper input slots. +pub struct SplitDispatcher<'a, Key, Item> { + pub(crate) connections: &'a HashMap, + pub(crate) outputs: &'a mut Vec>, +} + +impl<'a, Key, Item> SplitDispatcher<'a, Key, Item> +where + Key: 'static + Send + Sync + Eq + Hash + Clone + Debug, + Item: 'static + Send + Sync, +{ + /// Get the output buffer a certain key. If there are no connections for the + /// given key, then this will return [`None`]. + /// + /// Push items into the output buffer to send them to the input connected to + /// this key. + pub fn outputs_for<'o>(&'o mut self, key: &Key) -> Option<&'o mut Vec> { + let index = *self.connections.get(key)?; + + if self.outputs.len() <= index { + // We do this just in case something bad happened with the cache + // that reset its size. + self.outputs.resize_with(index + 1, Vec::new); + } + + self.outputs.get_mut(index) + } +} + +/// Turn a sequence index into a split key. Implemented by [`ListSplitKey`] and +/// [`MapSplitKey`]. +pub trait FromSequential { + /// Convert the sequence index into the split key type. + fn from_sequential(seq: usize) -> Self; +} + +/// Get the key that represents all remaining/unspecified keys. Implemented by +/// [`ListSplitKey`] and [`MapSplitKey`]. +pub trait ForRemaining { + /// Get the key for remaining items. + fn for_remaining() -> Self; +} + +/// Turn a specific key into a split key. Implemented by [`MapSplitKey`]. +pub trait FromSpecific { + /// The specific key type + type SpecificKey; + + /// Convert the specific key into the split key type + fn from_specific(specific: Self::SpecificKey) -> Self; +} + +/// This enum allows users to key into splittable list-like structures based on +/// the sequence in which an item appears in the list. It also has an option for +/// keying into any items that were left over in the sequence. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum ListSplitKey { + /// Key into an item at a specific point in the sequence. + Sequential(usize), + /// Key into any remaining items that were not covered by the sequential keys. + Remaining, +} + +impl FromSequential for ListSplitKey { + fn from_sequential(seq: usize) -> Self { + ListSplitKey::Sequential(seq) + } +} + +impl ForRemaining for ListSplitKey { + fn for_remaining() -> Self { + ListSplitKey::Remaining + } +} + +/// This is a newtype that implements [`Splittable`] for anything that can be +/// turned into an iterator, but always splits it as if the iterator is list-like. +/// +/// If used with a map-type (e.g. [`HashMap`] or [BTreeMap]), the items will be +/// the `(key, value)` pairs of the maps, and the key for the split will be the +/// order in which the map is iterated through. For [`BTreeMap`] this will always +/// be a sorted order (e.g. alphabetical or numerical) but for [`HashMap`] this +/// order can be completely arbitrary and may be unstable. +pub struct SplitAsList { + pub contents: T, +} + +impl SplitAsList { + pub fn new(contents: T) -> Self { + Self { contents } + } +} + +impl Splittable for SplitAsList +where + T: 'static + Send + Sync + IntoIterator, + T::Item: 'static + Send + Sync, +{ + type Key = ListSplitKey; + type Item = (usize, T::Item); + + fn validate(_: &Self::Key) -> bool { + // We don't know if there are any restrictions for the iterable, so just + // return say all keys are valid + true + } + + fn next(key: &Option) -> Option { + if let Some(key) = key { + match key { + ListSplitKey::Sequential(k) => Some(ListSplitKey::Sequential(*k + 1)), + ListSplitKey::Remaining => None, + } + } else { + Some(ListSplitKey::Sequential(0)) + } + } + + fn split(self, mut dispatcher: SplitDispatcher<'_, Self::Key, Self::Item>) -> OperationResult { + for (index, value) in self.contents.into_iter().enumerate() { + match dispatcher.outputs_for(&ListSplitKey::Sequential(index)) { + Some(outputs) => { + outputs.push((index, value)); + } + None => { + if let Some(outputs) = dispatcher.outputs_for(&ListSplitKey::Remaining) { + outputs.push((index, value)); + } + } + } + } + + Ok(()) + } +} + +impl Splittable for Vec { + type Key = ListSplitKey; + type Item = (usize, T); + + fn validate(_: &Self::Key) -> bool { + // Vec has no restrictions on what index is valid + true + } + + fn next(key: &Option) -> Option { + SplitAsList::::next(key) + } + + fn split(self, dispatcher: SplitDispatcher<'_, Self::Key, Self::Item>) -> OperationResult { + SplitAsList::new(self).split(dispatcher) + } +} + +impl Splittable for smallvec::SmallVec<[T; N]> { + type Key = ListSplitKey; + type Item = (usize, T); + + fn validate(_: &Self::Key) -> bool { + // SmallVec has no restrictions on what index is valid + true + } + + fn next(key: &Option) -> Option { + SplitAsList::::next(key) + } + + fn split(self, dispatcher: SplitDispatcher<'_, Self::Key, Self::Item>) -> OperationResult { + SplitAsList::new(self).split(dispatcher) + } +} + +impl Splittable for [T; N] { + type Key = ListSplitKey; + type Item = (usize, T); + + fn validate(key: &Self::Key) -> bool { + // Static arrays have a firm limit of N + match key { + ListSplitKey::Sequential(s) => *s < N, + ListSplitKey::Remaining => true, + } + } + + fn next(key: &Option) -> Option { + // Static arrays have a firm limit of N + SplitAsList::::next(key).take_if(|key| Self::validate(key)) + } + + fn split(self, dispatcher: SplitDispatcher<'_, Self::Key, Self::Item>) -> OperationResult { + SplitAsList::new(self).split(dispatcher) + } +} + +/// This enum allows users to key into splittable map-like structures based on +/// the presence of a specific value or based on the sequence in which a value +/// is reached that wasn't associated with a specific key. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum MapSplitKey { + /// Key into the item associated with this specific key. + Specific(K), + /// Key into an anonymous sequential item. Anonymous items are items for + /// which no specific connection was made to their key. + Sequential(usize), + /// Key into any items that were not covered by one of the other types. + Remaining, +} + +impl MapSplitKey { + pub fn specific(self) -> Option { + match self { + MapSplitKey::Specific(key) => Some(key), + _ => None, + } + } +} + +impl From for MapSplitKey { + fn from(value: K) -> Self { + MapSplitKey::Specific(value) + } +} + +impl FromSpecific for MapSplitKey { + type SpecificKey = K; + fn from_specific(specific: Self::SpecificKey) -> Self { + Self::Specific(specific) + } +} + +impl FromSequential for MapSplitKey { + fn from_sequential(seq: usize) -> Self { + Self::Sequential(seq) + } +} + +impl ForRemaining for MapSplitKey { + fn for_remaining() -> Self { + Self::Remaining + } +} + +/// This is a newtype that implements [`Splittable`] for anything that can be +/// turned into an iterator whose items take the form of a `(key, value)` pair +/// where `key` meets all the bounds needed for a [`Splittable`] key. +/// +/// This is used to implement [`Splittable`] for map-like structures. +pub struct SplitAsMap +where + K: 'static + Send + Sync + Eq + Hash + Clone + Debug, + V: 'static + Send + Sync, + M: 'static + Send + Sync + IntoIterator, +{ + pub contents: M, + _ignore: std::marker::PhantomData<(K, V)>, +} + +impl SplitAsMap +where + K: 'static + Send + Sync + Eq + Hash + Clone + Debug, + V: 'static + Send + Sync, + M: 'static + Send + Sync + IntoIterator, +{ + pub fn new(contents: M) -> Self { + Self { + contents, + _ignore: Default::default(), + } + } +} + +impl Splittable for SplitAsMap +where + K: 'static + Send + Sync + Eq + Hash + Clone + Debug, + V: 'static + Send + Sync, + M: 'static + Send + Sync + IntoIterator, +{ + type Key = MapSplitKey; + type Item = (K, V); + + fn validate(_: &Self::Key) -> bool { + // We have no way of knowing what the key bounds are for an arbitrary map + true + } + + fn next(key: &Option) -> Option { + match key { + Some(key) => { + match key { + // Give the next key in the sequence + MapSplitKey::Sequential(index) => Some(MapSplitKey::Sequential(index + 1)), + // For an arbitrary map we don't know what would follow a specific key, so + // just stop iterating. This should never be reached in practice anyway. + MapSplitKey::Specific(_) => None, + MapSplitKey::Remaining => None, + } + } + None => Some(MapSplitKey::Sequential(0)), + } + } + + fn split(self, mut dispatcher: SplitDispatcher<'_, Self::Key, Self::Item>) -> OperationResult { + let mut next_seq = 0; + for (specific_key, value) in self.contents.into_iter() { + let key = MapSplitKey::Specific(specific_key); + match dispatcher.outputs_for(&key) { + Some(outputs) => { + outputs.push((key.specific().unwrap(), value)); + } + None => { + // No connection to the specific key, so let's check for a + // sequential connection. + let seq = MapSplitKey::Sequential(next_seq); + next_seq += 1; + match dispatcher.outputs_for(&seq) { + Some(outputs) => { + outputs.push((key.specific().unwrap(), value)); + } + None => { + // No connection to this point in the sequence, so + // let's send it to any remaining connection. + let remaining = MapSplitKey::Remaining; + if let Some(outputs) = dispatcher.outputs_for(&remaining) { + outputs.push((key.specific().unwrap(), value)); + } + } + } + } + } + } + + Ok(()) + } +} + +impl Splittable for HashMap +where + K: 'static + Send + Sync + Eq + Hash + Clone + Debug, + V: 'static + Send + Sync, +{ + type Key = MapSplitKey; + type Item = (K, V); + + fn validate(_: &Self::Key) -> bool { + true + } + + fn next(key: &Option) -> Option { + SplitAsMap::::next(key) + } + + fn split(self, dispatcher: SplitDispatcher<'_, Self::Key, Self::Item>) -> OperationResult { + SplitAsMap::new(self).split(dispatcher) + } +} + +impl Splittable for BTreeMap +where + K: 'static + Send + Sync + Eq + Hash + Clone + Debug, + V: 'static + Send + Sync, +{ + type Key = MapSplitKey; + type Item = (K, V); + + fn validate(_: &Self::Key) -> bool { + true + } + + fn next(key: &Option) -> Option { + SplitAsMap::::next(key) + } + + fn split(self, dispatcher: SplitDispatcher<'_, Self::Key, Self::Item>) -> OperationResult { + SplitAsMap::new(self).split(dispatcher) + } +} + +#[cfg(test)] +mod tests { + use crate::{testing::*, *}; + use std::collections::{BTreeMap, HashMap}; + + #[test] + fn test_split_array() { + let mut context = TestingContext::minimal_plugins(); + + let workflow = context.spawn_io_workflow(|scope, builder| { + scope + .input + .chain(builder) + .split(|split| { + let mut outputs = Vec::new(); + split + .sequential_branch(0, |chain| { + outputs.push(chain.value().map_block(|v| v + 0.0).output()); + }) + .unwrap() + .sequential_branch(2, |chain| { + outputs + .push(chain.value().map_async(|v| async move { v + 2.0 }).output()); + }) + .unwrap() + .sequential_branch(4, |chain| { + outputs.push(chain.value().map_block(|v| v + 4.0).output()); + }) + .unwrap() + .unused(); + + outputs + }) + .join_vec::<5>(builder) + .connect(scope.terminate); + }); + + let mut promise = context.command(|commands| { + commands + .request([5.0, 4.0, 3.0, 2.0, 1.0], workflow) + .take_response() + }); + + context.run_with_conditions(&mut promise, Duration::from_secs(30)); + assert!(context.no_unhandled_errors()); + let value = promise.take().available().unwrap(); + assert_eq!(value, [5.0, 5.0, 5.0].into()); + + let workflow = context.spawn_io_workflow(|scope: Scope<[f64; 3], f64>, builder| { + scope.input.chain(builder).split(|split| { + let split = split + .sequential_branch(0, |chain| { + chain + // Do some nonsense with the first element in the split + .fork_clone(( + |chain: Chain<_>| chain.unused(), + |chain: Chain<_>| chain.unused(), + |chain: Chain<_>| chain.unused(), + )); + }) + .unwrap(); + + // This is outside the valid range for the array, so it should + // fail + let err = split.sequential_branch(3, |chain| { + chain.value().connect(scope.terminate); + }); + assert!(matches!( + &err, + Err((_, SplitConnectionError::KeyOutOfBounds)) + )); + + let split = err + .ignore_result() + .sequential_branch(1, |chain| { + chain.unused(); + }) + .unwrap(); + + // We already connected to this key, so it should fail + let err = split.sequential_branch(0, |chain| { + chain.value().connect(scope.terminate); + }); + assert!(matches!( + &err, + Err((_, SplitConnectionError::KeyAlreadyUsed)) + )); + + // Connect the last element in the split to the termination node + err.ignore_result() + .sequential_branch(2, |chain| { + chain.value().connect(scope.terminate); + }) + .unwrap() + .unused(); + }); + }); + + let mut promise = + context.command(|commands| commands.request([1.0, 2.0, 3.0], workflow).take_response()); + + context.run_with_conditions(&mut promise, 1); + assert!(context.no_unhandled_errors()); + // Only the third element in the split gets connected to the workflow + // termination, the rest are discarded. This ensures that SplitBuilder + // connections still work after multiple failed connection attempts. + assert_eq!(promise.take().available().unwrap(), 3.0); + } + + #[test] + fn test_split_map() { + let mut context = TestingContext::minimal_plugins(); + + let km_to_miles = 0.621371; + let per_second_to_per_hour = 3600.0; + let convert_speed = move |v: f64| v * km_to_miles * per_second_to_per_hour; + let convert_distance = move |d: f64| d * km_to_miles; + + let workflow = + context.spawn_io_workflow(|scope: Scope, _>, builder| { + let collector = builder.create_collect_all::<_, 16>(); + + scope.input.chain(builder).split(|split| { + split + .specific_branch("speed".to_owned(), |chain| { + chain + .map_block(move |(k, v)| (k, convert_speed(v))) + .connect(collector.input); + }) + .ignore_result() + .specific_branch("velocity".to_owned(), |chain| { + chain + .map_async(move |(k, v)| async move { (k, convert_speed(v)) }) + .connect(collector.input); + }) + .unwrap() + .specific_branch("distance".to_owned(), |chain| { + chain + .map_block(move |(k, v)| (k, convert_distance(v))) + .connect(collector.input); + }) + .unwrap() + .sequential_branch(0, |chain| { + chain + .map_block(move |(k, v)| (k, 0.0 * v)) + .connect(collector.input); + }) + .unwrap() + .sequential_branch(1, |chain| { + chain + .map_async(move |(k, v)| async move { (k, 1.0 * v) }) + .connect(collector.input); + }) + .unwrap() + .sequential_branch(2, |chain| { + chain + .map_block(move |(k, v)| (k, 2.0 * v)) + .connect(collector.input); + }) + .unwrap() + .remaining_branch(|chain| { + chain.connect(collector.input); + }) + .unwrap() + .unused(); + }); + + collector + .output + .chain(builder) + .map_block(|v| HashMap::::from_iter(v)) + .connect(scope.terminate); + }); + + // We input a BTreeMap so we can ensure the first three sequence items + // are always the same: a, b, and c. Make sure that no other keys in the + // map come before c alphabetically. + let input_map: BTreeMap = [ + ("a", 3.14159), + ("b", 2.71828), + ("c", 4.0), + ("speed", 16.1), + ("velocity", -32.4), + ("distance", 4325.78), + ("foo", 42.0), + ("fib", 78.3), + ("dib", -22.1), + ] + .into_iter() + .map(|(k, v)| (k.to_owned(), v)) + .collect(); + + let mut promise = context.command(|commands| { + commands + .request(input_map.clone(), workflow) + .take_response() + }); + + context.run_with_conditions(&mut promise, Duration::from_secs(30)); + assert!(context.no_unhandled_errors()); + + let result = promise.take().available().unwrap(); + assert_eq!(result.len(), input_map.len()); + assert_eq!(result["a"], input_map["a"] * 0.0); + assert_eq!(result["b"], input_map["b"] * 1.0); + assert_eq!(result["c"], input_map["c"] * 2.0); + assert_eq!(result["speed"], convert_speed(input_map["speed"])); + assert_eq!(result["velocity"], convert_speed(input_map["velocity"])); + assert_eq!(result["distance"], convert_distance(input_map["distance"])); + assert_eq!(result["foo"], input_map["foo"]); + assert_eq!(result["fib"], input_map["fib"]); + assert_eq!(result["dib"], input_map["dib"]); + } +} diff --git a/src/disposal.rs b/src/disposal.rs index be8204d4..1bc453a5 100644 --- a/src/disposal.rs +++ b/src/disposal.rs @@ -122,6 +122,17 @@ impl Disposal { } .into() } + + pub fn incomplete_split( + split_node: Entity, + missing_keys: SmallVec<[Option>; 16]>, + ) -> Self { + IncompleteSplit { + split_node, + missing_keys, + } + .into() + } } #[derive(Debug)] @@ -176,6 +187,10 @@ pub enum DisposalCause { /// been sent out to indicate that the workflow is blocked up on the /// collection. DeficientCollection(DeficientCollection), + + /// A split operation took place, but not all connections to the split + /// received a value. + IncompleteSplit(IncompleteSplit), } /// A variant of [`DisposalCause`] @@ -387,6 +402,21 @@ impl From for DisposalCause { } } +/// A variant of [`DisposalCause`] +#[derive(Debug)] +pub struct IncompleteSplit { + /// The node that does the splitting + pub split_node: Entity, + /// The debug text of each key that was missing in the split + pub missing_keys: SmallVec<[Option>; 16]>, +} + +impl From for DisposalCause { + fn from(value: IncompleteSplit) -> Self { + Self::IncompleteSplit(value) + } +} + pub trait ManageDisposal { fn emit_disposal(&mut self, session: Entity, disposal: Disposal, roster: &mut OperationRoster); diff --git a/src/node.rs b/src/node.rs index 69a3665b..b3c054f5 100644 --- a/src/node.rs +++ b/src/node.rs @@ -87,6 +87,9 @@ impl InputSlot { /// `Response` parameter can be cloned then you can call [`Self::fork_clone`] to /// transform this into a [`ForkCloneOutput`] and then connect the output into /// any number of input slots. +/// +/// `Output` intentionally does not implement copy or clone because it must only +/// be consumed exactly once. #[must_use] pub struct Output { scope: Entity, diff --git a/src/operation.rs b/src/operation.rs index 8ffaa7e8..1d6fc47e 100644 --- a/src/operation.rs +++ b/src/operation.rs @@ -80,6 +80,9 @@ pub(crate) use operate_map::*; mod operate_service; pub(crate) use operate_service::*; +mod operate_split; +pub(crate) use operate_split::*; + mod operate_task; pub(crate) use operate_task::*; diff --git a/src/operation/operate_split.rs b/src/operation/operate_split.rs new file mode 100644 index 00000000..da2dead8 --- /dev/null +++ b/src/operation/operate_split.rs @@ -0,0 +1,223 @@ +/* + * Copyright (C) 2024 Open Source Robotics Foundation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * +*/ + +use bevy_ecs::{ + prelude::{Component, Entity, World}, + system::Command, +}; +use smallvec::SmallVec; +use std::{collections::HashMap, sync::Arc}; + +use crate::{ + Broken, Disposal, ForkTargetStorage, Input, InputBundle, ManageDisposal, ManageInput, + MiscellaneousFailure, Operation, OperationCleanup, OperationError, OperationReachability, + OperationRequest, OperationResult, OperationSetup, OrBroken, ReachabilityResult, + SingleInputStorage, SplitDispatcher, Splittable, UnhandledErrors, +}; + +#[derive(Component)] +pub(crate) struct OperateSplit { + /// The connections that lead out of this split operation. These only change + /// while the workflow is being built, afterwards they should be frozen. + connections: HashMap, + /// A reverse map that keeps track of what key is at each index + index_to_key: Vec>, + /// A cache used to transfer the split values from the input to the outputs. + /// Every iteration this must be reset to all None values. If any one of them + /// is a None after the Splittable has filled it in, we must issue a disposal + /// notice because one of the outputs might not be receiving anything. + outputs_cache: Option>>, +} + +impl Default for OperateSplit { + fn default() -> Self { + Self { + connections: Default::default(), + index_to_key: Vec::new(), + outputs_cache: Some(Vec::new()), + } + } +} + +impl Operation for OperateSplit { + fn setup(self, OperationSetup { source, world }: OperationSetup) -> OperationResult { + world.entity_mut(source).insert(( + self, + InputBundle::::new(), + ForkTargetStorage::default(), + )); + Ok(()) + } + + fn execute( + OperationRequest { + source, + world, + roster, + }: OperationRequest, + ) -> OperationResult { + let mut source_mut = world.get_entity_mut(source).or_broken()?; + let Input { session, data } = source_mut.take_input::()?; + let targets = source_mut.get::().or_broken()?.0.clone(); + + let mut split = source_mut.get_mut::>().or_broken()?; + let mut outputs = split.outputs_cache.take().unwrap_or(Vec::new()); + let dispatcher = SplitDispatcher { + connections: &split.connections, + outputs: &mut outputs, + }; + data.split(dispatcher)?; + + let mut missed_indices: SmallVec<[usize; 16]> = SmallVec::new(); + for (index, (items, target)) in outputs.iter_mut().zip(targets).enumerate() { + if items.is_empty() { + missed_indices.push(index); + } + + for output in items.drain(..) { + world + .get_entity_mut(target) + .or_broken()? + .give_input(session, output, roster)?; + } + } + + let mut source_mut = world.get_entity_mut(source).or_broken()?; + + if !missed_indices.is_empty() { + let split = source_mut.get::>().or_broken()?; + let missing_keys: SmallVec<[Option>; 16]> = missed_indices + .into_iter() + .map(|index| split.index_to_key.get(index).cloned()) + .collect(); + + source_mut.emit_disposal( + session, + Disposal::incomplete_split(source, missing_keys), + roster, + ); + } + + // Return the cache into the component + source_mut + .get_mut::>() + .or_broken()? + .outputs_cache + .replace(outputs); + + Ok(()) + } + + fn cleanup(mut clean: OperationCleanup) -> OperationResult { + clean.cleanup_inputs::()?; + clean.notify_cleaned() + } + + fn is_reachable(mut reachability: OperationReachability) -> ReachabilityResult { + if reachability.has_input::()? { + return Ok(true); + } + + SingleInputStorage::is_reachable(&mut reachability) + } +} + +pub(crate) struct ConnectToSplit { + pub(crate) source: Entity, + pub(crate) target: Entity, + pub(crate) key: T::Key, +} + +impl Command for ConnectToSplit { + fn apply(self, world: &mut World) { + let node = self.source; + if let Err(OperationError::Broken(backtrace)) = self.connect(world) { + world + .get_resource_or_insert_with(UnhandledErrors::default) + .broken + .push(Broken { node, backtrace }); + } + } +} + +impl ConnectToSplit { + fn connect(self, world: &mut World) -> Result<(), OperationError> { + let mut target_storage = world + .get_mut::(self.source) + .or_broken()?; + let index = target_storage.0.len(); + target_storage.0.push(self.target); + + world + .get_entity_mut(self.target) + .or_broken()? + .insert(SingleInputStorage::new(self.source)); + + let mut split = world.get_mut::>(self.source).or_broken()?; + let previous_index = split.connections.insert(self.key.clone(), index); + split + .outputs_cache + .as_mut() + .or_broken()? + .resize_with(index + 1, Vec::new); + if split.index_to_key.len() != index { + // If the next element of the reverse map does not match the new index + // then something has fallen out of sync. This doesn't really break + // the workflow because this reverse map is only used to generate + // disposal messages, but it does indicate a bug is present. + let reverse_map_size = split.index_to_key.len(); + world + .get_resource_or_insert_with(UnhandledErrors::default) + .miscellaneous + .push(MiscellaneousFailure { + error: Arc::new(anyhow::anyhow!( + "Mismatch between reverse map size [{}] and new connection index [{}]", + reverse_map_size, + index, + )), + backtrace: Some(backtrace::Backtrace::new()), + }); + } else { + split + .index_to_key + .push(format!("{:?}", self.key).as_str().into()); + } + + if let Some(previous_index) = previous_index { + // If something was already using this key then there is a flaw in + // the implementation of SplitBuilder and we should log it. + let target_storage = world.get::(self.source).or_broken()?; + let previous_target = *target_storage.0.get(previous_index).or_broken()?; + + world + .get_resource_or_insert_with(UnhandledErrors::default) + .miscellaneous + .push(MiscellaneousFailure { + error: Arc::new(anyhow::anyhow!( + "Double-connected key [{:?}] for split node {:?}. Original target: {:?}, new target: {:?}", + self.key, + self.source, + previous_target, + self.target, + )), + backtrace: Some(backtrace::Backtrace::new()), + }); + } + + Ok(()) + } +} From d074da8ddd2964da0335b27cb4575ab03ffb43dc Mon Sep 17 00:00:00 2001 From: Luca Della Vedova Date: Wed, 27 Nov 2024 19:40:04 +0800 Subject: [PATCH 04/20] Fix 1.75 build failure / add CI (#34) Signed-off-by: Luca Della Vedova Signed-off-by: Michael X. Grey Co-authored-by: Michael X. Grey --- .github/workflows/ci_linux.yaml | 8 +++++ src/chain/split.rs | 52 ++++++++++++++++++++++++++++++++- 2 files changed, 59 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci_linux.yaml b/.github/workflows/ci_linux.yaml index e5c5d83b..3f10af5c 100644 --- a/.github/workflows/ci_linux.yaml +++ b/.github/workflows/ci_linux.yaml @@ -31,3 +31,11 @@ jobs: run: cargo build --features single_threaded_async - name: Test single_threaded_async run: cargo test --features single_threaded_async + + # Build and test with 1.75 + - name: Install Rust 1.75 + run: rustup default 1.75 + - name: Build default features with Rust 1.75 + run: cargo build + - name: Test default features with Rust 1.75 + run: cargo test diff --git a/src/chain/split.rs b/src/chain/split.rs index c22046d7..e5fab968 100644 --- a/src/chain/split.rs +++ b/src/chain/split.rs @@ -532,7 +532,12 @@ impl Splittable for [T; N] { fn next(key: &Option) -> Option { // Static arrays have a firm limit of N - SplitAsList::::next(key).take_if(|key| Self::validate(key)) + let mut key = SplitAsList::::next(key); + if key.map_or(false, |key| Self::validate(&key)) { + key.take() + } else { + None + } } fn split(self, dispatcher: SplitDispatcher<'_, Self::Key, Self::Item>) -> OperationResult { @@ -936,4 +941,49 @@ mod tests { assert_eq!(result["fib"], input_map["fib"]); assert_eq!(result["dib"], input_map["dib"]); } + + #[test] + fn test_array_split_limit() { + let mut context = TestingContext::minimal_plugins(); + + let workflow = context.spawn_io_workflow(|scope, builder| { + scope.input.chain(builder).split(|split| { + let err = split + .next_branch(|_, chain| { + chain.value().connect(scope.terminate); + }) + .unwrap() + .next_branch(|_, chain| { + chain.value().connect(scope.terminate); + }) + .unwrap() + .next_branch(|_, chain| { + chain.value().connect(scope.terminate); + }) + .unwrap() + .next_branch(|_, chain| { + chain.value().connect(scope.terminate); + }) + .unwrap() + // This last one should fail because it should exceed the + // array limit + .next_branch(|_, chain| { + chain.value().connect(scope.terminate); + }); + + assert!(matches!(err, Err(_))); + }) + }); + + let mut promise = + context.command(|commands| commands.request([1, 2, 3, 4], workflow).take_response()); + + context.run_with_conditions(&mut promise, 1); + assert!(context.no_unhandled_errors()); + + let result = promise.take().available().unwrap(); + // All the values in the array are racing to finish, but the first value + // should finish first since it will naturally get queued first. + assert_eq!(result, 1); + } } From ca95c2da5ec8fb45338bf16e5e7dbfff9b37ee74 Mon Sep 17 00:00:00 2001 From: Grey Date: Wed, 4 Dec 2024 13:09:57 +0800 Subject: [PATCH 05/20] Expand lifetimes of SplitBuilder (#36) Signed-off-by: Michael X. Grey --- src/chain.rs | 12 +++++++++--- src/chain/split.rs | 30 +++++++++++++++++------------- 2 files changed, 26 insertions(+), 16 deletions(-) diff --git a/src/chain.rs b/src/chain.rs index e81fcd7e..7fa5ffa5 100644 --- a/src/chain.rs +++ b/src/chain.rs @@ -608,7 +608,7 @@ impl<'w, 's, 'a, 'b, T: 'static + Send + Sync> Chain<'w, 's, 'a, 'b, T> { /// will insert a split operation and provide your `build` function with the /// [`SplitBuilder`] for it. This returns the return value of your build /// function. - pub fn split(self, build: impl FnOnce(SplitBuilder) -> U) -> U + pub fn split(self, build: impl FnOnce(SplitBuilder<'w, 's, 'a, 'b, T>) -> U) -> U where T: Splittable, { @@ -644,7 +644,10 @@ impl<'w, 's, 'a, 'b, T: 'static + Send + Sync> Chain<'w, 's, 'a, 'b, T> { /// ```text /// .map_block(SplitAsList::new).split(build) /// ``` - pub fn split_as_list(self, build: impl FnOnce(SplitBuilder>) -> U) -> U + pub fn split_as_list( + self, + build: impl FnOnce(SplitBuilder<'w, 's, 'a, 'b, SplitAsList>) -> U, + ) -> U where T: IntoIterator, T::Item: 'static + Send + Sync, @@ -1026,7 +1029,10 @@ where /// ```text /// .map_block(SplitAsMap::new).split(build) /// ``` - pub fn split_as_map(self, build: impl FnOnce(SplitBuilder>) -> U) -> U { + pub fn split_as_map( + self, + build: impl FnOnce(SplitBuilder<'w, 's, 'a, 'b, SplitAsMap>) -> U, + ) -> U { self.map_block(SplitAsMap::new).split(build) } diff --git a/src/chain/split.rs b/src/chain/split.rs index e5fab968..3593f093 100644 --- a/src/chain/split.rs +++ b/src/chain/split.rs @@ -566,6 +566,22 @@ impl MapSplitKey { _ => None, } } + + pub fn next(this: &Option) -> Option { + match this { + Some(key) => { + match key { + // Give the next key in the sequence + MapSplitKey::Sequential(index) => Some(MapSplitKey::Sequential(index + 1)), + // For an arbitrary map we don't know what would follow a specific key, so + // just stop iterating. This should never be reached in practice anyway. + MapSplitKey::Specific(_) => None, + MapSplitKey::Remaining => None, + } + } + None => Some(MapSplitKey::Sequential(0)), + } + } } impl From for MapSplitKey { @@ -637,19 +653,7 @@ where } fn next(key: &Option) -> Option { - match key { - Some(key) => { - match key { - // Give the next key in the sequence - MapSplitKey::Sequential(index) => Some(MapSplitKey::Sequential(index + 1)), - // For an arbitrary map we don't know what would follow a specific key, so - // just stop iterating. This should never be reached in practice anyway. - MapSplitKey::Specific(_) => None, - MapSplitKey::Remaining => None, - } - } - None => Some(MapSplitKey::Sequential(0)), - } + MapSplitKey::next(key) } fn split(self, mut dispatcher: SplitDispatcher<'_, Self::Key, Self::Item>) -> OperationResult { From 0def9dbdb2a29540437a113e43c468f8086e6a11 Mon Sep 17 00:00:00 2001 From: Teo Koon Peng Date: Thu, 5 Dec 2024 17:51:11 +0800 Subject: [PATCH 06/20] use matrix testing (#39) Signed-off-by: Teo Koon Peng --- .github/workflows/ci_linux.yaml | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/.github/workflows/ci_linux.yaml b/.github/workflows/ci_linux.yaml index 3f10af5c..05813b39 100644 --- a/.github/workflows/ci_linux.yaml +++ b/.github/workflows/ci_linux.yaml @@ -16,12 +16,18 @@ env: jobs: build: + strategy: + matrix: + rust-version: [stable, 1.75] runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 + - name: Setup rust + run: rustup default ${{ matrix.rust-version }} + - name: Build default features run: cargo build - name: Test default features @@ -32,10 +38,3 @@ jobs: - name: Test single_threaded_async run: cargo test --features single_threaded_async - # Build and test with 1.75 - - name: Install Rust 1.75 - run: rustup default 1.75 - - name: Build default features with Rust 1.75 - run: cargo build - - name: Test default features with Rust 1.75 - run: cargo test From 3d5425de69f6e9ded7d80df317a0df91df9ebf91 Mon Sep 17 00:00:00 2001 From: Grey Date: Fri, 6 Dec 2024 16:58:11 +0800 Subject: [PATCH 07/20] Fix detachment (#40) Signed-off-by: Michael X. Grey --- src/builder.rs | 6 +- src/flush.rs | 6 ++ src/impulse.rs | 49 +++++++++++++++- src/input.rs | 54 ++++++++++++++++-- src/operation.rs | 19 +++++++ src/operation/injection.rs | 10 +++- src/operation/scope.rs | 112 +++++++++++++++++++------------------ src/service.rs | 23 +++++++- 8 files changed, 215 insertions(+), 64 deletions(-) diff --git a/src/builder.rs b/src/builder.rs index 882b2a8a..0661023a 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -959,11 +959,15 @@ mod tests { let mut promise = context.command(|commands| commands.request(5, workflow).take_response()); context.run_with_conditions(&mut promise, Duration::from_secs(2)); + assert!( + context.no_unhandled_errors(), + "{:#?}", + context.get_unhandled_errors(), + ); assert!(promise.peek().is_cancelled()); let channel_output = receiver.try_recv().unwrap(); assert_eq!(channel_output, 5); assert!(receiver.try_recv().is_err()); - assert!(context.no_unhandled_errors()); assert!(context.confirm_buffers_empty().is_ok()); let (cancel_sender, mut cancel_receiver) = unbounded_channel(); diff --git a/src/flush.rs b/src/flush.rs index 6e18e51d..622c3113 100644 --- a/src/flush.rs +++ b/src/flush.rs @@ -87,6 +87,12 @@ fn flush_impulses_impl( let mut loop_count = 0; while !roster.is_empty() { + for e in roster.deferred_despawn.drain(..) { + if let Some(e_mut) = world.get_entity_mut(e) { + e_mut.despawn_recursive(); + } + } + let parameters = world.get_resource_or_insert_with(FlushParameters::default); let flush_loop_limit = parameters.flush_loop_limit; let single_threaded_poll_limit = parameters.single_threaded_poll_limit; diff --git a/src/impulse.rs b/src/impulse.rs index 875740e3..77765ecd 100644 --- a/src/impulse.rs +++ b/src/impulse.rs @@ -131,7 +131,8 @@ where // this one is finished. self.commands .entity(source) - .insert(Cancellable::new(cancel_impulse)) + .insert((Cancellable::new(cancel_impulse), ImpulseMarker)) + .remove::() .set_parent(target); provider.connect(None, source, target, self.commands); Impulse { @@ -484,6 +485,38 @@ mod tests { assert!(context.no_unhandled_errors()); } + #[test] + fn test_detach() { + // This is a regression test that covers a bug which existed due to + // an incorrect handling of detached impulses when giving input. + let mut context = TestingContext::minimal_plugins(); + let service = context.spawn_delayed_map(Duration::from_millis(1), |n| n + 1); + + context.command(|commands| { + commands.provide(0).then(service).detach(); + }); + + let (sender, mut promise) = Promise::<()>::new(); + context.run_with_conditions(&mut promise, Duration::from_millis(5)); + assert!( + context.no_unhandled_errors(), + "Unhandled errors: {:#?}", + context.get_unhandled_errors(), + ); + + // The promise and sender only exist because run_with_conditions requires + // them. Moreover we need to make sure that sender does not get dropped + // prematurely by the compiler, otherwise the promise will have the run + // exit prematurely. Therefore we call .send(()) here to guarantee the + // compiler knows to keep it alive until the running is finished. + // + // We have observed that using `let (_, mut promise) = ` will cause the + // sender to drop prematurely, so we don't want to risk that there are + // other cases where that may happen. It is important for the run to + // last multiple cycles. + sender.send(()).ok(); + } + #[derive(Clone, Debug, PartialEq, Eq, Hash)] struct UnitLabel; @@ -547,6 +580,20 @@ mod tests { ); verify_delivery_instruction_matrix(service, &mut context); + + let async_service = service; + let service = context.spawn_io_workflow(|scope, builder| { + scope + .input + .chain(builder) + .then(async_service) + .connect(scope.terminate); + }); + + verify_delivery_instruction_matrix(service, &mut context); + + // We don't test blocking services because blocking services are always + // serial no matter what, so delivery instructions have no effect for them. } fn verify_delivery_instruction_matrix( diff --git a/src/input.rs b/src/input.rs index f724c965..b5b1ba9a 100644 --- a/src/input.rs +++ b/src/input.rs @@ -26,8 +26,9 @@ use smallvec::SmallVec; use backtrace::Backtrace; use crate::{ - Broken, BufferStorage, Cancel, Cancellation, CancellationCause, DeferredRoster, OperationError, - OperationRoster, OrBroken, SessionStatus, UnusedTarget, + Broken, BufferStorage, Cancel, Cancellation, CancellationCause, DeferredRoster, Detached, + MiscellaneousFailure, OperationError, OperationRoster, OrBroken, SessionStatus, + UnhandledErrors, UnusedTarget, }; /// This contains data that has been provided as input into an operation, along @@ -69,15 +70,31 @@ impl Default for InputStorage { } } +/// Used to keep track of the expected input type for an operation +#[derive(Component)] +pub(crate) struct InputTypeIndicator { + pub(crate) name: &'static str, +} + +impl InputTypeIndicator { + fn new() -> Self { + Self { + name: std::any::type_name::(), + } + } +} + #[derive(Bundle)] pub struct InputBundle { storage: InputStorage, + indicator: InputTypeIndicator, } impl InputBundle { pub fn new() -> Self { Self { storage: Default::default(), + indicator: InputTypeIndicator::new::(), } } } @@ -125,6 +142,7 @@ pub trait ManageInput { session: Entity, data: T, only_if_active: bool, + roster: &mut OperationRoster, ) -> Result; /// Get an input that is ready to be taken, or else produce an error. @@ -150,7 +168,7 @@ impl<'w> ManageInput for EntityWorldMut<'w> { data: T, roster: &mut OperationRoster, ) -> Result<(), OperationError> { - if unsafe { self.sneak_input(session, data, true)? } { + if unsafe { self.sneak_input(session, data, true, roster)? } { roster.queue(self.id()); } Ok(()) @@ -162,7 +180,7 @@ impl<'w> ManageInput for EntityWorldMut<'w> { data: T, roster: &mut OperationRoster, ) -> Result<(), OperationError> { - if unsafe { self.sneak_input(session, data, true)? } { + if unsafe { self.sneak_input(session, data, true, roster)? } { roster.defer(self.id()); } Ok(()) @@ -173,6 +191,7 @@ impl<'w> ManageInput for EntityWorldMut<'w> { session: Entity, data: T, only_if_active: bool, + roster: &mut OperationRoster, ) -> Result { if only_if_active { let active_session = @@ -193,6 +212,21 @@ impl<'w> ManageInput for EntityWorldMut<'w> { if let Some(mut storage) = self.get_mut::>() { storage.reverse_queue.insert(0, Input { session, data }); } else if !self.contains::() { + let id = self.id(); + if let Some(detached) = self.get::() { + if detached.is_detached() { + // The input is going to a detached impulse that will not + // react any further. We need to tell that detached impulse + // to despawn since it is no longer needed. + roster.defer_despawn(id); + + // No error occurred, but the caller should not queue the + // operation into the roster because it is being despawned. + return Ok(false); + } + } + + let expected = self.get::().map(|i| i.name); // If the input is being fed to an unused target then we can // generally ignore it, although it may indicate a bug in the user's // workflow because workflow branches that end in an unused target @@ -200,6 +234,18 @@ impl<'w> ManageInput for EntityWorldMut<'w> { // However in this case, the target is not unused but also does not // have the correct input storage type. This indicates + self.world_mut() + .get_resource_or_insert_with(|| UnhandledErrors::default()) + .miscellaneous + .push(MiscellaneousFailure { + error: std::sync::Arc::new(anyhow::anyhow!( + "Incorrect input type for operation [{:?}]: received [{}], expected [{}]", + id, + std::any::type_name::(), + expected.unwrap_or(""), + )), + backtrace: None, + }); None.or_broken()?; } Ok(true) diff --git a/src/operation.rs b/src/operation.rs index 1d6fc47e..373d30ec 100644 --- a/src/operation.rs +++ b/src/operation.rs @@ -223,6 +223,9 @@ pub struct OperationRoster { pub(crate) disposed: Vec, /// Tell a scope to attempt cleanup pub(crate) cleanup_finished: Vec, + /// Despawn these entities while no other operation is running. This is used + /// to cleanup detached impulses that receive no input. + pub(crate) deferred_despawn: Vec, } impl OperationRoster { @@ -262,6 +265,10 @@ impl OperationRoster { self.cleanup_finished.push(cleanup); } + pub fn defer_despawn(&mut self, source: Entity) { + self.deferred_despawn.push(source); + } + pub fn is_empty(&self) -> bool { self.queue.is_empty() && self.awake.is_empty() @@ -270,6 +277,7 @@ impl OperationRoster { && self.unblock.is_empty() && self.disposed.is_empty() && self.cleanup_finished.is_empty() + && self.deferred_despawn.is_empty() } pub fn append(&mut self, other: &mut Self) { @@ -319,6 +327,17 @@ pub(crate) struct Blocker { pub(crate) serve_next: fn(Blocker, &mut World, &mut OperationRoster), } +impl std::fmt::Debug for Blocker { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Blocker") + .field("provider", &self.provider) + .field("source", &self.source) + .field("session", &self.session) + .field("label", &self.label) + .finish() + } +} + #[derive(Clone, Debug)] pub enum OperationError { Broken(Option), diff --git a/src/operation/injection.rs b/src/operation/injection.rs index 6e00f3c8..20a3f3a0 100644 --- a/src/operation/injection.rs +++ b/src/operation/injection.rs @@ -120,10 +120,16 @@ where // roster to register the task as an operation. In fact it does not // implement Operation at all. It is just a temporary container for the // input and the stream targets. - unsafe { + let execute = unsafe { world .entity_mut(task) - .sneak_input(session, request, false)?; + .sneak_input(session, request, false, roster)? + }; + + if !execute { + // If giving the input failed then this workflow will not be able to + // proceed. Therefore we should report that this is broken. + None.or_broken()?; } let mut storage = world.get_mut::(source).or_broken()?; diff --git a/src/operation/scope.rs b/src/operation/scope.rs index 5e39f9eb..fe18821f 100644 --- a/src/operation/scope.rs +++ b/src/operation/scope.rs @@ -765,17 +765,6 @@ where let is_terminated = awaiting.info.status.is_terminated(); - unsafe { - // INVARIANT: We use sneak_input here to side-step the protection of - // only allowing inputs for Active sessions. This current session is - // not active because cleaning already started for it. It's okay to - // use this session as input despite not being active because we are - // passing it to an operation that will only use it to begin a - // cleanup workflow. - finish_cleanup_workflow_mut.sneak_input(scoped_session, CheckAwaitingSession, false)?; - roster.queue(finish_cleanup); - } - for begin in begin_cleanup_workflows { let run_this_workflow = (is_terminated && begin.on_terminate) || (!is_terminated && begin.on_cancelled); @@ -787,21 +776,32 @@ where // We execute the begin nodes immediately so that they can load up the // finish_cancel node with all their cancellation behavior IDs before // the finish_cancel node gets executed. - unsafe { + let execute = unsafe { // INVARIANT: We can use sneak_input here because we execute the // recipient node immediately after giving the input. world .get_entity_mut(begin.source) .or_broken()? - .sneak_input(scoped_session, (), false)?; + .sneak_input(scoped_session, (), false, roster)? + }; + if execute { + execute_operation(OperationRequest { + source: begin.source, + world, + roster, + }); } - execute_operation(OperationRequest { - source: begin.source, - world, - roster, - }); } + // Check if there are any cleanup workflows waiting to be run. If not, + // the workflow can fully terminate. + FinishCleanup::::check_awaiting_session( + finish_cleanup, + scoped_session, + world, + roster, + )?; + Ok(()) } } @@ -1235,7 +1235,6 @@ impl Operation for FinishCleanup { world.entity_mut(source).insert(( CleanupForScope(self.from_scope), InputBundle::<()>::new(), - InputBundle::::new(), Cancellable::new(Self::receive_cancel), AwaitingCleanupStorage::default(), )); @@ -1250,43 +1249,15 @@ impl Operation for FinishCleanup { }: OperationRequest, ) -> OperationResult { let mut source_mut = world.get_entity_mut(source).or_broken()?; - if let Some(Input { - session: new_scoped_session, - .. - }) = source_mut.try_take_input::()? - { - let mut awaiting = source_mut.get_mut::().or_broken()?; - if let Some((index, a)) = awaiting - .0 - .iter_mut() - .enumerate() - .find(|(_, a)| a.scoped_session == new_scoped_session) - { - if a.cleanup_workflow_sessions - .as_ref() - .is_some_and(|s| s.is_empty()) - { - // No cancellation sessions were started for this scoped - // session so we can immediately clean it up. - Self::finalize_scoped_session( - index, - OperationRequest { - source, - world, - roster, - }, - )?; - } - } - } else if let Some(Input { + let Some(Input { session: cancellation_session, .. }) = source_mut.try_take_input::<()>()? - { - Self::deduct_finished_cleanup(source, cancellation_session, world, roster, None)?; - } + else { + return Ok(()); + }; - Ok(()) + Self::deduct_finished_cleanup(source, cancellation_session, world, roster, None) } fn cleanup(_: OperationCleanup) -> OperationResult { @@ -1361,6 +1332,40 @@ impl FinishCleanup { Ok(()) } + fn check_awaiting_session( + source: Entity, + new_scoped_session: Entity, + world: &mut World, + roster: &mut OperationRoster, + ) -> OperationResult { + let mut source_mut = world.get_entity_mut(source).or_broken()?; + let mut awaiting = source_mut.get_mut::().or_broken()?; + if let Some((index, a)) = awaiting + .0 + .iter_mut() + .enumerate() + .find(|(_, a)| a.scoped_session == new_scoped_session) + { + if a.cleanup_workflow_sessions + .as_ref() + .is_some_and(|s| s.is_empty()) + { + // No cancellation sessions were started for this scoped + // session so we can immediately clean it up. + Self::finalize_scoped_session( + index, + OperationRequest { + source, + world, + roster, + }, + )?; + } + } + + Ok(()) + } + fn deduct_finished_cleanup( source: Entity, cancellation_session: Entity, @@ -1605,14 +1610,13 @@ impl AwaitingCleanup { } } -struct CheckAwaitingSession; - #[derive(Component, Default)] pub(crate) struct ExitTargetStorage { /// Map from session value to the target pub(crate) map: HashMap, } +#[derive(Debug)] pub(crate) struct ExitTarget { pub(crate) target: Entity, pub(crate) source: Entity, diff --git a/src/service.rs b/src/service.rs index 706a9412..f4b0a72c 100644 --- a/src/service.rs +++ b/src/service.rs @@ -28,7 +28,7 @@ use bevy_ecs::{ }; pub use bevy_impulse_derive::DeliveryLabel; use bevy_utils::{define_label, intern::Interned}; -use std::{any::TypeId, collections::HashSet}; +use std::{any::TypeId, collections::HashSet, sync::OnceLock}; use thiserror::Error as ThisError; mod async_srv; @@ -76,13 +76,32 @@ pub(crate) use workflow::*; /// [App]: bevy_app::prelude::App /// [Commands]: bevy_ecs::prelude::Commands /// [World]: bevy_ecs::prelude::World -#[derive(Debug, PartialEq, Eq)] +#[derive(PartialEq, Eq)] pub struct Service { provider: Entity, instructions: Option, _ignore: std::marker::PhantomData, } +impl std::fmt::Debug for Service { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + static NAME: OnceLock = OnceLock::new(); + let name = NAME.get_or_init(|| { + format!( + "Service<{}, {}, {}>", + std::any::type_name::(), + std::any::type_name::(), + std::any::type_name::(), + ) + }); + + f.debug_struct(name.as_str()) + .field("provider", &self.provider) + .field("instructions", &self.instructions) + .finish() + } +} + impl Clone for Service { fn clone(&self) -> Self { *self From 1a0de67e3bbdba795a2ebfbf0c59c3e85b8754fe Mon Sep 17 00:00:00 2001 From: Teo Koon Peng Date: Fri, 13 Dec 2024 15:25:02 +0800 Subject: [PATCH 08/20] fix broken doc link (#41) Signed-off-by: Teo Koon Peng --- .cargo/config.toml | 2 ++ .github/workflows/ci_linux.yaml | 3 +++ src/chain/split.rs | 2 +- 3 files changed, 6 insertions(+), 1 deletion(-) create mode 100644 .cargo/config.toml diff --git a/.cargo/config.toml b/.cargo/config.toml new file mode 100644 index 00000000..ddbe09f3 --- /dev/null +++ b/.cargo/config.toml @@ -0,0 +1,2 @@ +[build] +rustdocflags = ["-D", "warnings"] diff --git a/.github/workflows/ci_linux.yaml b/.github/workflows/ci_linux.yaml index 05813b39..adebf68f 100644 --- a/.github/workflows/ci_linux.yaml +++ b/.github/workflows/ci_linux.yaml @@ -38,3 +38,6 @@ jobs: - name: Test single_threaded_async run: cargo test --features single_threaded_async + - name: Build docs + run: cargo doc + diff --git a/src/chain/split.rs b/src/chain/split.rs index 3593f093..beb1c83e 100644 --- a/src/chain/split.rs +++ b/src/chain/split.rs @@ -259,7 +259,7 @@ impl<'w, 's, 'a, 'b, T: 'static + Splittable> Iterator for SplitBuilder<'w, 's, /// This tracks the connections that have been made to a split. This can be /// retrieved from [`SplitBuilder`] by calling [`SplitBuilder::outputs`]. -/// You can then continue building connections by calling [`SplitBuilder::build`]. +/// You can then continue building connections by calling [`SplitOutputs::build`]. #[must_use] pub struct SplitOutputs { scope: Entity, From 013dc62827a6c0c54971e325ae9fc61e9328be2f Mon Sep 17 00:00:00 2001 From: Teo Koon Peng Date: Thu, 9 Jan 2025 15:33:20 +0800 Subject: [PATCH 09/20] First draft of workflow diagrams (#27) Signed-off-by: Teo Koon Peng --- .github/workflows/ci_linux.yaml | 15 +- .github/workflows/ci_windows.yaml | 8 +- Cargo.toml | 37 +- diagram.schema.json | 300 ++++++++ examples/diagram/calculator/Cargo.toml | 16 + examples/diagram/calculator/multiply3.json | 13 + examples/diagram/calculator/src/main.rs | 61 ++ examples/diagram/calculator/tests/e2e.rs | 10 + scripts/patch-versions-msrv-1_75.sh | 4 + src/diagram.rs | 785 +++++++++++++++++++ src/diagram/fork_clone.rs | 134 ++++ src/diagram/fork_result.rs | 133 ++++ src/diagram/generate_schema.rs | 36 + src/diagram/impls.rs | 5 + src/diagram/join.rs | 321 ++++++++ src/diagram/node_registry.rs | 849 +++++++++++++++++++++ src/diagram/serialization.rs | 253 ++++++ src/diagram/split_serialized.rs | 622 +++++++++++++++ src/diagram/testing.rs | 137 ++++ src/diagram/transform.rs | 206 +++++ src/diagram/unzip.rs | 253 ++++++ src/diagram/workflow_builder.rs | 568 ++++++++++++++ src/lib.rs | 5 + 23 files changed, 4765 insertions(+), 6 deletions(-) create mode 100644 diagram.schema.json create mode 100644 examples/diagram/calculator/Cargo.toml create mode 100644 examples/diagram/calculator/multiply3.json create mode 100644 examples/diagram/calculator/src/main.rs create mode 100644 examples/diagram/calculator/tests/e2e.rs create mode 100755 scripts/patch-versions-msrv-1_75.sh create mode 100644 src/diagram.rs create mode 100644 src/diagram/fork_clone.rs create mode 100644 src/diagram/fork_result.rs create mode 100644 src/diagram/generate_schema.rs create mode 100644 src/diagram/impls.rs create mode 100644 src/diagram/join.rs create mode 100644 src/diagram/node_registry.rs create mode 100644 src/diagram/serialization.rs create mode 100644 src/diagram/split_serialized.rs create mode 100644 src/diagram/testing.rs create mode 100644 src/diagram/transform.rs create mode 100644 src/diagram/unzip.rs create mode 100644 src/diagram/workflow_builder.rs diff --git a/.github/workflows/ci_linux.yaml b/.github/workflows/ci_linux.yaml index adebf68f..116de865 100644 --- a/.github/workflows/ci_linux.yaml +++ b/.github/workflows/ci_linux.yaml @@ -28,10 +28,21 @@ jobs: - name: Setup rust run: rustup default ${{ matrix.rust-version }} + # As new versions of our dependencies come out, they might depend on newer + # versions of the Rust compiler. When that happens, we'll use this step to + # lock down the dependency to a version that is known to be compatible with + # compiler version 1.75. + - name: Patch dependencies + if: ${{ matrix.rust-version == 1.75 }} + run: ./scripts/patch-versions-msrv-1_75.sh + - name: Build default features - run: cargo build + run: cargo build --workspace - name: Test default features - run: cargo test + run: cargo test --workspace + + - name: Test diagram + run: cargo test --workspace -F=diagram - name: Build single_threaded_async run: cargo build --features single_threaded_async diff --git a/.github/workflows/ci_windows.yaml b/.github/workflows/ci_windows.yaml index 3a902b41..87b0767b 100644 --- a/.github/workflows/ci_windows.yaml +++ b/.github/workflows/ci_windows.yaml @@ -30,11 +30,15 @@ jobs: shell: powershell - name: Build default features - run: cargo build + run: cargo build --workspace shell: cmd - name: Test default features - run: cargo test + run: cargo test --workspace + shell: cmd + + - name: Test diagram + run: cargo test --workspace -F=diagram shell: cmd - name: Build single_threaded_async diff --git a/Cargo.toml b/Cargo.toml index 08bcde41..8164d7b8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,7 +8,12 @@ description = "Reactive programming and workflow execution for bevy" readme = "README.md" repository = "https://github.com/open-rmf/bevy_impulse" keywords = ["reactive", "workflow", "behavior", "agent", "bevy"] -categories = ["science::robotics", "asynchronous", "concurrency", "game-development"] +categories = [ + "science::robotics", + "asynchronous", + "concurrency", + "game-development", +] [dependencies] bevy_impulse_derive = { path = "macros", version = "0.0.2" } @@ -27,7 +32,7 @@ bevy_tasks = { version = "0.12", features = ["multi-threaded"] } itertools = "0.13" smallvec = "1.13" -tokio = { version = "1.39", features = ["sync"]} +tokio = { version = "1.39", features = ["sync"] } futures = "0.3" backtrace = "0.3" anyhow = "1.0" @@ -41,8 +46,36 @@ thiserror = "1.0" bevy_core = "0.12" bevy_time = "0.12" +schemars = { version = "0.8.21", optional = true } +serde = { version = "1.0.210", features = ["derive"], optional = true } +serde_json = { version = "1.0.128", optional = true } +cel-interpreter = { version = "0.9.0", features = ["json"], optional = true } +tracing = "0.1.41" +strum = { version = "0.26.3", optional = true, features = ["derive"] } +semver = { version = "1.0.24", optional = true } + [features] single_threaded_async = ["dep:async-task"] +diagram = [ + "dep:cel-interpreter", + "dep:schemars", + "dep:semver", + "dep:serde", + "dep:serde_json", + "dep:strum", +] [dev-dependencies] async-std = { version = "1.12" } +test-log = { version = "0.2.16", features = [ + "trace", +], default-features = false } + +[workspace] +members = ["examples/diagram/calculator"] + +[[bin]] +name = "generate_schema" +path = "src/diagram/generate_schema.rs" +required-features = ["diagram"] +doc = false diff --git a/diagram.schema.json b/diagram.schema.json new file mode 100644 index 00000000..8a8c9a9b --- /dev/null +++ b/diagram.schema.json @@ -0,0 +1,300 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema#", + "title": "Diagram", + "type": "object", + "required": [ + "ops", + "start", + "version" + ], + "properties": { + "ops": { + "type": "object", + "additionalProperties": { + "$ref": "#/definitions/DiagramOperation" + } + }, + "start": { + "description": "Signifies the start of a workflow.", + "allOf": [ + { + "$ref": "#/definitions/NextOperation" + } + ] + }, + "version": { + "description": "Version of the diagram, should always be `0.1.0`.", + "type": "string" + } + }, + "definitions": { + "BuiltinSource": { + "type": "string", + "enum": [ + "start" + ] + }, + "BuiltinTarget": { + "oneOf": [ + { + "description": "Use the output to terminate the workflow. This will be the return value of the workflow.", + "type": "string", + "enum": [ + "terminate" + ] + }, + { + "description": "Dispose of the output.", + "type": "string", + "enum": [ + "dispose" + ] + } + ] + }, + "DiagramOperation": { + "oneOf": [ + { + "description": "Connect the request to a registered node.\n\n``` # bevy_impulse::Diagram::from_json_str(r#\" { \"version\": \"0.1.0\", \"start\": \"node_op\", \"ops\": { \"node_op\": { \"type\": \"node\", \"builder\": \"my_node_builder\", \"next\": { \"builtin\": \"terminate\" } } } } # \"#)?; # Ok::<_, serde_json::Error>(())", + "type": "object", + "required": [ + "builder", + "next", + "type" + ], + "properties": { + "builder": { + "type": "string" + }, + "config": { + "default": null + }, + "next": { + "$ref": "#/definitions/NextOperation" + }, + "type": { + "type": "string", + "enum": [ + "node" + ] + } + } + }, + { + "description": "If the request is cloneable, clone it into multiple responses.\n\n# Examples ``` # bevy_impulse::Diagram::from_json_str(r#\" { \"version\": \"0.1.0\", \"start\": \"fork_clone\", \"ops\": { \"fork_clone\": { \"type\": \"fork_clone\", \"next\": [\"terminate\"] } } } # \"#)?; # Ok::<_, serde_json::Error>(())", + "type": "object", + "required": [ + "next", + "type" + ], + "properties": { + "next": { + "type": "array", + "items": { + "$ref": "#/definitions/NextOperation" + } + }, + "type": { + "type": "string", + "enum": [ + "fork_clone" + ] + } + } + }, + { + "description": "If the request is a tuple of (T1, T2, T3, ...), unzip it into multiple responses of T1, T2, T3, ...\n\n# Examples ``` # bevy_impulse::Diagram::from_json_str(r#\" { \"version\": \"0.1.0\", \"start\": \"unzip\", \"ops\": { \"unzip\": { \"type\": \"unzip\", \"next\": [{ \"builtin\": \"terminate\" }] } } } # \"#)?; # Ok::<_, serde_json::Error>(())", + "type": "object", + "required": [ + "next", + "type" + ], + "properties": { + "next": { + "type": "array", + "items": { + "$ref": "#/definitions/NextOperation" + } + }, + "type": { + "type": "string", + "enum": [ + "unzip" + ] + } + } + }, + { + "description": "If the request is a `Result<_, _>`, branch it to `Ok` and `Err`.\n\n# Examples ``` # bevy_impulse::Diagram::from_json_str(r#\" { \"version\": \"0.1.0\", \"start\": \"fork_result\", \"ops\": { \"fork_result\": { \"type\": \"fork_result\", \"ok\": { \"builtin\": \"terminate\" }, \"err\": { \"builtin\": \"dispose\" } } } } # \"#)?; # Ok::<_, serde_json::Error>(())", + "type": "object", + "required": [ + "err", + "ok", + "type" + ], + "properties": { + "err": { + "$ref": "#/definitions/NextOperation" + }, + "ok": { + "$ref": "#/definitions/NextOperation" + }, + "type": { + "type": "string", + "enum": [ + "fork_result" + ] + } + } + }, + { + "description": "If the request is a list-like or map-like object, split it into multiple responses. Note that the split output is a tuple of `(KeyOrIndex, Value)`, nodes receiving a split output should have request of that type instead of just the value type.\n\n# Examples ``` # bevy_impulse::Diagram::from_json_str(r#\" { \"version\": \"0.1.0\", \"start\": \"split\", \"ops\": { \"split\": { \"type\": \"split\", \"index\": [{ \"builtin\": \"terminate\" }] } } } # \"#)?; # Ok::<_, serde_json::Error>(()) ```", + "type": "object", + "required": [ + "type" + ], + "properties": { + "keyed": { + "default": {}, + "type": "object", + "additionalProperties": { + "$ref": "#/definitions/NextOperation" + } + }, + "remaining": { + "anyOf": [ + { + "$ref": "#/definitions/NextOperation" + }, + { + "type": "null" + } + ] + }, + "sequential": { + "default": [], + "type": "array", + "items": { + "$ref": "#/definitions/NextOperation" + } + }, + "type": { + "type": "string", + "enum": [ + "split" + ] + } + } + }, + { + "description": "Wait for an item to be emitted from each of the inputs, then combined the oldest of each into an array.\n\n# Examples ``` # bevy_impulse::Diagram::from_json_str(r#\" { \"version\": \"0.1.0\", \"start\": \"split\", \"ops\": { \"split\": { \"type\": \"split\", \"index\": [\"op1\", \"op2\"] }, \"op1\": { \"type\": \"node\", \"builder\": \"foo\", \"next\": \"join\" }, \"op2\": { \"type\": \"node\", \"builder\": \"bar\", \"next\": \"join\" }, \"join\": { \"type\": \"join\", \"inputs\": [\"op1\", \"op2\"], \"next\": { \"builtin\": \"terminate\" } } } } # \"#)?; # Ok::<_, serde_json::Error>(()) ```", + "type": "object", + "required": [ + "inputs", + "next", + "type" + ], + "properties": { + "inputs": { + "description": "Controls the order of the resulting join. Each item must be an operation id of one of the incoming outputs.", + "type": "array", + "items": { + "$ref": "#/definitions/SourceOperation" + } + }, + "next": { + "$ref": "#/definitions/NextOperation" + }, + "no_serialize": { + "description": "Do not serialize before performing the join. If true, joins can only be done on outputs of the same type.", + "type": [ + "boolean", + "null" + ] + }, + "type": { + "type": "string", + "enum": [ + "join" + ] + } + } + }, + { + "description": "If the request is serializable, transform it by running it through a [CEL](https://cel.dev/) program. The context includes a \"request\" variable which contains the request.\n\n# Examples ``` # bevy_impulse::Diagram::from_json_str(r#\" { \"version\": \"0.1.0\", \"start\": \"transform\", \"ops\": { \"transform\": { \"type\": \"transform\", \"cel\": \"request.name\", \"next\": { \"builtin\": \"terminate\" } } } } # \"#)?; # Ok::<_, serde_json::Error>(()) ```\n\nNote that due to how `serde_json` performs serialization, positive integers are always serialized as unsigned. In CEL, You can't do an operation between unsigned and signed so it is recommended to always perform explicit casts.\n\n# Examples ``` # bevy_impulse::Diagram::from_json_str(r#\" { \"version\": \"0.1.0\", \"start\": \"transform\", \"ops\": { \"transform\": { \"type\": \"transform\", \"cel\": \"int(request.score) * 3\", \"next\": { \"builtin\": \"terminate\" } } } } # \"#)?; # Ok::<_, serde_json::Error>(()) ```", + "type": "object", + "required": [ + "cel", + "next", + "type" + ], + "properties": { + "cel": { + "type": "string" + }, + "next": { + "$ref": "#/definitions/NextOperation" + }, + "type": { + "type": "string", + "enum": [ + "transform" + ] + } + } + }, + { + "description": "Drop the request, equivalent to a no-op.", + "type": "object", + "required": [ + "type" + ], + "properties": { + "type": { + "type": "string", + "enum": [ + "dispose" + ] + } + } + } + ] + }, + "NextOperation": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "object", + "required": [ + "builtin" + ], + "properties": { + "builtin": { + "$ref": "#/definitions/BuiltinTarget" + } + } + } + ] + }, + "SourceOperation": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "object", + "required": [ + "builtin" + ], + "properties": { + "builtin": { + "$ref": "#/definitions/BuiltinSource" + } + } + } + ] + } + } +} \ No newline at end of file diff --git a/examples/diagram/calculator/Cargo.toml b/examples/diagram/calculator/Cargo.toml new file mode 100644 index 00000000..8e89632b --- /dev/null +++ b/examples/diagram/calculator/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "calculator" +version = "0.1.0" +edition = "2021" + +[dependencies] +bevy_app = "0.12" +bevy_core = "0.12" +bevy_impulse = { version = "0.0.2", path = "../../..", features = ["diagram"] } +bevy_time = "0.12" +clap = { version = "4.5.23", features = ["derive"] } +serde_json = "1.0.128" +tracing-subscriber = "0.3.19" + +[dev-dependencies] +assert_cmd = "2.0.16" diff --git a/examples/diagram/calculator/multiply3.json b/examples/diagram/calculator/multiply3.json new file mode 100644 index 00000000..e3421e46 --- /dev/null +++ b/examples/diagram/calculator/multiply3.json @@ -0,0 +1,13 @@ +{ + "$schema": "../../../diagram.schema.json", + "version": "0.1.0", + "start": "mul3", + "ops": { + "mul3": { + "type": "node", + "builder": "mul", + "config": 3, + "next": { "builtin": "terminate" } + } + } +} diff --git a/examples/diagram/calculator/src/main.rs b/examples/diagram/calculator/src/main.rs new file mode 100644 index 00000000..6214ab2b --- /dev/null +++ b/examples/diagram/calculator/src/main.rs @@ -0,0 +1,61 @@ +use std::{error::Error, fs::File, str::FromStr}; + +use bevy_impulse::{ + Diagram, DiagramError, ImpulsePlugin, NodeBuilderOptions, NodeRegistry, Promise, RequestExt, + RunCommandsOnWorldExt, +}; +use clap::Parser; + +#[derive(Parser, Debug)] +/// Example calculator app using diagrams. +struct Args { + #[arg(help = "path to the diagram to run")] + diagram: String, + + #[arg(help = "json containing the request to the diagram")] + request: String, +} + +fn main() -> Result<(), Box> { + let args = Args::parse(); + + tracing_subscriber::fmt::init(); + + let mut registry = NodeRegistry::default(); + registry.register_node_builder( + NodeBuilderOptions::new("add").with_name("Add"), + |builder, config: f64| builder.create_map_block(move |req: f64| req + config), + ); + registry.register_node_builder( + NodeBuilderOptions::new("sub").with_name("Subtract"), + |builder, config: f64| builder.create_map_block(move |req: f64| req - config), + ); + registry.register_node_builder( + NodeBuilderOptions::new("mul").with_name("Multiply"), + |builder, config: f64| builder.create_map_block(move |req: f64| req * config), + ); + registry.register_node_builder( + NodeBuilderOptions::new("div").with_name("Divide"), + |builder, config: f64| builder.create_map_block(move |req: f64| req / config), + ); + + let mut app = bevy_app::App::new(); + app.add_plugins(ImpulsePlugin::default()); + let file = File::open(args.diagram).unwrap(); + let diagram = Diagram::from_reader(file)?; + + let request = serde_json::Value::from_str(&args.request)?; + let mut promise = + app.world + .command(|cmds| -> Result, DiagramError> { + let workflow = diagram.spawn_io_workflow(cmds, ®istry)?; + Ok(cmds.request(request, workflow).take_response()) + })?; + + while promise.peek().is_pending() { + app.update(); + } + + println!("{}", promise.take().available().unwrap()); + Ok(()) +} diff --git a/examples/diagram/calculator/tests/e2e.rs b/examples/diagram/calculator/tests/e2e.rs new file mode 100644 index 00000000..f0b0e223 --- /dev/null +++ b/examples/diagram/calculator/tests/e2e.rs @@ -0,0 +1,10 @@ +use assert_cmd::Command; + +#[test] +fn multiply3() { + Command::cargo_bin("calculator") + .unwrap() + .args(["multiply3.json", "4"]) + .assert() + .stdout("12.0\n"); +} diff --git a/scripts/patch-versions-msrv-1_75.sh b/scripts/patch-versions-msrv-1_75.sh new file mode 100755 index 00000000..292036de --- /dev/null +++ b/scripts/patch-versions-msrv-1_75.sh @@ -0,0 +1,4 @@ +# This script is useful for forcing dependencies to be compatible with Rust v1.75 +# Run this script in the root directory of the package. + +cargo add home@=0.5.9 diff --git a/src/diagram.rs b/src/diagram.rs new file mode 100644 index 00000000..7efea31b --- /dev/null +++ b/src/diagram.rs @@ -0,0 +1,785 @@ +mod fork_clone; +mod fork_result; +mod impls; +mod join; +mod node_registry; +mod serialization; +mod split_serialized; +mod transform; +mod unzip; +mod workflow_builder; + +use bevy_ecs::system::Commands; +use fork_clone::ForkCloneOp; +use fork_result::ForkResultOp; +use join::JoinOp; +pub use join::JoinOutput; +pub use node_registry::*; +pub use serialization::*; +pub use split_serialized::*; +use tracing::debug; +use transform::{TransformError, TransformOp}; +use unzip::UnzipOp; +use workflow_builder::create_workflow; + +// ---------- + +use std::{collections::HashMap, fmt::Display, io::Read}; + +use crate::{Builder, Scope, Service, SpawnWorkflowExt, SplitConnectionError, StreamPack}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; + +const SUPPORTED_DIAGRAM_VERSION: &str = ">=0.1.0, <0.2.0"; + +pub type BuilderId = String; +pub type OperationId = String; + +#[derive( + Debug, Clone, Serialize, Deserialize, JsonSchema, Hash, PartialEq, Eq, PartialOrd, Ord, +)] +#[serde(untagged, rename_all = "snake_case")] +pub enum NextOperation { + Target(OperationId), + Builtin { builtin: BuiltinTarget }, +} + +impl Display for NextOperation { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Target(operation_id) => f.write_str(operation_id), + Self::Builtin { builtin } => write!(f, "builtin:{}", builtin), + } + } +} + +#[derive( + Debug, + Clone, + Serialize, + Deserialize, + JsonSchema, + Hash, + PartialEq, + Eq, + PartialOrd, + Ord, + strum::Display, +)] +#[serde(rename_all = "snake_case")] +#[strum(serialize_all = "snake_case")] +pub enum BuiltinTarget { + /// Use the output to terminate the workflow. This will be the return value + /// of the workflow. + Terminate, + + /// Dispose of the output. + Dispose, +} + +#[derive( + Debug, Clone, Serialize, Deserialize, JsonSchema, Hash, PartialEq, Eq, PartialOrd, Ord, +)] +#[serde(untagged, rename_all = "snake_case")] +pub enum SourceOperation { + Source(OperationId), + Builtin { builtin: BuiltinSource }, +} + +impl From for SourceOperation { + fn from(value: OperationId) -> Self { + SourceOperation::Source(value) + } +} + +impl Display for SourceOperation { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Source(operation_id) => f.write_str(operation_id), + Self::Builtin { builtin } => write!(f, "builtin:{}", builtin), + } + } +} + +#[derive( + Debug, + Clone, + Serialize, + Deserialize, + JsonSchema, + Hash, + PartialEq, + Eq, + PartialOrd, + Ord, + strum::Display, +)] +#[serde(rename_all = "snake_case")] +#[strum(serialize_all = "snake_case")] +pub enum BuiltinSource { + Start, +} + +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +#[serde(rename_all = "snake_case")] +pub struct TerminateOp {} + +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +#[serde(rename_all = "snake_case")] +pub struct NodeOp { + builder: BuilderId, + #[serde(default)] + config: serde_json::Value, + next: NextOperation, +} + +#[derive(Debug, JsonSchema, Serialize, Deserialize)] +#[serde(rename_all = "snake_case", tag = "type")] +pub enum DiagramOperation { + /// Connect the request to a registered node. + /// + /// ``` + /// # bevy_impulse::Diagram::from_json_str(r#" + /// { + /// "version": "0.1.0", + /// "start": "node_op", + /// "ops": { + /// "node_op": { + /// "type": "node", + /// "builder": "my_node_builder", + /// "next": { "builtin": "terminate" } + /// } + /// } + /// } + /// # "#)?; + /// # Ok::<_, serde_json::Error>(()) + Node(NodeOp), + + /// If the request is cloneable, clone it into multiple responses. + /// + /// # Examples + /// ``` + /// # bevy_impulse::Diagram::from_json_str(r#" + /// { + /// "version": "0.1.0", + /// "start": "fork_clone", + /// "ops": { + /// "fork_clone": { + /// "type": "fork_clone", + /// "next": ["terminate"] + /// } + /// } + /// } + /// # "#)?; + /// # Ok::<_, serde_json::Error>(()) + ForkClone(ForkCloneOp), + + /// If the request is a tuple of (T1, T2, T3, ...), unzip it into multiple responses + /// of T1, T2, T3, ... + /// + /// # Examples + /// ``` + /// # bevy_impulse::Diagram::from_json_str(r#" + /// { + /// "version": "0.1.0", + /// "start": "unzip", + /// "ops": { + /// "unzip": { + /// "type": "unzip", + /// "next": [{ "builtin": "terminate" }] + /// } + /// } + /// } + /// # "#)?; + /// # Ok::<_, serde_json::Error>(()) + Unzip(UnzipOp), + + /// If the request is a `Result<_, _>`, branch it to `Ok` and `Err`. + /// + /// # Examples + /// ``` + /// # bevy_impulse::Diagram::from_json_str(r#" + /// { + /// "version": "0.1.0", + /// "start": "fork_result", + /// "ops": { + /// "fork_result": { + /// "type": "fork_result", + /// "ok": { "builtin": "terminate" }, + /// "err": { "builtin": "dispose" } + /// } + /// } + /// } + /// # "#)?; + /// # Ok::<_, serde_json::Error>(()) + ForkResult(ForkResultOp), + + /// If the request is a list-like or map-like object, split it into multiple responses. + /// Note that the split output is a tuple of `(KeyOrIndex, Value)`, nodes receiving a split + /// output should have request of that type instead of just the value type. + /// + /// # Examples + /// ``` + /// # bevy_impulse::Diagram::from_json_str(r#" + /// { + /// "version": "0.1.0", + /// "start": "split", + /// "ops": { + /// "split": { + /// "type": "split", + /// "index": [{ "builtin": "terminate" }] + /// } + /// } + /// } + /// # "#)?; + /// # Ok::<_, serde_json::Error>(()) + /// ``` + Split(SplitOp), + + /// Wait for an item to be emitted from each of the inputs, then combined the + /// oldest of each into an array. + /// + /// # Examples + /// ``` + /// # bevy_impulse::Diagram::from_json_str(r#" + /// { + /// "version": "0.1.0", + /// "start": "split", + /// "ops": { + /// "split": { + /// "type": "split", + /// "index": ["op1", "op2"] + /// }, + /// "op1": { + /// "type": "node", + /// "builder": "foo", + /// "next": "join" + /// }, + /// "op2": { + /// "type": "node", + /// "builder": "bar", + /// "next": "join" + /// }, + /// "join": { + /// "type": "join", + /// "inputs": ["op1", "op2"], + /// "next": { "builtin": "terminate" } + /// } + /// } + /// } + /// # "#)?; + /// # Ok::<_, serde_json::Error>(()) + /// ``` + Join(JoinOp), + + /// If the request is serializable, transform it by running it through a [CEL](https://cel.dev/) program. + /// The context includes a "request" variable which contains the request. + /// + /// # Examples + /// ``` + /// # bevy_impulse::Diagram::from_json_str(r#" + /// { + /// "version": "0.1.0", + /// "start": "transform", + /// "ops": { + /// "transform": { + /// "type": "transform", + /// "cel": "request.name", + /// "next": { "builtin": "terminate" } + /// } + /// } + /// } + /// # "#)?; + /// # Ok::<_, serde_json::Error>(()) + /// ``` + /// + /// Note that due to how `serde_json` performs serialization, positive integers are always + /// serialized as unsigned. In CEL, You can't do an operation between unsigned and signed so + /// it is recommended to always perform explicit casts. + /// + /// # Examples + /// ``` + /// # bevy_impulse::Diagram::from_json_str(r#" + /// { + /// "version": "0.1.0", + /// "start": "transform", + /// "ops": { + /// "transform": { + /// "type": "transform", + /// "cel": "int(request.score) * 3", + /// "next": { "builtin": "terminate" } + /// } + /// } + /// } + /// # "#)?; + /// # Ok::<_, serde_json::Error>(()) + /// ``` + Transform(TransformOp), + + /// Drop the request, equivalent to a no-op. + Dispose, +} + +type DiagramStart = serde_json::Value; +type DiagramTerminate = serde_json::Value; +type DiagramScope = Scope; + +/// Returns the schema for [`String`] +fn schema_with_string(gen: &mut schemars::gen::SchemaGenerator) -> schemars::schema::Schema { + gen.subschema_for::() +} + +/// deserialize semver and validate that it has a supported version +fn deserialize_semver<'de, D>(de: D) -> Result +where + D: serde::Deserializer<'de>, +{ + let s = String::deserialize(de)?; + // SAFETY: `SUPPORTED_DIAGRAM_VERSION` is a const, this will never fail. + let ver_req = semver::VersionReq::parse(SUPPORTED_DIAGRAM_VERSION).unwrap(); + let ver = semver::Version::parse(&s).map_err(|_| { + serde::de::Error::invalid_value(serde::de::Unexpected::Str(&s), &SUPPORTED_DIAGRAM_VERSION) + })?; + if !ver_req.matches(&ver) { + return Err(serde::de::Error::invalid_value( + serde::de::Unexpected::Str(&s), + &SUPPORTED_DIAGRAM_VERSION, + )); + } + Ok(ver) +} + +/// serialize semver as a string +fn serialize_semver(o: &semver::Version, ser: S) -> Result +where + S: serde::ser::Serializer, +{ + o.to_string().serialize(ser) +} + +#[derive(JsonSchema, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub struct Diagram { + /// Version of the diagram, should always be `0.1.0`. + #[serde( + deserialize_with = "deserialize_semver", + serialize_with = "serialize_semver" + )] + #[schemars(schema_with = "schema_with_string")] + version: semver::Version, + + /// Signifies the start of a workflow. + start: NextOperation, + + ops: HashMap, +} + +impl Diagram { + /// Spawns a workflow from this diagram. + /// + /// # Examples + /// + /// ``` + /// use bevy_impulse::{Diagram, DiagramError, NodeBuilderOptions, NodeRegistry, RunCommandsOnWorldExt}; + /// + /// let mut app = bevy_app::App::new(); + /// let mut registry = NodeRegistry::default(); + /// registry.register_node_builder(NodeBuilderOptions::new("echo".to_string()), |builder, _config: ()| { + /// builder.create_map_block(|msg: String| msg) + /// }); + /// + /// let json_str = r#" + /// { + /// "version": "0.1.0", + /// "start": "echo", + /// "ops": { + /// "echo": { + /// "type": "node", + /// "builder": "echo", + /// "next": { "builtin": "terminate" } + /// } + /// } + /// } + /// "#; + /// + /// let diagram = Diagram::from_json_str(json_str)?; + /// let workflow = app.world.command(|cmds| diagram.spawn_io_workflow(cmds, ®istry))?; + /// # Ok::<_, DiagramError>(()) + /// ``` + // TODO(koonpeng): Support streams other than `()` #43. + /* pub */ + fn spawn_workflow( + &self, + cmds: &mut Commands, + registry: &NodeRegistry, + ) -> Result, DiagramError> + where + Streams: StreamPack, + { + let mut err: Option = None; + + macro_rules! unwrap_or_return { + ($v:expr) => { + match $v { + Ok(v) => v, + Err(e) => { + err = Some(e); + return; + } + } + }; + } + + let w = cmds.spawn_workflow(|scope: DiagramScope, builder: &mut Builder| { + debug!( + "spawn workflow, scope input: {:?}, terminate: {:?}", + scope.input.id(), + scope.terminate.id() + ); + + unwrap_or_return!(create_workflow(scope, builder, registry, self)); + }); + + if let Some(err) = err { + return Err(err); + } + + Ok(w) + } + + /// Wrapper to [spawn_workflow::<()>](Self::spawn_workflow). + pub fn spawn_io_workflow( + &self, + cmds: &mut Commands, + registry: &NodeRegistry, + ) -> Result, DiagramError> { + self.spawn_workflow::<()>(cmds, registry) + } + + pub fn from_json(value: serde_json::Value) -> Result { + serde_json::from_value(value) + } + + pub fn from_json_str(s: &str) -> Result { + serde_json::from_str(s) + } + + pub fn from_reader(r: R) -> Result + where + R: Read, + { + serde_json::from_reader(r) + } +} + +#[derive(thiserror::Error, Debug)] +pub enum DiagramError { + #[error("node builder [{0}] is not registered")] + BuilderNotFound(BuilderId), + + #[error("operation [{0}] not found")] + OperationNotFound(OperationId), + + #[error("output type does not match input type")] + TypeMismatch, + + #[error("missing start or terminate")] + MissingStartOrTerminate, + + #[error("cannot connect to start")] + CannotConnectStart, + + #[error("request or response cannot be serialized or deserialized")] + NotSerializable, + + #[error("response cannot be cloned")] + NotCloneable, + + #[error("the number of unzip slots in response does not match the number of inputs")] + NotUnzippable, + + #[error( + "node must be registered with \"with_fork_result()\" to be able to perform fork result" + )] + CannotForkResult, + + #[error("response cannot be split")] + NotSplittable, + + #[error("empty join is not allowed")] + EmptyJoin, + + #[error(transparent)] + CannotTransform(#[from] TransformError), + + #[error("an interconnect like fork_clone cannot connect to another interconnect")] + BadInterconnectChain, + + #[error(transparent)] + JsonError(#[from] serde_json::Error), + + #[error(transparent)] + ConnectionError(#[from] SplitConnectionError), + + /// Use this only for errors that *should* never happen because of some preconditions. + /// If this error ever comes up, then it likely means that there is some logical flaws + /// in the algorithm. + #[error("an unknown error occurred while building the diagram, {0}")] + UnknownError(String), +} + +#[macro_export] +macro_rules! unknown_diagram_error { + () => { + DiagramError::UnknownError(format!("{}:{}", file!(), line!())) + }; +} + +#[cfg(test)] +mod testing; + +#[cfg(test)] +mod tests { + use crate::{Cancellation, CancellationCause}; + use serde_json::json; + use test_log::test; + use testing::DiagramTestFixture; + + use super::*; + + #[test] + fn test_no_terminate() { + let mut fixture = DiagramTestFixture::new(); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "start": "op1", + "ops": { + "op1": { + "type": "node", + "builder": "multiply3_uncloneable", + "next": { "builtin": "dispose" }, + }, + }, + })) + .unwrap(); + + let err = fixture + .spawn_and_run(&diagram, serde_json::Value::from(4)) + .unwrap_err(); + assert!(matches!( + *err.downcast_ref::().unwrap().cause, + CancellationCause::Unreachable(_) + )); + } + + #[test] + fn test_unserializable_start() { + let mut fixture = DiagramTestFixture::new(); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "start": "op1", + "ops": { + "op1": { + "type": "node", + "builder": "opaque_request", + "next": { "builtin": "terminate" }, + }, + }, + })) + .unwrap(); + + let err = fixture.spawn_io_workflow(&diagram).unwrap_err(); + assert!(matches!(err, DiagramError::NotSerializable), "{:?}", err); + } + + #[test] + fn test_unserializable_terminate() { + let mut fixture = DiagramTestFixture::new(); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "start": "op1", + "ops": { + "op1": { + "type": "node", + "builder": "opaque_response", + "next": { "builtin": "terminate" }, + }, + }, + })) + .unwrap(); + + let err = fixture.spawn_io_workflow(&diagram).unwrap_err(); + assert!(matches!(err, DiagramError::NotSerializable), "{:?}", err); + } + + #[test] + fn test_mismatch_types() { + let mut fixture = DiagramTestFixture::new(); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "start": "op1", + "ops": { + "op1": { + "type": "node", + "builder": "multiply3_uncloneable", + "next": "op2", + }, + "op2": { + "type": "node", + "builder": "opaque_request", + "next": { "builtin": "terminate" }, + }, + }, + })) + .unwrap(); + + let err = fixture.spawn_io_workflow(&diagram).unwrap_err(); + assert!(matches!(err, DiagramError::TypeMismatch), "{:?}", err); + } + + #[test] + fn test_disconnected() { + let mut fixture = DiagramTestFixture::new(); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "start": "op1", + "ops": { + "op1": { + "type": "node", + "builder": "multiply3_uncloneable", + "next": "op2", + }, + "op2": { + "type": "node", + "builder": "multiply3_uncloneable", + "next": "op1", + }, + }, + })) + .unwrap(); + + let err = fixture + .spawn_and_run(&diagram, serde_json::Value::from(4)) + .unwrap_err(); + assert!(matches!( + *err.downcast_ref::().unwrap().cause, + CancellationCause::Unreachable(_) + )); + } + + #[test] + fn test_looping_diagram() { + let mut fixture = DiagramTestFixture::new(); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "start": "op1", + "ops": { + "op1": { + "type": "node", + "builder": "multiply3", + "next": "fork_clone", + }, + "fork_clone": { + "type": "fork_clone", + "next": ["op1", "op2"], + }, + "op2": { + "type": "node", + "builder": "multiply3_uncloneable", + "next": { "builtin": "terminate" }, + }, + }, + })) + .unwrap(); + + let result = fixture + .spawn_and_run(&diagram, serde_json::Value::from(4)) + .unwrap(); + assert_eq!(result, 36); + } + + #[test] + fn test_noop_diagram() { + let mut fixture = DiagramTestFixture::new(); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "start": { "builtin": "terminate" }, + "ops": {}, + })) + .unwrap(); + + let result = fixture + .spawn_and_run(&diagram, serde_json::Value::from(4)) + .unwrap(); + assert_eq!(result, 4); + } + + #[test] + fn test_serialized_diagram() { + let mut fixture = DiagramTestFixture::new(); + + let json_str = r#" + { + "version": "0.1.0", + "start": "multiply3_uncloneable", + "ops": { + "multiply3_uncloneable": { + "type": "node", + "builder": "multiplyBy", + "config": 7, + "next": { "builtin": "terminate" } + } + } + } + "#; + + let result = fixture + .spawn_and_run( + &Diagram::from_json_str(json_str).unwrap(), + serde_json::Value::from(4), + ) + .unwrap(); + assert_eq!(result, 28); + } + + /// Test that we can transform on a slot of a unzipped response. Operations which changes + /// the output type has extra serialization logic. + #[test] + fn test_transform_unzip() { + let mut fixture = DiagramTestFixture::new(); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "start": "op1", + "ops": { + "op1": { + "type": "node", + "builder": "multiply3_5", + "next": "unzip", + }, + "unzip": { + "type": "unzip", + "next": ["transform"], + }, + "transform": { + "type": "transform", + "cel": "777", + "next": { "builtin": "terminate" }, + }, + }, + })) + .unwrap(); + + let result = fixture + .spawn_and_run(&diagram, serde_json::Value::from(4)) + .unwrap(); + assert_eq!(result, 777); + } +} diff --git a/src/diagram/fork_clone.rs b/src/diagram/fork_clone.rs new file mode 100644 index 00000000..d6b8dc34 --- /dev/null +++ b/src/diagram/fork_clone.rs @@ -0,0 +1,134 @@ +use std::any::TypeId; + +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use tracing::debug; + +use crate::Builder; + +use super::{ + impls::{DefaultImpl, NotSupported}, + DiagramError, DynOutput, NextOperation, +}; + +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +#[serde(rename_all = "snake_case")] +pub struct ForkCloneOp { + pub(super) next: Vec, +} + +pub trait DynForkClone { + const CLONEABLE: bool; + + fn dyn_fork_clone( + builder: &mut Builder, + output: DynOutput, + amount: usize, + ) -> Result, DiagramError>; +} + +impl DynForkClone for NotSupported { + const CLONEABLE: bool = false; + + fn dyn_fork_clone( + _builder: &mut Builder, + _output: DynOutput, + _amount: usize, + ) -> Result, DiagramError> { + Err(DiagramError::NotCloneable) + } +} + +impl DynForkClone for DefaultImpl +where + T: Send + Sync + 'static + Clone, +{ + const CLONEABLE: bool = true; + + fn dyn_fork_clone( + builder: &mut Builder, + output: DynOutput, + amount: usize, + ) -> Result, DiagramError> { + debug!("fork clone: {:?}", output); + assert_eq!(output.type_id, TypeId::of::()); + + let fork_clone = output.into_output::()?.fork_clone(builder); + let outputs = (0..amount) + .map(|_| fork_clone.clone_output(builder).into()) + .collect(); + debug!("forked outputs: {:?}", outputs); + Ok(outputs) + } +} + +#[cfg(test)] +mod tests { + use serde_json::json; + use test_log::test; + + use crate::{diagram::testing::DiagramTestFixture, Diagram}; + + use super::*; + + #[test] + fn test_fork_clone_uncloneable() { + let mut fixture = DiagramTestFixture::new(); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "start": "op1", + "ops": { + "op1": { + "type": "node", + "builder": "multiply3_uncloneable", + "next": "fork_clone" + }, + "fork_clone": { + "type": "fork_clone", + "next": ["op2"] + }, + "op2": { + "type": "node", + "builder": "multiply3_uncloneable", + "next": { "builtin": "terminate" }, + }, + }, + })) + .unwrap(); + let err = fixture.spawn_io_workflow(&diagram).unwrap_err(); + assert!(matches!(err, DiagramError::NotCloneable), "{:?}", err); + } + + #[test] + fn test_fork_clone() { + let mut fixture = DiagramTestFixture::new(); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "start": "op1", + "ops": { + "op1": { + "type": "node", + "builder": "multiply3", + "next": "fork_clone" + }, + "fork_clone": { + "type": "fork_clone", + "next": ["op2"] + }, + "op2": { + "type": "node", + "builder": "multiply3", + "next": { "builtin": "terminate" }, + }, + }, + })) + .unwrap(); + + let result = fixture + .spawn_and_run(&diagram, serde_json::Value::from(4)) + .unwrap(); + assert_eq!(result, 36); + } +} diff --git a/src/diagram/fork_result.rs b/src/diagram/fork_result.rs new file mode 100644 index 00000000..6decaddf --- /dev/null +++ b/src/diagram/fork_result.rs @@ -0,0 +1,133 @@ +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use tracing::debug; + +use crate::Builder; + +use super::{ + impls::{DefaultImpl, NotSupported}, + DiagramError, DynOutput, NextOperation, +}; + +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +#[serde(rename_all = "snake_case")] +pub struct ForkResultOp { + pub(super) ok: NextOperation, + pub(super) err: NextOperation, +} + +pub trait DynForkResult { + const SUPPORTED: bool; + + fn dyn_fork_result( + builder: &mut Builder, + output: DynOutput, + ) -> Result<(DynOutput, DynOutput), DiagramError>; +} + +impl DynForkResult for NotSupported { + const SUPPORTED: bool = false; + + fn dyn_fork_result( + _builder: &mut Builder, + _output: DynOutput, + ) -> Result<(DynOutput, DynOutput), DiagramError> { + Err(DiagramError::CannotForkResult) + } +} + +impl DynForkResult> for DefaultImpl +where + T: Send + Sync + 'static, + E: Send + Sync + 'static, +{ + const SUPPORTED: bool = true; + + fn dyn_fork_result( + builder: &mut Builder, + output: DynOutput, + ) -> Result<(DynOutput, DynOutput), DiagramError> { + debug!("fork result: {:?}", output); + + let chain = output.into_output::>()?.chain(builder); + let outputs = chain.fork_result(|c| c.output().into(), |c| c.output().into()); + debug!("forked outputs: {:?}", outputs); + Ok(outputs) + } +} + +#[cfg(test)] +mod tests { + use serde_json::json; + use test_log::test; + + use crate::{diagram::testing::DiagramTestFixture, Builder, Diagram, NodeBuilderOptions}; + + #[test] + fn test_fork_result() { + let mut fixture = DiagramTestFixture::new(); + + fn check_even(v: i64) -> Result { + if v % 2 == 0 { + Ok("even".to_string()) + } else { + Err("odd".to_string()) + } + } + + fixture + .registry + .register_node_builder( + NodeBuilderOptions::new("check_even".to_string()), + |builder: &mut Builder, _config: ()| builder.create_map_block(&check_even), + ) + .with_fork_result(); + + fn echo(s: String) -> String { + s + } + + fixture.registry.register_node_builder( + NodeBuilderOptions::new("echo".to_string()), + |builder: &mut Builder, _config: ()| builder.create_map_block(&echo), + ); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "start": "op1", + "ops": { + "op1": { + "type": "node", + "builder": "check_even", + "next": "fork_result", + }, + "fork_result": { + "type": "fork_result", + "ok": "op2", + "err": "op3", + }, + "op2": { + "type": "node", + "builder": "echo", + "next": { "builtin": "terminate" }, + }, + "op3": { + "type": "node", + "builder": "echo", + "next": { "builtin": "terminate" }, + }, + }, + })) + .unwrap(); + + let result = fixture + .spawn_and_run(&diagram, serde_json::Value::from(4)) + .unwrap(); + assert_eq!(result, "even"); + + let result = fixture + .spawn_and_run(&diagram, serde_json::Value::from(3)) + .unwrap(); + assert_eq!(result, "odd"); + } +} diff --git a/src/diagram/generate_schema.rs b/src/diagram/generate_schema.rs new file mode 100644 index 00000000..22dae04a --- /dev/null +++ b/src/diagram/generate_schema.rs @@ -0,0 +1,36 @@ +use bevy_impulse::Diagram; + +fn main() -> Result<(), Box> { + let schema = schemars::schema_for!(Diagram); + let f = std::fs::OpenOptions::new() + .write(true) + .truncate(true) + .create(true) + .open("diagram.schema.json") + .unwrap(); + serde_json::to_writer_pretty(f, &schema)?; + Ok(()) +} + +#[cfg(test)] +mod diagram { + mod test { + use super::super::*; + use std::iter::zip; + + #[cfg(not(target_os = "windows"))] + #[test] + fn check_schema_changes() -> Result<(), String> { + let cur_schema_json = std::fs::read("diagram.schema.json").unwrap(); + let schema = schemars::schema_for!(Diagram); + let new_schema_json = serde_json::to_vec_pretty(&schema).unwrap(); + + if cur_schema_json.len() != new_schema_json.len() + || zip(cur_schema_json, new_schema_json).any(|(a, b)| a != b) + { + return Err(String::from("There are changes in the json schema, please run `cargo run -F=diagram generate_schema` to regenerate it")); + } + Ok(()) + } + } +} diff --git a/src/diagram/impls.rs b/src/diagram/impls.rs new file mode 100644 index 00000000..b2116505 --- /dev/null +++ b/src/diagram/impls.rs @@ -0,0 +1,5 @@ +/// A struct to provide the default implementation for various operations. +pub struct DefaultImpl; + +/// A struct to provide "not supported" implementations for various operations. +pub struct NotSupported; diff --git a/src/diagram/join.rs b/src/diagram/join.rs new file mode 100644 index 00000000..da5fbe75 --- /dev/null +++ b/src/diagram/join.rs @@ -0,0 +1,321 @@ +use std::any::TypeId; + +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use smallvec::SmallVec; +use tracing::debug; + +use crate::{Builder, IterBufferable, Output}; + +use super::{ + DiagramError, DynOutput, MessageRegistry, NextOperation, SerializeMessage, SourceOperation, +}; + +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +#[serde(rename_all = "snake_case")] +pub struct JoinOp { + pub(super) next: NextOperation, + + /// Controls the order of the resulting join. Each item must be an operation id of one of the + /// incoming outputs. + pub(super) inputs: Vec, + + /// Do not serialize before performing the join. If true, joins can only be done + /// on outputs of the same type. + pub(super) no_serialize: Option, +} + +pub(super) fn register_join_impl(registry: &mut MessageRegistry) +where + T: Send + Sync + 'static, + Serializer: SerializeMessage>, +{ + if registry.join_impls.contains_key(&TypeId::of::()) { + return; + } + + registry + .join_impls + .insert(TypeId::of::(), Box::new(join_impl::)); +} + +/// Serialize the outputs before joining them, and convert the resulting joined output into a +/// [`serde_json::Value`]. +pub(super) fn serialize_and_join( + builder: &mut Builder, + registry: &MessageRegistry, + outputs: Vec, +) -> Result, DiagramError> { + debug!("serialize and join outputs {:?}", outputs); + + if outputs.is_empty() { + // do not allow empty joins + return Err(DiagramError::EmptyJoin); + } + + let outputs = outputs + .into_iter() + .map(|o| { + let serialize_impl = registry + .serialize_impls + .get(&o.type_id) + .ok_or(DiagramError::NotSerializable)?; + let serialized_output = serialize_impl(builder, o)?; + Ok(serialized_output) + }) + .collect::, DiagramError>>()?; + + // we need to convert the joined output to [`serde_json::Value`] in order for it to be + // serializable. + let joined_output = outputs.join_vec::<4>(builder).output(); + let json_output = joined_output + .chain(builder) + .map_block(|o| serde_json::to_value(o)) + .cancel_on_err() + .output(); + Ok(json_output) +} + +fn join_impl(builder: &mut Builder, outputs: Vec) -> Result +where + T: Send + Sync + 'static, +{ + debug!("join outputs {:?}", outputs); + + if outputs.is_empty() { + // do a empty join, in practice, this branch is never ran because [`WorkflowBuilder`] + // should error out if there is an empty join. + return Err(DiagramError::EmptyJoin); + } + + let first_type = outputs[0].type_id; + + let outputs = outputs + .into_iter() + .map(|o| { + if o.type_id != first_type { + Err(DiagramError::TypeMismatch) + } else { + Ok(o.into_output::()?) + } + }) + .collect::, _>>()?; + + // we don't know the number of items at compile time, so we just use a sensible number. + // NOTE: Be sure to update `JoinOutput` if this changes. + Ok(outputs.join_vec::<4>(builder).output().into()) +} + +/// The resulting type of a `join` operation. Nodes receiving a join output must have request +/// of this type. Note that the join output is NOT serializable. If you would like to serialize it, +/// convert it to a `Vec` first. +pub type JoinOutput = SmallVec<[T; 4]>; + +#[cfg(test)] +mod tests { + use serde_json::json; + use test_log::test; + + use super::*; + use crate::{ + diagram::testing::DiagramTestFixture, Diagram, DiagramError, JsonPosition, + NodeBuilderOptions, + }; + + #[test] + fn test_join() { + let mut fixture = DiagramTestFixture::new(); + + fn get_split_value(pair: (JsonPosition, serde_json::Value)) -> serde_json::Value { + pair.1 + } + + fixture.registry.register_node_builder( + NodeBuilderOptions::new("get_split_value".to_string()), + |builder, _config: ()| builder.create_map_block(get_split_value), + ); + + fn serialize_join_output(join_output: JoinOutput) -> serde_json::Value { + serde_json::to_value(join_output).unwrap() + } + + fixture + .registry + .opt_out() + .no_request_deserializing() + .register_node_builder( + NodeBuilderOptions::new("serialize_join_output".to_string()), + |builder, _config: ()| builder.create_map_block(serialize_join_output), + ); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "start": "split", + "ops": { + "split": { + "type": "split", + "sequential": ["get_split_value1", "get_split_value2"] + }, + "get_split_value1": { + "type": "node", + "builder": "get_split_value", + "next": "op1", + }, + "op1": { + "type": "node", + "builder": "multiply3_uncloneable", + "next": "join", + }, + "get_split_value2": { + "type": "node", + "builder": "get_split_value", + "next": "op2", + }, + "op2": { + "type": "node", + "builder": "multiply3_uncloneable", + "next": "join", + }, + "join": { + "type": "join", + "inputs": ["op1", "op2"], + "next": "serialize_join_output", + "no_serialize": true, + }, + "serialize_join_output": { + "type": "node", + "builder": "serialize_join_output", + "next": { "builtin": "terminate" }, + }, + } + })) + .unwrap(); + + let result = fixture + .spawn_and_run(&diagram, serde_json::Value::from([1, 2])) + .unwrap(); + assert_eq!(result.as_array().unwrap().len(), 2); + assert_eq!(result[0], 3); + assert_eq!(result[1], 6); + } + + /// This test is to ensure that the order of split and join operations are stable. + #[test] + fn test_join_stress() { + for _ in 1..20 { + test_join(); + } + } + + #[test] + fn test_empty_join() { + let mut fixture = DiagramTestFixture::new(); + + fn get_split_value(pair: (JsonPosition, serde_json::Value)) -> serde_json::Value { + pair.1 + } + + fixture.registry.register_node_builder( + NodeBuilderOptions::new("get_split_value".to_string()), + |builder, _config: ()| builder.create_map_block(get_split_value), + ); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "start": "split", + "ops": { + "split": { + "type": "split", + "sequential": ["get_split_value1", "get_split_value2"] + }, + "get_split_value1": { + "type": "node", + "builder": "get_split_value", + "next": "op1", + }, + "op1": { + "type": "node", + "builder": "multiply3_uncloneable", + "next": { "builtin": "terminate" }, + }, + "get_split_value2": { + "type": "node", + "builder": "get_split_value", + "next": "op2", + }, + "op2": { + "type": "node", + "builder": "multiply3_uncloneable", + "next": { "builtin": "terminate" }, + }, + "join": { + "type": "join", + "inputs": [], + "next": { "builtin": "terminate" }, + "no_serialize": true, + }, + } + })) + .unwrap(); + + let err = fixture.spawn_io_workflow(&diagram).unwrap_err(); + assert!(matches!(err, DiagramError::EmptyJoin)); + } + + #[test] + fn test_serialize_and_join() { + let mut fixture = DiagramTestFixture::new(); + + fn num_output(_: serde_json::Value) -> i64 { + 1 + } + + fixture.registry.register_node_builder( + NodeBuilderOptions::new("num_output".to_string()), + |builder, _config: ()| builder.create_map_block(num_output), + ); + + fn string_output(_: serde_json::Value) -> String { + "hello".to_string() + } + + fixture.registry.register_node_builder( + NodeBuilderOptions::new("string_output".to_string()), + |builder, _config: ()| builder.create_map_block(string_output), + ); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "start": "fork_clone", + "ops": { + "fork_clone": { + "type": "fork_clone", + "next": ["op1", "op2"] + }, + "op1": { + "type": "node", + "builder": "num_output", + "next": "join", + }, + "op2": { + "type": "node", + "builder": "string_output", + "next": "join", + }, + "join": { + "type": "join", + "inputs": ["op1", "op2"], + "next": { "builtin": "terminate" }, + }, + } + })) + .unwrap(); + + let result = fixture + .spawn_and_run(&diagram, serde_json::Value::Null) + .unwrap(); + assert_eq!(result.as_array().unwrap().len(), 2); + assert_eq!(result[0], 1); + assert_eq!(result[1], "hello"); + } +} diff --git a/src/diagram/node_registry.rs b/src/diagram/node_registry.rs new file mode 100644 index 00000000..09a832d1 --- /dev/null +++ b/src/diagram/node_registry.rs @@ -0,0 +1,849 @@ +use std::{ + any::{Any, TypeId}, + borrow::Borrow, + cell::RefCell, + collections::HashMap, + fmt::Debug, + marker::PhantomData, +}; + +use crate::{Builder, InputSlot, Node, Output, StreamPack}; +use bevy_ecs::entity::Entity; +use schemars::{ + gen::{SchemaGenerator, SchemaSettings}, + schema::Schema, + JsonSchema, +}; +use serde::{de::DeserializeOwned, ser::SerializeStruct, Serialize}; +use tracing::debug; + +use crate::{RequestMetadata, SerializeMessage}; + +use super::{ + fork_clone::DynForkClone, + fork_result::DynForkResult, + impls::{DefaultImpl, NotSupported}, + register_deserialize, register_serialize, + unzip::DynUnzip, + BuilderId, DefaultDeserializer, DefaultSerializer, DeserializeMessage, DiagramError, DynSplit, + DynSplitOutputs, DynType, OpaqueMessageDeserializer, OpaqueMessageSerializer, ResponseMetadata, + SplitOp, +}; + +/// A type erased [`crate::InputSlot`] +#[derive(Copy, Clone, Debug)] +pub struct DynInputSlot { + scope: Entity, + source: Entity, + pub(super) type_id: TypeId, +} + +impl DynInputSlot { + pub(super) fn scope(&self) -> Entity { + self.scope + } + + pub(super) fn id(&self) -> Entity { + self.source + } +} + +impl From> for DynInputSlot { + fn from(input: InputSlot) -> Self { + Self { + scope: input.scope(), + source: input.id(), + type_id: TypeId::of::(), + } + } +} + +#[derive(Debug)] +/// A type erased [`crate::Output`] +pub struct DynOutput { + scope: Entity, + target: Entity, + pub(super) type_id: TypeId, +} + +impl DynOutput { + pub(super) fn into_output(self) -> Result, DiagramError> + where + T: Send + Sync + 'static + Any, + { + if self.type_id != TypeId::of::() { + Err(DiagramError::TypeMismatch) + } else { + Ok(Output::::new(self.scope, self.target)) + } + } + + pub(super) fn scope(&self) -> Entity { + self.scope + } + + pub(super) fn id(&self) -> Entity { + self.target + } +} + +impl From> for DynOutput +where + T: Send + Sync + 'static, +{ + fn from(output: Output) -> Self { + Self { + scope: output.scope(), + target: output.id(), + type_id: TypeId::of::(), + } + } +} + +#[derive(Clone, Serialize)] +pub(super) struct NodeMetadata { + pub(super) id: BuilderId, + pub(super) name: String, + pub(super) request: RequestMetadata, + pub(super) response: ResponseMetadata, + pub(super) config_schema: Schema, +} + +/// A type erased [`bevy_impulse::Node`] +pub(super) struct DynNode { + pub(super) input: DynInputSlot, + pub(super) output: DynOutput, +} + +impl DynNode { + fn new(output: Output, input: InputSlot) -> Self + where + Request: 'static, + Response: Send + Sync + 'static, + { + Self { + input: input.into(), + output: output.into(), + } + } +} + +impl From> for DynNode +where + Request: 'static, + Response: Send + Sync + 'static, + Streams: StreamPack, +{ + fn from(node: Node) -> Self { + Self { + input: node.input.into(), + output: node.output.into(), + } + } +} + +pub struct NodeRegistration { + pub(super) metadata: NodeMetadata, + + /// Creates an instance of the registered node. + create_node_impl: CreateNodeFn, + + fork_clone_impl: Option, + + unzip_impl: + Option Result, DiagramError>>>, + + fork_result_impl: Option< + Box Result<(DynOutput, DynOutput), DiagramError>>, + >, + + split_impl: Option< + Box< + dyn for<'a> Fn( + &mut Builder, + DynOutput, + &'a SplitOp, + ) -> Result, DiagramError>, + >, + >, +} + +impl NodeRegistration { + fn new( + metadata: NodeMetadata, + create_node_impl: CreateNodeFn, + fork_clone_impl: Option, + ) -> NodeRegistration { + NodeRegistration { + metadata, + create_node_impl, + fork_clone_impl, + unzip_impl: None, + fork_result_impl: None, + split_impl: None, + } + } + + pub(super) fn create_node( + &self, + builder: &mut Builder, + config: serde_json::Value, + ) -> Result { + let n = (self.create_node_impl.borrow_mut())(builder, config)?; + debug!( + "created node of {}, output: {:?}, input: {:?}", + self.metadata.id, n.output, n.input + ); + Ok(n) + } + + pub(super) fn fork_clone( + &self, + builder: &mut Builder, + output: DynOutput, + amount: usize, + ) -> Result, DiagramError> { + let f = self + .fork_clone_impl + .as_ref() + .ok_or(DiagramError::NotCloneable)?; + f(builder, output, amount) + } + + pub(super) fn unzip( + &self, + builder: &mut Builder, + output: DynOutput, + ) -> Result, DiagramError> { + let f = self + .unzip_impl + .as_ref() + .ok_or(DiagramError::NotUnzippable)?; + f(builder, output) + } + + pub(super) fn fork_result( + &self, + builder: &mut Builder, + output: DynOutput, + ) -> Result<(DynOutput, DynOutput), DiagramError> { + let f = self + .fork_result_impl + .as_ref() + .ok_or(DiagramError::CannotForkResult)?; + f(builder, output) + } + + pub(super) fn split<'a>( + &self, + builder: &mut Builder, + output: DynOutput, + split_op: &'a SplitOp, + ) -> Result, DiagramError> { + let f = self + .split_impl + .as_ref() + .ok_or(DiagramError::NotSplittable)?; + f(builder, output, split_op) + } +} + +type CreateNodeFn = + RefCell Result>>; +type ForkCloneFn = + Box Result, DiagramError>>; + +pub struct CommonOperations<'a, Deserialize, Serialize, Cloneable> { + registry: &'a mut NodeRegistry, + _ignore: PhantomData<(Deserialize, Serialize, Cloneable)>, +} + +impl<'a, DeserializeImpl, SerializeImpl, Cloneable> + CommonOperations<'a, DeserializeImpl, SerializeImpl, Cloneable> +{ + /// Register a node builder with the specified common operations. + /// + /// # Arguments + /// + /// * `id` - Id of the builder, this must be unique. + /// * `name` - Friendly name for the builder, this is only used for display purposes. + /// * `f` - The node builder to register. + pub fn register_node_builder( + self, + options: NodeBuilderOptions, + mut f: impl FnMut(&mut Builder, Config) -> Node + 'static, + ) -> RegistrationBuilder<'a, Request, Response, Streams> + where + Config: JsonSchema + DeserializeOwned, + Request: Send + Sync + 'static, + Response: Send + Sync + 'static, + Streams: StreamPack, + DeserializeImpl: DeserializeMessage, + SerializeImpl: SerializeMessage, + Cloneable: DynForkClone, + { + register_deserialize::(&mut self.registry.data); + register_serialize::(&mut self.registry.data); + + let registration = NodeRegistration::new( + NodeMetadata { + id: options.id.clone(), + name: options.name.unwrap_or(options.id.clone()), + request: RequestMetadata { + schema: DeserializeImpl::json_schema(&mut self.registry.data.schema_generator) + .unwrap_or_else(|| { + self.registry.data.schema_generator.subschema_for::<()>() + }), + deserializable: DeserializeImpl::deserializable(), + }, + response: ResponseMetadata::new( + SerializeImpl::json_schema(&mut self.registry.data.schema_generator) + .unwrap_or_else(|| { + self.registry.data.schema_generator.subschema_for::<()>() + }), + SerializeImpl::serializable(), + Cloneable::CLONEABLE, + ), + config_schema: self + .registry + .data + .schema_generator + .subschema_for::(), + }, + RefCell::new(Box::new(move |builder, config| { + let config = serde_json::from_value(config)?; + let n = f(builder, config); + Ok(DynNode::new(n.output, n.input)) + })), + if Cloneable::CLONEABLE { + Some(Box::new(|builder, output, amount| { + Cloneable::dyn_fork_clone(builder, output, amount) + })) + } else { + None + }, + ); + self.registry.nodes.insert(options.id.clone(), registration); + + // SAFETY: We inserted an entry at this ID a moment ago + let node = self.registry.nodes.get_mut(&options.id).unwrap(); + + RegistrationBuilder:: { + node, + data: &mut self.registry.data, + _ignore: Default::default(), + } + } + + /// Opt out of deserializing the request of the node. Use this to build a + /// node whose request type is not deserializable. + pub fn no_request_deserializing( + self, + ) -> CommonOperations<'a, OpaqueMessageDeserializer, SerializeImpl, Cloneable> { + CommonOperations { + registry: self.registry, + _ignore: Default::default(), + } + } + + /// Opt out of serializing the response of the node. Use this to build a + /// node whose response type is not serializable. + pub fn no_response_serializing( + self, + ) -> CommonOperations<'a, DeserializeImpl, OpaqueMessageSerializer, Cloneable> { + CommonOperations { + registry: self.registry, + _ignore: Default::default(), + } + } + + /// Opt out of cloning the response of the node. Use this to build a node + /// whose response type is not cloneable. + pub fn no_response_cloning( + self, + ) -> CommonOperations<'a, DeserializeImpl, SerializeImpl, NotSupported> { + CommonOperations { + registry: self.registry, + _ignore: Default::default(), + } + } +} + +pub struct RegistrationBuilder<'a, Request, Response, Streams> { + node: &'a mut NodeRegistration, + data: &'a mut MessageRegistry, + _ignore: PhantomData<(Request, Response, Streams)>, +} + +impl<'a, Request, Response, Streams> RegistrationBuilder<'a, Request, Response, Streams> { + /// Mark the node as having a unzippable response. This is required in order for the node + /// to be able to be connected to a "Unzip" operation. + pub fn with_unzip(self) -> Self + where + DefaultImpl: DynUnzip, + { + self.with_unzip_impl::() + } + + /// Mark the node as having an unzippable response whose elements are not serializable. + pub fn with_unzip_unserializable(self) -> Self + where + DefaultImpl: DynUnzip, + { + self.with_unzip_impl::() + } + + fn with_unzip_impl(self) -> Self + where + UnzipImpl: DynUnzip, + { + self.node.metadata.response.unzip_slots = UnzipImpl::UNZIP_SLOTS; + self.node.unzip_impl = if UnzipImpl::UNZIP_SLOTS > 0 { + Some(Box::new(|builder, output| { + UnzipImpl::dyn_unzip(builder, output) + })) + } else { + None + }; + + UnzipImpl::on_register(self.data); + + self + } + + /// Mark the node as having a [`Result<_, _>`] response. This is required in order for the node + /// to be able to be connected to a "Fork Result" operation. + pub fn with_fork_result(self) -> Self + where + DefaultImpl: DynForkResult, + { + self.node.metadata.response.fork_result = true; + self.node.fork_result_impl = Some(Box::new(|builder, output| { + >::dyn_fork_result(builder, output) + })); + + self + } + + /// Mark the node as having a splittable response. This is required in order + /// for the node to be able to be connected to a "Split" operation. + pub fn with_split(self) -> Self + where + DefaultImpl: DynSplit, + { + self.with_split_impl::() + } + + /// Mark the node as having a splittable response but the items from the split + /// are unserializable. + pub fn with_split_unserializable(self) -> Self + where + DefaultImpl: DynSplit, + { + self.with_split_impl::() + } + + pub fn with_split_impl(self) -> Self + where + SplitImpl: DynSplit, + { + self.node.metadata.response.splittable = true; + self.node.split_impl = Some(Box::new(|builder, output, split_op| { + SplitImpl::dyn_split(builder, output, split_op) + })); + + SplitImpl::on_register(self.data); + + self + } +} + +pub trait IntoNodeRegistration { + fn into_node_registration( + self, + id: BuilderId, + name: String, + schema_generator: &mut SchemaGenerator, + ) -> NodeRegistration; +} + +pub struct NodeRegistry { + nodes: HashMap, + pub(super) data: MessageRegistry, +} + +pub struct MessageRegistry { + /// List of all request and response types used in all registered nodes, this only + /// contains serializable types, non serializable types are opaque and is only compatible + /// with itself. + schema_generator: SchemaGenerator, + + pub(super) deserialize_impls: HashMap< + TypeId, + Box) -> Result>, + >, + + pub(super) serialize_impls: HashMap< + TypeId, + Box Result, DiagramError>>, + >, + + pub(super) join_impls: HashMap< + TypeId, + Box) -> Result>, + >, +} + +impl Default for NodeRegistry { + fn default() -> Self { + let mut settings = SchemaSettings::default(); + settings.definitions_path = "#/types/".to_string(); + NodeRegistry { + nodes: Default::default(), + data: MessageRegistry { + schema_generator: SchemaGenerator::new(settings), + deserialize_impls: HashMap::new(), + serialize_impls: HashMap::new(), + join_impls: HashMap::new(), + }, + } + } +} + +impl NodeRegistry { + pub fn new() -> Self { + Self::default() + } + + /// Register a node builder with all the common operations (deserialize the + /// request, serialize the response, and clone the response) enabled. + /// + /// You will receive a [`RegistrationBuilder`] which you can then use to + /// enable more operations around your node, such as fork result, split, + /// or unzip. The data types of your node need to be suitable for those + /// operations or else the compiler will not allow you to enable them. + /// + /// ``` + /// use bevy_impulse::{NodeBuilderOptions, NodeRegistry}; + /// + /// let mut registry = NodeRegistry::default(); + /// registry.register_node_builder( + /// NodeBuilderOptions::new("echo".to_string()), + /// |builder, _config: ()| builder.create_map_block(|msg: String| msg) + /// ); + /// ``` + /// + /// # Arguments + /// + /// * `id` - Id of the builder, this must be unique. + /// * `name` - Friendly name for the builder, this is only used for display purposes. + /// * `f` - The node builder to register. + pub fn register_node_builder( + &mut self, + options: NodeBuilderOptions, + builder: impl FnMut(&mut Builder, Config) -> Node + 'static, + ) -> RegistrationBuilder + where + Config: JsonSchema + DeserializeOwned, + Request: Send + Sync + 'static + DynType + DeserializeOwned, + Response: Send + Sync + 'static + DynType + Serialize + Clone, + { + self.opt_out().register_node_builder(options, builder) + } + + /// In some cases the common operations of deserialization, serialization, + /// and cloning cannot be performed for the request or response of a node. + /// When that happens you can still register your node builder by calling + /// this function and explicitly disabling the common operations that your + /// node cannot support. + /// + /// + /// In order for the request to be deserializable, it must implement [`schemars::JsonSchema`] and [`serde::de::DeserializeOwned`]. + /// In order for the response to be serializable, it must implement [`schemars::JsonSchema`] and [`serde::Serialize`]. + /// + /// ``` + /// use schemars::JsonSchema; + /// use serde::{Deserialize, Serialize}; + /// + /// #[derive(JsonSchema, Deserialize)] + /// struct DeserializableRequest {} + /// + /// #[derive(JsonSchema, Serialize)] + /// struct SerializableResponse {} + /// ``` + /// + /// If your node have a request or response that is not serializable, there is still + /// a way to register it. + /// + /// ``` + /// use bevy_impulse::{NodeBuilderOptions, NodeRegistry}; + /// + /// struct NonSerializable { + /// data: String + /// } + /// + /// let mut registry = NodeRegistry::default(); + /// registry + /// .opt_out() + /// .no_request_deserializing() + /// .no_response_serializing() + /// .no_response_cloning() + /// .register_node_builder( + /// NodeBuilderOptions::new("echo"), + /// |builder, _config: ()| { + /// builder.create_map_block(|msg: NonSerializable| msg) + /// } + /// ); + /// ``` + /// + /// Note that nodes registered without deserialization cannot be connected + /// to the workflow start, and nodes registered without serialization cannot + /// be connected to the workflow termination. + pub fn opt_out( + &mut self, + ) -> CommonOperations { + CommonOperations { + registry: self, + _ignore: Default::default(), + } + } + + pub(super) fn get_registration(&self, id: &Q) -> Result<&NodeRegistration, DiagramError> + where + Q: Borrow + ?Sized, + { + let k = id.borrow(); + self.nodes + .get(k) + .ok_or(DiagramError::BuilderNotFound(k.to_string())) + } +} + +impl Serialize for NodeRegistry { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let mut s = serializer.serialize_struct("NodeRegistry", 2)?; + // serialize only the nodes metadata + s.serialize_field( + "nodes", + // Since the serializer methods are consuming, we can't call `serialize_struct` and `collect_map`. + // This genius solution of creating an inline struct and impl `Serialize` on it is based on + // the code that `#[derive(Serialize)]` generates. + { + struct SerializeWith<'a> { + value: &'a NodeRegistry, + } + impl<'a> Serialize for SerializeWith<'a> { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.collect_map( + self.value + .nodes + .iter() + .map(|(k, v)| (k.clone(), &v.metadata)), + ) + } + } + &SerializeWith { value: self } + }, + )?; + s.serialize_field("types", self.data.schema_generator.definitions())?; + s.end() + } +} + +#[non_exhaustive] +pub struct NodeBuilderOptions { + pub id: BuilderId, + pub name: Option, +} + +impl NodeBuilderOptions { + pub fn new(id: impl ToString) -> Self { + Self { + id: id.to_string(), + name: None, + } + } + + pub fn with_name(mut self, name: impl ToString) -> Self { + self.name = Some(name.to_string()); + self + } +} + +#[cfg(test)] +mod tests { + use schemars::JsonSchema; + use serde::Deserialize; + + use super::*; + + fn multiply3(i: i64) -> i64 { + i * 3 + } + + #[test] + fn test_register_node_builder() { + let mut registry = NodeRegistry::default(); + registry + .opt_out() + .no_response_cloning() + .register_node_builder( + NodeBuilderOptions::new("multiply3_uncloneable").with_name("Test Name"), + |builder, _config: ()| builder.create_map_block(multiply3), + ); + let registration = registry.get_registration("multiply3_uncloneable").unwrap(); + assert!(registration.metadata.request.deserializable); + assert!(registration.metadata.response.serializable); + assert!(!registration.metadata.response.cloneable); + assert_eq!(registration.metadata.response.unzip_slots, 0); + } + + #[test] + fn test_register_cloneable_node() { + let mut registry = NodeRegistry::default(); + registry.register_node_builder( + NodeBuilderOptions::new("multiply3").with_name("Test Name"), + |builder, _config: ()| builder.create_map_block(multiply3), + ); + let registration = registry.get_registration("multiply3").unwrap(); + assert!(registration.metadata.request.deserializable); + assert!(registration.metadata.response.serializable); + assert!(registration.metadata.response.cloneable); + assert_eq!(registration.metadata.response.unzip_slots, 0); + } + + #[test] + fn test_register_unzippable_node() { + let mut registry = NodeRegistry::default(); + let tuple_resp = |_: ()| -> (i64,) { (1,) }; + registry + .opt_out() + .no_response_cloning() + .register_node_builder( + NodeBuilderOptions::new("multiply3_uncloneable").with_name("Test Name"), + move |builder: &mut Builder, _config: ()| builder.create_map_block(tuple_resp), + ) + .with_unzip(); + let registration = registry.get_registration("multiply3_uncloneable").unwrap(); + assert!(registration.metadata.request.deserializable); + assert!(registration.metadata.response.serializable); + assert!(!registration.metadata.response.cloneable); + assert_eq!(registration.metadata.response.unzip_slots, 1); + } + + #[test] + fn test_register_splittable_node() { + let mut registry = NodeRegistry::default(); + let vec_resp = |_: ()| -> Vec { vec![1, 2] }; + registry + .register_node_builder( + NodeBuilderOptions::new("vec_resp").with_name("Test Name"), + move |builder: &mut Builder, _config: ()| builder.create_map_block(vec_resp), + ) + .with_split(); + let registration = registry.get_registration("vec_resp").unwrap(); + assert!(registration.metadata.response.splittable); + + let map_resp = |_: ()| -> HashMap { HashMap::new() }; + registry + .register_node_builder( + NodeBuilderOptions::new("map_resp").with_name("Test Name"), + move |builder: &mut Builder, _config: ()| builder.create_map_block(map_resp), + ) + .with_split(); + + let registration = registry.get_registration("map_resp").unwrap(); + assert!(registration.metadata.response.splittable); + + registry.register_node_builder( + NodeBuilderOptions::new("not_splittable").with_name("Test Name"), + move |builder: &mut Builder, _config: ()| builder.create_map_block(map_resp), + ); + let registration = registry.get_registration("not_splittable").unwrap(); + assert!(!registration.metadata.response.splittable); + } + + #[test] + fn test_register_with_config() { + let mut registry = NodeRegistry::default(); + + #[derive(Deserialize, JsonSchema)] + struct TestConfig { + by: i64, + } + + registry.register_node_builder( + NodeBuilderOptions::new("multiply").with_name("Test Name"), + move |builder: &mut Builder, config: TestConfig| { + builder.create_map_block(move |operand: i64| operand * config.by) + }, + ); + assert!(registry.get_registration("multiply").is_ok()); + } + + struct NonSerializableRequest {} + + #[test] + fn test_register_opaque_node() { + let opaque_request_map = |_: NonSerializableRequest| {}; + + let mut registry = NodeRegistry::default(); + registry + .opt_out() + .no_request_deserializing() + .no_response_cloning() + .register_node_builder( + NodeBuilderOptions::new("opaque_request_map").with_name("Test Name"), + move |builder, _config: ()| builder.create_map_block(opaque_request_map), + ); + assert!(registry.get_registration("opaque_request_map").is_ok()); + let registration = registry.get_registration("opaque_request_map").unwrap(); + assert!(!registration.metadata.request.deserializable); + assert!(registration.metadata.response.serializable); + assert!(!registration.metadata.response.cloneable); + assert_eq!(registration.metadata.response.unzip_slots, 0); + + let opaque_response_map = |_: ()| NonSerializableRequest {}; + registry + .opt_out() + .no_response_serializing() + .no_response_cloning() + .register_node_builder( + NodeBuilderOptions::new("opaque_response_map").with_name("Test Name"), + move |builder: &mut Builder, _config: ()| { + builder.create_map_block(opaque_response_map) + }, + ); + assert!(registry.get_registration("opaque_response_map").is_ok()); + let registration = registry.get_registration("opaque_response_map").unwrap(); + assert!(registration.metadata.request.deserializable); + assert!(!registration.metadata.response.serializable); + assert!(!registration.metadata.response.cloneable); + assert_eq!(registration.metadata.response.unzip_slots, 0); + + let opaque_req_resp_map = |_: NonSerializableRequest| NonSerializableRequest {}; + registry + .opt_out() + .no_request_deserializing() + .no_response_serializing() + .no_response_cloning() + .register_node_builder( + NodeBuilderOptions::new("opaque_req_resp_map").with_name("Test Name"), + move |builder: &mut Builder, _config: ()| { + builder.create_map_block(opaque_req_resp_map) + }, + ); + assert!(registry.get_registration("opaque_req_resp_map").is_ok()); + let registration = registry.get_registration("opaque_req_resp_map").unwrap(); + assert!(!registration.metadata.request.deserializable); + assert!(!registration.metadata.response.serializable); + assert!(!registration.metadata.response.cloneable); + assert_eq!(registration.metadata.response.unzip_slots, 0); + } +} diff --git a/src/diagram/serialization.rs b/src/diagram/serialization.rs new file mode 100644 index 00000000..a3ba4b22 --- /dev/null +++ b/src/diagram/serialization.rs @@ -0,0 +1,253 @@ +use std::any::TypeId; + +use schemars::{gen::SchemaGenerator, schema::Schema, JsonSchema}; +use serde::{de::DeserializeOwned, Serialize}; +use tracing::debug; + +use super::MessageRegistry; + +#[derive(thiserror::Error, Debug)] +pub enum SerializationError { + #[error("not supported")] + NotSupported, + + #[error(transparent)] + JsonError(#[from] serde_json::Error), +} + +pub trait DynType { + /// Returns the type name of the request. Note that the type name must be unique. + fn type_name() -> String; + + fn json_schema(gen: &mut SchemaGenerator) -> schemars::schema::Schema; +} + +impl DynType for T +where + T: JsonSchema, +{ + fn type_name() -> String { + ::schema_name() + } + + fn json_schema(gen: &mut SchemaGenerator) -> schemars::schema::Schema { + gen.subschema_for::() + } +} + +#[derive(Clone, Debug, Serialize)] +pub struct RequestMetadata { + /// The JSON Schema of the request. + pub(super) schema: Schema, + + /// Indicates if the request is deserializable. + pub(super) deserializable: bool, +} + +#[derive(Clone, Debug, Serialize)] +pub struct ResponseMetadata { + /// The JSON Schema of the response. + pub(super) schema: Schema, + + /// Indicates if the response is serializable. + pub(super) serializable: bool, + + /// Indicates if the response is cloneable, a node must have a cloneable response + /// in order to connect it to a "fork clone" operation. + pub(super) cloneable: bool, + + /// The number of unzip slots that a response have, a value of 0 means that the response + /// cannot be unzipped. This should be > 0 only if the response is a tuple. + pub(super) unzip_slots: usize, + + /// Indicates if the response can fork result + pub(super) fork_result: bool, + + /// Indiciates if the response can be split + pub(super) splittable: bool, +} + +impl ResponseMetadata { + pub(super) fn new(schema: Schema, serializable: bool, cloneable: bool) -> ResponseMetadata { + ResponseMetadata { + schema, + serializable, + cloneable, + unzip_slots: 0, + fork_result: false, + splittable: false, + } + } +} + +pub trait SerializeMessage { + fn type_name() -> String; + + fn json_schema(gen: &mut SchemaGenerator) -> Option; + + fn to_json(v: &T) -> Result; + + fn serializable() -> bool; +} + +#[derive(Default)] +pub struct DefaultSerializer; + +impl SerializeMessage for DefaultSerializer +where + T: Serialize + DynType, +{ + fn type_name() -> String { + T::type_name() + } + + fn json_schema(gen: &mut SchemaGenerator) -> Option { + Some(T::json_schema(gen)) + } + + fn to_json(v: &T) -> Result { + serde_json::to_value(v).map_err(|err| SerializationError::from(err)) + } + + fn serializable() -> bool { + true + } +} + +pub trait DeserializeMessage { + fn type_name() -> String; + + fn json_schema(gen: &mut SchemaGenerator) -> Option; + + fn from_json(json: serde_json::Value) -> Result; + + fn deserializable() -> bool; +} + +#[derive(Default)] +pub struct DefaultDeserializer; + +impl DeserializeMessage for DefaultDeserializer +where + T: DeserializeOwned + DynType, +{ + fn type_name() -> String { + T::type_name() + } + + fn json_schema(gen: &mut SchemaGenerator) -> Option { + Some(T::json_schema(gen)) + } + + fn from_json(json: serde_json::Value) -> Result { + serde_json::from_value::(json).map_err(|err| SerializationError::from(err)) + } + + fn deserializable() -> bool { + true + } +} + +#[derive(Default)] +pub struct OpaqueMessageSerializer; + +impl SerializeMessage for OpaqueMessageSerializer { + fn type_name() -> String { + std::any::type_name::().to_string() + } + + fn json_schema(_gen: &mut SchemaGenerator) -> Option { + None + } + + fn to_json(_v: &T) -> Result { + Err(SerializationError::NotSupported) + } + + fn serializable() -> bool { + false + } +} + +#[derive(Default)] +pub struct OpaqueMessageDeserializer; + +impl DeserializeMessage for OpaqueMessageDeserializer { + fn type_name() -> String { + std::any::type_name::().to_string() + } + + fn json_schema(_gen: &mut SchemaGenerator) -> Option { + None + } + + fn from_json(_json: serde_json::Value) -> Result { + Err(SerializationError::NotSupported) + } + + fn deserializable() -> bool { + false + } +} + +pub(super) fn register_deserialize(registry: &mut MessageRegistry) +where + Deserializer: DeserializeMessage, + T: Send + Sync + 'static, +{ + if registry.deserialize_impls.contains_key(&TypeId::of::()) + || !Deserializer::deserializable() + { + return; + } + + debug!( + "register deserialize for type: {}, with deserializer: {}", + std::any::type_name::(), + std::any::type_name::() + ); + registry.deserialize_impls.insert( + TypeId::of::(), + Box::new(|builder, output| { + debug!("deserialize output: {:?}", output); + let receiver = + builder.create_map_block(|json: serde_json::Value| Deserializer::from_json(json)); + builder.connect(output, receiver.input); + let deserialized_output = receiver + .output + .chain(builder) + .cancel_on_err() + .output() + .into(); + debug!("deserialized output: {:?}", deserialized_output); + Ok(deserialized_output) + }), + ); +} + +pub(super) fn register_serialize(registry: &mut MessageRegistry) +where + Serializer: SerializeMessage, + T: Send + Sync + 'static, +{ + if registry.serialize_impls.contains_key(&TypeId::of::()) || !Serializer::serializable() { + return; + } + + debug!( + "register serialize for type: {}, with serializer: {}", + std::any::type_name::(), + std::any::type_name::() + ); + registry.serialize_impls.insert( + TypeId::of::(), + Box::new(|builder, output| { + debug!("serialize output: {:?}", output); + let n = builder.create_map_block(|resp: T| Serializer::to_json(&resp)); + builder.connect(output.into_output()?, n.input); + let serialized_output = n.output.chain(builder).cancel_on_err().output(); + debug!("serialized output: {:?}", serialized_output); + Ok(serialized_output) + }), + ); +} diff --git a/src/diagram/split_serialized.rs b/src/diagram/split_serialized.rs new file mode 100644 index 00000000..585ffe16 --- /dev/null +++ b/src/diagram/split_serialized.rs @@ -0,0 +1,622 @@ +/* + * Copyright (C) 2024 Open Source Robotics Foundation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * +*/ + +use std::{any::TypeId, collections::HashMap, usize}; + +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use tracing::debug; + +use crate::{ + Builder, Chain, ForRemaining, FromSequential, FromSpecific, ListSplitKey, MapSplitKey, + OperationResult, SplitDispatcher, Splittable, +}; + +use super::{ + impls::{DefaultImpl, NotSupported}, + join::register_join_impl, + register_serialize, DiagramError, DynOutput, MessageRegistry, NextOperation, SerializeMessage, +}; + +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +#[serde(rename_all = "snake_case")] +pub struct SplitOp { + #[serde(default)] + pub(super) sequential: Vec, + + #[serde(default)] + pub(super) keyed: HashMap, + + pub(super) remaining: Option, +} + +impl Splittable for Value { + type Key = MapSplitKey; + type Item = (JsonPosition, Value); + + fn validate(_: &Self::Key) -> bool { + true + } + + fn next(key: &Option) -> Option { + MapSplitKey::next(key) + } + + fn split(self, mut dispatcher: SplitDispatcher<'_, Self::Key, Self::Item>) -> OperationResult { + match self { + Value::Array(array) => { + for (index, value) in array.into_iter().enumerate() { + let position = JsonPosition::ArrayElement(index); + match dispatcher.outputs_for(&MapSplitKey::Sequential(index)) { + Some(outputs) => { + outputs.push((position, value)); + } + None => { + if let Some(outputs) = dispatcher.outputs_for(&MapSplitKey::Remaining) { + outputs.push((position, value)); + } + } + } + } + } + Value::Object(map) => { + let mut next_seq = 0; + for (name, value) in map.into_iter() { + let key = MapSplitKey::Specific(name); + match dispatcher.outputs_for(&key) { + Some(outputs) => { + // SAFETY: This key was initialized as MapSplitKey::Specific earlier + // in the function and is immutable, so this method is guaranteed to + // return `Some` + let position = JsonPosition::ObjectField(key.specific().unwrap()); + outputs.push((position, value)); + } + None => { + // No connection to the specific field name, so let's + // check for a sequential connection. + let seq = MapSplitKey::Sequential(next_seq); + next_seq += 1; + + // SAFETY: This key was initialized as MapSplitKey::Specific earlier + // in the function and is immutable, so this method is guaranteed to + // return `Some` + let position = JsonPosition::ObjectField(key.specific().unwrap()); + match dispatcher.outputs_for(&seq) { + Some(outputs) => outputs.push((position, value)), + None => { + // No connection to this point in the sequence + // so let's send it to any remaining connection. + let remaining = MapSplitKey::Remaining; + if let Some(outputs) = dispatcher.outputs_for(&remaining) { + outputs.push((position, value)); + } + } + } + } + } + } + } + singular => { + // This is a singular value, so it cannot be split. We will + // send it to the first sequential connection or else to the + // remaining connection. + let position = JsonPosition::Singular; + match dispatcher.outputs_for(&MapSplitKey::Sequential(0)) { + Some(outputs) => outputs.push((position, singular)), + None => { + let remaining = MapSplitKey::Remaining; + if let Some(outputs) = dispatcher.outputs_for(&remaining) { + outputs.push((position, singular)); + } + } + } + } + } + + Ok(()) + } +} + +/// Where was this positioned within the JSON structure. +#[derive( + Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize, JsonSchema, +)] +pub enum JsonPosition { + /// This was the only item, e.g. the [`Value`] was a [`Null`][Value::Null], + /// [`Bool`][Value::Bool], [`Number`][Value::Number], or [`String`][Value::String]. + Singular, + /// The item came from an array. + ArrayElement(usize), + /// The item was a field of an object. + ObjectField(String), +} + +impl FromSpecific for ListSplitKey { + type SpecificKey = String; + + fn from_specific(specific: Self::SpecificKey) -> Self { + match specific.parse::() { + Ok(seq) => Self::Sequential(seq), + Err(_) => Self::Remaining, + } + } +} + +#[derive(Debug)] +pub struct DynSplitOutputs<'a> { + pub(super) outputs: HashMap<&'a NextOperation, DynOutput>, + pub(super) remaining: DynOutput, +} + +pub(super) fn split_chain<'a, T>( + chain: Chain, + split_op: &'a SplitOp, +) -> Result, DiagramError> +where + T: Send + Sync + 'static + Splittable, + T::Key: FromSequential + FromSpecific + ForRemaining, +{ + debug!( + "split chain of type: {:?}, op: {:?}", + TypeId::of::(), + split_op + ); + + enum SeqOrKey<'inner> { + Seq(usize), + Key(&'inner String), + } + + chain.split(|mut sb| -> Result { + let outputs: HashMap<&NextOperation, DynOutput> = split_op + .sequential + .iter() + .enumerate() + .map(|(i, op_id)| (SeqOrKey::Seq(i), op_id)) + .chain( + split_op + .keyed + .iter() + .map(|(k, op_id)| (SeqOrKey::Key(k), op_id)), + ) + .map( + |(ki, op_id)| -> Result<(&NextOperation, DynOutput), DiagramError> { + match ki { + SeqOrKey::Seq(i) => Ok((op_id, sb.sequential_output(i)?.into())), + SeqOrKey::Key(k) => Ok((op_id, sb.specific_output(k.clone())?.into())), + } + }, + ) + .collect::, _>>()?; + let split_outputs = DynSplitOutputs { + outputs, + remaining: sb.remaining_output()?.into(), + }; + debug!("splitted outputs: {:?}", split_outputs); + Ok(split_outputs) + }) +} + +pub trait DynSplit { + const SUPPORTED: bool; + + fn dyn_split<'a>( + builder: &mut Builder, + output: DynOutput, + split_op: &'a SplitOp, + ) -> Result, DiagramError>; + + fn on_register(registry: &mut MessageRegistry); +} + +impl DynSplit for NotSupported { + const SUPPORTED: bool = false; + + fn dyn_split<'a>( + _builder: &mut Builder, + _output: DynOutput, + _split_op: &'a SplitOp, + ) -> Result, DiagramError> { + Err(DiagramError::NotSplittable) + } + + fn on_register(_registry: &mut MessageRegistry) {} +} + +impl DynSplit for DefaultImpl +where + T: Send + Sync + 'static + Splittable, + T::Key: FromSequential + FromSpecific + ForRemaining, + Serializer: SerializeMessage + SerializeMessage>, +{ + const SUPPORTED: bool = true; + + fn dyn_split<'a>( + builder: &mut Builder, + output: DynOutput, + split_op: &'a SplitOp, + ) -> Result, DiagramError> { + let chain = output.into_output::()?.chain(builder); + split_chain(chain, split_op) + } + + fn on_register(registry: &mut MessageRegistry) { + register_serialize::(registry); + register_join_impl::(registry); + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use crate::{testing::*, *}; + use diagram::testing::DiagramTestFixture; + use serde::{Deserialize, Serialize}; + use serde_json::json; + use test_log::test; + + #[derive(Serialize, Deserialize, Debug, PartialEq, Eq)] + struct Person { + name: String, + age: i8, + } + + impl Person { + fn new(name: impl Into, age: i8) -> Self { + Self { + name: name.into(), + age, + } + } + } + + #[test] + fn test_json_value_split() { + let mut context = TestingContext::minimal_plugins(); + + let value = json!( + { + "foo": 10, + "bar": "hello", + "jobs": { + "engineer": { + "name": "Alice", + "age": 28, + }, + "designer": { + "name": "Bob", + "age": 67, + } + } + } + ); + + // Test multiple layers of splitting + let workflow = context.spawn_io_workflow(|scope, builder| { + scope.input.chain(builder).split(|split| { + split + // Get only the jobs data from the json + .specific_branch("jobs".to_owned(), |chain| { + chain.value().split(|jobs| { + jobs + // Grab the "first" job in the list, which should be + // alphabetical by default, so we should get the + // "designer" job. + .next_branch(|_, person| { + person + .value() + .map_block(serde_json::from_value) + .cancel_on_err() + .connect(scope.terminate); + }) + .unwrap() + .unused(); + }); + }) + .unwrap() + .unused(); + }); + }); + + let mut promise = + context.command(|commands| commands.request(value, workflow).take_response()); + + context.run_with_conditions(&mut promise, 1); + assert!(context.no_unhandled_errors()); + + let result: Person = promise.take().available().unwrap(); + assert_eq!(result, Person::new("Bob", 67)); + + // Test serializing and splitting a tuple, then deserializing the split item + let workflow = context.spawn_io_workflow(|scope, builder| { + scope + .input + .chain(builder) + .map_block(serde_json::to_value) + .cancel_on_err() + .split(|split| { + split + // The second branch of our test input should have + // seralized Person data + .sequential_branch(1, |chain| { + chain + .value() + .map_block(serde_json::from_value) + .cancel_on_err() + .connect(scope.terminate); + }) + .unwrap() + .unused(); + }); + }); + + let mut promise = context.command(|commands| { + commands + .request((3.14159, Person::new("Charlie", 42)), workflow) + .take_response() + }); + + context.run_with_conditions(&mut promise, 1); + assert!(context.no_unhandled_errors()); + + let result: Person = promise.take().available().unwrap(); + assert_eq!(result, Person::new("Charlie", 42)); + } + + #[test] + fn test_split_list() { + let mut fixture = DiagramTestFixture::new(); + + fn split_list(_: i64) -> Vec { + vec![1, 2, 3] + } + + fixture + .registry + .register_node_builder( + NodeBuilderOptions::new("split_list".to_string()), + |builder: &mut Builder, _config: ()| builder.create_map_block(&split_list), + ) + .with_split(); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "start": "op1", + "ops": { + "op1": { + "type": "node", + "builder": "split_list", + "next": "split", + }, + "split": { + "type": "split", + "sequential": [{ "builtin": "terminate" }], + }, + }, + })) + .unwrap(); + + let result = fixture + .spawn_and_run(&diagram, serde_json::Value::from(4)) + .unwrap(); + assert_eq!(result[1], 1); + } + + #[test] + fn test_split_list_with_key() { + let mut fixture = DiagramTestFixture::new(); + + fn split_list(_: i64) -> Vec { + vec![1, 2, 3] + } + + fixture + .registry + .register_node_builder( + NodeBuilderOptions::new("split_list".to_string()), + |builder: &mut Builder, _config: ()| builder.create_map_block(&split_list), + ) + .with_split(); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "start": "op1", + "ops": { + "op1": { + "type": "node", + "builder": "split_list", + "next": "split", + }, + "split": { + "type": "split", + "keyed": { "1": { "builtin": "terminate" } }, + }, + }, + })) + .unwrap(); + + let result = fixture + .spawn_and_run(&diagram, serde_json::Value::from(4)) + .unwrap(); + assert_eq!(result[1], 2); + } + + #[test] + fn test_split_map() { + let mut fixture = DiagramTestFixture::new(); + + fn split_map(_: i64) -> HashMap { + HashMap::from([ + ("a".to_string(), 1), + ("b".to_string(), 2), + ("c".to_string(), 3), + ]) + } + + fixture + .registry + .register_node_builder( + NodeBuilderOptions::new("split_map".to_string()), + |builder: &mut Builder, _config: ()| builder.create_map_block(&split_map), + ) + .with_split(); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "start": "op1", + "ops": { + "op1": { + "type": "node", + "builder": "split_map", + "next": "split", + }, + "split": { + "type": "split", + "keyed": { "b": { "builtin": "terminate" } }, + }, + }, + })) + .unwrap(); + + let result = fixture + .spawn_and_run(&diagram, serde_json::Value::from(4)) + .unwrap(); + assert_eq!(result[1], 2); + } + + #[test] + fn test_split_dual_key_seq() { + let mut fixture = DiagramTestFixture::new(); + + fn split_map(_: i64) -> HashMap { + HashMap::from([("a".to_string(), 1), ("b".to_string(), 2)]) + } + + fixture + .registry + .register_node_builder( + NodeBuilderOptions::new("split_map".to_string()), + |builder: &mut Builder, _config: ()| builder.create_map_block(&split_map), + ) + .with_split(); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "start": "op1", + "ops": { + "op1": { + "type": "node", + "builder": "split_map", + "next": "split", + }, + "split": { + "type": "split", + "keyed": { "a": { "builtin": "dispose" } }, + "sequential": [{ "builtin": "terminate" }], + }, + }, + })) + .unwrap(); + + let result = fixture + .spawn_and_run(&diagram, serde_json::Value::from(4)) + .unwrap(); + // "a" is "eaten" up by the keyed path, so we should be the result of "b". + assert_eq!(result[1], 2); + } + + #[test] + fn test_split_remaining() { + let mut fixture = DiagramTestFixture::new(); + + fn split_list(_: i64) -> Vec { + vec![1, 2, 3] + } + + fixture + .registry + .register_node_builder( + NodeBuilderOptions::new("split_list".to_string()), + |builder: &mut Builder, _config: ()| builder.create_map_block(&split_list), + ) + .with_split(); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "start": "op1", + "ops": { + "op1": { + "type": "node", + "builder": "split_list", + "next": "split", + }, + "split": { + "type": "split", + "sequential": [{ "builtin": "dispose" }], + "remaining": { "builtin": "terminate" }, + }, + }, + })) + .unwrap(); + + let result = fixture + .spawn_and_run(&diagram, serde_json::Value::from(4)) + .unwrap(); + assert_eq!(result[1], 2); + } + + #[test] + fn test_split_start() { + let mut fixture = DiagramTestFixture::new(); + + fn get_split_value(pair: (JsonPosition, serde_json::Value)) -> serde_json::Value { + pair.1 + } + + fixture.registry.register_node_builder( + NodeBuilderOptions::new("get_split_value".to_string()), + |builder, _config: ()| builder.create_map_block(get_split_value), + ); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "start": "split", + "ops": { + "split": { + "type": "split", + "sequential": ["getSplitValue"], + }, + "getSplitValue": { + "type": "node", + "builder": "get_split_value", + "next": { "builtin": "terminate" }, + }, + }, + })) + .unwrap(); + + let result = fixture + .spawn_and_run( + &diagram, + serde_json::to_value(HashMap::from([("test".to_string(), 1)])).unwrap(), + ) + .unwrap(); + assert_eq!(result, 1); + } +} diff --git a/src/diagram/testing.rs b/src/diagram/testing.rs new file mode 100644 index 00000000..fc0a553a --- /dev/null +++ b/src/diagram/testing.rs @@ -0,0 +1,137 @@ +use std::error::Error; + +use crate::{ + testing::TestingContext, Builder, RequestExt, RunCommandsOnWorldExt, Service, StreamPack, +}; + +use super::{ + Diagram, DiagramError, DiagramStart, DiagramTerminate, NodeBuilderOptions, NodeRegistry, +}; + +pub(super) struct DiagramTestFixture { + pub(super) context: TestingContext, + pub(super) registry: NodeRegistry, +} + +impl DiagramTestFixture { + pub(super) fn new() -> Self { + Self { + context: TestingContext::minimal_plugins(), + registry: new_registry_with_basic_nodes(), + } + } + + pub(super) fn spawn_workflow( + &mut self, + diagram: &Diagram, + ) -> Result, DiagramError> { + self.context + .app + .world + .command(|cmds| diagram.spawn_workflow(cmds, &self.registry)) + } + + /// Equivalent to `self.spawn_workflow::<()>(diagram)` + pub(super) fn spawn_io_workflow( + &mut self, + diagram: &Diagram, + ) -> Result, DiagramError> { + self.spawn_workflow::<()>(diagram) + } + + /// Spawns a workflow from a diagram then run the workflow until completion. + /// Returns the result of the workflow. + pub(super) fn spawn_and_run( + &mut self, + diagram: &Diagram, + request: serde_json::Value, + ) -> Result> { + let workflow = self.spawn_workflow::<()>(diagram)?; + let mut promise = self + .context + .command(|cmds| cmds.request(request, workflow).take_response()); + self.context.run_while_pending(&mut promise); + let taken = promise.take(); + if taken.is_available() { + Ok(taken.available().unwrap()) + } else if taken.is_cancelled() { + let cancellation = taken.cancellation().unwrap(); + Err(cancellation.clone().into()) + } else { + Err(String::from("promise is in invalid state").into()) + } + } +} + +fn multiply3(i: i64) -> i64 { + i * 3 +} + +fn multiply3_5(x: i64) -> (i64, i64) { + (x * 3, x * 5) +} + +struct Unserializable; + +fn opaque(_: Unserializable) -> Unserializable { + Unserializable {} +} + +fn opaque_request(_: Unserializable) {} + +fn opaque_response(_: i64) -> Unserializable { + Unserializable {} +} + +/// create a new node registry with some basic nodes registered +fn new_registry_with_basic_nodes() -> NodeRegistry { + let mut registry = NodeRegistry::default(); + registry + .opt_out() + .no_response_cloning() + .register_node_builder( + NodeBuilderOptions::new("multiply3_uncloneable"), + |builder: &mut Builder, _config: ()| builder.create_map_block(multiply3), + ); + registry.register_node_builder( + NodeBuilderOptions::new("multiply3"), + |builder: &mut Builder, _config: ()| builder.create_map_block(multiply3), + ); + registry + .register_node_builder( + NodeBuilderOptions::new("multiply3_5"), + |builder: &mut Builder, _config: ()| builder.create_map_block(multiply3_5), + ) + .with_unzip(); + + registry.register_node_builder( + NodeBuilderOptions::new("multiplyBy"), + |builder: &mut Builder, config: i64| builder.create_map_block(move |a: i64| a * config), + ); + + registry + .opt_out() + .no_request_deserializing() + .no_response_serializing() + .no_response_cloning() + .register_node_builder( + NodeBuilderOptions::new("opaque"), + |builder: &mut Builder, _config: ()| builder.create_map_block(opaque), + ); + registry + .opt_out() + .no_request_deserializing() + .register_node_builder( + NodeBuilderOptions::new("opaque_request"), + |builder: &mut Builder, _config: ()| builder.create_map_block(opaque_request), + ); + registry + .opt_out() + .no_response_serializing() + .no_response_cloning() + .register_node_builder( + NodeBuilderOptions::new("opaque_response"), + |builder: &mut Builder, _config: ()| builder.create_map_block(opaque_response), + ); + registry +} diff --git a/src/diagram/transform.rs b/src/diagram/transform.rs new file mode 100644 index 00000000..9ad17a67 --- /dev/null +++ b/src/diagram/transform.rs @@ -0,0 +1,206 @@ +use std::{any::TypeId, error::Error}; + +use cel_interpreter::{Context, ExecutionError, ParseError, Program}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use thiserror::Error; +use tracing::debug; + +use crate::{Builder, Output}; + +use super::{DiagramError, DynOutput, NextOperation, NodeRegistry}; + +#[derive(Error, Debug)] +pub enum TransformError { + #[error(transparent)] + Parse(#[from] ParseError), + + #[error(transparent)] + Execution(#[from] ExecutionError), + + #[error(transparent)] + Other(#[from] Box), +} + +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +#[serde(rename_all = "snake_case")] +pub struct TransformOp { + pub(super) cel: String, + pub(super) next: NextOperation, +} + +pub(super) fn transform_output( + builder: &mut Builder, + registry: &NodeRegistry, + output: DynOutput, + transform_op: &TransformOp, +) -> Result, DiagramError> { + debug!("transform output: {:?}, op: {:?}", output, transform_op); + + let json_output = if output.type_id == TypeId::of::() { + output.into_output() + } else { + let serialize = registry + .data + .serialize_impls + .get(&output.type_id) + .ok_or(DiagramError::NotSerializable)?; + serialize(builder, output) + }?; + + let program = Program::compile(&transform_op.cel).map_err(|err| TransformError::Parse(err))?; + let transform_node = builder.create_map_block( + move |req: serde_json::Value| -> Result { + let mut context = Context::default(); + context + .add_variable("request", req) + // cannot keep the original error because it is not Send + Sync + .map_err(|err| TransformError::Other(err.to_string().into()))?; + program + .execute(&context)? + .json() + // cel_interpreter::json is private so we have to type erase ConvertToJsonError + .map_err(|err| TransformError::Other(err.to_string().into())) + }, + ); + builder.connect(json_output, transform_node.input); + let transformed_output = transform_node + .output + .chain(builder) + .cancel_on_err() + .output(); + debug!("transformed output: {:?}", transformed_output); + Ok(transformed_output) +} + +#[cfg(test)] +mod tests { + use serde_json::json; + use test_log::test; + + use crate::{diagram::testing::DiagramTestFixture, Diagram}; + + #[test] + fn test_transform_node_response() { + let mut fixture = DiagramTestFixture::new(); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "start": "op1", + "ops": { + "op1": { + "type": "node", + "builder": "multiply3_uncloneable", + "next": "transform", + }, + "transform": { + "type": "transform", + "cel": "777", + "next": { "builtin": "terminate" }, + }, + }, + })) + .unwrap(); + + let result = fixture + .spawn_and_run(&diagram, serde_json::Value::from(4)) + .unwrap(); + assert_eq!(result, 777); + } + + #[test] + fn test_transform_scope_start() { + let mut fixture = DiagramTestFixture::new(); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "start": "transform", + "ops": { + "transform": { + "type": "transform", + "cel": "777", + "next": { "builtin": "terminate" }, + }, + }, + })) + .unwrap(); + + let result = fixture + .spawn_and_run(&diagram, serde_json::Value::from(4)) + .unwrap(); + assert_eq!(result, 777); + } + + #[test] + fn test_cel_multiply() { + let mut fixture = DiagramTestFixture::new(); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "start": "transform", + "ops": { + "transform": { + "type": "transform", + "cel": "int(request) * 3", + "next": { "builtin": "terminate" }, + }, + }, + })) + .unwrap(); + + let result = fixture + .spawn_and_run(&diagram, serde_json::Value::from(4)) + .unwrap(); + assert_eq!(result, 12); + } + + #[test] + fn test_cel_compose() { + let mut fixture = DiagramTestFixture::new(); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "start": "transform", + "ops": { + "transform": { + "type": "transform", + "cel": "{ \"request\": request, \"seven\": 7 }", + "next": { "builtin": "terminate" }, + }, + }, + })) + .unwrap(); + + let result = fixture + .spawn_and_run(&diagram, serde_json::Value::from(4)) + .unwrap(); + assert_eq!(result["request"], 4); + assert_eq!(result["seven"], 7); + } + + #[test] + fn test_cel_destructure() { + let mut fixture = DiagramTestFixture::new(); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "start": "transform", + "ops": { + "transform": { + "type": "transform", + "cel": "request.age", + "next": { "builtin": "terminate" }, + }, + }, + })) + .unwrap(); + + let request = json!({ + "name": "John", + "age": 40, + }); + + let result = fixture.spawn_and_run(&diagram, request).unwrap(); + assert_eq!(result, 40); + } +} diff --git a/src/diagram/unzip.rs b/src/diagram/unzip.rs new file mode 100644 index 00000000..648aca53 --- /dev/null +++ b/src/diagram/unzip.rs @@ -0,0 +1,253 @@ +use bevy_utils::all_tuples_with_size; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use tracing::debug; + +use crate::Builder; + +use super::{ + impls::{DefaultImpl, NotSupported}, + join::register_join_impl, + register_serialize as register_serialize_impl, DiagramError, DynOutput, MessageRegistry, + NextOperation, SerializeMessage, +}; + +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] +#[serde(rename_all = "snake_case")] +pub struct UnzipOp { + pub(super) next: Vec, +} + +pub trait DynUnzip { + const UNZIP_SLOTS: usize; + + fn dyn_unzip(builder: &mut Builder, output: DynOutput) -> Result, DiagramError>; + + /// Called when a node is registered. + fn on_register(registry: &mut MessageRegistry); +} + +impl DynUnzip for NotSupported { + const UNZIP_SLOTS: usize = 0; + + fn dyn_unzip( + _builder: &mut Builder, + _output: DynOutput, + ) -> Result, DiagramError> { + Err(DiagramError::NotUnzippable) + } + + fn on_register(_registry: &mut MessageRegistry) {} +} + +macro_rules! dyn_unzip_impl { + ($len:literal, $(($P:ident, $o:ident)),*) => { + impl<$($P),*, Serializer> DynUnzip<($($P,)*), Serializer> for DefaultImpl + where + $($P: Send + Sync + 'static),*, + Serializer: $(SerializeMessage<$P> +)* $(SerializeMessage> +)*, + { + const UNZIP_SLOTS: usize = $len; + + fn dyn_unzip( + builder: &mut Builder, + output: DynOutput + ) -> Result, DiagramError> { + debug!("unzip output: {:?}", output); + let mut outputs: Vec = Vec::with_capacity($len); + let chain = output.into_output::<($($P,)*)>()?.chain(builder); + let ($($o,)*) = chain.unzip(); + + $({ + outputs.push($o.into()); + })* + + debug!("unzipped outputs: {:?}", outputs); + Ok(outputs) + } + + fn on_register(registry: &mut MessageRegistry) + { + // Register serialize functions for all items in the tuple. + // For a tuple of (T1, T2, T3), registers serialize for T1, T2 and T3. + $( + register_serialize_impl::<$P, Serializer>(registry); + )* + + // Register join impls for T1, T2, T3... + $( + register_join_impl::<$P, Serializer>(registry); + )* + } + } + }; +} + +all_tuples_with_size!(dyn_unzip_impl, 1, 12, R, o); + +#[cfg(test)] +mod tests { + use serde_json::json; + use test_log::test; + + use crate::{diagram::testing::DiagramTestFixture, Diagram, DiagramError}; + + #[test] + fn test_unzip_not_unzippable() { + let mut fixture = DiagramTestFixture::new(); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "start": "op1", + "ops": { + "op1": { + "type": "node", + "builder": "multiply3_uncloneable", + "next": "unzip" + }, + "unzip": { + "type": "unzip", + "next": [{ "builtin": "terminate" }], + }, + }, + })) + .unwrap(); + + let err = fixture.spawn_io_workflow(&diagram).unwrap_err(); + assert!(matches!(err, DiagramError::NotUnzippable), "{}", err); + } + + #[test] + fn test_unzip_to_too_many_slots() { + let mut fixture = DiagramTestFixture::new(); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "start": "op1", + "ops": { + "op1": { + "type": "node", + "builder": "multiply3_5", + "next": "unzip" + }, + "unzip": { + "type": "unzip", + "next": ["op2", "op3", "op4"], + }, + "op2": { + "type": "node", + "builder": "multiply3_uncloneable", + "next": { "builtin": "terminate" }, + }, + "op3": { + "type": "node", + "builder": "multiply3_uncloneable", + "next": { "builtin": "terminate" }, + }, + "op4": { + "type": "node", + "builder": "multiply3_uncloneable", + "next": { "builtin": "terminate" }, + }, + }, + })) + .unwrap(); + + let err = fixture.spawn_io_workflow(&diagram).unwrap_err(); + assert!(matches!(err, DiagramError::NotUnzippable)); + } + + #[test] + fn test_unzip_to_terminate() { + let mut fixture = DiagramTestFixture::new(); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "start": "op1", + "ops": { + "op1": { + "type": "node", + "builder": "multiply3_5", + "next": "unzip" + }, + "unzip": { + "type": "unzip", + "next": [{ "builtin": "dispose" }, { "builtin": "terminate" }], + }, + }, + })) + .unwrap(); + + let result = fixture + .spawn_and_run(&diagram, serde_json::Value::from(4)) + .unwrap(); + assert_eq!(result, 20); + } + + #[test] + fn test_unzip() { + let mut fixture = DiagramTestFixture::new(); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "start": "op1", + "ops": { + "op1": { + "type": "node", + "builder": "multiply3_5", + "next": "unzip", + }, + "unzip": { + "type": "unzip", + "next": ["op2"], + }, + "op2": { + "type": "node", + "builder": "multiply3_uncloneable", + "next": { "builtin": "terminate" }, + }, + }, + })) + .unwrap(); + + let result = fixture + .spawn_and_run(&diagram, serde_json::Value::from(4)) + .unwrap(); + assert_eq!(result, 36); + } + + #[test] + fn test_unzip_with_dispose() { + let mut fixture = DiagramTestFixture::new(); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "start": "op1", + "ops": { + "op1": { + "type": "node", + "builder": "multiply3_5", + "next": "unzip", + }, + "unzip": { + "type": "unzip", + "next": ["dispose", "op2"], + }, + "dispose": { + "type": "dispose", + }, + "op2": { + "type": "node", + "builder": "multiply3_uncloneable", + "next": { "builtin": "terminate" }, + }, + }, + })) + .unwrap(); + + let result = fixture + .spawn_and_run(&diagram, serde_json::Value::from(4)) + .unwrap(); + assert_eq!(result, 60); + } +} diff --git a/src/diagram/workflow_builder.rs b/src/diagram/workflow_builder.rs new file mode 100644 index 00000000..777853cd --- /dev/null +++ b/src/diagram/workflow_builder.rs @@ -0,0 +1,568 @@ +use std::{any::TypeId, collections::HashMap}; + +use tracing::{debug, warn}; + +use crate::{ + diagram::join::serialize_and_join, unknown_diagram_error, Builder, InputSlot, Output, + StreamPack, +}; + +use super::{ + fork_clone::DynForkClone, impls::DefaultImpl, split_chain, transform::transform_output, + BuiltinTarget, Diagram, DiagramError, DiagramOperation, DiagramScope, DynInputSlot, DynOutput, + NextOperation, NodeOp, NodeRegistry, OperationId, SourceOperation, +}; + +struct Vertex<'a> { + op_id: &'a OperationId, + op: &'a DiagramOperation, + in_edges: Vec, + out_edges: Vec, +} + +struct Edge<'a> { + source: SourceOperation, + target: &'a NextOperation, + state: EdgeState<'a>, +} + +enum EdgeState<'a> { + Ready { + output: DynOutput, + /// The node that initially produces the output, may be `None` if there is no origin. + /// e.g. The entry point, or if the output passes through a `join` operation which + /// results in multiple origins. + origin: Option<&'a NodeOp>, + }, + Pending, +} + +pub(super) fn create_workflow<'a, Streams: StreamPack>( + scope: DiagramScope, + builder: &mut Builder, + registry: &NodeRegistry, + diagram: &'a Diagram, +) -> Result<(), DiagramError> { + // first create all the vertices + let mut vertices: HashMap<&OperationId, Vertex> = diagram + .ops + .iter() + .map(|(op_id, op)| { + ( + op_id, + Vertex { + op_id, + op, + in_edges: Vec::new(), + out_edges: Vec::new(), + }, + ) + }) + .collect(); + + // init with some capacity to reduce resizing. HashMap for faster removal. + // NOTE: There are many `unknown_diagram_errors!()` used when accessing this. + // In theory these accesses should never fail because the keys come from + // `vertices` which are built using the same data as `edges`. But we do modify + // `edges` while we are building the workflow so if an unknown error occurs, it is + // likely due to some logic issue in the algorithm. + let mut edges: HashMap = HashMap::with_capacity(diagram.ops.len() * 2); + + // process start separately because we need to consume the scope input + match &diagram.start { + NextOperation::Builtin { builtin } => match builtin { + BuiltinTarget::Terminate => { + // such a workflow is equivalent to an no-op. + builder.connect(scope.input, scope.terminate); + return Ok(()); + } + BuiltinTarget::Dispose => { + // bevy_impulse will immediate stop with an `CancellationCause::Unreachable` error + // if trying to run such a workflow. + return Ok(()); + } + }, + NextOperation::Target(op_id) => { + edges.insert( + edges.len(), + Edge { + source: SourceOperation::Builtin { + builtin: super::BuiltinSource::Start, + }, + target: &diagram.start, + state: EdgeState::Ready { + output: scope.input.into(), + origin: None, + }, + }, + ); + vertices + .get_mut(&op_id) + .ok_or(DiagramError::OperationNotFound(op_id.clone()))? + .in_edges + .push(0); + } + }; + + let mut inputs: HashMap<&OperationId, DynInputSlot> = HashMap::with_capacity(diagram.ops.len()); + + let mut terminate_edges: Vec = Vec::new(); + + let mut add_edge = |source: SourceOperation, + target: &'a NextOperation, + state: EdgeState<'a>| + -> Result<(), DiagramError> { + let source_id = if let SourceOperation::Source(source) = &source { + Some(source.clone()) + } else { + None + }; + + edges.insert( + edges.len(), + Edge { + source, + target, + state, + }, + ); + let new_edge_id = edges.len() - 1; + + if let Some(source_id) = source_id { + let source_vertex = vertices + .get_mut(&source_id) + .ok_or_else(|| DiagramError::OperationNotFound(source_id.clone()))?; + source_vertex.out_edges.push(new_edge_id); + } + + match target { + NextOperation::Target(target) => { + let target_vertex = vertices + .get_mut(target) + .ok_or_else(|| DiagramError::OperationNotFound(target.clone()))?; + target_vertex.in_edges.push(new_edge_id); + } + NextOperation::Builtin { builtin } => match builtin { + BuiltinTarget::Terminate => { + terminate_edges.push(new_edge_id); + } + BuiltinTarget::Dispose => {} + }, + } + Ok(()) + }; + + // create all the edges + for (op_id, op) in &diagram.ops { + match op { + DiagramOperation::Node(node_op) => { + let reg = registry.get_registration(&node_op.builder)?; + let n = reg.create_node(builder, node_op.config.clone())?; + inputs.insert(op_id, n.input); + add_edge( + op_id.clone().into(), + &node_op.next, + EdgeState::Ready { + output: n.output.into(), + origin: Some(node_op), + }, + )?; + } + DiagramOperation::ForkClone(fork_clone_op) => { + for next_op_id in fork_clone_op.next.iter() { + add_edge(op_id.clone().into(), next_op_id, EdgeState::Pending)?; + } + } + DiagramOperation::Unzip(unzip_op) => { + for next_op_id in unzip_op.next.iter() { + add_edge(op_id.clone().into(), next_op_id, EdgeState::Pending)?; + } + } + DiagramOperation::ForkResult(fork_result_op) => { + add_edge(op_id.clone().into(), &fork_result_op.ok, EdgeState::Pending)?; + add_edge( + op_id.clone().into(), + &fork_result_op.err, + EdgeState::Pending, + )?; + } + DiagramOperation::Split(split_op) => { + let next_op_ids: Vec<&NextOperation> = split_op + .sequential + .iter() + .chain(split_op.keyed.values()) + .collect(); + for next_op_id in next_op_ids { + add_edge(op_id.clone().into(), next_op_id, EdgeState::Pending)?; + } + if let Some(remaining) = &split_op.remaining { + add_edge(op_id.clone().into(), &remaining, EdgeState::Pending)?; + } + } + DiagramOperation::Join(join_op) => { + add_edge(op_id.clone().into(), &join_op.next, EdgeState::Pending)?; + } + DiagramOperation::Transform(transform_op) => { + add_edge(op_id.clone().into(), &transform_op.next, EdgeState::Pending)?; + } + DiagramOperation::Dispose => {} + } + } + + let mut unconnected_vertices: Vec<&Vertex> = vertices.values().collect(); + while unconnected_vertices.len() > 0 { + let ws = unconnected_vertices.clone(); + let ws_length = ws.len(); + unconnected_vertices.clear(); + + for v in ws { + let in_edges: Vec<&Edge> = v.in_edges.iter().map(|idx| &edges[idx]).collect(); + if in_edges + .iter() + .any(|e| matches!(e.state, EdgeState::Pending)) + { + // not all inputs are ready + debug!( + "defer connecting [{}] until all incoming edges are ready", + v.op_id + ); + unconnected_vertices.push(v); + continue; + } + + connect_vertex(builder, registry, &mut edges, &inputs, v)?; + } + + // can't connect anything and there are still remaining vertices + if unconnected_vertices.len() > 0 && ws_length == unconnected_vertices.len() { + warn!( + "the following operations are not connected {:?}", + unconnected_vertices + .iter() + .map(|v| v.op_id) + .collect::>() + ); + return Err(DiagramError::BadInterconnectChain); + } + } + + // connect terminate + for edge_id in terminate_edges { + let edge = edges.remove(&edge_id).ok_or(unknown_diagram_error!())?; + match edge.state { + EdgeState::Ready { output, origin } => { + let serialized_output = serialize(builder, registry, output, origin)?; + builder.connect(serialized_output, scope.terminate); + } + EdgeState::Pending => return Err(DiagramError::BadInterconnectChain), + } + } + + Ok(()) +} + +fn connect_vertex<'a>( + builder: &mut Builder, + registry: &NodeRegistry, + edges: &mut HashMap>, + inputs: &HashMap<&OperationId, DynInputSlot>, + target: &'a Vertex, +) -> Result<(), DiagramError> { + debug!("connecting [{}]", target.op_id); + match target.op { + // join needs all incoming edges to be connected at once so it is done at the vertex level + // instead of per edge level. + DiagramOperation::Join(join_op) => { + if target.in_edges.is_empty() { + return Err(DiagramError::EmptyJoin); + } + let mut outputs: HashMap = target + .in_edges + .iter() + .map(|e| { + let edge = edges.remove(e).ok_or(unknown_diagram_error!())?; + match edge.state { + EdgeState::Ready { output, origin: _ } => Ok((edge.source, output)), + // "expected all incoming edges to be ready" + _ => Err(unknown_diagram_error!()), + } + }) + .collect::, _>>()?; + + let mut ordered_outputs: Vec = Vec::with_capacity(target.in_edges.len()); + for source_id in join_op.inputs.iter() { + let o = outputs + .remove(source_id) + .ok_or(DiagramError::OperationNotFound(source_id.to_string()))?; + ordered_outputs.push(o); + } + + let joined_output = if join_op.no_serialize.unwrap_or(false) { + let join_impl = ®istry.data.join_impls[&ordered_outputs[0].type_id]; + join_impl(builder, ordered_outputs)? + } else { + serialize_and_join(builder, ®istry.data, ordered_outputs)?.into() + }; + + let out_edge = edges + .get_mut(&target.out_edges[0]) + .ok_or(unknown_diagram_error!())?; + out_edge.state = EdgeState::Ready { + output: joined_output, + origin: None, + }; + Ok(()) + } + // for other operations, each edge is independent, so we can connect at the edge level. + _ => { + for edge_id in target.in_edges.iter() { + connect_edge(builder, registry, edges, inputs, *edge_id, target)?; + } + Ok(()) + } + } +} + +fn connect_edge<'a>( + builder: &mut Builder, + registry: &NodeRegistry, + edges: &mut HashMap>, + inputs: &HashMap<&OperationId, DynInputSlot>, + edge_id: usize, + target: &Vertex, +) -> Result<(), DiagramError> { + let edge = edges.remove(&edge_id).ok_or(unknown_diagram_error!())?; + debug!( + "connect edge {:?}, source: {:?}, target: {:?}", + edge_id, edge.source, edge.target + ); + let (output, origin) = match edge.state { + EdgeState::Ready { + output, + origin: origin_node, + } => { + if let Some(origin_node) = origin_node { + (output, Some(origin_node)) + } else { + (output, None) + } + } + EdgeState::Pending => panic!("can only connect ready edges"), + }; + + match target.op { + DiagramOperation::Node(_) => { + let input = inputs[target.op_id]; + let deserialized_output = + deserialize(builder, registry, output, target, input.type_id)?; + dyn_connect(builder, deserialized_output, input)?; + } + DiagramOperation::ForkClone(fork_clone_op) => { + let amount = fork_clone_op.next.len(); + let outputs = if output.type_id == TypeId::of::() { + >::dyn_fork_clone( + builder, output, amount, + ) + } else { + let origin = if let Some(origin_node) = origin { + origin_node + } else { + return Err(DiagramError::NotCloneable); + }; + + let reg = registry.get_registration(&origin.builder)?; + reg.fork_clone(builder, output, amount) + }?; + for (o, e) in outputs.into_iter().zip(target.out_edges.iter()) { + let out_edge = edges.get_mut(e).ok_or(unknown_diagram_error!())?; + out_edge.state = EdgeState::Ready { output: o, origin }; + } + } + DiagramOperation::Unzip(unzip_op) => { + let outputs = if output.type_id == TypeId::of::() { + Err(DiagramError::NotUnzippable) + } else { + let origin = if let Some(origin_node) = origin { + origin_node + } else { + return Err(DiagramError::NotUnzippable); + }; + + let reg = registry.get_registration(&origin.builder)?; + reg.unzip(builder, output) + }?; + if outputs.len() < unzip_op.next.len() { + return Err(DiagramError::NotUnzippable); + } + for (o, e) in outputs.into_iter().zip(target.out_edges.iter()) { + let out_edge = edges.get_mut(e).ok_or(unknown_diagram_error!())?; + out_edge.state = EdgeState::Ready { output: o, origin }; + } + } + DiagramOperation::ForkResult(_) => { + let (ok, err) = if output.type_id == TypeId::of::() { + Err(DiagramError::CannotForkResult) + } else { + let origin = if let Some(origin_node) = origin { + origin_node + } else { + return Err(DiagramError::CannotForkResult); + }; + + let reg = registry.get_registration(&origin.builder)?; + reg.fork_result(builder, output) + }?; + { + let out_edge = edges + .get_mut(&target.out_edges[0]) + .ok_or(unknown_diagram_error!())?; + out_edge.state = EdgeState::Ready { output: ok, origin }; + } + { + let out_edge = edges + .get_mut(&target.out_edges[1]) + .ok_or(unknown_diagram_error!())?; + out_edge.state = EdgeState::Ready { + output: err, + origin, + }; + } + } + DiagramOperation::Split(split_op) => { + let mut outputs = if output.type_id == TypeId::of::() { + let chain = output.into_output::()?.chain(builder); + split_chain(chain, split_op) + } else { + let origin = if let Some(origin_node) = origin { + origin_node + } else { + return Err(DiagramError::NotSplittable); + }; + + let reg = registry.get_registration(&origin.builder)?; + reg.split(builder, output, split_op) + }?; + + // Because of how we build `out_edges`, if the split op uses the `remaining` slot, + // then the last item will always be the remaining edge. + let remaining_edge_id = if split_op.remaining.is_some() { + Some(target.out_edges.last().ok_or(unknown_diagram_error!())?) + } else { + None + }; + let other_edge_ids = if split_op.remaining.is_some() { + &target.out_edges[..(target.out_edges.len() - 1)] + } else { + &target.out_edges[..] + }; + + for e in other_edge_ids { + let out_edge = edges.get_mut(e).ok_or(unknown_diagram_error!())?; + let output = outputs + .outputs + .remove(out_edge.target) + .ok_or(unknown_diagram_error!())?; + out_edge.state = EdgeState::Ready { output, origin }; + } + if let Some(remaining_edge_id) = remaining_edge_id { + let out_edge = edges + .get_mut(remaining_edge_id) + .ok_or(unknown_diagram_error!())?; + out_edge.state = EdgeState::Ready { + output: outputs.remaining, + origin, + }; + } + } + DiagramOperation::Join(_) => { + // join is connected at the vertex level + } + DiagramOperation::Transform(transform_op) => { + let transformed_output = transform_output(builder, registry, output, transform_op)?; + let out_edge = edges + .get_mut(&target.out_edges[0]) + .ok_or(unknown_diagram_error!())?; + out_edge.state = EdgeState::Ready { + output: transformed_output.into(), + origin, + } + } + DiagramOperation::Dispose => {} + } + Ok(()) +} + +/// Connect a [`DynOutput`] to a [`DynInputSlot`]. Use this only when both the output and input +/// are type erased. To connect an [`Output`] to a [`DynInputSlot`] or vice versa, prefer converting +/// the type erased output/input slot to the typed equivalent. +/// +/// ```text +/// builder.connect(output.into_output::()?, dyn_input)?; +/// ``` +fn dyn_connect( + builder: &mut Builder, + output: DynOutput, + input: DynInputSlot, +) -> Result<(), DiagramError> { + if output.type_id != input.type_id { + return Err(DiagramError::TypeMismatch); + } + struct TypeErased {} + let typed_output = Output::::new(output.scope(), output.id()); + let typed_input = InputSlot::::new(input.scope(), input.id()); + builder.connect(typed_output, typed_input); + Ok(()) +} + +/// Try to deserialize `output` into `input_type`. If `output` is not `serde_json::Value`, this does nothing. +fn deserialize( + builder: &mut Builder, + registry: &NodeRegistry, + output: DynOutput, + target: &Vertex, + input_type: TypeId, +) -> Result { + if output.type_id != TypeId::of::() || output.type_id == input_type { + Ok(output) + } else { + let serialized = output.into_output::()?; + match target.op { + DiagramOperation::Node(node_op) => { + let reg = registry.get_registration(&node_op.builder)?; + if reg.metadata.request.deserializable { + let deserialize_impl = ®istry.data.deserialize_impls[&input_type]; + deserialize_impl(builder, serialized) + } else { + Err(DiagramError::NotSerializable) + } + } + _ => Err(DiagramError::NotSerializable), + } + } +} + +fn serialize( + builder: &mut Builder, + registry: &NodeRegistry, + output: DynOutput, + origin: Option<&NodeOp>, +) -> Result, DiagramError> { + if output.type_id == TypeId::of::() { + output.into_output() + } else { + // Cannot serialize if we don't know the origin, as we need it to know which serialize impl to use. + let origin = if let Some(origin) = origin { + origin + } else { + return Err(DiagramError::NotSerializable); + }; + + let reg = registry.get_registration(&origin.builder)?; + if reg.metadata.response.serializable { + let serialize_impl = ®istry.data.serialize_impls[&output.type_id]; + serialize_impl(builder, output) + } else { + Err(DiagramError::NotSerializable) + } + } +} diff --git a/src/lib.rs b/src/lib.rs index f1e30384..d7e27ab8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -87,6 +87,11 @@ pub use chain::*; pub mod channel; pub use channel::*; +#[cfg(feature = "diagram")] +pub mod diagram; +#[cfg(feature = "diagram")] +pub use diagram::*; + pub mod disposal; pub use disposal::*; From 9175367e4ae9a940fc670fe7960a7338af0ca8f3 Mon Sep 17 00:00:00 2001 From: Teo Koon Peng Date: Mon, 20 Jan 2025 22:03:07 +0800 Subject: [PATCH 10/20] Message registry (#50) Signed-off-by: Teo Koon Peng --- examples/diagram/calculator/src/main.rs | 6 +- src/diagram.rs | 31 +- src/diagram/impls.rs | 20 + src/diagram/join.rs | 27 +- src/diagram/node_registry.rs | 849 -------------- src/diagram/registration.rs | 1419 +++++++++++++++++++++++ src/diagram/serialization.rs | 112 -- src/diagram/split_serialized.rs | 4 +- src/diagram/testing.rs | 23 +- src/diagram/transform.rs | 13 +- src/diagram/unzip.rs | 52 +- src/diagram/workflow_builder.rs | 114 +- 12 files changed, 1537 insertions(+), 1133 deletions(-) delete mode 100644 src/diagram/node_registry.rs create mode 100644 src/diagram/registration.rs diff --git a/examples/diagram/calculator/src/main.rs b/examples/diagram/calculator/src/main.rs index 6214ab2b..2321f0cf 100644 --- a/examples/diagram/calculator/src/main.rs +++ b/examples/diagram/calculator/src/main.rs @@ -1,8 +1,8 @@ use std::{error::Error, fs::File, str::FromStr}; use bevy_impulse::{ - Diagram, DiagramError, ImpulsePlugin, NodeBuilderOptions, NodeRegistry, Promise, RequestExt, - RunCommandsOnWorldExt, + Diagram, DiagramElementRegistry, DiagramError, ImpulsePlugin, NodeBuilderOptions, Promise, + RequestExt, RunCommandsOnWorldExt, }; use clap::Parser; @@ -21,7 +21,7 @@ fn main() -> Result<(), Box> { tracing_subscriber::fmt::init(); - let mut registry = NodeRegistry::default(); + let mut registry = DiagramElementRegistry::new(); registry.register_node_builder( NodeBuilderOptions::new("add").with_name("Add"), |builder, config: f64| builder.create_map_block(move |req: f64| req + config), diff --git a/src/diagram.rs b/src/diagram.rs index 7efea31b..f6b70203 100644 --- a/src/diagram.rs +++ b/src/diagram.rs @@ -2,7 +2,7 @@ mod fork_clone; mod fork_result; mod impls; mod join; -mod node_registry; +mod registration; mod serialization; mod split_serialized; mod transform; @@ -14,7 +14,7 @@ use fork_clone::ForkCloneOp; use fork_result::ForkResultOp; use join::JoinOp; pub use join::JoinOutput; -pub use node_registry::*; +pub use registration::*; pub use serialization::*; pub use split_serialized::*; use tracing::debug; @@ -380,10 +380,10 @@ impl Diagram { /// # Examples /// /// ``` - /// use bevy_impulse::{Diagram, DiagramError, NodeBuilderOptions, NodeRegistry, RunCommandsOnWorldExt}; + /// use bevy_impulse::{Diagram, DiagramError, NodeBuilderOptions, DiagramElementRegistry, RunCommandsOnWorldExt}; /// /// let mut app = bevy_app::App::new(); - /// let mut registry = NodeRegistry::default(); + /// let mut registry = DiagramElementRegistry::new(); /// registry.register_node_builder(NodeBuilderOptions::new("echo".to_string()), |builder, _config: ()| { /// builder.create_map_block(|msg: String| msg) /// }); @@ -411,7 +411,7 @@ impl Diagram { fn spawn_workflow( &self, cmds: &mut Commands, - registry: &NodeRegistry, + registry: &DiagramElementRegistry, ) -> Result, DiagramError> where Streams: StreamPack, @@ -451,7 +451,7 @@ impl Diagram { pub fn spawn_io_workflow( &self, cmds: &mut Commands, - registry: &NodeRegistry, + registry: &DiagramElementRegistry, ) -> Result, DiagramError> { self.spawn_workflow::<()>(cmds, registry) } @@ -506,6 +506,9 @@ pub enum DiagramError { #[error("response cannot be split")] NotSplittable, + #[error("responses cannot be joined")] + NotJoinable, + #[error("empty join is not allowed")] EmptyJoin, @@ -557,7 +560,7 @@ mod tests { "ops": { "op1": { "type": "node", - "builder": "multiply3_uncloneable", + "builder": "multiply3", "next": { "builtin": "dispose" }, }, }, @@ -625,7 +628,7 @@ mod tests { "ops": { "op1": { "type": "node", - "builder": "multiply3_uncloneable", + "builder": "multiply3", "next": "op2", }, "op2": { @@ -651,12 +654,12 @@ mod tests { "ops": { "op1": { "type": "node", - "builder": "multiply3_uncloneable", + "builder": "multiply3", "next": "op2", }, "op2": { "type": "node", - "builder": "multiply3_uncloneable", + "builder": "multiply3", "next": "op1", }, }, @@ -691,7 +694,7 @@ mod tests { }, "op2": { "type": "node", - "builder": "multiply3_uncloneable", + "builder": "multiply3", "next": { "builtin": "terminate" }, }, }, @@ -728,11 +731,11 @@ mod tests { let json_str = r#" { "version": "0.1.0", - "start": "multiply3_uncloneable", + "start": "multiply3", "ops": { - "multiply3_uncloneable": { + "multiply3": { "type": "node", - "builder": "multiplyBy", + "builder": "multiply_by", "config": 7, "next": { "builtin": "terminate" } } diff --git a/src/diagram/impls.rs b/src/diagram/impls.rs index b2116505..9412fbb5 100644 --- a/src/diagram/impls.rs +++ b/src/diagram/impls.rs @@ -1,5 +1,25 @@ +use std::marker::PhantomData; + /// A struct to provide the default implementation for various operations. pub struct DefaultImpl; +/// A struct to provide the default implementation for various operations. +pub struct DefaultImplMarker { + _unused: PhantomData, +} + +impl DefaultImplMarker { + pub(super) fn new() -> Self { + Self { + _unused: Default::default(), + } + } +} + /// A struct to provide "not supported" implementations for various operations. pub struct NotSupported; + +/// A struct to provide "not supported" implementations for various operations. +pub struct NotSupportedMarker { + _unused: PhantomData, +} diff --git a/src/diagram/join.rs b/src/diagram/join.rs index da5fbe75..cdba36be 100644 --- a/src/diagram/join.rs +++ b/src/diagram/join.rs @@ -1,5 +1,3 @@ -use std::any::TypeId; - use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use smallvec::SmallVec; @@ -30,13 +28,7 @@ where T: Send + Sync + 'static, Serializer: SerializeMessage>, { - if registry.join_impls.contains_key(&TypeId::of::()) { - return; - } - - registry - .join_impls - .insert(TypeId::of::(), Box::new(join_impl::)); + registry.register_join::(Box::new(join_impl::)); } /// Serialize the outputs before joining them, and convert the resulting joined output into a @@ -55,14 +47,7 @@ pub(super) fn serialize_and_join( let outputs = outputs .into_iter() - .map(|o| { - let serialize_impl = registry - .serialize_impls - .get(&o.type_id) - .ok_or(DiagramError::NotSerializable)?; - let serialized_output = serialize_impl(builder, o)?; - Ok(serialized_output) - }) + .map(|o| registry.serialize(builder, o)) .collect::, DiagramError>>()?; // we need to convert the joined output to [`serde_json::Value`] in order for it to be @@ -163,7 +148,7 @@ mod tests { }, "op1": { "type": "node", - "builder": "multiply3_uncloneable", + "builder": "multiply3", "next": "join", }, "get_split_value2": { @@ -173,7 +158,7 @@ mod tests { }, "op2": { "type": "node", - "builder": "multiply3_uncloneable", + "builder": "multiply3", "next": "join", }, "join": { @@ -235,7 +220,7 @@ mod tests { }, "op1": { "type": "node", - "builder": "multiply3_uncloneable", + "builder": "multiply3", "next": { "builtin": "terminate" }, }, "get_split_value2": { @@ -245,7 +230,7 @@ mod tests { }, "op2": { "type": "node", - "builder": "multiply3_uncloneable", + "builder": "multiply3", "next": { "builtin": "terminate" }, }, "join": { diff --git a/src/diagram/node_registry.rs b/src/diagram/node_registry.rs deleted file mode 100644 index 09a832d1..00000000 --- a/src/diagram/node_registry.rs +++ /dev/null @@ -1,849 +0,0 @@ -use std::{ - any::{Any, TypeId}, - borrow::Borrow, - cell::RefCell, - collections::HashMap, - fmt::Debug, - marker::PhantomData, -}; - -use crate::{Builder, InputSlot, Node, Output, StreamPack}; -use bevy_ecs::entity::Entity; -use schemars::{ - gen::{SchemaGenerator, SchemaSettings}, - schema::Schema, - JsonSchema, -}; -use serde::{de::DeserializeOwned, ser::SerializeStruct, Serialize}; -use tracing::debug; - -use crate::{RequestMetadata, SerializeMessage}; - -use super::{ - fork_clone::DynForkClone, - fork_result::DynForkResult, - impls::{DefaultImpl, NotSupported}, - register_deserialize, register_serialize, - unzip::DynUnzip, - BuilderId, DefaultDeserializer, DefaultSerializer, DeserializeMessage, DiagramError, DynSplit, - DynSplitOutputs, DynType, OpaqueMessageDeserializer, OpaqueMessageSerializer, ResponseMetadata, - SplitOp, -}; - -/// A type erased [`crate::InputSlot`] -#[derive(Copy, Clone, Debug)] -pub struct DynInputSlot { - scope: Entity, - source: Entity, - pub(super) type_id: TypeId, -} - -impl DynInputSlot { - pub(super) fn scope(&self) -> Entity { - self.scope - } - - pub(super) fn id(&self) -> Entity { - self.source - } -} - -impl From> for DynInputSlot { - fn from(input: InputSlot) -> Self { - Self { - scope: input.scope(), - source: input.id(), - type_id: TypeId::of::(), - } - } -} - -#[derive(Debug)] -/// A type erased [`crate::Output`] -pub struct DynOutput { - scope: Entity, - target: Entity, - pub(super) type_id: TypeId, -} - -impl DynOutput { - pub(super) fn into_output(self) -> Result, DiagramError> - where - T: Send + Sync + 'static + Any, - { - if self.type_id != TypeId::of::() { - Err(DiagramError::TypeMismatch) - } else { - Ok(Output::::new(self.scope, self.target)) - } - } - - pub(super) fn scope(&self) -> Entity { - self.scope - } - - pub(super) fn id(&self) -> Entity { - self.target - } -} - -impl From> for DynOutput -where - T: Send + Sync + 'static, -{ - fn from(output: Output) -> Self { - Self { - scope: output.scope(), - target: output.id(), - type_id: TypeId::of::(), - } - } -} - -#[derive(Clone, Serialize)] -pub(super) struct NodeMetadata { - pub(super) id: BuilderId, - pub(super) name: String, - pub(super) request: RequestMetadata, - pub(super) response: ResponseMetadata, - pub(super) config_schema: Schema, -} - -/// A type erased [`bevy_impulse::Node`] -pub(super) struct DynNode { - pub(super) input: DynInputSlot, - pub(super) output: DynOutput, -} - -impl DynNode { - fn new(output: Output, input: InputSlot) -> Self - where - Request: 'static, - Response: Send + Sync + 'static, - { - Self { - input: input.into(), - output: output.into(), - } - } -} - -impl From> for DynNode -where - Request: 'static, - Response: Send + Sync + 'static, - Streams: StreamPack, -{ - fn from(node: Node) -> Self { - Self { - input: node.input.into(), - output: node.output.into(), - } - } -} - -pub struct NodeRegistration { - pub(super) metadata: NodeMetadata, - - /// Creates an instance of the registered node. - create_node_impl: CreateNodeFn, - - fork_clone_impl: Option, - - unzip_impl: - Option Result, DiagramError>>>, - - fork_result_impl: Option< - Box Result<(DynOutput, DynOutput), DiagramError>>, - >, - - split_impl: Option< - Box< - dyn for<'a> Fn( - &mut Builder, - DynOutput, - &'a SplitOp, - ) -> Result, DiagramError>, - >, - >, -} - -impl NodeRegistration { - fn new( - metadata: NodeMetadata, - create_node_impl: CreateNodeFn, - fork_clone_impl: Option, - ) -> NodeRegistration { - NodeRegistration { - metadata, - create_node_impl, - fork_clone_impl, - unzip_impl: None, - fork_result_impl: None, - split_impl: None, - } - } - - pub(super) fn create_node( - &self, - builder: &mut Builder, - config: serde_json::Value, - ) -> Result { - let n = (self.create_node_impl.borrow_mut())(builder, config)?; - debug!( - "created node of {}, output: {:?}, input: {:?}", - self.metadata.id, n.output, n.input - ); - Ok(n) - } - - pub(super) fn fork_clone( - &self, - builder: &mut Builder, - output: DynOutput, - amount: usize, - ) -> Result, DiagramError> { - let f = self - .fork_clone_impl - .as_ref() - .ok_or(DiagramError::NotCloneable)?; - f(builder, output, amount) - } - - pub(super) fn unzip( - &self, - builder: &mut Builder, - output: DynOutput, - ) -> Result, DiagramError> { - let f = self - .unzip_impl - .as_ref() - .ok_or(DiagramError::NotUnzippable)?; - f(builder, output) - } - - pub(super) fn fork_result( - &self, - builder: &mut Builder, - output: DynOutput, - ) -> Result<(DynOutput, DynOutput), DiagramError> { - let f = self - .fork_result_impl - .as_ref() - .ok_or(DiagramError::CannotForkResult)?; - f(builder, output) - } - - pub(super) fn split<'a>( - &self, - builder: &mut Builder, - output: DynOutput, - split_op: &'a SplitOp, - ) -> Result, DiagramError> { - let f = self - .split_impl - .as_ref() - .ok_or(DiagramError::NotSplittable)?; - f(builder, output, split_op) - } -} - -type CreateNodeFn = - RefCell Result>>; -type ForkCloneFn = - Box Result, DiagramError>>; - -pub struct CommonOperations<'a, Deserialize, Serialize, Cloneable> { - registry: &'a mut NodeRegistry, - _ignore: PhantomData<(Deserialize, Serialize, Cloneable)>, -} - -impl<'a, DeserializeImpl, SerializeImpl, Cloneable> - CommonOperations<'a, DeserializeImpl, SerializeImpl, Cloneable> -{ - /// Register a node builder with the specified common operations. - /// - /// # Arguments - /// - /// * `id` - Id of the builder, this must be unique. - /// * `name` - Friendly name for the builder, this is only used for display purposes. - /// * `f` - The node builder to register. - pub fn register_node_builder( - self, - options: NodeBuilderOptions, - mut f: impl FnMut(&mut Builder, Config) -> Node + 'static, - ) -> RegistrationBuilder<'a, Request, Response, Streams> - where - Config: JsonSchema + DeserializeOwned, - Request: Send + Sync + 'static, - Response: Send + Sync + 'static, - Streams: StreamPack, - DeserializeImpl: DeserializeMessage, - SerializeImpl: SerializeMessage, - Cloneable: DynForkClone, - { - register_deserialize::(&mut self.registry.data); - register_serialize::(&mut self.registry.data); - - let registration = NodeRegistration::new( - NodeMetadata { - id: options.id.clone(), - name: options.name.unwrap_or(options.id.clone()), - request: RequestMetadata { - schema: DeserializeImpl::json_schema(&mut self.registry.data.schema_generator) - .unwrap_or_else(|| { - self.registry.data.schema_generator.subschema_for::<()>() - }), - deserializable: DeserializeImpl::deserializable(), - }, - response: ResponseMetadata::new( - SerializeImpl::json_schema(&mut self.registry.data.schema_generator) - .unwrap_or_else(|| { - self.registry.data.schema_generator.subschema_for::<()>() - }), - SerializeImpl::serializable(), - Cloneable::CLONEABLE, - ), - config_schema: self - .registry - .data - .schema_generator - .subschema_for::(), - }, - RefCell::new(Box::new(move |builder, config| { - let config = serde_json::from_value(config)?; - let n = f(builder, config); - Ok(DynNode::new(n.output, n.input)) - })), - if Cloneable::CLONEABLE { - Some(Box::new(|builder, output, amount| { - Cloneable::dyn_fork_clone(builder, output, amount) - })) - } else { - None - }, - ); - self.registry.nodes.insert(options.id.clone(), registration); - - // SAFETY: We inserted an entry at this ID a moment ago - let node = self.registry.nodes.get_mut(&options.id).unwrap(); - - RegistrationBuilder:: { - node, - data: &mut self.registry.data, - _ignore: Default::default(), - } - } - - /// Opt out of deserializing the request of the node. Use this to build a - /// node whose request type is not deserializable. - pub fn no_request_deserializing( - self, - ) -> CommonOperations<'a, OpaqueMessageDeserializer, SerializeImpl, Cloneable> { - CommonOperations { - registry: self.registry, - _ignore: Default::default(), - } - } - - /// Opt out of serializing the response of the node. Use this to build a - /// node whose response type is not serializable. - pub fn no_response_serializing( - self, - ) -> CommonOperations<'a, DeserializeImpl, OpaqueMessageSerializer, Cloneable> { - CommonOperations { - registry: self.registry, - _ignore: Default::default(), - } - } - - /// Opt out of cloning the response of the node. Use this to build a node - /// whose response type is not cloneable. - pub fn no_response_cloning( - self, - ) -> CommonOperations<'a, DeserializeImpl, SerializeImpl, NotSupported> { - CommonOperations { - registry: self.registry, - _ignore: Default::default(), - } - } -} - -pub struct RegistrationBuilder<'a, Request, Response, Streams> { - node: &'a mut NodeRegistration, - data: &'a mut MessageRegistry, - _ignore: PhantomData<(Request, Response, Streams)>, -} - -impl<'a, Request, Response, Streams> RegistrationBuilder<'a, Request, Response, Streams> { - /// Mark the node as having a unzippable response. This is required in order for the node - /// to be able to be connected to a "Unzip" operation. - pub fn with_unzip(self) -> Self - where - DefaultImpl: DynUnzip, - { - self.with_unzip_impl::() - } - - /// Mark the node as having an unzippable response whose elements are not serializable. - pub fn with_unzip_unserializable(self) -> Self - where - DefaultImpl: DynUnzip, - { - self.with_unzip_impl::() - } - - fn with_unzip_impl(self) -> Self - where - UnzipImpl: DynUnzip, - { - self.node.metadata.response.unzip_slots = UnzipImpl::UNZIP_SLOTS; - self.node.unzip_impl = if UnzipImpl::UNZIP_SLOTS > 0 { - Some(Box::new(|builder, output| { - UnzipImpl::dyn_unzip(builder, output) - })) - } else { - None - }; - - UnzipImpl::on_register(self.data); - - self - } - - /// Mark the node as having a [`Result<_, _>`] response. This is required in order for the node - /// to be able to be connected to a "Fork Result" operation. - pub fn with_fork_result(self) -> Self - where - DefaultImpl: DynForkResult, - { - self.node.metadata.response.fork_result = true; - self.node.fork_result_impl = Some(Box::new(|builder, output| { - >::dyn_fork_result(builder, output) - })); - - self - } - - /// Mark the node as having a splittable response. This is required in order - /// for the node to be able to be connected to a "Split" operation. - pub fn with_split(self) -> Self - where - DefaultImpl: DynSplit, - { - self.with_split_impl::() - } - - /// Mark the node as having a splittable response but the items from the split - /// are unserializable. - pub fn with_split_unserializable(self) -> Self - where - DefaultImpl: DynSplit, - { - self.with_split_impl::() - } - - pub fn with_split_impl(self) -> Self - where - SplitImpl: DynSplit, - { - self.node.metadata.response.splittable = true; - self.node.split_impl = Some(Box::new(|builder, output, split_op| { - SplitImpl::dyn_split(builder, output, split_op) - })); - - SplitImpl::on_register(self.data); - - self - } -} - -pub trait IntoNodeRegistration { - fn into_node_registration( - self, - id: BuilderId, - name: String, - schema_generator: &mut SchemaGenerator, - ) -> NodeRegistration; -} - -pub struct NodeRegistry { - nodes: HashMap, - pub(super) data: MessageRegistry, -} - -pub struct MessageRegistry { - /// List of all request and response types used in all registered nodes, this only - /// contains serializable types, non serializable types are opaque and is only compatible - /// with itself. - schema_generator: SchemaGenerator, - - pub(super) deserialize_impls: HashMap< - TypeId, - Box) -> Result>, - >, - - pub(super) serialize_impls: HashMap< - TypeId, - Box Result, DiagramError>>, - >, - - pub(super) join_impls: HashMap< - TypeId, - Box) -> Result>, - >, -} - -impl Default for NodeRegistry { - fn default() -> Self { - let mut settings = SchemaSettings::default(); - settings.definitions_path = "#/types/".to_string(); - NodeRegistry { - nodes: Default::default(), - data: MessageRegistry { - schema_generator: SchemaGenerator::new(settings), - deserialize_impls: HashMap::new(), - serialize_impls: HashMap::new(), - join_impls: HashMap::new(), - }, - } - } -} - -impl NodeRegistry { - pub fn new() -> Self { - Self::default() - } - - /// Register a node builder with all the common operations (deserialize the - /// request, serialize the response, and clone the response) enabled. - /// - /// You will receive a [`RegistrationBuilder`] which you can then use to - /// enable more operations around your node, such as fork result, split, - /// or unzip. The data types of your node need to be suitable for those - /// operations or else the compiler will not allow you to enable them. - /// - /// ``` - /// use bevy_impulse::{NodeBuilderOptions, NodeRegistry}; - /// - /// let mut registry = NodeRegistry::default(); - /// registry.register_node_builder( - /// NodeBuilderOptions::new("echo".to_string()), - /// |builder, _config: ()| builder.create_map_block(|msg: String| msg) - /// ); - /// ``` - /// - /// # Arguments - /// - /// * `id` - Id of the builder, this must be unique. - /// * `name` - Friendly name for the builder, this is only used for display purposes. - /// * `f` - The node builder to register. - pub fn register_node_builder( - &mut self, - options: NodeBuilderOptions, - builder: impl FnMut(&mut Builder, Config) -> Node + 'static, - ) -> RegistrationBuilder - where - Config: JsonSchema + DeserializeOwned, - Request: Send + Sync + 'static + DynType + DeserializeOwned, - Response: Send + Sync + 'static + DynType + Serialize + Clone, - { - self.opt_out().register_node_builder(options, builder) - } - - /// In some cases the common operations of deserialization, serialization, - /// and cloning cannot be performed for the request or response of a node. - /// When that happens you can still register your node builder by calling - /// this function and explicitly disabling the common operations that your - /// node cannot support. - /// - /// - /// In order for the request to be deserializable, it must implement [`schemars::JsonSchema`] and [`serde::de::DeserializeOwned`]. - /// In order for the response to be serializable, it must implement [`schemars::JsonSchema`] and [`serde::Serialize`]. - /// - /// ``` - /// use schemars::JsonSchema; - /// use serde::{Deserialize, Serialize}; - /// - /// #[derive(JsonSchema, Deserialize)] - /// struct DeserializableRequest {} - /// - /// #[derive(JsonSchema, Serialize)] - /// struct SerializableResponse {} - /// ``` - /// - /// If your node have a request or response that is not serializable, there is still - /// a way to register it. - /// - /// ``` - /// use bevy_impulse::{NodeBuilderOptions, NodeRegistry}; - /// - /// struct NonSerializable { - /// data: String - /// } - /// - /// let mut registry = NodeRegistry::default(); - /// registry - /// .opt_out() - /// .no_request_deserializing() - /// .no_response_serializing() - /// .no_response_cloning() - /// .register_node_builder( - /// NodeBuilderOptions::new("echo"), - /// |builder, _config: ()| { - /// builder.create_map_block(|msg: NonSerializable| msg) - /// } - /// ); - /// ``` - /// - /// Note that nodes registered without deserialization cannot be connected - /// to the workflow start, and nodes registered without serialization cannot - /// be connected to the workflow termination. - pub fn opt_out( - &mut self, - ) -> CommonOperations { - CommonOperations { - registry: self, - _ignore: Default::default(), - } - } - - pub(super) fn get_registration(&self, id: &Q) -> Result<&NodeRegistration, DiagramError> - where - Q: Borrow + ?Sized, - { - let k = id.borrow(); - self.nodes - .get(k) - .ok_or(DiagramError::BuilderNotFound(k.to_string())) - } -} - -impl Serialize for NodeRegistry { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - let mut s = serializer.serialize_struct("NodeRegistry", 2)?; - // serialize only the nodes metadata - s.serialize_field( - "nodes", - // Since the serializer methods are consuming, we can't call `serialize_struct` and `collect_map`. - // This genius solution of creating an inline struct and impl `Serialize` on it is based on - // the code that `#[derive(Serialize)]` generates. - { - struct SerializeWith<'a> { - value: &'a NodeRegistry, - } - impl<'a> Serialize for SerializeWith<'a> { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - serializer.collect_map( - self.value - .nodes - .iter() - .map(|(k, v)| (k.clone(), &v.metadata)), - ) - } - } - &SerializeWith { value: self } - }, - )?; - s.serialize_field("types", self.data.schema_generator.definitions())?; - s.end() - } -} - -#[non_exhaustive] -pub struct NodeBuilderOptions { - pub id: BuilderId, - pub name: Option, -} - -impl NodeBuilderOptions { - pub fn new(id: impl ToString) -> Self { - Self { - id: id.to_string(), - name: None, - } - } - - pub fn with_name(mut self, name: impl ToString) -> Self { - self.name = Some(name.to_string()); - self - } -} - -#[cfg(test)] -mod tests { - use schemars::JsonSchema; - use serde::Deserialize; - - use super::*; - - fn multiply3(i: i64) -> i64 { - i * 3 - } - - #[test] - fn test_register_node_builder() { - let mut registry = NodeRegistry::default(); - registry - .opt_out() - .no_response_cloning() - .register_node_builder( - NodeBuilderOptions::new("multiply3_uncloneable").with_name("Test Name"), - |builder, _config: ()| builder.create_map_block(multiply3), - ); - let registration = registry.get_registration("multiply3_uncloneable").unwrap(); - assert!(registration.metadata.request.deserializable); - assert!(registration.metadata.response.serializable); - assert!(!registration.metadata.response.cloneable); - assert_eq!(registration.metadata.response.unzip_slots, 0); - } - - #[test] - fn test_register_cloneable_node() { - let mut registry = NodeRegistry::default(); - registry.register_node_builder( - NodeBuilderOptions::new("multiply3").with_name("Test Name"), - |builder, _config: ()| builder.create_map_block(multiply3), - ); - let registration = registry.get_registration("multiply3").unwrap(); - assert!(registration.metadata.request.deserializable); - assert!(registration.metadata.response.serializable); - assert!(registration.metadata.response.cloneable); - assert_eq!(registration.metadata.response.unzip_slots, 0); - } - - #[test] - fn test_register_unzippable_node() { - let mut registry = NodeRegistry::default(); - let tuple_resp = |_: ()| -> (i64,) { (1,) }; - registry - .opt_out() - .no_response_cloning() - .register_node_builder( - NodeBuilderOptions::new("multiply3_uncloneable").with_name("Test Name"), - move |builder: &mut Builder, _config: ()| builder.create_map_block(tuple_resp), - ) - .with_unzip(); - let registration = registry.get_registration("multiply3_uncloneable").unwrap(); - assert!(registration.metadata.request.deserializable); - assert!(registration.metadata.response.serializable); - assert!(!registration.metadata.response.cloneable); - assert_eq!(registration.metadata.response.unzip_slots, 1); - } - - #[test] - fn test_register_splittable_node() { - let mut registry = NodeRegistry::default(); - let vec_resp = |_: ()| -> Vec { vec![1, 2] }; - registry - .register_node_builder( - NodeBuilderOptions::new("vec_resp").with_name("Test Name"), - move |builder: &mut Builder, _config: ()| builder.create_map_block(vec_resp), - ) - .with_split(); - let registration = registry.get_registration("vec_resp").unwrap(); - assert!(registration.metadata.response.splittable); - - let map_resp = |_: ()| -> HashMap { HashMap::new() }; - registry - .register_node_builder( - NodeBuilderOptions::new("map_resp").with_name("Test Name"), - move |builder: &mut Builder, _config: ()| builder.create_map_block(map_resp), - ) - .with_split(); - - let registration = registry.get_registration("map_resp").unwrap(); - assert!(registration.metadata.response.splittable); - - registry.register_node_builder( - NodeBuilderOptions::new("not_splittable").with_name("Test Name"), - move |builder: &mut Builder, _config: ()| builder.create_map_block(map_resp), - ); - let registration = registry.get_registration("not_splittable").unwrap(); - assert!(!registration.metadata.response.splittable); - } - - #[test] - fn test_register_with_config() { - let mut registry = NodeRegistry::default(); - - #[derive(Deserialize, JsonSchema)] - struct TestConfig { - by: i64, - } - - registry.register_node_builder( - NodeBuilderOptions::new("multiply").with_name("Test Name"), - move |builder: &mut Builder, config: TestConfig| { - builder.create_map_block(move |operand: i64| operand * config.by) - }, - ); - assert!(registry.get_registration("multiply").is_ok()); - } - - struct NonSerializableRequest {} - - #[test] - fn test_register_opaque_node() { - let opaque_request_map = |_: NonSerializableRequest| {}; - - let mut registry = NodeRegistry::default(); - registry - .opt_out() - .no_request_deserializing() - .no_response_cloning() - .register_node_builder( - NodeBuilderOptions::new("opaque_request_map").with_name("Test Name"), - move |builder, _config: ()| builder.create_map_block(opaque_request_map), - ); - assert!(registry.get_registration("opaque_request_map").is_ok()); - let registration = registry.get_registration("opaque_request_map").unwrap(); - assert!(!registration.metadata.request.deserializable); - assert!(registration.metadata.response.serializable); - assert!(!registration.metadata.response.cloneable); - assert_eq!(registration.metadata.response.unzip_slots, 0); - - let opaque_response_map = |_: ()| NonSerializableRequest {}; - registry - .opt_out() - .no_response_serializing() - .no_response_cloning() - .register_node_builder( - NodeBuilderOptions::new("opaque_response_map").with_name("Test Name"), - move |builder: &mut Builder, _config: ()| { - builder.create_map_block(opaque_response_map) - }, - ); - assert!(registry.get_registration("opaque_response_map").is_ok()); - let registration = registry.get_registration("opaque_response_map").unwrap(); - assert!(registration.metadata.request.deserializable); - assert!(!registration.metadata.response.serializable); - assert!(!registration.metadata.response.cloneable); - assert_eq!(registration.metadata.response.unzip_slots, 0); - - let opaque_req_resp_map = |_: NonSerializableRequest| NonSerializableRequest {}; - registry - .opt_out() - .no_request_deserializing() - .no_response_serializing() - .no_response_cloning() - .register_node_builder( - NodeBuilderOptions::new("opaque_req_resp_map").with_name("Test Name"), - move |builder: &mut Builder, _config: ()| { - builder.create_map_block(opaque_req_resp_map) - }, - ); - assert!(registry.get_registration("opaque_req_resp_map").is_ok()); - let registration = registry.get_registration("opaque_req_resp_map").unwrap(); - assert!(!registration.metadata.request.deserializable); - assert!(!registration.metadata.response.serializable); - assert!(!registration.metadata.response.cloneable); - assert_eq!(registration.metadata.response.unzip_slots, 0); - } -} diff --git a/src/diagram/registration.rs b/src/diagram/registration.rs new file mode 100644 index 00000000..66afb59f --- /dev/null +++ b/src/diagram/registration.rs @@ -0,0 +1,1419 @@ +use std::{ + any::{type_name, Any, TypeId}, + borrow::Borrow, + cell::RefCell, + collections::HashMap, + fmt::Debug, + marker::PhantomData, +}; + +use crate::{Builder, InputSlot, Node, Output, StreamPack}; +use bevy_ecs::entity::Entity; +use schemars::{ + gen::{SchemaGenerator, SchemaSettings}, + schema::Schema, + JsonSchema, +}; +use serde::{ + de::DeserializeOwned, + ser::{SerializeMap, SerializeStruct}, + Serialize, +}; +use serde_json::json; +use tracing::debug; + +use crate::SerializeMessage; + +use super::{ + fork_clone::DynForkClone, + fork_result::DynForkResult, + impls::{DefaultImpl, DefaultImplMarker, NotSupported}, + join::register_join_impl, + unzip::DynUnzip, + BuilderId, DefaultDeserializer, DefaultSerializer, DeserializeMessage, DiagramError, DynSplit, + DynSplitOutputs, DynType, OpaqueMessageDeserializer, OpaqueMessageSerializer, SplitOp, +}; + +/// A type erased [`crate::InputSlot`] +#[derive(Copy, Clone, Debug)] +pub struct DynInputSlot { + scope: Entity, + source: Entity, + pub(super) type_id: TypeId, +} + +impl DynInputSlot { + pub(super) fn scope(&self) -> Entity { + self.scope + } + + pub(super) fn id(&self) -> Entity { + self.source + } +} + +impl From> for DynInputSlot { + fn from(input: InputSlot) -> Self { + Self { + scope: input.scope(), + source: input.id(), + type_id: TypeId::of::(), + } + } +} + +#[derive(Debug)] +/// A type erased [`crate::Output`] +pub struct DynOutput { + scope: Entity, + target: Entity, + pub(super) type_id: TypeId, +} + +impl DynOutput { + pub(super) fn into_output(self) -> Result, DiagramError> + where + T: Send + Sync + 'static + Any, + { + if self.type_id != TypeId::of::() { + Err(DiagramError::TypeMismatch) + } else { + Ok(Output::::new(self.scope, self.target)) + } + } + + pub(super) fn scope(&self) -> Entity { + self.scope + } + + pub(super) fn id(&self) -> Entity { + self.target + } +} + +impl From> for DynOutput +where + T: Send + Sync + 'static, +{ + fn from(output: Output) -> Self { + Self { + scope: output.scope(), + target: output.id(), + type_id: TypeId::of::(), + } + } +} + +/// A type erased [`bevy_impulse::Node`] +pub(super) struct DynNode { + pub(super) input: DynInputSlot, + pub(super) output: DynOutput, +} + +impl DynNode { + fn new(output: Output, input: InputSlot) -> Self + where + Request: 'static, + Response: Send + Sync + 'static, + { + Self { + input: input.into(), + output: output.into(), + } + } +} + +impl From> for DynNode +where + Request: 'static, + Response: Send + Sync + 'static, + Streams: StreamPack, +{ + fn from(node: Node) -> Self { + Self { + input: node.input.into(), + output: node.output.into(), + } + } +} + +#[derive(Serialize)] +pub struct NodeRegistration { + pub(super) id: BuilderId, + pub(super) name: String, + /// type name of the request + pub(super) request: &'static str, + /// type name of the response + pub(super) response: &'static str, + pub(super) config_schema: Schema, + + /// Creates an instance of the registered node. + #[serde(skip)] + create_node_impl: CreateNodeFn, +} + +impl NodeRegistration { + pub(super) fn create_node( + &self, + builder: &mut Builder, + config: serde_json::Value, + ) -> Result { + let n = (self.create_node_impl.borrow_mut())(builder, config)?; + debug!( + "created node of {}, output: {:?}, input: {:?}", + self.id, n.output, n.input + ); + Ok(n) + } +} + +type CreateNodeFn = + RefCell Result>>; +type DeserializeFn = + Box) -> Result>; +type SerializeFn = + Box Result, DiagramError>>; +type ForkCloneFn = + Box Result, DiagramError>>; +type ForkResultFn = + Box Result<(DynOutput, DynOutput), DiagramError>>; +type SplitFn = Box< + dyn for<'a> Fn( + &mut Builder, + DynOutput, + &'a SplitOp, + ) -> Result, DiagramError>, +>; +type JoinFn = Box) -> Result>; + +#[must_use] +pub struct CommonOperations<'a, Deserialize, Serialize, Cloneable> { + registry: &'a mut DiagramElementRegistry, + _ignore: PhantomData<(Deserialize, Serialize, Cloneable)>, +} + +impl<'a, DeserializeImpl, SerializeImpl, Cloneable> + CommonOperations<'a, DeserializeImpl, SerializeImpl, Cloneable> +{ + /// Register a node builder with the specified common operations. + /// + /// # Arguments + /// + /// * `id` - Id of the builder, this must be unique. + /// * `name` - Friendly name for the builder, this is only used for display purposes. + /// * `f` - The node builder to register. + pub fn register_node_builder( + self, + options: NodeBuilderOptions, + mut f: impl FnMut(&mut Builder, Config) -> Node + 'static, + ) -> NodeRegistrationBuilder<'a, Request, Response, Streams> + where + Config: JsonSchema + DeserializeOwned, + Request: Send + Sync + 'static, + Response: Send + Sync + 'static, + Streams: StreamPack, + DeserializeImpl: DeserializeMessage, + SerializeImpl: SerializeMessage, + Cloneable: DynForkClone, + { + self.registry + .messages + .register_deserialize::(); + self.registry + .messages + .register_serialize::(); + self.registry + .messages + .register_fork_clone::(); + + let registration = NodeRegistration { + id: options.id.clone(), + name: options.name.unwrap_or(options.id.clone()), + request: type_name::(), + response: type_name::(), + config_schema: self + .registry + .messages + .schema_generator + .subschema_for::(), + create_node_impl: RefCell::new(Box::new(move |builder, config| { + let config = serde_json::from_value(config)?; + let n = f(builder, config); + Ok(DynNode::new(n.output, n.input)) + })), + }; + self.registry.nodes.insert(options.id.clone(), registration); + + NodeRegistrationBuilder::::new(&mut self.registry.messages) + } + + /// Register a message with the specified common operations. + pub fn register_message(self) -> MessageRegistrationBuilder<'a, Message> + where + Message: Send + Sync + 'static, + DeserializeImpl: DeserializeMessage, + SerializeImpl: SerializeMessage + SerializeMessage>, + Cloneable: DynForkClone, + { + self.registry + .messages + .register_deserialize::(); + self.registry + .messages + .register_serialize::(); + self.registry + .messages + .register_fork_clone::(); + register_join_impl::(&mut self.registry.messages); + + MessageRegistrationBuilder::::new(&mut self.registry.messages) + } + + /// Opt out of deserializing the request of the node. Use this to build a + /// node whose request type is not deserializable. + pub fn no_request_deserializing( + self, + ) -> CommonOperations<'a, OpaqueMessageDeserializer, SerializeImpl, Cloneable> { + CommonOperations { + registry: self.registry, + _ignore: Default::default(), + } + } + + /// Opt out of serializing the response of the node. Use this to build a + /// node whose response type is not serializable. + pub fn no_response_serializing( + self, + ) -> CommonOperations<'a, DeserializeImpl, OpaqueMessageSerializer, Cloneable> { + CommonOperations { + registry: self.registry, + _ignore: Default::default(), + } + } + + /// Opt out of cloning the response of the node. Use this to build a node + /// whose response type is not cloneable. + pub fn no_response_cloning( + self, + ) -> CommonOperations<'a, DeserializeImpl, SerializeImpl, NotSupported> { + CommonOperations { + registry: self.registry, + _ignore: Default::default(), + } + } +} + +pub struct MessageRegistrationBuilder<'a, Message> { + data: &'a mut MessageRegistry, + _ignore: PhantomData, +} + +impl<'a, Message> MessageRegistrationBuilder<'a, Message> +where + Message: Any, +{ + fn new(registry: &'a mut MessageRegistry) -> Self { + Self { + data: registry, + _ignore: Default::default(), + } + } + + /// Mark the node as having a unzippable response. This is required in order for the node + /// to be able to be connected to a "Unzip" operation. + pub fn with_unzip(&mut self) -> &mut Self + where + DefaultImplMarker<(Message, DefaultSerializer)>: DynUnzip, + { + self.data.register_unzip::(); + self + } + + /// Mark the node as having an unzippable response whose elements are not serializable. + pub fn with_unzip_minimal(&mut self) -> &mut Self + where + DefaultImplMarker<(Message, NotSupported)>: DynUnzip, + { + self.data.register_unzip::(); + self + } + + /// Mark the node as having a [`Result<_, _>`] response. This is required in order for the node + /// to be able to be connected to a "Fork Result" operation. + pub fn with_fork_result(&mut self) -> &mut Self + where + DefaultImpl: DynForkResult, + { + self.data.register_fork_result::(); + self + } + + /// Mark the node as having a splittable response. This is required in order + /// for the node to be able to be connected to a "Split" operation. + pub fn with_split(&mut self) -> &mut Self + where + DefaultImpl: DynSplit, + { + self.data + .register_split::(); + self + } + + /// Mark the node as having a splittable response but the items from the split + /// are unserializable. + pub fn with_split_minimal(&mut self) -> &mut Self + where + DefaultImpl: DynSplit, + { + self.data + .register_split::(); + self + } +} + +pub struct NodeRegistrationBuilder<'a, Request, Response, Streams> { + registry: &'a mut MessageRegistry, + _ignore: PhantomData<(Request, Response, Streams)>, +} + +impl<'a, Request, Response, Streams> NodeRegistrationBuilder<'a, Request, Response, Streams> +where + Request: Any, + Response: Any, +{ + fn new(registry: &'a mut MessageRegistry) -> Self { + Self { + registry, + _ignore: Default::default(), + } + } + + /// Mark the node as having a unzippable response. This is required in order for the node + /// to be able to be connected to a "Unzip" operation. + pub fn with_unzip(&mut self) -> &mut Self + where + DefaultImplMarker<(Response, DefaultSerializer)>: DynUnzip, + { + MessageRegistrationBuilder::new(self.registry).with_unzip(); + self + } + + /// Mark the node as having an unzippable response whose elements are not serializable. + pub fn with_unzip_unserializable(&mut self) -> &mut Self + where + DefaultImplMarker<(Response, NotSupported)>: DynUnzip, + { + MessageRegistrationBuilder::new(self.registry).with_unzip_minimal(); + self + } + + /// Mark the node as having a [`Result<_, _>`] response. This is required in order for the node + /// to be able to be connected to a "Fork Result" operation. + pub fn with_fork_result(&mut self) -> &mut Self + where + DefaultImpl: DynForkResult, + { + MessageRegistrationBuilder::new(self.registry).with_fork_result(); + self + } + + /// Mark the node as having a splittable response. This is required in order + /// for the node to be able to be connected to a "Split" operation. + pub fn with_split(&mut self) -> &mut Self + where + DefaultImpl: DynSplit, + { + MessageRegistrationBuilder::new(self.registry).with_split(); + self + } + + /// Mark the node as having a splittable response but the items from the split + /// are unserializable. + pub fn with_split_unserializable(&mut self) -> &mut Self + where + DefaultImpl: DynSplit, + { + MessageRegistrationBuilder::new(self.registry).with_split_minimal(); + self + } +} + +pub trait IntoNodeRegistration { + fn into_node_registration( + self, + id: BuilderId, + name: String, + schema_generator: &mut SchemaGenerator, + ) -> NodeRegistration; +} + +#[derive(Serialize)] +pub struct DiagramElementRegistry { + pub(super) nodes: HashMap, + #[serde(flatten)] + pub(super) messages: MessageRegistry, +} + +#[derive(Default)] +pub(super) struct MessageOperation { + deserialize_impl: Option, + serialize_impl: Option, + fork_clone_impl: Option, + unzip_impl: Option>, + fork_result_impl: Option, + split_impl: Option, + join_impl: Option, +} + +impl MessageOperation { + fn new() -> Self { + Self::default() + } + + /// Try to deserialize `output` into `input_type`. If `output` is not `serde_json::Value`, this does nothing. + pub(super) fn deserialize( + &self, + builder: &mut Builder, + output: DynOutput, + ) -> Result { + let f = self + .deserialize_impl + .as_ref() + .ok_or(DiagramError::NotSerializable)?; + f(builder, output.into_output()?) + } + + pub(super) fn serialize( + &self, + builder: &mut Builder, + output: DynOutput, + ) -> Result, DiagramError> { + let f = self + .serialize_impl + .as_ref() + .ok_or(DiagramError::NotSerializable)?; + f(builder, output) + } + + pub(super) fn fork_clone( + &self, + builder: &mut Builder, + output: DynOutput, + amount: usize, + ) -> Result, DiagramError> { + let f = self + .fork_clone_impl + .as_ref() + .ok_or(DiagramError::NotCloneable)?; + f(builder, output, amount) + } + + pub(super) fn unzip( + &self, + builder: &mut Builder, + output: DynOutput, + ) -> Result, DiagramError> { + let unzip_impl = &self + .unzip_impl + .as_ref() + .ok_or(DiagramError::NotUnzippable)?; + unzip_impl.dyn_unzip(builder, output) + } + + pub(super) fn fork_result( + &self, + builder: &mut Builder, + output: DynOutput, + ) -> Result<(DynOutput, DynOutput), DiagramError> { + let f = self + .fork_result_impl + .as_ref() + .ok_or(DiagramError::CannotForkResult)?; + f(builder, output) + } + + pub(super) fn split<'a>( + &self, + builder: &mut Builder, + output: DynOutput, + split_op: &'a SplitOp, + ) -> Result, DiagramError> { + let f = self + .split_impl + .as_ref() + .ok_or(DiagramError::NotSplittable)?; + f(builder, output, split_op) + } + + pub(super) fn join( + &self, + builder: &mut Builder, + outputs: OutputIter, + ) -> Result + where + OutputIter: IntoIterator, + { + let f = self.join_impl.as_ref().ok_or(DiagramError::NotJoinable)?; + f(builder, outputs.into_iter().collect()) + } +} + +impl Serialize for MessageOperation { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let mut s = serializer.serialize_map(None)?; + if self.deserialize_impl.is_some() { + s.serialize_entry("deserialize", &serde_json::Value::Null)?; + } + if self.serialize_impl.is_some() { + s.serialize_entry("serialize", &serde_json::Value::Null)?; + } + if self.fork_clone_impl.is_some() { + s.serialize_entry("fork_clone", &serde_json::Value::Null)?; + } + if let Some(unzip_impl) = &self.unzip_impl { + s.serialize_entry("unzip", &json!({"output_types": unzip_impl.output_types()}))?; + } + if self.fork_result_impl.is_some() { + s.serialize_entry("fork_result", &serde_json::Value::Null)?; + } + if self.split_impl.is_some() { + s.serialize_entry("split", &serde_json::Value::Null)?; + } + if self.join_impl.is_some() { + s.serialize_entry("join", &serde_json::Value::Null)?; + } + s.end() + } +} + +pub struct MessageRegistration { + type_name: &'static str, + schema: Option, + operations: MessageOperation, +} + +impl MessageRegistration { + fn new() -> Self { + Self { + type_name: type_name::(), + schema: None, + operations: MessageOperation::new(), + } + } +} + +impl Serialize for MessageRegistration { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let mut s = serializer.serialize_struct("MessageRegistration", 3)?; + s.serialize_field("schema", &self.schema)?; + s.serialize_field("operations", &self.operations)?; + s.end() + } +} + +#[derive(Serialize)] +pub struct MessageRegistry { + #[serde(serialize_with = "MessageRegistry::serialize_messages")] + messages: HashMap, + + #[serde( + rename = "schemas", + serialize_with = "MessageRegistry::serialize_schemas" + )] + schema_generator: SchemaGenerator, +} + +impl MessageRegistry { + fn get(&self) -> Option<&MessageRegistration> + where + T: Any, + { + self.messages.get(&TypeId::of::()) + } + + pub(super) fn deserialize( + &self, + target_type: &TypeId, + builder: &mut Builder, + output: DynOutput, + ) -> Result { + if output.type_id != TypeId::of::() || &output.type_id == target_type { + Ok(output) + } else if let Some(reg) = self.messages.get(target_type) { + reg.operations.deserialize(builder, output) + } else { + Err(DiagramError::NotSerializable) + } + } + + /// Register a deserialize function if not already registered, returns true if the new + /// function is registered. + pub(super) fn register_deserialize(&mut self) -> bool + where + T: Send + Sync + 'static + Any, + Deserializer: DeserializeMessage, + { + let reg = self + .messages + .entry(TypeId::of::()) + .or_insert(MessageRegistration::new::()); + let ops = &mut reg.operations; + if !Deserializer::deserializable() || ops.deserialize_impl.is_some() { + return false; + } + + debug!( + "register deserialize for type: {}, with deserializer: {}", + std::any::type_name::(), + std::any::type_name::() + ); + ops.deserialize_impl = Some(Box::new(|builder, output| { + debug!("deserialize output: {:?}", output); + let receiver = + builder.create_map_block(|json: serde_json::Value| Deserializer::from_json(json)); + builder.connect(output, receiver.input); + let deserialized_output = receiver + .output + .chain(builder) + .cancel_on_err() + .output() + .into(); + debug!("deserialized output: {:?}", deserialized_output); + Ok(deserialized_output) + })); + + reg.schema = Deserializer::json_schema(&mut self.schema_generator); + + true + } + + pub(super) fn serialize( + &self, + builder: &mut Builder, + output: DynOutput, + ) -> Result, DiagramError> { + if output.type_id == TypeId::of::() { + output.into_output() + } else if let Some(reg) = self.messages.get(&output.type_id) { + reg.operations.serialize(builder, output) + } else { + Err(DiagramError::NotSerializable) + } + } + + /// Register a serialize function if not already registered, returns true if the new + /// function is registered. + pub(super) fn register_serialize(&mut self) -> bool + where + T: Send + Sync + 'static + Any, + Serializer: SerializeMessage, + { + let reg = &mut self + .messages + .entry(TypeId::of::()) + .or_insert(MessageRegistration::new::()); + let ops = &mut reg.operations; + if !Serializer::serializable() || ops.serialize_impl.is_some() { + return false; + } + + debug!( + "register serialize for type: {}, with serializer: {}", + std::any::type_name::(), + std::any::type_name::() + ); + ops.serialize_impl = Some(Box::new(|builder, output| { + debug!("serialize output: {:?}", output); + let n = builder.create_map_block(|resp: T| Serializer::to_json(&resp)); + builder.connect(output.into_output()?, n.input); + let serialized_output = n.output.chain(builder).cancel_on_err().output(); + debug!("serialized output: {:?}", serialized_output); + Ok(serialized_output) + })); + + reg.schema = Serializer::json_schema(&mut self.schema_generator); + + true + } + + pub(super) fn fork_clone( + &self, + builder: &mut Builder, + output: DynOutput, + amount: usize, + ) -> Result, DiagramError> { + if let Some(reg) = self.messages.get(&output.type_id) { + reg.operations.fork_clone(builder, output, amount) + } else { + Err(DiagramError::NotCloneable) + } + } + + /// Register a fork_clone function if not already registered, returns true if the new + /// function is registered. + pub(super) fn register_fork_clone(&mut self) -> bool + where + T: Any, + F: DynForkClone, + { + let ops = &mut self + .messages + .entry(TypeId::of::()) + .or_insert(MessageRegistration::new::()) + .operations; + if !F::CLONEABLE || ops.fork_clone_impl.is_some() { + return false; + } + + ops.fork_clone_impl = Some(Box::new(|builder, output, amount| { + F::dyn_fork_clone(builder, output, amount) + })); + + true + } + + pub(super) fn unzip( + &self, + builder: &mut Builder, + output: DynOutput, + ) -> Result, DiagramError> { + if let Some(reg) = self.messages.get(&output.type_id) { + reg.operations.unzip(builder, output) + } else { + Err(DiagramError::NotUnzippable) + } + } + + /// Register a unzip function if not already registered, returns true if the new + /// function is registered. + pub(super) fn register_unzip(&mut self) -> bool + where + T: Any, + Serializer: 'static, + DefaultImplMarker<(T, Serializer)>: DynUnzip, + { + let unzip_impl = DefaultImplMarker::<(T, Serializer)>::new(); + unzip_impl.on_register(self); + + let ops = &mut self + .messages + .entry(TypeId::of::()) + .or_insert(MessageRegistration::new::()) + .operations; + if ops.unzip_impl.is_some() { + return false; + } + ops.unzip_impl = Some(Box::new(unzip_impl)); + + true + } + + pub(super) fn fork_result( + &self, + builder: &mut Builder, + output: DynOutput, + ) -> Result<(DynOutput, DynOutput), DiagramError> { + if let Some(reg) = self.messages.get(&output.type_id) { + reg.operations.fork_result(builder, output) + } else { + Err(DiagramError::CannotForkResult) + } + } + + /// Register a fork_result function if not already registered, returns true if the new + /// function is registered. + pub(super) fn register_fork_result(&mut self) -> bool + where + T: Any, + F: DynForkResult, + { + let ops = &mut self + .messages + .entry(TypeId::of::()) + .or_insert(MessageRegistration::new::()) + .operations; + if ops.fork_result_impl.is_some() { + return false; + } + + ops.fork_result_impl = Some(Box::new(|builder, output| { + F::dyn_fork_result(builder, output) + })); + + true + } + + pub(super) fn split<'b>( + &self, + builder: &mut Builder, + output: DynOutput, + split_op: &'b SplitOp, + ) -> Result, DiagramError> { + if let Some(reg) = self.messages.get(&output.type_id) { + reg.operations.split(builder, output, split_op) + } else { + Err(DiagramError::NotSplittable) + } + } + + /// Register a split function if not already registered, returns true if the new + /// function is registered. + pub(super) fn register_split(&mut self) -> bool + where + T: Any, + F: DynSplit, + { + let ops = &mut self + .messages + .entry(TypeId::of::()) + .or_insert(MessageRegistration::new::()) + .operations; + if ops.split_impl.is_some() { + return false; + } + + ops.split_impl = Some(Box::new(|builder, output, split_op| { + F::dyn_split(builder, output, split_op) + })); + F::on_register(self); + + true + } + + pub(super) fn join( + &self, + builder: &mut Builder, + outputs: OutputIter, + ) -> Result + where + OutputIter: IntoIterator, + { + let mut i = outputs.into_iter().peekable(); + let output_type_id = if let Some(o) = i.peek() { + Some(o.type_id.clone()) + } else { + None + }; + + if let Some(output_type_id) = output_type_id { + if let Some(reg) = self.messages.get(&output_type_id) { + reg.operations.join(builder, i) + } else { + Err(DiagramError::NotJoinable) + } + } else { + Err(DiagramError::NotJoinable) + } + } + + /// Register a join function if not already registered, returns true if the new + /// function is registered. + pub(super) fn register_join(&mut self, f: JoinFn) -> bool + where + T: Any, + { + let ops = &mut self + .messages + .entry(TypeId::of::()) + .or_insert(MessageRegistration::new::()) + .operations; + if ops.join_impl.is_some() { + return false; + } + + ops.join_impl = Some(f); + + true + } + + fn serialize_messages( + v: &HashMap, + serializer: S, + ) -> Result + where + S: serde::Serializer, + { + let mut s = serializer.serialize_map(Some(v.len()))?; + for msg in v.values() { + // should we use short name? It makes the serialized json more readable at the cost + // of greatly increased chance of key conflicts. + // let short_name = { + // if let Some(start) = msg.type_name.rfind(":") { + // &msg.type_name[start + 1..] + // } else { + // msg.type_name + // } + // }; + s.serialize_entry(msg.type_name, msg)?; + } + s.end() + } + + fn serialize_schemas(v: &SchemaGenerator, serializer: S) -> Result + where + S: serde::Serializer, + { + v.definitions().serialize(serializer) + } +} + +impl Default for DiagramElementRegistry { + fn default() -> Self { + let mut settings = SchemaSettings::default(); + settings.definitions_path = "#/schemas/".to_string(); + DiagramElementRegistry { + nodes: Default::default(), + messages: MessageRegistry { + schema_generator: SchemaGenerator::new(settings), + messages: HashMap::new(), + }, + } + } +} + +impl DiagramElementRegistry { + pub fn new() -> Self { + Self::default() + } + + /// Register a node builder with all the common operations (deserialize the + /// request, serialize the response, and clone the response) enabled. + /// + /// You will receive a [`RegistrationBuilder`] which you can then use to + /// enable more operations around your node, such as fork result, split, + /// or unzip. The data types of your node need to be suitable for those + /// operations or else the compiler will not allow you to enable them. + /// + /// ``` + /// use bevy_impulse::{NodeBuilderOptions, DiagramElementRegistry}; + /// + /// let mut registry = DiagramElementRegistry::new(); + /// registry.register_node_builder( + /// NodeBuilderOptions::new("echo".to_string()), + /// |builder, _config: ()| builder.create_map_block(|msg: String| msg) + /// ); + /// ``` + /// + /// # Arguments + /// + /// * `id` - Id of the builder, this must be unique. + /// * `name` - Friendly name for the builder, this is only used for display purposes. + /// * `f` - The node builder to register. + pub fn register_node_builder( + &mut self, + options: NodeBuilderOptions, + builder: impl FnMut(&mut Builder, Config) -> Node + 'static, + ) -> NodeRegistrationBuilder + where + Config: JsonSchema + DeserializeOwned, + Request: Send + Sync + 'static + DynType + DeserializeOwned, + Response: Send + Sync + 'static + DynType + Serialize + Clone, + { + self.opt_out().register_node_builder(options, builder) + } + + /// In some cases the common operations of deserialization, serialization, + /// and cloning cannot be performed for the request or response of a node. + /// When that happens you can still register your node builder by calling + /// this function and explicitly disabling the common operations that your + /// node cannot support. + /// + /// + /// In order for the request to be deserializable, it must implement [`schemars::JsonSchema`] and [`serde::de::DeserializeOwned`]. + /// In order for the response to be serializable, it must implement [`schemars::JsonSchema`] and [`serde::Serialize`]. + /// + /// ``` + /// use schemars::JsonSchema; + /// use serde::{Deserialize, Serialize}; + /// + /// #[derive(JsonSchema, Deserialize)] + /// struct DeserializableRequest {} + /// + /// #[derive(JsonSchema, Serialize)] + /// struct SerializableResponse {} + /// ``` + /// + /// If your node have a request or response that is not serializable, there is still + /// a way to register it. + /// + /// ``` + /// use bevy_impulse::{NodeBuilderOptions, DiagramElementRegistry}; + /// + /// struct NonSerializable { + /// data: String + /// } + /// + /// let mut registry = DiagramElementRegistry::new(); + /// registry + /// .opt_out() + /// .no_request_deserializing() + /// .no_response_serializing() + /// .no_response_cloning() + /// .register_node_builder( + /// NodeBuilderOptions::new("echo"), + /// |builder, _config: ()| { + /// builder.create_map_block(|msg: NonSerializable| msg) + /// } + /// ); + /// ``` + /// + /// Note that nodes registered without deserialization cannot be connected + /// to the workflow start, and nodes registered without serialization cannot + /// be connected to the workflow termination. + pub fn opt_out( + &mut self, + ) -> CommonOperations { + CommonOperations { + registry: self, + _ignore: Default::default(), + } + } + + pub fn get_node_registration(&self, id: &Q) -> Result<&NodeRegistration, DiagramError> + where + Q: Borrow + ?Sized, + { + let k = id.borrow(); + self.nodes + .get(k) + .ok_or(DiagramError::BuilderNotFound(k.to_string())) + } + + pub fn get_message_registration(&self) -> Option<&MessageRegistration> + where + T: Any, + { + self.messages.get::() + } +} + +#[non_exhaustive] +pub struct NodeBuilderOptions { + pub id: BuilderId, + pub name: Option, +} + +impl NodeBuilderOptions { + pub fn new(id: impl ToString) -> Self { + Self { + id: id.to_string(), + name: None, + } + } + + pub fn with_name(mut self, name: impl ToString) -> Self { + self.name = Some(name.to_string()); + self + } +} + +#[cfg(test)] +mod tests { + use schemars::JsonSchema; + use serde::Deserialize; + + use super::*; + + fn multiply3(i: i64) -> i64 { + i * 3 + } + + /// Some extra impl only used in tests (for now). + /// If these impls are needed outside tests, then move them to the main impl. + impl MessageOperation { + fn deserializable(&self) -> bool { + self.deserialize_impl.is_some() + } + + fn serializable(&self) -> bool { + self.serialize_impl.is_some() + } + + fn cloneable(&self) -> bool { + self.fork_clone_impl.is_some() + } + + fn unzippable(&self) -> bool { + self.unzip_impl.is_some() + } + + fn can_fork_result(&self) -> bool { + self.fork_result_impl.is_some() + } + + fn splittable(&self) -> bool { + self.split_impl.is_some() + } + + fn joinable(&self) -> bool { + self.join_impl.is_some() + } + } + + #[test] + fn test_register_node_builder() { + let mut registry = DiagramElementRegistry::new(); + registry.opt_out().register_node_builder( + NodeBuilderOptions::new("multiply3").with_name("Test Name"), + |builder, _config: ()| builder.create_map_block(multiply3), + ); + let req_ops = ®istry.messages.get::().unwrap().operations; + let resp_ops = ®istry.messages.get::().unwrap().operations; + assert!(req_ops.deserializable()); + assert!(resp_ops.serializable()); + assert!(resp_ops.cloneable()); + assert!(!resp_ops.unzippable()); + assert!(!resp_ops.can_fork_result()); + assert!(!resp_ops.splittable()); + assert!(!resp_ops.joinable()); + } + + #[test] + fn test_register_cloneable_node() { + let mut registry = DiagramElementRegistry::new(); + registry.register_node_builder( + NodeBuilderOptions::new("multiply3").with_name("Test Name"), + |builder, _config: ()| builder.create_map_block(multiply3), + ); + let req_ops = ®istry.messages.get::().unwrap().operations; + let resp_ops = ®istry.messages.get::().unwrap().operations; + assert!(req_ops.deserializable()); + assert!(resp_ops.serializable()); + assert!(resp_ops.cloneable()); + } + + #[test] + fn test_register_unzippable_node() { + let mut registry = DiagramElementRegistry::new(); + let tuple_resp = |_: ()| -> (i64,) { (1,) }; + registry + .opt_out() + .no_response_cloning() + .register_node_builder( + NodeBuilderOptions::new("multiply3_uncloneable").with_name("Test Name"), + move |builder: &mut Builder, _config: ()| builder.create_map_block(tuple_resp), + ) + .with_unzip(); + let req_ops = ®istry.messages.get::<()>().unwrap().operations; + let resp_ops = ®istry.messages.get::<(i64,)>().unwrap().operations; + assert!(req_ops.deserializable()); + assert!(resp_ops.serializable()); + assert!(resp_ops.unzippable()); + } + + #[test] + fn test_register_splittable_node() { + let mut registry = DiagramElementRegistry::new(); + let vec_resp = |_: ()| -> Vec { vec![1, 2] }; + + registry + .register_node_builder( + NodeBuilderOptions::new("vec_resp").with_name("Test Name"), + move |builder: &mut Builder, _config: ()| builder.create_map_block(vec_resp), + ) + .with_split(); + assert!(registry + .messages + .get::>() + .unwrap() + .operations + .splittable()); + + let map_resp = |_: ()| -> HashMap { HashMap::new() }; + registry + .register_node_builder( + NodeBuilderOptions::new("map_resp").with_name("Test Name"), + move |builder: &mut Builder, _config: ()| builder.create_map_block(map_resp), + ) + .with_split(); + assert!(registry + .messages + .get::>() + .unwrap() + .operations + .splittable()); + + registry.register_node_builder( + NodeBuilderOptions::new("not_splittable").with_name("Test Name"), + move |builder: &mut Builder, _config: ()| builder.create_map_block(map_resp), + ); + // even though we didn't register with `with_split`, it is still splittable because we + // previously registered another splittable node with the same response type. + assert!(registry + .messages + .get::>() + .unwrap() + .operations + .splittable()); + } + + #[test] + fn test_register_with_config() { + let mut registry = DiagramElementRegistry::new(); + + #[derive(Deserialize, JsonSchema)] + struct TestConfig { + by: i64, + } + + registry.register_node_builder( + NodeBuilderOptions::new("multiply").with_name("Test Name"), + move |builder: &mut Builder, config: TestConfig| { + builder.create_map_block(move |operand: i64| operand * config.by) + }, + ); + assert!(registry.get_node_registration("multiply").is_ok()); + } + + struct NonSerializableRequest {} + + #[test] + fn test_register_opaque_node() { + let opaque_request_map = |_: NonSerializableRequest| {}; + + let mut registry = DiagramElementRegistry::new(); + registry + .opt_out() + .no_request_deserializing() + .no_response_cloning() + .register_node_builder( + NodeBuilderOptions::new("opaque_request_map").with_name("Test Name"), + move |builder, _config: ()| builder.create_map_block(opaque_request_map), + ); + assert!(registry.get_node_registration("opaque_request_map").is_ok()); + let req_ops = ®istry + .messages + .get::() + .unwrap() + .operations; + let resp_ops = ®istry.messages.get::<()>().unwrap().operations; + assert!(!req_ops.deserializable()); + assert!(resp_ops.serializable()); + + let opaque_response_map = |_: ()| NonSerializableRequest {}; + registry + .opt_out() + .no_response_serializing() + .no_response_cloning() + .register_node_builder( + NodeBuilderOptions::new("opaque_response_map").with_name("Test Name"), + move |builder: &mut Builder, _config: ()| { + builder.create_map_block(opaque_response_map) + }, + ); + assert!(registry + .get_node_registration("opaque_response_map") + .is_ok()); + let req_ops = ®istry.messages.get::<()>().unwrap().operations; + let resp_ops = ®istry + .messages + .get::() + .unwrap() + .operations; + assert!(req_ops.deserializable()); + assert!(!resp_ops.serializable()); + + let opaque_req_resp_map = |_: NonSerializableRequest| NonSerializableRequest {}; + registry + .opt_out() + .no_request_deserializing() + .no_response_serializing() + .no_response_cloning() + .register_node_builder( + NodeBuilderOptions::new("opaque_req_resp_map").with_name("Test Name"), + move |builder: &mut Builder, _config: ()| { + builder.create_map_block(opaque_req_resp_map) + }, + ); + assert!(registry + .get_node_registration("opaque_req_resp_map") + .is_ok()); + let req_ops = ®istry + .messages + .get::() + .unwrap() + .operations; + let resp_ops = ®istry + .messages + .get::() + .unwrap() + .operations; + assert!(!req_ops.deserializable()); + assert!(!resp_ops.serializable()); + } + + #[test] + fn test_register_message() { + let mut registry = DiagramElementRegistry::new(); + + #[derive(Deserialize, Serialize, JsonSchema, Clone)] + struct TestMessage; + + registry.opt_out().register_message::(); + + let ops = ®istry + .get_message_registration::() + .unwrap() + .operations; + assert!(ops.deserializable()); + assert!(ops.serializable()); + assert!(ops.cloneable()); + assert!(!ops.unzippable()); + assert!(!ops.can_fork_result()); + assert!(!ops.splittable()); + assert!(ops.joinable()); + } + + #[test] + fn test_serialize_registry() { + let mut reg = DiagramElementRegistry::new(); + + #[derive(Deserialize, Serialize, JsonSchema, Clone)] + struct Foo { + hello: String, + } + + #[derive(Deserialize, Serialize, JsonSchema, Clone)] + struct Bar { + foo: Foo, + } + + struct Opaque; + + reg.opt_out() + .no_request_deserializing() + .register_node_builder(NodeBuilderOptions::new("test"), |builder, _config: ()| { + builder.create_map_block(|_: Opaque| { + ( + Foo { + hello: "hello".to_string(), + }, + Bar { + foo: Foo { + hello: "world".to_string(), + }, + }, + ) + }) + }) + .with_unzip(); + + // print out a pretty json for manual inspection + println!("{}", serde_json::to_string_pretty(®).unwrap()); + + // test that schema refs are pointing to the correct path + let value = serde_json::to_value(®).unwrap(); + let messages = &value["messages"]; + let schemas = &value["schemas"]; + let bar_schema = &messages[type_name::()]["schema"]; + assert_eq!(bar_schema["$ref"].as_str().unwrap(), "#/schemas/Bar"); + assert!(schemas.get("Bar").is_some()); + assert!(schemas.get("Foo").is_some()); + } +} diff --git a/src/diagram/serialization.rs b/src/diagram/serialization.rs index a3ba4b22..4c3c2e96 100644 --- a/src/diagram/serialization.rs +++ b/src/diagram/serialization.rs @@ -1,10 +1,5 @@ -use std::any::TypeId; - use schemars::{gen::SchemaGenerator, schema::Schema, JsonSchema}; use serde::{de::DeserializeOwned, Serialize}; -use tracing::debug; - -use super::MessageRegistry; #[derive(thiserror::Error, Debug)] pub enum SerializationError { @@ -35,51 +30,6 @@ where } } -#[derive(Clone, Debug, Serialize)] -pub struct RequestMetadata { - /// The JSON Schema of the request. - pub(super) schema: Schema, - - /// Indicates if the request is deserializable. - pub(super) deserializable: bool, -} - -#[derive(Clone, Debug, Serialize)] -pub struct ResponseMetadata { - /// The JSON Schema of the response. - pub(super) schema: Schema, - - /// Indicates if the response is serializable. - pub(super) serializable: bool, - - /// Indicates if the response is cloneable, a node must have a cloneable response - /// in order to connect it to a "fork clone" operation. - pub(super) cloneable: bool, - - /// The number of unzip slots that a response have, a value of 0 means that the response - /// cannot be unzipped. This should be > 0 only if the response is a tuple. - pub(super) unzip_slots: usize, - - /// Indicates if the response can fork result - pub(super) fork_result: bool, - - /// Indiciates if the response can be split - pub(super) splittable: bool, -} - -impl ResponseMetadata { - pub(super) fn new(schema: Schema, serializable: bool, cloneable: bool) -> ResponseMetadata { - ResponseMetadata { - schema, - serializable, - cloneable, - unzip_slots: 0, - fork_result: false, - splittable: false, - } - } -} - pub trait SerializeMessage { fn type_name() -> String; @@ -189,65 +139,3 @@ impl DeserializeMessage for OpaqueMessageDeserializer { false } } - -pub(super) fn register_deserialize(registry: &mut MessageRegistry) -where - Deserializer: DeserializeMessage, - T: Send + Sync + 'static, -{ - if registry.deserialize_impls.contains_key(&TypeId::of::()) - || !Deserializer::deserializable() - { - return; - } - - debug!( - "register deserialize for type: {}, with deserializer: {}", - std::any::type_name::(), - std::any::type_name::() - ); - registry.deserialize_impls.insert( - TypeId::of::(), - Box::new(|builder, output| { - debug!("deserialize output: {:?}", output); - let receiver = - builder.create_map_block(|json: serde_json::Value| Deserializer::from_json(json)); - builder.connect(output, receiver.input); - let deserialized_output = receiver - .output - .chain(builder) - .cancel_on_err() - .output() - .into(); - debug!("deserialized output: {:?}", deserialized_output); - Ok(deserialized_output) - }), - ); -} - -pub(super) fn register_serialize(registry: &mut MessageRegistry) -where - Serializer: SerializeMessage, - T: Send + Sync + 'static, -{ - if registry.serialize_impls.contains_key(&TypeId::of::()) || !Serializer::serializable() { - return; - } - - debug!( - "register serialize for type: {}, with serializer: {}", - std::any::type_name::(), - std::any::type_name::() - ); - registry.serialize_impls.insert( - TypeId::of::(), - Box::new(|builder, output| { - debug!("serialize output: {:?}", output); - let n = builder.create_map_block(|resp: T| Serializer::to_json(&resp)); - builder.connect(output.into_output()?, n.input); - let serialized_output = n.output.chain(builder).cancel_on_err().output(); - debug!("serialized output: {:?}", serialized_output); - Ok(serialized_output) - }), - ); -} diff --git a/src/diagram/split_serialized.rs b/src/diagram/split_serialized.rs index 585ffe16..333c8be4 100644 --- a/src/diagram/split_serialized.rs +++ b/src/diagram/split_serialized.rs @@ -30,7 +30,7 @@ use crate::{ use super::{ impls::{DefaultImpl, NotSupported}, join::register_join_impl, - register_serialize, DiagramError, DynOutput, MessageRegistry, NextOperation, SerializeMessage, + DiagramError, DynOutput, MessageRegistry, NextOperation, SerializeMessage, }; #[derive(Debug, Serialize, Deserialize, JsonSchema)] @@ -256,7 +256,7 @@ where } fn on_register(registry: &mut MessageRegistry) { - register_serialize::(registry); + registry.register_serialize::(); register_join_impl::(registry); } } diff --git a/src/diagram/testing.rs b/src/diagram/testing.rs index fc0a553a..0e170d36 100644 --- a/src/diagram/testing.rs +++ b/src/diagram/testing.rs @@ -1,16 +1,20 @@ use std::error::Error; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; + use crate::{ testing::TestingContext, Builder, RequestExt, RunCommandsOnWorldExt, Service, StreamPack, }; use super::{ - Diagram, DiagramError, DiagramStart, DiagramTerminate, NodeBuilderOptions, NodeRegistry, + Diagram, DiagramElementRegistry, DiagramError, DiagramStart, DiagramTerminate, + NodeBuilderOptions, }; pub(super) struct DiagramTestFixture { pub(super) context: TestingContext, - pub(super) registry: NodeRegistry, + pub(super) registry: DiagramElementRegistry, } impl DiagramTestFixture { @@ -63,10 +67,17 @@ impl DiagramTestFixture { } } +#[derive(Serialize, Deserialize, JsonSchema)] +struct Uncloneable(T); + fn multiply3(i: i64) -> i64 { i * 3 } +fn multiply3_uncloneable(i: i64) -> Uncloneable { + Uncloneable(i * 3) +} + fn multiply3_5(x: i64) -> (i64, i64) { (x * 3, x * 5) } @@ -84,14 +95,14 @@ fn opaque_response(_: i64) -> Unserializable { } /// create a new node registry with some basic nodes registered -fn new_registry_with_basic_nodes() -> NodeRegistry { - let mut registry = NodeRegistry::default(); +fn new_registry_with_basic_nodes() -> DiagramElementRegistry { + let mut registry = DiagramElementRegistry::new(); registry .opt_out() .no_response_cloning() .register_node_builder( NodeBuilderOptions::new("multiply3_uncloneable"), - |builder: &mut Builder, _config: ()| builder.create_map_block(multiply3), + |builder: &mut Builder, _config: ()| builder.create_map_block(multiply3_uncloneable), ); registry.register_node_builder( NodeBuilderOptions::new("multiply3"), @@ -105,7 +116,7 @@ fn new_registry_with_basic_nodes() -> NodeRegistry { .with_unzip(); registry.register_node_builder( - NodeBuilderOptions::new("multiplyBy"), + NodeBuilderOptions::new("multiply_by"), |builder: &mut Builder, config: i64| builder.create_map_block(move |a: i64| a * config), ); diff --git a/src/diagram/transform.rs b/src/diagram/transform.rs index 9ad17a67..c98d5971 100644 --- a/src/diagram/transform.rs +++ b/src/diagram/transform.rs @@ -8,7 +8,7 @@ use tracing::debug; use crate::{Builder, Output}; -use super::{DiagramError, DynOutput, NextOperation, NodeRegistry}; +use super::{DiagramElementRegistry, DiagramError, DynOutput, NextOperation}; #[derive(Error, Debug)] pub enum TransformError { @@ -31,7 +31,7 @@ pub struct TransformOp { pub(super) fn transform_output( builder: &mut Builder, - registry: &NodeRegistry, + registry: &DiagramElementRegistry, output: DynOutput, transform_op: &TransformOp, ) -> Result, DiagramError> { @@ -40,12 +40,7 @@ pub(super) fn transform_output( let json_output = if output.type_id == TypeId::of::() { output.into_output() } else { - let serialize = registry - .data - .serialize_impls - .get(&output.type_id) - .ok_or(DiagramError::NotSerializable)?; - serialize(builder, output) + registry.messages.serialize(builder, output) }?; let program = Program::compile(&transform_op.cel).map_err(|err| TransformError::Parse(err))?; @@ -90,7 +85,7 @@ mod tests { "ops": { "op1": { "type": "node", - "builder": "multiply3_uncloneable", + "builder": "multiply3", "next": "transform", }, "transform": { diff --git a/src/diagram/unzip.rs b/src/diagram/unzip.rs index 648aca53..1afd11e7 100644 --- a/src/diagram/unzip.rs +++ b/src/diagram/unzip.rs @@ -6,10 +6,9 @@ use tracing::debug; use crate::Builder; use super::{ - impls::{DefaultImpl, NotSupported}, + impls::{DefaultImplMarker, NotSupportedMarker}, join::register_join_impl, - register_serialize as register_serialize_impl, DiagramError, DynOutput, MessageRegistry, - NextOperation, SerializeMessage, + DiagramError, DynOutput, MessageRegistry, NextOperation, SerializeMessage, }; #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] @@ -18,38 +17,51 @@ pub struct UnzipOp { pub(super) next: Vec, } -pub trait DynUnzip { - const UNZIP_SLOTS: usize; +pub trait DynUnzip { + /// Returns a list of type names that this message unzips to. + fn output_types(&self) -> Vec<&'static str>; - fn dyn_unzip(builder: &mut Builder, output: DynOutput) -> Result, DiagramError>; + fn dyn_unzip( + &self, + builder: &mut Builder, + output: DynOutput, + ) -> Result, DiagramError>; /// Called when a node is registered. - fn on_register(registry: &mut MessageRegistry); + fn on_register(&self, registry: &mut MessageRegistry); } -impl DynUnzip for NotSupported { - const UNZIP_SLOTS: usize = 0; +impl DynUnzip for NotSupportedMarker { + fn output_types(&self) -> Vec<&'static str> { + Vec::new() + } fn dyn_unzip( + &self, _builder: &mut Builder, _output: DynOutput, ) -> Result, DiagramError> { Err(DiagramError::NotUnzippable) } - fn on_register(_registry: &mut MessageRegistry) {} + fn on_register(&self, _registry: &mut MessageRegistry) {} } macro_rules! dyn_unzip_impl { ($len:literal, $(($P:ident, $o:ident)),*) => { - impl<$($P),*, Serializer> DynUnzip<($($P,)*), Serializer> for DefaultImpl + impl<$($P),*, Serializer> DynUnzip for DefaultImplMarker<(($($P,)*), Serializer)> where $($P: Send + Sync + 'static),*, Serializer: $(SerializeMessage<$P> +)* $(SerializeMessage> +)*, { - const UNZIP_SLOTS: usize = $len; + fn output_types(&self) -> Vec<&'static str> { + vec![$( + std::any::type_name::<$P>(), + )*] + } fn dyn_unzip( + &self, builder: &mut Builder, output: DynOutput ) -> Result, DiagramError> { @@ -66,12 +78,12 @@ macro_rules! dyn_unzip_impl { Ok(outputs) } - fn on_register(registry: &mut MessageRegistry) + fn on_register(&self, registry: &mut MessageRegistry) { // Register serialize functions for all items in the tuple. // For a tuple of (T1, T2, T3), registers serialize for T1, T2 and T3. $( - register_serialize_impl::<$P, Serializer>(registry); + registry.register_serialize::<$P, Serializer>(); )* // Register join impls for T1, T2, T3... @@ -102,7 +114,7 @@ mod tests { "ops": { "op1": { "type": "node", - "builder": "multiply3_uncloneable", + "builder": "multiply3", "next": "unzip" }, "unzip": { @@ -136,17 +148,17 @@ mod tests { }, "op2": { "type": "node", - "builder": "multiply3_uncloneable", + "builder": "multiply3", "next": { "builtin": "terminate" }, }, "op3": { "type": "node", - "builder": "multiply3_uncloneable", + "builder": "multiply3", "next": { "builtin": "terminate" }, }, "op4": { "type": "node", - "builder": "multiply3_uncloneable", + "builder": "multiply3", "next": { "builtin": "terminate" }, }, }, @@ -203,7 +215,7 @@ mod tests { }, "op2": { "type": "node", - "builder": "multiply3_uncloneable", + "builder": "multiply3", "next": { "builtin": "terminate" }, }, }, @@ -238,7 +250,7 @@ mod tests { }, "op2": { "type": "node", - "builder": "multiply3_uncloneable", + "builder": "multiply3", "next": { "builtin": "terminate" }, }, }, diff --git a/src/diagram/workflow_builder.rs b/src/diagram/workflow_builder.rs index 777853cd..c1156be9 100644 --- a/src/diagram/workflow_builder.rs +++ b/src/diagram/workflow_builder.rs @@ -9,8 +9,8 @@ use crate::{ use super::{ fork_clone::DynForkClone, impls::DefaultImpl, split_chain, transform::transform_output, - BuiltinTarget, Diagram, DiagramError, DiagramOperation, DiagramScope, DynInputSlot, DynOutput, - NextOperation, NodeOp, NodeRegistry, OperationId, SourceOperation, + BuiltinTarget, Diagram, DiagramElementRegistry, DiagramError, DiagramOperation, DiagramScope, + DynInputSlot, DynOutput, NextOperation, NodeOp, OperationId, SourceOperation, }; struct Vertex<'a> { @@ -40,7 +40,7 @@ enum EdgeState<'a> { pub(super) fn create_workflow<'a, Streams: StreamPack>( scope: DiagramScope, builder: &mut Builder, - registry: &NodeRegistry, + registry: &DiagramElementRegistry, diagram: &'a Diagram, ) -> Result<(), DiagramError> { // first create all the vertices @@ -156,7 +156,7 @@ pub(super) fn create_workflow<'a, Streams: StreamPack>( for (op_id, op) in &diagram.ops { match op { DiagramOperation::Node(node_op) => { - let reg = registry.get_registration(&node_op.builder)?; + let reg = registry.get_node_registration(&node_op.builder)?; let n = reg.create_node(builder, node_op.config.clone())?; inputs.insert(op_id, n.input); add_edge( @@ -250,8 +250,8 @@ pub(super) fn create_workflow<'a, Streams: StreamPack>( for edge_id in terminate_edges { let edge = edges.remove(&edge_id).ok_or(unknown_diagram_error!())?; match edge.state { - EdgeState::Ready { output, origin } => { - let serialized_output = serialize(builder, registry, output, origin)?; + EdgeState::Ready { output, origin: _ } => { + let serialized_output = registry.messages.serialize(builder, output)?; builder.connect(serialized_output, scope.terminate); } EdgeState::Pending => return Err(DiagramError::BadInterconnectChain), @@ -263,7 +263,7 @@ pub(super) fn create_workflow<'a, Streams: StreamPack>( fn connect_vertex<'a>( builder: &mut Builder, - registry: &NodeRegistry, + registry: &DiagramElementRegistry, edges: &mut HashMap>, inputs: &HashMap<&OperationId, DynInputSlot>, target: &'a Vertex, @@ -298,10 +298,9 @@ fn connect_vertex<'a>( } let joined_output = if join_op.no_serialize.unwrap_or(false) { - let join_impl = ®istry.data.join_impls[&ordered_outputs[0].type_id]; - join_impl(builder, ordered_outputs)? + registry.messages.join(builder, ordered_outputs)? } else { - serialize_and_join(builder, ®istry.data, ordered_outputs)?.into() + serialize_and_join(builder, ®istry.messages, ordered_outputs)?.into() }; let out_edge = edges @@ -325,7 +324,7 @@ fn connect_vertex<'a>( fn connect_edge<'a>( builder: &mut Builder, - registry: &NodeRegistry, + registry: &DiagramElementRegistry, edges: &mut HashMap>, inputs: &HashMap<&OperationId, DynInputSlot>, edge_id: usize, @@ -354,7 +353,9 @@ fn connect_edge<'a>( DiagramOperation::Node(_) => { let input = inputs[target.op_id]; let deserialized_output = - deserialize(builder, registry, output, target, input.type_id)?; + registry + .messages + .deserialize(&input.type_id, builder, output)?; dyn_connect(builder, deserialized_output, input)?; } DiagramOperation::ForkClone(fork_clone_op) => { @@ -364,14 +365,7 @@ fn connect_edge<'a>( builder, output, amount, ) } else { - let origin = if let Some(origin_node) = origin { - origin_node - } else { - return Err(DiagramError::NotCloneable); - }; - - let reg = registry.get_registration(&origin.builder)?; - reg.fork_clone(builder, output, amount) + registry.messages.fork_clone(builder, output, amount) }?; for (o, e) in outputs.into_iter().zip(target.out_edges.iter()) { let out_edge = edges.get_mut(e).ok_or(unknown_diagram_error!())?; @@ -382,14 +376,7 @@ fn connect_edge<'a>( let outputs = if output.type_id == TypeId::of::() { Err(DiagramError::NotUnzippable) } else { - let origin = if let Some(origin_node) = origin { - origin_node - } else { - return Err(DiagramError::NotUnzippable); - }; - - let reg = registry.get_registration(&origin.builder)?; - reg.unzip(builder, output) + registry.messages.unzip(builder, output) }?; if outputs.len() < unzip_op.next.len() { return Err(DiagramError::NotUnzippable); @@ -403,14 +390,7 @@ fn connect_edge<'a>( let (ok, err) = if output.type_id == TypeId::of::() { Err(DiagramError::CannotForkResult) } else { - let origin = if let Some(origin_node) = origin { - origin_node - } else { - return Err(DiagramError::CannotForkResult); - }; - - let reg = registry.get_registration(&origin.builder)?; - reg.fork_result(builder, output) + registry.messages.fork_result(builder, output) }?; { let out_edge = edges @@ -433,14 +413,7 @@ fn connect_edge<'a>( let chain = output.into_output::()?.chain(builder); split_chain(chain, split_op) } else { - let origin = if let Some(origin_node) = origin { - origin_node - } else { - return Err(DiagramError::NotSplittable); - }; - - let reg = registry.get_registration(&origin.builder)?; - reg.split(builder, output, split_op) + registry.messages.split(builder, output, split_op) }?; // Because of how we build `out_edges`, if the split op uses the `remaining` slot, @@ -513,56 +486,3 @@ fn dyn_connect( builder.connect(typed_output, typed_input); Ok(()) } - -/// Try to deserialize `output` into `input_type`. If `output` is not `serde_json::Value`, this does nothing. -fn deserialize( - builder: &mut Builder, - registry: &NodeRegistry, - output: DynOutput, - target: &Vertex, - input_type: TypeId, -) -> Result { - if output.type_id != TypeId::of::() || output.type_id == input_type { - Ok(output) - } else { - let serialized = output.into_output::()?; - match target.op { - DiagramOperation::Node(node_op) => { - let reg = registry.get_registration(&node_op.builder)?; - if reg.metadata.request.deserializable { - let deserialize_impl = ®istry.data.deserialize_impls[&input_type]; - deserialize_impl(builder, serialized) - } else { - Err(DiagramError::NotSerializable) - } - } - _ => Err(DiagramError::NotSerializable), - } - } -} - -fn serialize( - builder: &mut Builder, - registry: &NodeRegistry, - output: DynOutput, - origin: Option<&NodeOp>, -) -> Result, DiagramError> { - if output.type_id == TypeId::of::() { - output.into_output() - } else { - // Cannot serialize if we don't know the origin, as we need it to know which serialize impl to use. - let origin = if let Some(origin) = origin { - origin - } else { - return Err(DiagramError::NotSerializable); - }; - - let reg = registry.get_registration(&origin.builder)?; - if reg.metadata.response.serializable { - let serialize_impl = ®istry.data.serialize_impls[&output.type_id]; - serialize_impl(builder, output) - } else { - Err(DiagramError::NotSerializable) - } - } -} From e430c1421a34d9722c0f86ec1151ebe391d7bd83 Mon Sep 17 00:00:00 2001 From: Grey Date: Fri, 14 Feb 2025 15:45:58 +0800 Subject: [PATCH 11/20] Lock in specific features for uuid on wasm builds (#54) Signed-off-by: Michael X. Grey --- Cargo.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/Cargo.toml b/Cargo.toml index 8164d7b8..347cbdef 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -54,6 +54,9 @@ tracing = "0.1.41" strum = { version = "0.26.3", optional = true, features = ["derive"] } semver = { version = "1.0.24", optional = true } +[target.'cfg(target_arch = "wasm32")'.dependencies] +uuid = { version = "1.13.1", default-features = false, features = ["js"] } + [features] single_threaded_async = ["dep:async-task"] diagram = [ From 0ee3fe2c8fc73d65f400420fe4b48f249a9922bd Mon Sep 17 00:00:00 2001 From: Teo Koon Peng Date: Tue, 18 Feb 2025 15:56:03 +0800 Subject: [PATCH 12/20] avoid conflict in take() (#55) Signed-off-by: Teo Koon Peng --- src/builder.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/builder.rs b/src/builder.rs index 0661023a..4476d060 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -1198,7 +1198,9 @@ mod tests { context.command(|commands| commands.request(input, workflow).take_response()); context.run_with_conditions(&mut promise, Duration::from_secs(2)); - assert!(promise.take().available().is_some_and(|v| v == expectation)); + assert!(Promise::take(&mut promise) + .available() + .is_some_and(|v| v == expectation)); assert!(context.no_unhandled_errors()); } } From b66535c8032cbd407b0756e967b950c7759801bf Mon Sep 17 00:00:00 2001 From: Grey Date: Fri, 21 Feb 2025 11:50:18 +0800 Subject: [PATCH 13/20] Support for dynamic dictionaries of buffers (#52) Signed-off-by: Michael X. Grey Signed-off-by: Teo Koon Peng Co-authored-by: Teo Koon Peng --- .github/workflows/ci_linux.yaml | 2 +- macros/Cargo.toml | 1 + macros/src/buffer.rs | 493 +++++++ macros/src/lib.rs | 32 +- src/buffer.rs | 300 +++- src/buffer/any_buffer.rs | 1327 ++++++++++++++++++ src/buffer/buffer_access_lifecycle.rs | 30 +- src/buffer/buffer_key_builder.rs | 8 +- src/buffer/buffer_map.rs | 894 ++++++++++++ src/buffer/buffer_storage.rs | 116 +- src/buffer/bufferable.rs | 210 ++- src/buffer/{buffered.rs => buffering.rs} | 560 +++++--- src/buffer/json_buffer.rs | 1598 ++++++++++++++++++++++ src/buffer/manage_buffer.rs | 9 + src/builder.rs | 150 +- src/chain.rs | 27 +- src/diagram.rs | 64 +- src/diagram/registration.rs | 2 +- src/gate.rs | 4 +- src/lib.rs | 15 +- src/operation/cleanup.rs | 4 +- src/operation/join.rs | 4 +- src/operation/listen.rs | 4 +- src/operation/operate_buffer_access.rs | 22 +- src/operation/operate_gate.rs | 6 +- src/operation/scope.rs | 19 +- src/re_exports.rs | 21 + src/testing.rs | 113 +- 28 files changed, 5448 insertions(+), 587 deletions(-) create mode 100644 macros/src/buffer.rs create mode 100644 src/buffer/any_buffer.rs create mode 100644 src/buffer/buffer_map.rs rename src/buffer/{buffered.rs => buffering.rs} (54%) create mode 100644 src/buffer/json_buffer.rs create mode 100644 src/re_exports.rs diff --git a/.github/workflows/ci_linux.yaml b/.github/workflows/ci_linux.yaml index 116de865..660d6933 100644 --- a/.github/workflows/ci_linux.yaml +++ b/.github/workflows/ci_linux.yaml @@ -50,5 +50,5 @@ jobs: run: cargo test --features single_threaded_async - name: Build docs - run: cargo doc + run: cargo doc --all-features diff --git a/macros/Cargo.toml b/macros/Cargo.toml index 12fa2bb3..c8627251 100644 --- a/macros/Cargo.toml +++ b/macros/Cargo.toml @@ -16,3 +16,4 @@ proc-macro = true [dependencies] syn = "2.0" quote = "1.0" +proc-macro2 = "1.0.93" diff --git a/macros/src/buffer.rs b/macros/src/buffer.rs new file mode 100644 index 00000000..b231da09 --- /dev/null +++ b/macros/src/buffer.rs @@ -0,0 +1,493 @@ +use proc_macro2::TokenStream; +use quote::{format_ident, quote}; +use syn::{ + parse_quote, Field, Generics, Ident, ImplGenerics, ItemStruct, Type, TypeGenerics, TypePath, + Visibility, WhereClause, +}; + +use crate::Result; + +const JOINED_ATTR_TAG: &'static str = "joined"; +const KEY_ATTR_TAG: &'static str = "key"; + +pub(crate) fn impl_joined_value(input_struct: &ItemStruct) -> Result { + let struct_ident = &input_struct.ident; + let (impl_generics, ty_generics, where_clause) = input_struct.generics.split_for_impl(); + let StructConfig { + buffer_struct_name: buffer_struct_ident, + } = StructConfig::from_data_struct(&input_struct, &JOINED_ATTR_TAG); + let buffer_struct_vis = &input_struct.vis; + + let (field_ident, _, field_config) = + get_fields_map(&input_struct.fields, FieldSettings::for_joined())?; + let buffer: Vec<&Type> = field_config.iter().map(|config| &config.buffer).collect(); + let noncopy = field_config.iter().any(|config| config.noncopy); + + let buffer_struct: ItemStruct = generate_buffer_struct( + &buffer_struct_ident, + buffer_struct_vis, + &impl_generics, + &where_clause, + &field_ident, + &buffer, + ); + + let impl_buffer_clone = impl_buffer_clone( + &buffer_struct_ident, + &impl_generics, + &ty_generics, + &where_clause, + &field_ident, + noncopy, + ); + + let impl_select_buffers = impl_select_buffers( + struct_ident, + &buffer_struct_ident, + buffer_struct_vis, + &impl_generics, + &ty_generics, + &where_clause, + &field_ident, + &buffer, + ); + + let impl_buffer_map_layout = + impl_buffer_map_layout(&buffer_struct, &field_ident, &field_config)?; + let impl_joined = impl_joined(&buffer_struct, &input_struct, &field_ident)?; + + let gen = quote! { + impl #impl_generics ::bevy_impulse::Joined for #struct_ident #ty_generics #where_clause { + type Buffers = #buffer_struct_ident #ty_generics; + } + + #buffer_struct + + #impl_buffer_clone + + #impl_select_buffers + + #impl_buffer_map_layout + + #impl_joined + }; + + Ok(gen.into()) +} + +pub(crate) fn impl_buffer_key_map(input_struct: &ItemStruct) -> Result { + let struct_ident = &input_struct.ident; + let (impl_generics, ty_generics, where_clause) = input_struct.generics.split_for_impl(); + let StructConfig { + buffer_struct_name: buffer_struct_ident, + } = StructConfig::from_data_struct(&input_struct, &KEY_ATTR_TAG); + let buffer_struct_vis = &input_struct.vis; + + let (field_ident, field_type, field_config) = + get_fields_map(&input_struct.fields, FieldSettings::for_key())?; + let buffer: Vec<&Type> = field_config.iter().map(|config| &config.buffer).collect(); + let noncopy = field_config.iter().any(|config| config.noncopy); + + let buffer_struct: ItemStruct = generate_buffer_struct( + &buffer_struct_ident, + buffer_struct_vis, + &impl_generics, + &where_clause, + &field_ident, + &buffer, + ); + + let impl_buffer_clone = impl_buffer_clone( + &buffer_struct_ident, + &impl_generics, + &ty_generics, + &where_clause, + &field_ident, + noncopy, + ); + + let impl_select_buffers = impl_select_buffers( + struct_ident, + &buffer_struct_ident, + buffer_struct_vis, + &impl_generics, + &ty_generics, + &where_clause, + &field_ident, + &buffer, + ); + + let impl_buffer_map_layout = + impl_buffer_map_layout(&buffer_struct, &field_ident, &field_config)?; + let impl_accessed = impl_accessed(&buffer_struct, &input_struct, &field_ident, &field_type)?; + + let gen = quote! { + impl #impl_generics ::bevy_impulse::Accessor for #struct_ident #ty_generics #where_clause { + type Buffers = #buffer_struct_ident #ty_generics; + } + + #buffer_struct + + #impl_buffer_clone + + #impl_select_buffers + + #impl_buffer_map_layout + + #impl_accessed + }; + + Ok(gen.into()) +} + +/// Code that are currently unused but could be used in the future, move them out of this mod if +/// they are ever used. +#[allow(unused)] +mod _unused { + use super::*; + + /// Converts a list of generics to a [`PhantomData`] TypePath. + /// e.g. `::std::marker::PhantomData` + fn to_phantom_data(generics: &Generics) -> TypePath { + let lifetimes: Vec = generics + .lifetimes() + .map(|lt| { + let lt = <.lifetime; + let ty: Type = parse_quote! { & #lt () }; + ty + }) + .collect(); + let ty_params: Vec<&Ident> = generics.type_params().map(|ty| &ty.ident).collect(); + parse_quote! { ::std::marker::PhantomData } + } +} + +struct StructConfig { + buffer_struct_name: Ident, +} + +impl StructConfig { + fn from_data_struct(data_struct: &ItemStruct, attr_tag: &str) -> Self { + let mut config = Self { + buffer_struct_name: format_ident!("__bevy_impulse_{}_Buffers", data_struct.ident), + }; + + let attr = data_struct + .attrs + .iter() + .find(|attr| attr.path().is_ident(attr_tag)); + + if let Some(attr) = attr { + attr.parse_nested_meta(|meta| { + if meta.path.is_ident("buffers_struct_name") { + config.buffer_struct_name = meta.value()?.parse()?; + } + Ok(()) + }) + // panic if attribute is malformed, this will result in a compile error which is intended. + .unwrap(); + } + + config + } +} + +struct FieldSettings { + default_buffer: fn(&Type) -> Type, + attr_tag: &'static str, +} + +impl FieldSettings { + fn for_joined() -> Self { + Self { + default_buffer: Self::default_field_for_joined, + attr_tag: JOINED_ATTR_TAG, + } + } + + fn for_key() -> Self { + Self { + default_buffer: Self::default_field_for_key, + attr_tag: KEY_ATTR_TAG, + } + } + + fn default_field_for_joined(ty: &Type) -> Type { + parse_quote! { ::bevy_impulse::Buffer<#ty> } + } + + fn default_field_for_key(ty: &Type) -> Type { + parse_quote! { <#ty as ::bevy_impulse::BufferKeyLifecycle>::TargetBuffer } + } +} + +struct FieldConfig { + buffer: Type, + noncopy: bool, +} + +impl FieldConfig { + fn from_field(field: &Field, settings: &FieldSettings) -> Self { + let ty = &field.ty; + let mut config = Self { + buffer: (settings.default_buffer)(ty), + noncopy: false, + }; + + for attr in field + .attrs + .iter() + .filter(|attr| attr.path().is_ident(settings.attr_tag)) + { + attr.parse_nested_meta(|meta| { + if meta.path.is_ident("buffer") { + config.buffer = meta.value()?.parse()?; + } + if meta.path.is_ident("noncopy_buffer") { + config.noncopy = true; + } + Ok(()) + }) + // panic if attribute is malformed, this will result in a compile error which is intended. + .unwrap(); + } + + config + } +} + +fn get_fields_map( + fields: &syn::Fields, + settings: FieldSettings, +) -> Result<(Vec<&Ident>, Vec<&Type>, Vec)> { + match fields { + syn::Fields::Named(data) => { + let mut idents = Vec::new(); + let mut types = Vec::new(); + let mut configs = Vec::new(); + for field in &data.named { + let ident = field + .ident + .as_ref() + .ok_or("expected named fields".to_string())?; + idents.push(ident); + types.push(&field.ty); + configs.push(FieldConfig::from_field(field, &settings)); + } + Ok((idents, types, configs)) + } + _ => return Err("expected named fields".to_string()), + } +} + +fn generate_buffer_struct( + buffer_struct_ident: &Ident, + buffer_struct_vis: &Visibility, + impl_generics: &ImplGenerics, + where_clause: &Option<&WhereClause>, + field_ident: &Vec<&Ident>, + buffer: &Vec<&Type>, +) -> ItemStruct { + parse_quote! { + #[allow(non_camel_case_types, unused)] + #buffer_struct_vis struct #buffer_struct_ident #impl_generics #where_clause { + #( + #buffer_struct_vis #field_ident: #buffer, + )* + } + } +} + +fn impl_select_buffers( + struct_ident: &Ident, + buffer_struct_ident: &Ident, + buffer_struct_vis: &Visibility, + impl_generics: &ImplGenerics, + ty_generics: &TypeGenerics, + where_clause: &Option<&WhereClause>, + field_ident: &Vec<&Ident>, + buffer: &Vec<&Type>, +) -> TokenStream { + quote! { + impl #impl_generics #struct_ident #ty_generics #where_clause { + #buffer_struct_vis fn select_buffers( + #( + #field_ident: #buffer, + )* + ) -> #buffer_struct_ident #ty_generics { + #buffer_struct_ident { + #( + #field_ident, + )* + } + } + } + } + .into() +} + +fn impl_buffer_clone( + buffer_struct_ident: &Ident, + impl_generics: &ImplGenerics, + ty_generics: &TypeGenerics, + where_clause: &Option<&WhereClause>, + field_ident: &Vec<&Ident>, + noncopy: bool, +) -> TokenStream { + if noncopy { + // Clone impl for structs with a buffer that is not copyable + quote! { + impl #impl_generics ::std::clone::Clone for #buffer_struct_ident #ty_generics #where_clause { + fn clone(&self) -> Self { + Self { + #( + #field_ident: self.#field_ident.clone(), + )* + } + } + } + } + } else { + // Clone and copy impl for structs with buffers that are all copyable + quote! { + impl #impl_generics ::std::clone::Clone for #buffer_struct_ident #ty_generics #where_clause { + fn clone(&self) -> Self { + *self + } + } + + impl #impl_generics ::std::marker::Copy for #buffer_struct_ident #ty_generics #where_clause {} + } + } +} + +/// Params: +/// buffer_struct: The struct to implement `BufferMapLayout`. +/// item_struct: The struct which `buffer_struct` is derived from. +/// settings: [`FieldSettings`] to use when parsing the field attributes +fn impl_buffer_map_layout( + buffer_struct: &ItemStruct, + field_ident: &Vec<&Ident>, + field_config: &Vec, +) -> Result { + let struct_ident = &buffer_struct.ident; + let (impl_generics, ty_generics, where_clause) = buffer_struct.generics.split_for_impl(); + let buffer: Vec<&Type> = field_config.iter().map(|config| &config.buffer).collect(); + let map_key: Vec = field_ident.iter().map(|v| v.to_string()).collect(); + + Ok(quote! { + impl #impl_generics ::bevy_impulse::BufferMapLayout for #struct_ident #ty_generics #where_clause { + fn try_from_buffer_map(buffers: &::bevy_impulse::BufferMap) -> Result { + let mut compatibility = ::bevy_impulse::IncompatibleLayout::default(); + #( + let #field_ident = compatibility.require_buffer_for_identifier::<#buffer>(#map_key, buffers); + )* + + // Unwrap the Ok after inspecting every field so that the + // IncompatibleLayout error can include all information about + // which fields were incompatible. + #( + let Ok(#field_ident) = #field_ident else { + return Err(compatibility); + }; + )* + + Ok(Self { + #( + #field_ident, + )* + }) + } + } + + impl #impl_generics ::bevy_impulse::BufferMapStruct for #struct_ident #ty_generics #where_clause { + fn buffer_list(&self) -> ::smallvec::SmallVec<[AnyBuffer; 8]> { + use smallvec::smallvec; + smallvec![#( + ::bevy_impulse::AsAnyBuffer::as_any_buffer(&self.#field_ident), + )*] + } + } + } + .into()) +} + +/// Params: +/// joined_struct: The struct to implement `Joining`. +/// item_struct: The associated `Item` type to use for the `Joining` implementation. +fn impl_joined( + joined_struct: &ItemStruct, + item_struct: &ItemStruct, + field_ident: &Vec<&Ident>, +) -> Result { + let struct_ident = &joined_struct.ident; + let item_struct_ident = &item_struct.ident; + let (impl_generics, ty_generics, where_clause) = item_struct.generics.split_for_impl(); + + Ok(quote! { + impl #impl_generics ::bevy_impulse::Joining for #struct_ident #ty_generics #where_clause { + type Item = #item_struct_ident #ty_generics; + + fn pull(&self, session: ::bevy_impulse::re_exports::Entity, world: &mut ::bevy_impulse::re_exports::World) -> Result { + #( + let #field_ident = self.#field_ident.pull(session, world)?; + )* + + Ok(Self::Item {#( + #field_ident, + )*}) + } + } + }.into()) +} + +fn impl_accessed( + accessed_struct: &ItemStruct, + key_struct: &ItemStruct, + field_ident: &Vec<&Ident>, + field_type: &Vec<&Type>, +) -> Result { + let struct_ident = &accessed_struct.ident; + let key_struct_ident = &key_struct.ident; + let (impl_generics, ty_generics, where_clause) = key_struct.generics.split_for_impl(); + + Ok(quote! { + impl #impl_generics ::bevy_impulse::Accessing for #struct_ident #ty_generics #where_clause { + type Key = #key_struct_ident #ty_generics; + + fn add_accessor( + &self, + accessor: ::bevy_impulse::re_exports::Entity, + world: &mut ::bevy_impulse::re_exports::World, + ) -> ::bevy_impulse::OperationResult { + #( + ::bevy_impulse::Accessing::add_accessor(&self.#field_ident, accessor, world)?; + )* + Ok(()) + } + + fn create_key(&self, builder: &::bevy_impulse::BufferKeyBuilder) -> Self::Key { + Self::Key {#( + // TODO(@mxgrey): This currently does not have good support for the user + // substituting in a different key type than what the BufferKeyLifecycle expects. + // We could consider adding a .clone().into() to help support that use case, but + // this would be such a niche use case that I think we can ignore it for now. + #field_ident: <#field_type as ::bevy_impulse::BufferKeyLifecycle>::create_key(&self.#field_ident, builder), + )*} + } + + fn deep_clone_key(key: &Self::Key) -> Self::Key { + Self::Key {#( + #field_ident: ::bevy_impulse::BufferKeyLifecycle::deep_clone(&key.#field_ident), + )*} + } + + fn is_key_in_use(key: &Self::Key) -> bool { + false + #( + || ::bevy_impulse::BufferKeyLifecycle::is_in_use(&key.#field_ident) + )* + } + } + }.into()) +} diff --git a/macros/src/lib.rs b/macros/src/lib.rs index d40c9309..58873049 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -15,9 +15,12 @@ * */ +mod buffer; +use buffer::{impl_buffer_key_map, impl_joined_value}; + use proc_macro::TokenStream; use quote::quote; -use syn::DeriveInput; +use syn::{parse_macro_input, DeriveInput, ItemStruct}; #[proc_macro_derive(Stream)] pub fn simple_stream_macro(item: TokenStream) -> TokenStream { @@ -58,3 +61,30 @@ pub fn delivery_label_macro(item: TokenStream) -> TokenStream { } .into() } + +/// The result error is the compiler error message to be displayed. +type Result = std::result::Result; + +#[proc_macro_derive(Joined, attributes(joined))] +pub fn derive_joined_value(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as ItemStruct); + match impl_joined_value(&input) { + Ok(tokens) => tokens.into(), + Err(msg) => quote! { + compile_error!(#msg); + } + .into(), + } +} + +#[proc_macro_derive(Accessor, attributes(key))] +pub fn derive_buffer_key_map(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as ItemStruct); + match impl_buffer_key_map(&input) { + Ok(tokens) => tokens.into(), + Err(msg) => quote! { + compile_error!(#msg); + } + .into(), + } +} diff --git a/src/buffer.rs b/src/buffer.rs index bb0fe2b8..1d33c28d 100644 --- a/src/buffer.rs +++ b/src/buffer.rs @@ -17,28 +17,37 @@ use bevy_ecs::{ change_detection::Mut, - prelude::{Commands, Entity, Query}, + prelude::{Commands, Entity, Query, World}, query::QueryEntityError, - system::SystemParam, + system::{SystemParam, SystemState}, }; use std::{ops::RangeBounds, sync::Arc}; +use thiserror::Error as ThisError; + use crate::{ Builder, Chain, Gate, GateState, InputSlot, NotifyBufferUpdate, OnNewBufferValue, UnusedTarget, }; +mod any_buffer; +pub use any_buffer::*; + mod buffer_access_lifecycle; +pub use buffer_access_lifecycle::BufferKeyLifecycle; pub(crate) use buffer_access_lifecycle::*; mod buffer_key_builder; -pub(crate) use buffer_key_builder::*; +pub use buffer_key_builder::*; + +mod buffer_map; +pub use buffer_map::*; mod buffer_storage; pub(crate) use buffer_storage::*; -mod buffered; -pub use buffered::*; +mod buffering; +pub use buffering::*; mod bufferable; pub use bufferable::*; @@ -46,12 +55,16 @@ pub use bufferable::*; mod manage_buffer; pub use manage_buffer::*; +#[cfg(feature = "diagram")] +mod json_buffer; +#[cfg(feature = "diagram")] +pub use json_buffer::*; + /// A buffer is a special type of node within a workflow that is able to store /// and release data. When a session is finished, the buffered data from the /// session will be automatically cleared. pub struct Buffer { - pub(crate) scope: Entity, - pub(crate) source: Entity, + pub(crate) location: BufferLocation, pub(crate) _ignore: std::marker::PhantomData, } @@ -61,11 +74,11 @@ impl Buffer { &self, builder: &'b mut Builder<'w, 's, 'a>, ) -> Chain<'w, 's, 'a, 'b, ()> { - assert_eq!(self.scope, builder.scope); + assert_eq!(self.scope(), builder.scope); let target = builder.commands.spawn(UnusedTarget).id(); builder .commands - .add(OnNewBufferValue::new(self.source, target)); + .add(OnNewBufferValue::new(self.id(), target)); Chain::new(target, builder) } @@ -77,24 +90,86 @@ impl Buffer { T: Clone, { CloneFromBuffer { - scope: self.scope, - source: self.source, + location: self.location, _ignore: Default::default(), } } /// Get an input slot for this buffer. pub fn input_slot(self) -> InputSlot { - InputSlot::new(self.scope, self.source) + InputSlot::new(self.scope(), self.id()) + } + + /// Get the entity ID of the buffer. + pub fn id(&self) -> Entity { + self.location.source + } + + /// Get the ID of the workflow that the buffer is associated with. + pub fn scope(&self) -> Entity { + self.location.scope + } + + /// Get general information about the buffer. + pub fn location(&self) -> BufferLocation { + self.location + } +} + +impl Clone for Buffer { + fn clone(&self) -> Self { + *self } } +impl Copy for Buffer {} + +/// The general identifying information for a buffer to locate it within the +/// world. This does not indicate anything about the type of messages that the +/// buffer can contain. +#[derive(Clone, Copy, Debug)] +pub struct BufferLocation { + /// The entity ID of the buffer. + pub scope: Entity, + /// The ID of the workflow that the buffer is associated with. + pub source: Entity, +} + +#[derive(Clone)] pub struct CloneFromBuffer { - pub(crate) scope: Entity, - pub(crate) source: Entity, + pub(crate) location: BufferLocation, pub(crate) _ignore: std::marker::PhantomData, } +// +impl Copy for CloneFromBuffer {} + +impl CloneFromBuffer { + /// Get the entity ID of the buffer. + pub fn id(&self) -> Entity { + self.location.source + } + + /// Get the ID of the workflow that the buffer is associated with. + pub fn scope(&self) -> Entity { + self.location.scope + } + + /// Get general information about the buffer. + pub fn location(&self) -> BufferLocation { + self.location + } +} + +impl From> for Buffer { + fn from(value: CloneFromBuffer) -> Self { + Buffer { + location: value.location, + _ignore: Default::default(), + } + } +} + /// Settings to describe the behavior of a buffer. #[derive(Default, Clone, Copy)] pub struct BufferSettings { @@ -157,44 +232,22 @@ impl Default for RetentionPolicy { } } -impl Clone for Buffer { - fn clone(&self) -> Self { - *self - } -} - -impl Copy for Buffer {} - -impl Clone for CloneFromBuffer { - fn clone(&self) -> Self { - *self - } -} - -impl Copy for CloneFromBuffer {} - /// This key can unlock access to the contents of a buffer by passing it into /// [`BufferAccess`] or [`BufferAccessMut`]. /// /// To obtain a `BufferKey`, use [`Chain::with_access`][1], or [`listen`][2]. /// /// [1]: crate::Chain::with_access -/// [2]: crate::Bufferable::listen +/// [2]: crate::Accessible::listen pub struct BufferKey { - buffer: Entity, - session: Entity, - accessor: Entity, - lifecycle: Option>, + tag: BufferKeyTag, _ignore: std::marker::PhantomData, } impl Clone for BufferKey { fn clone(&self) -> Self { Self { - buffer: self.buffer, - session: self.session, - accessor: self.accessor, - lifecycle: self.lifecycle.as_ref().map(Arc::clone), + tag: self.tag.clone(), _ignore: Default::default(), } } @@ -202,28 +255,67 @@ impl Clone for BufferKey { impl BufferKey { /// The buffer ID of this key. - pub fn id(&self) -> Entity { - self.buffer + pub fn buffer(&self) -> Entity { + self.tag.buffer } /// The session that this key belongs to. pub fn session(&self) -> Entity { - self.session + self.tag.session + } + + pub fn tag(&self) -> &BufferKeyTag { + &self.tag + } +} + +impl BufferKeyLifecycle for BufferKey { + type TargetBuffer = Buffer; + + fn create_key(buffer: &Self::TargetBuffer, builder: &BufferKeyBuilder) -> Self { + BufferKey { + tag: builder.make_tag(buffer.id()), + _ignore: Default::default(), + } + } + + fn is_in_use(&self) -> bool { + self.tag.is_in_use() } - pub(crate) fn is_in_use(&self) -> bool { + fn deep_clone(&self) -> Self { + Self { + tag: self.tag.deep_clone(), + _ignore: Default::default(), + } + } +} + +impl std::fmt::Debug for BufferKey { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("BufferKey") + .field("message_type_name", &std::any::type_name::()) + .field("tag", &self.tag) + .finish() + } +} + +/// The identifying information for a buffer key. This does not indicate +/// anything about the type of messages that the buffer can contain. +#[derive(Clone)] +pub struct BufferKeyTag { + pub buffer: Entity, + pub session: Entity, + pub accessor: Entity, + pub lifecycle: Option>, +} + +impl BufferKeyTag { + pub fn is_in_use(&self) -> bool { self.lifecycle.as_ref().is_some_and(|l| l.is_in_use()) } - // We do a deep clone of the key when distributing it to decouple the - // lifecycle of the keys that we send out from the key that's held by the - // accessor node. - // - // The key instance held by the accessor node will never be dropped until - // the session is cleaned up, so the keys that we send out into the workflow - // need to have their own independent lifecycles or else we won't detect - // when the workflow has dropped them. - pub(crate) fn deep_clone(&self) -> Self { + pub fn deep_clone(&self) -> Self { let mut deep = self.clone(); deep.lifecycle = self .lifecycle @@ -233,6 +325,17 @@ impl BufferKey { } } +impl std::fmt::Debug for BufferKeyTag { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("BufferKeyTag") + .field("buffer", &self.buffer) + .field("session", &self.session) + .field("accessor", &self.accessor) + .field("in_use", &self.is_in_use()) + .finish() + } +} + /// This system parameter lets you get read-only access to a buffer that exists /// within a workflow. Use a [`BufferKey`] to unlock the access. /// @@ -247,9 +350,9 @@ where impl<'w, 's, T: 'static + Send + Sync> BufferAccess<'w, 's, T> { pub fn get<'a>(&'a self, key: &BufferKey) -> Result, QueryEntityError> { - let session = key.session; + let session = key.session(); self.query - .get(key.buffer) + .get(key.buffer()) .map(|(storage, gate)| BufferView { storage, gate, @@ -276,9 +379,9 @@ where T: 'static + Send + Sync, { pub fn get<'a>(&'a self, key: &BufferKey) -> Result, QueryEntityError> { - let session = key.session; + let session = key.session(); self.query - .get(key.buffer) + .get(key.buffer()) .map(|(storage, gate)| BufferView { storage, gate, @@ -290,15 +393,76 @@ where &'a mut self, key: &BufferKey, ) -> Result, QueryEntityError> { - let buffer = key.buffer; - let session = key.session; - let accessor = key.accessor; - self.query.get_mut(key.buffer).map(|(storage, gate)| { + let buffer = key.buffer(); + let session = key.session(); + let accessor = key.tag.accessor; + self.query.get_mut(key.buffer()).map(|(storage, gate)| { BufferMut::new(storage, gate, buffer, session, accessor, &mut self.commands) }) } } +/// This trait allows [`World`] to give you access to any buffer using a [`BufferKey`] +pub trait BufferWorldAccess { + /// Call this to get read-only access to a buffer from a [`World`]. + /// + /// Alternatively you can use [`BufferAccess`] as a regular bevy system parameter, + /// which does not need direct world access. + fn buffer_view(&self, key: &BufferKey) -> Result, BufferError> + where + T: 'static + Send + Sync; + + /// Call this to get mutable access to a buffer. + /// + /// Pass in a callback that will receive [`BufferMut`], allowing it to view + /// and modify the contents of the buffer. + fn buffer_mut( + &mut self, + key: &BufferKey, + f: impl FnOnce(BufferMut) -> U, + ) -> Result + where + T: 'static + Send + Sync; +} + +impl BufferWorldAccess for World { + fn buffer_view(&self, key: &BufferKey) -> Result, BufferError> + where + T: 'static + Send + Sync, + { + let buffer_ref = self + .get_entity(key.tag.buffer) + .ok_or(BufferError::BufferMissing)?; + let storage = buffer_ref + .get::>() + .ok_or(BufferError::BufferMissing)?; + let gate = buffer_ref + .get::() + .ok_or(BufferError::BufferMissing)?; + Ok(BufferView { + storage, + gate, + session: key.tag.session, + }) + } + + fn buffer_mut( + &mut self, + key: &BufferKey, + f: impl FnOnce(BufferMut) -> U, + ) -> Result + where + T: 'static + Send + Sync, + { + let mut state = SystemState::>::new(self); + let mut buffer_access_mut = state.get_mut(self); + let buffer_mut = buffer_access_mut + .get_mut(key) + .map_err(|_| BufferError::BufferMissing)?; + Ok(f(buffer_mut)) + } +} + /// Access to view a buffer that exists inside a workflow. pub struct BufferView<'a, T> where @@ -424,7 +588,7 @@ where self.len() == 0 } - /// Check whether the gate of this buffer is open or closed + /// Check whether the gate of this buffer is open or closed. pub fn gate(&self) -> Gate { self.gate .map @@ -467,7 +631,7 @@ where self.storage.drain(self.session, range) } - /// Pull the oldest item from the buffer + /// Pull the oldest item from the buffer. pub fn pull(&mut self) -> Option { self.modified = true; self.storage.pull(self.session) @@ -500,7 +664,7 @@ where // continuous systems with BufferAccessMut from running at the same time no // matter what the buffer type is. - /// Tell the buffer [`Gate`] to open + /// Tell the buffer [`Gate`] to open. pub fn open_gate(&mut self) { if let Some(gate) = self.gate.map.get_mut(&self.session) { if *gate != Gate::Open { @@ -510,7 +674,7 @@ where } } - /// Tell the buffer [`Gate`] to close + /// Tell the buffer [`Gate`] to close. pub fn close_gate(&mut self) { if let Some(gate) = self.gate.map.get_mut(&self.session) { *gate = Gate::Closed; @@ -519,7 +683,7 @@ where } } - /// Perform an action on the gate of the buffer + /// Perform an action on the gate of the buffer. pub fn gate_action(&mut self, action: Gate) { match action { Gate::Open => self.open_gate(), @@ -569,6 +733,12 @@ where } } +#[derive(ThisError, Debug, Clone)] +pub enum BufferError { + #[error("The key was unable to identify a buffer")] + BufferMissing, +} + #[cfg(test)] mod tests { use crate::{prelude::*, testing::*, Gate}; diff --git a/src/buffer/any_buffer.rs b/src/buffer/any_buffer.rs new file mode 100644 index 00000000..efde9907 --- /dev/null +++ b/src/buffer/any_buffer.rs @@ -0,0 +1,1327 @@ +/* + * Copyright (C) 2025 Open Source Robotics Foundation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * +*/ + +// TODO(@mxgrey): Add module-level documentation describing how to use AnyBuffer + +use std::{ + any::{Any, TypeId}, + collections::{hash_map::Entry, HashMap}, + ops::RangeBounds, + sync::{Mutex, OnceLock}, +}; + +use bevy_ecs::{ + prelude::{Commands, Entity, EntityRef, EntityWorldMut, Mut, World}, + system::SystemState, +}; + +use thiserror::Error as ThisError; + +use smallvec::SmallVec; + +use crate::{ + add_listener_to_source, Accessing, Buffer, BufferAccessMut, BufferAccessors, BufferError, + BufferKey, BufferKeyBuilder, BufferKeyLifecycle, BufferKeyTag, BufferLocation, BufferStorage, + Bufferable, Buffering, Builder, DrainBuffer, Gate, GateState, InspectBuffer, Joining, + ManageBuffer, NotifyBufferUpdate, OperationError, OperationResult, OperationRoster, OrBroken, +}; + +/// A [`Buffer`] whose message type has been anonymized. Joining with this buffer +/// type will yield an [`AnyMessageBox`]. +#[derive(Clone, Copy)] +pub struct AnyBuffer { + pub(crate) location: BufferLocation, + pub(crate) interface: &'static (dyn AnyBufferAccessInterface + Send + Sync), +} + +impl AnyBuffer { + /// The buffer ID for this key. + pub fn id(&self) -> Entity { + self.location.source + } + + /// ID of the workflow that this buffer is associated with. + pub fn scope(&self) -> Entity { + self.location.scope + } + + /// Get the type ID of the messages that this buffer supports. + pub fn message_type_id(&self) -> TypeId { + self.interface.message_type_id() + } + + pub fn message_type_name(&self) -> &'static str { + self.interface.message_type_name() + } + + /// Get the [`AnyBufferAccessInterface`] for this specific instance of [`AnyBuffer`]. + pub fn get_interface(&self) -> &'static (dyn AnyBufferAccessInterface + Send + Sync) { + self.interface + } + + /// Get the [`AnyBufferAccessInterface`] for a concrete message type. + pub fn interface_for( + ) -> &'static (dyn AnyBufferAccessInterface + Send + Sync) { + static INTERFACE_MAP: OnceLock< + Mutex>, + > = OnceLock::new(); + let interfaces = INTERFACE_MAP.get_or_init(|| Mutex::default()); + + // SAFETY: This will leak memory exactly once per type, so the leakage is bounded. + // Leaking this allows the interface to be shared freely across all instances. + let mut interfaces_mut = interfaces.lock().unwrap(); + *interfaces_mut + .entry(TypeId::of::()) + .or_insert_with(|| Box::leak(Box::new(AnyBufferAccessImpl::::new()))) + } +} + +impl std::fmt::Debug for AnyBuffer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("AnyBuffer") + .field("scope", &self.location.scope) + .field("source", &self.location.source) + .field("message_type_name", &self.interface.message_type_name()) + .finish() + } +} + +impl AnyBuffer { + /// Downcast this into a concrete [`Buffer`] for the specified message type. + /// + /// To downcast into a specialized kind of buffer, use [`Self::downcast_buffer`] instead. + pub fn downcast_for_message(&self) -> Option> { + if TypeId::of::() == self.interface.message_type_id() { + Some(Buffer { + location: self.location, + _ignore: Default::default(), + }) + } else { + None + } + } + + /// Downcast this into a different special buffer representation, such as a + /// `JsonBuffer`. + pub fn downcast_buffer(&self) -> Option { + self.interface.buffer_downcast(TypeId::of::())?(self.location) + .downcast::() + .ok() + .map(|x| *x) + } +} + +impl From> for AnyBuffer { + fn from(value: Buffer) -> Self { + let interface = AnyBuffer::interface_for::(); + AnyBuffer { + location: value.location, + interface, + } + } +} + +/// A trait for turning a buffer into an [`AnyBuffer`]. It is expected that all +/// buffer types implement this trait. +pub trait AsAnyBuffer { + /// Convert this buffer into an [`AnyBuffer`]. + fn as_any_buffer(&self) -> AnyBuffer; +} + +impl AsAnyBuffer for AnyBuffer { + fn as_any_buffer(&self) -> AnyBuffer { + *self + } +} + +impl AsAnyBuffer for Buffer { + fn as_any_buffer(&self) -> AnyBuffer { + (*self).into() + } +} + +/// Similar to a [`BufferKey`] except it can be used for any buffer without +/// knowing the buffer's message type at compile time. +/// +/// This can key be used with a [`World`][1] to directly view or manipulate the +/// contents of a buffer through the [`AnyBufferWorldAccess`] interface. +/// +/// [1]: bevy_ecs::prelude::World +#[derive(Clone)] +pub struct AnyBufferKey { + pub(crate) tag: BufferKeyTag, + pub(crate) interface: &'static (dyn AnyBufferAccessInterface + Send + Sync), +} + +impl AnyBufferKey { + /// Downcast this into a concrete [`BufferKey`] for the specified message type. + /// + /// To downcast to a specialized kind of key, use [`Self::downcast_buffer_key`] instead. + pub fn downcast_for_message(self) -> Option> { + if TypeId::of::() == self.interface.message_type_id() { + Some(BufferKey { + tag: self.tag, + _ignore: Default::default(), + }) + } else { + None + } + } + + /// Downcast this into a different special buffer key representation, such + /// as a `JsonBufferKey`. + pub fn downcast_buffer_key(self) -> Option { + self.interface.key_downcast(TypeId::of::())?(self.tag) + .downcast::() + .ok() + .map(|x| *x) + } + + /// The buffer ID of this key. + pub fn id(&self) -> Entity { + self.tag.buffer + } + + /// The session that this key belongs to. + pub fn session(&self) -> Entity { + self.tag.session + } +} + +impl BufferKeyLifecycle for AnyBufferKey { + type TargetBuffer = AnyBuffer; + + fn create_key(buffer: &AnyBuffer, builder: &BufferKeyBuilder) -> Self { + AnyBufferKey { + tag: builder.make_tag(buffer.id()), + interface: buffer.interface, + } + } + + fn is_in_use(&self) -> bool { + self.tag.is_in_use() + } + + fn deep_clone(&self) -> Self { + Self { + tag: self.tag.deep_clone(), + interface: self.interface, + } + } +} + +impl std::fmt::Debug for AnyBufferKey { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("AnyBufferKey") + .field("message_type_name", &self.interface.message_type_name()) + .field("tag", &self.tag) + .finish() + } +} + +impl From> for AnyBufferKey { + fn from(value: BufferKey) -> Self { + let interface = AnyBuffer::interface_for::(); + AnyBufferKey { + tag: value.tag, + interface, + } + } +} + +/// Similar to [`BufferView`][crate::BufferView], but this can be unlocked with +/// an [`AnyBufferKey`], so it can work for any buffer whose message types +/// support serialization and deserialization. +pub struct AnyBufferView<'a> { + storage: Box, + gate: &'a GateState, + session: Entity, +} + +impl<'a> AnyBufferView<'a> { + /// Look at the oldest message in the buffer. + pub fn oldest(&self) -> Option> { + self.storage.any_oldest(self.session) + } + + /// Look at the newest message in the buffer. + pub fn newest(&self) -> Option> { + self.storage.any_newest(self.session) + } + + /// Borrow a message from the buffer. Index 0 is the oldest message in the buffer + /// while the highest index is the newest message in the buffer. + pub fn get(&self, index: usize) -> Option> { + self.storage.any_get(self.session, index) + } + + /// Get how many messages are in this buffer. + pub fn len(&self) -> usize { + self.storage.any_count(self.session) + } + + /// Check if the buffer is empty. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Check whether the gate of this buffer is open or closed. + pub fn gate(&self) -> Gate { + self.gate + .map + .get(&self.session) + .copied() + .unwrap_or(Gate::Open) + } +} + +/// Similar to [`BufferMut`][crate::BufferMut], but this can be unlocked with an +/// [`AnyBufferKey`], so it can work for any buffer regardless of the data type +/// inside. +pub struct AnyBufferMut<'w, 's, 'a> { + storage: Box, + gate: Mut<'a, GateState>, + buffer: Entity, + session: Entity, + accessor: Option, + commands: &'a mut Commands<'w, 's>, + modified: bool, +} + +impl<'w, 's, 'a> AnyBufferMut<'w, 's, 'a> { + /// Same as [BufferMut::allow_closed_loops][1]. + /// + /// [1]: crate::BufferMut::allow_closed_loops + pub fn allow_closed_loops(mut self) -> Self { + self.accessor = None; + self + } + + /// Look at the oldest message in the buffer. + pub fn oldest(&self) -> Option> { + self.storage.any_oldest(self.session) + } + + /// Look at the newest message in the buffer. + pub fn newest(&self) -> Option> { + self.storage.any_newest(self.session) + } + + /// Borrow a message from the buffer. Index 0 is the oldest message in the buffer + /// while the highest index is the newest message in the buffer. + pub fn get(&self, index: usize) -> Option> { + self.storage.any_get(self.session, index) + } + + /// Get how many messages are in this buffer. + pub fn len(&self) -> usize { + self.storage.any_count(self.session) + } + + /// Check if the buffer is empty. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Check whether the gate of this buffer is open or closed. + pub fn gate(&self) -> Gate { + self.gate + .map + .get(&self.session) + .copied() + .unwrap_or(Gate::Open) + } + + /// Modify the oldest message in the buffer. + pub fn oldest_mut(&mut self) -> Option> { + self.modified = true; + self.storage.any_oldest_mut(self.session) + } + + /// Modify the newest message in the buffer. + pub fn newest_mut(&mut self) -> Option> { + self.modified = true; + self.storage.any_newest_mut(self.session) + } + + /// Modify a message in the buffer. Index 0 is the oldest message in the buffer + /// with the highest index being the newest message in the buffer. + pub fn get_mut(&mut self, index: usize) -> Option> { + self.modified = true; + self.storage.any_get_mut(self.session, index) + } + + /// Drain a range of messages out of the buffer. + pub fn drain>(&mut self, range: R) -> DrainAnyBuffer<'_> { + self.modified = true; + DrainAnyBuffer { + interface: self.storage.any_drain(self.session, AnyRange::new(range)), + } + } + + /// Pull the oldest message from the buffer. + pub fn pull(&mut self) -> Option { + self.modified = true; + self.storage.any_pull(self.session) + } + + /// Pull the message that was most recently put into the buffer (instead of the + /// oldest, which is what [`Self::pull`] gives). + pub fn pull_newest(&mut self) -> Option { + self.modified = true; + self.storage.any_pull_newest(self.session) + } + + /// Attempt to push a new value into the buffer. + /// + /// If the input value matches the message type of the buffer, this will + /// return [`Ok`]. If the buffer is at its limit before a successful push, this + /// will return the value that needed to be removed. + /// + /// If the input value does not match the message type of the buffer, this + /// will return [`Err`] and give back the message that you tried to push. + pub fn push(&mut self, value: T) -> Result, T> { + if TypeId::of::() != self.storage.any_message_type() { + return Err(value); + } + + self.modified = true; + + // SAFETY: We checked that T matches the message type for this buffer, + // so pushing and downcasting should not exhibit any errors. + let removed = self + .storage + .any_push(self.session, Box::new(value)) + .unwrap() + .map(|value| *value.downcast::().unwrap()); + + Ok(removed) + } + + /// Attempt to push a new value of any message type into the buffer. + /// + /// If the input value matches the message type of the buffer, this will + /// return [`Ok`]. If the buffer is at its limit before a successful push, this + /// will return the value that needed to be removed. + /// + /// If the input value does not match the message type of the buffer, this + /// will return [`Err`] and give back an error with the message that you + /// tried to push and the type information for the expected message type. + pub fn push_any( + &mut self, + value: AnyMessageBox, + ) -> Result, AnyMessageError> { + self.storage.any_push(self.session, value) + } + + /// Attempt to push a value into the buffer as if it is the oldest value of + /// the buffer. + /// + /// The result follows the same rules as [`Self::push`]. + pub fn push_as_oldest( + &mut self, + value: T, + ) -> Result, T> { + if TypeId::of::() != self.storage.any_message_type() { + return Err(value); + } + + self.modified = true; + + // SAFETY: We checked that T matches the message type for this buffer, + // so pushing and downcasting should not exhibit any errors. + let removed = self + .storage + .any_push_as_oldest(self.session, Box::new(value)) + .unwrap() + .map(|value| *value.downcast::().unwrap()); + + Ok(removed) + } + + /// Attempt to push a value into the buffer as if it is the oldest value of + /// the buffer. + /// + /// The result follows the same rules as [`Self::push_any`]. + pub fn push_any_as_oldest( + &mut self, + value: AnyMessageBox, + ) -> Result, AnyMessageError> { + self.storage.any_push_as_oldest(self.session, value) + } + + /// Tell the buffer [`Gate`] to open. + pub fn open_gate(&mut self) { + if let Some(gate) = self.gate.map.get_mut(&self.session) { + if *gate != Gate::Open { + *gate = Gate::Open; + self.modified = true; + } + } + } + + /// Tell the buffer [`Gate`] to close. + pub fn close_gate(&mut self) { + if let Some(gate) = self.gate.map.get_mut(&self.session) { + *gate = Gate::Closed; + // There is no need to to indicate that a modification happened + // because listeners do not get notified about gates closing. + } + } + + /// Perform an action on the gate of the buffer. + pub fn gate_action(&mut self, action: Gate) { + match action { + Gate::Open => self.open_gate(), + Gate::Closed => self.close_gate(), + } + } + + /// Trigger the listeners for this buffer to wake up even if nothing in the + /// buffer has changed. This could be used for timers or timeout elements + /// in a workflow. + pub fn pulse(&mut self) { + self.modified = true; + } +} + +impl<'w, 's, 'a> Drop for AnyBufferMut<'w, 's, 'a> { + fn drop(&mut self) { + if self.modified { + self.commands.add(NotifyBufferUpdate::new( + self.buffer, + self.session, + self.accessor, + )); + } + } +} + +/// This trait allows [`World`] to give you access to any buffer using an +/// [`AnyBufferKey`]. +pub trait AnyBufferWorldAccess { + /// Call this to get read-only access to any buffer. + /// + /// For technical reasons this requires direct [`World`] access, but you can + /// do other read-only queries on the world while holding onto the + /// [`AnyBufferView`]. + fn any_buffer_view(&self, key: &AnyBufferKey) -> Result, BufferError>; + + /// Call this to get mutable access to any buffer. + /// + /// Pass in a callback that will receive a [`AnyBufferMut`], allowing it to + /// view and modify the contents of the buffer. + fn any_buffer_mut( + &mut self, + key: &AnyBufferKey, + f: impl FnOnce(AnyBufferMut) -> U, + ) -> Result; +} + +impl AnyBufferWorldAccess for World { + fn any_buffer_view(&self, key: &AnyBufferKey) -> Result, BufferError> { + key.interface.create_any_buffer_view(key, self) + } + + fn any_buffer_mut( + &mut self, + key: &AnyBufferKey, + f: impl FnOnce(AnyBufferMut) -> U, + ) -> Result { + let interface = key.interface; + let mut state = interface.create_any_buffer_access_mut_state(self); + let mut access = state.get_any_buffer_access_mut(self); + let buffer_mut = access.as_any_buffer_mut(key)?; + Ok(f(buffer_mut)) + } +} + +trait AnyBufferViewing { + fn any_count(&self, session: Entity) -> usize; + fn any_oldest<'a>(&'a self, session: Entity) -> Option>; + fn any_newest<'a>(&'a self, session: Entity) -> Option>; + fn any_get<'a>(&'a self, session: Entity, index: usize) -> Option>; + fn any_message_type(&self) -> TypeId; +} + +trait AnyBufferManagement: AnyBufferViewing { + fn any_push(&mut self, session: Entity, value: AnyMessageBox) -> AnyMessagePushResult; + fn any_push_as_oldest(&mut self, session: Entity, value: AnyMessageBox) + -> AnyMessagePushResult; + fn any_pull(&mut self, session: Entity) -> Option; + fn any_pull_newest(&mut self, session: Entity) -> Option; + fn any_oldest_mut<'a>(&'a mut self, session: Entity) -> Option>; + fn any_newest_mut<'a>(&'a mut self, session: Entity) -> Option>; + fn any_get_mut<'a>(&'a mut self, session: Entity, index: usize) -> Option>; + fn any_drain<'a>( + &'a mut self, + session: Entity, + range: AnyRange, + ) -> Box; +} + +pub(crate) struct AnyRange { + start_bound: std::ops::Bound, + end_bound: std::ops::Bound, +} + +impl AnyRange { + pub(crate) fn new>(range: T) -> Self { + AnyRange { + start_bound: deref_bound(range.start_bound()), + end_bound: deref_bound(range.end_bound()), + } + } +} + +fn deref_bound(bound: std::ops::Bound<&usize>) -> std::ops::Bound { + match bound { + std::ops::Bound::Included(v) => std::ops::Bound::Included(*v), + std::ops::Bound::Excluded(v) => std::ops::Bound::Excluded(*v), + std::ops::Bound::Unbounded => std::ops::Bound::Unbounded, + } +} + +impl std::ops::RangeBounds for AnyRange { + fn start_bound(&self) -> std::ops::Bound<&usize> { + self.start_bound.as_ref() + } + + fn end_bound(&self) -> std::ops::Bound<&usize> { + self.end_bound.as_ref() + } + + fn contains(&self, item: &U) -> bool + where + usize: PartialOrd, + U: ?Sized + PartialOrd, + { + match self.start_bound { + std::ops::Bound::Excluded(lower) => { + if *item <= lower { + return false; + } + } + std::ops::Bound::Included(lower) => { + if *item < lower { + return false; + } + } + _ => {} + } + + match self.end_bound { + std::ops::Bound::Excluded(upper) => { + if upper <= *item { + return false; + } + } + std::ops::Bound::Included(upper) => { + if upper < *item { + return false; + } + } + _ => {} + } + + return true; + } +} + +pub type AnyMessageRef<'a> = &'a (dyn Any + 'static + Send + Sync); + +impl AnyBufferViewing for &'_ BufferStorage { + fn any_count(&self, session: Entity) -> usize { + self.count(session) + } + + fn any_oldest<'a>(&'a self, session: Entity) -> Option> { + self.oldest(session).map(to_any_ref) + } + + fn any_newest<'a>(&'a self, session: Entity) -> Option> { + self.newest(session).map(to_any_ref) + } + + fn any_get<'a>(&'a self, session: Entity, index: usize) -> Option> { + self.get(session, index).map(to_any_ref) + } + + fn any_message_type(&self) -> TypeId { + TypeId::of::() + } +} + +impl AnyBufferViewing for Mut<'_, BufferStorage> { + fn any_count(&self, session: Entity) -> usize { + self.count(session) + } + + fn any_oldest<'a>(&'a self, session: Entity) -> Option> { + self.oldest(session).map(to_any_ref) + } + + fn any_newest<'a>(&'a self, session: Entity) -> Option> { + self.newest(session).map(to_any_ref) + } + + fn any_get<'a>(&'a self, session: Entity, index: usize) -> Option> { + self.get(session, index).map(to_any_ref) + } + + fn any_message_type(&self) -> TypeId { + TypeId::of::() + } +} + +pub type AnyMessageMut<'a> = &'a mut (dyn Any + 'static + Send + Sync); + +pub type AnyMessageBox = Box; + +#[derive(ThisError, Debug)] +#[error("failed to convert a message")] +pub struct AnyMessageError { + /// The original value provided + pub value: AnyMessageBox, + /// The ID of the type expected by the buffer + pub type_id: TypeId, + /// The name of the type expected by the buffer + pub type_name: &'static str, +} + +pub type AnyMessagePushResult = Result, AnyMessageError>; + +impl AnyBufferManagement for Mut<'_, BufferStorage> { + fn any_push(&mut self, session: Entity, value: AnyMessageBox) -> AnyMessagePushResult { + let value = from_any_message::(value)?; + Ok(self.push(session, value).map(to_any_message)) + } + + fn any_push_as_oldest( + &mut self, + session: Entity, + value: AnyMessageBox, + ) -> AnyMessagePushResult { + let value = from_any_message::(value)?; + Ok(self.push_as_oldest(session, value).map(to_any_message)) + } + + fn any_pull(&mut self, session: Entity) -> Option { + self.pull(session).map(to_any_message) + } + + fn any_pull_newest(&mut self, session: Entity) -> Option { + self.pull_newest(session).map(to_any_message) + } + + fn any_oldest_mut<'a>(&'a mut self, session: Entity) -> Option> { + self.oldest_mut(session).map(to_any_mut) + } + + fn any_newest_mut<'a>(&'a mut self, session: Entity) -> Option> { + self.newest_mut(session).map(to_any_mut) + } + + fn any_get_mut<'a>(&'a mut self, session: Entity, index: usize) -> Option> { + self.get_mut(session, index).map(to_any_mut) + } + + fn any_drain<'a>( + &'a mut self, + session: Entity, + range: AnyRange, + ) -> Box { + Box::new(self.drain(session, range)) + } +} + +fn to_any_ref<'a, T: 'static + Send + Sync + Any>(x: &'a T) -> AnyMessageRef<'a> { + x +} + +fn to_any_mut<'a, T: 'static + Send + Sync + Any>(x: &'a mut T) -> AnyMessageMut<'a> { + x +} + +fn to_any_message(x: T) -> AnyMessageBox { + Box::new(x) +} + +fn from_any_message( + value: AnyMessageBox, +) -> Result +where + T: 'static, +{ + let value = value.downcast::().map_err(|value| AnyMessageError { + value, + type_id: TypeId::of::(), + type_name: std::any::type_name::(), + })?; + + Ok(*value) +} + +pub trait AnyBufferAccessMutState { + fn get_any_buffer_access_mut<'s, 'w: 's>( + &'s mut self, + world: &'w mut World, + ) -> Box + 's>; +} + +impl AnyBufferAccessMutState + for SystemState> +{ + fn get_any_buffer_access_mut<'s, 'w: 's>( + &'s mut self, + world: &'w mut World, + ) -> Box + 's> { + Box::new(self.get_mut(world)) + } +} + +pub trait AnyBufferAccessMut<'w, 's> { + fn as_any_buffer_mut<'a>( + &'a mut self, + key: &AnyBufferKey, + ) -> Result, BufferError>; +} + +impl<'w, 's, T: 'static + Send + Sync + Any> AnyBufferAccessMut<'w, 's> + for BufferAccessMut<'w, 's, T> +{ + fn as_any_buffer_mut<'a>( + &'a mut self, + key: &AnyBufferKey, + ) -> Result, BufferError> { + let BufferAccessMut { query, commands } = self; + let (storage, gate) = query + .get_mut(key.tag.buffer) + .map_err(|_| BufferError::BufferMissing)?; + Ok(AnyBufferMut { + storage: Box::new(storage), + gate, + buffer: key.tag.buffer, + session: key.tag.session, + accessor: Some(key.tag.accessor), + commands, + modified: false, + }) + } +} + +pub trait AnyBufferAccessInterface { + fn message_type_id(&self) -> TypeId; + + fn message_type_name(&self) -> &'static str; + + fn buffered_count(&self, entity: &EntityRef, session: Entity) -> Result; + + fn ensure_session(&self, entity_mut: &mut EntityWorldMut, session: Entity) -> OperationResult; + + fn register_buffer_downcast(&self, buffer_type: TypeId, f: BufferDowncastBox); + + fn buffer_downcast(&self, buffer_type: TypeId) -> Option; + + fn register_key_downcast(&self, key_type: TypeId, f: KeyDowncastBox); + + fn key_downcast(&self, key_type: TypeId) -> Option; + + fn pull( + &self, + entity_mut: &mut EntityWorldMut, + session: Entity, + ) -> Result; + + fn create_any_buffer_view<'a>( + &self, + key: &AnyBufferKey, + world: &'a World, + ) -> Result, BufferError>; + + fn create_any_buffer_access_mut_state( + &self, + world: &mut World, + ) -> Box; +} + +pub type BufferDowncastBox = Box Box + Send + Sync>; +pub type BufferDowncastRef = &'static (dyn Fn(BufferLocation) -> Box + Send + Sync); +pub type KeyDowncastBox = Box Box + Send + Sync>; +pub type KeyDowncastRef = &'static (dyn Fn(BufferKeyTag) -> Box + Send + Sync); + +struct AnyBufferAccessImpl { + buffer_downcasts: Mutex>, + key_downcasts: Mutex>, + _ignore: std::marker::PhantomData, +} + +impl AnyBufferAccessImpl { + fn new() -> Self { + let mut buffer_downcasts: HashMap<_, BufferDowncastRef> = HashMap::new(); + + // SAFETY: These leaks are okay because we will only ever instantiate + // AnyBufferAccessImpl once per generic argument T, which puts a firm + // ceiling on how many of these callbacks will get leaked. + + // Automatically register a downcast into AnyBuffer + buffer_downcasts.insert( + TypeId::of::(), + Box::leak(Box::new(|location| -> Box { + Box::new(AnyBuffer { + location, + interface: AnyBuffer::interface_for::(), + }) + })), + ); + + // Allow downcasting back to the original Buffer + buffer_downcasts.insert( + TypeId::of::>(), + Box::leak(Box::new(|location| -> Box { + Box::new(Buffer:: { + location, + _ignore: Default::default(), + }) + })), + ); + + let mut key_downcasts: HashMap<_, KeyDowncastRef> = HashMap::new(); + + // Automatically register a downcast to AnyBufferKey + key_downcasts.insert( + TypeId::of::(), + Box::leak(Box::new(|tag| -> Box { + Box::new(AnyBufferKey { + tag, + interface: AnyBuffer::interface_for::(), + }) + })), + ); + + Self { + buffer_downcasts: Mutex::new(buffer_downcasts), + key_downcasts: Mutex::new(key_downcasts), + _ignore: Default::default(), + } + } +} + +impl AnyBufferAccessInterface for AnyBufferAccessImpl { + fn message_type_id(&self) -> TypeId { + TypeId::of::() + } + + fn message_type_name(&self) -> &'static str { + std::any::type_name::() + } + + fn buffered_count(&self, entity: &EntityRef, session: Entity) -> Result { + entity.buffered_count::(session) + } + + fn ensure_session(&self, entity_mut: &mut EntityWorldMut, session: Entity) -> OperationResult { + entity_mut.ensure_session::(session) + } + + fn register_buffer_downcast(&self, buffer_type: TypeId, f: BufferDowncastBox) { + let mut downcasts = self.buffer_downcasts.lock().unwrap(); + + if let Entry::Vacant(entry) = downcasts.entry(buffer_type) { + // SAFETY: We only leak this into the register once per type + entry.insert(Box::leak(f)); + } + } + + fn buffer_downcast(&self, buffer_type: TypeId) -> Option { + self.buffer_downcasts + .lock() + .unwrap() + .get(&buffer_type) + .copied() + } + + fn register_key_downcast(&self, key_type: TypeId, f: KeyDowncastBox) { + let mut downcasts = self.key_downcasts.lock().unwrap(); + + if let Entry::Vacant(entry) = downcasts.entry(key_type) { + // SAFTY: We only leak this in to the register once per type + entry.insert(Box::leak(f)); + } + } + + fn key_downcast(&self, key_type: TypeId) -> Option { + self.key_downcasts.lock().unwrap().get(&key_type).copied() + } + + fn pull( + &self, + entity_mut: &mut EntityWorldMut, + session: Entity, + ) -> Result { + entity_mut + .pull_from_buffer::(session) + .map(to_any_message) + } + + fn create_any_buffer_view<'a>( + &self, + key: &AnyBufferKey, + world: &'a World, + ) -> Result, BufferError> { + let buffer_ref = world + .get_entity(key.tag.buffer) + .ok_or(BufferError::BufferMissing)?; + let storage = buffer_ref + .get::>() + .ok_or(BufferError::BufferMissing)?; + let gate = buffer_ref + .get::() + .ok_or(BufferError::BufferMissing)?; + Ok(AnyBufferView { + storage: Box::new(storage), + gate, + session: key.tag.session, + }) + } + + fn create_any_buffer_access_mut_state( + &self, + world: &mut World, + ) -> Box { + Box::new(SystemState::>::new(world)) + } +} + +pub struct DrainAnyBuffer<'a> { + interface: Box, +} + +impl<'a> Iterator for DrainAnyBuffer<'a> { + type Item = AnyMessageBox; + + fn next(&mut self) -> Option { + self.interface.any_next() + } +} + +trait DrainAnyBufferInterface { + fn any_next(&mut self) -> Option; +} + +impl DrainAnyBufferInterface for DrainBuffer<'_, T> { + fn any_next(&mut self) -> Option { + self.next().map(to_any_message) + } +} + +impl Bufferable for AnyBuffer { + type BufferType = Self; + fn into_buffer(self, builder: &mut Builder) -> Self::BufferType { + assert_eq!(self.scope(), builder.scope()); + self + } +} + +impl Buffering for AnyBuffer { + fn verify_scope(&self, scope: Entity) { + assert_eq!(scope, self.scope()); + } + + fn buffered_count(&self, session: Entity, world: &World) -> Result { + let entity_ref = world.get_entity(self.id()).or_broken()?; + self.interface.buffered_count(&entity_ref, session) + } + + fn add_listener(&self, listener: Entity, world: &mut World) -> OperationResult { + add_listener_to_source(self.id(), listener, world) + } + + fn gate_action( + &self, + session: Entity, + action: Gate, + world: &mut World, + roster: &mut OperationRoster, + ) -> OperationResult { + GateState::apply(self.id(), session, action, world, roster) + } + + fn as_input(&self) -> SmallVec<[Entity; 8]> { + SmallVec::from_iter([self.id()]) + } + + fn ensure_active_session(&self, session: Entity, world: &mut World) -> OperationResult { + let mut entity_mut = world.get_entity_mut(self.id()).or_broken()?; + self.interface.ensure_session(&mut entity_mut, session) + } +} + +impl Joining for AnyBuffer { + type Item = AnyMessageBox; + fn pull(&self, session: Entity, world: &mut World) -> Result { + let mut buffer_mut = world.get_entity_mut(self.id()).or_broken()?; + self.interface.pull(&mut buffer_mut, session) + } +} + +impl Accessing for AnyBuffer { + type Key = AnyBufferKey; + fn add_accessor(&self, accessor: Entity, world: &mut World) -> OperationResult { + world + .get_mut::(self.id()) + .or_broken()? + .add_accessor(accessor); + Ok(()) + } + + fn create_key(&self, builder: &super::BufferKeyBuilder) -> Self::Key { + AnyBufferKey { + tag: builder.make_tag(self.id()), + interface: self.interface, + } + } + + fn deep_clone_key(key: &Self::Key) -> Self::Key { + key.deep_clone() + } + + fn is_key_in_use(key: &Self::Key) -> bool { + key.is_in_use() + } +} + +#[cfg(test)] +mod tests { + use crate::{prelude::*, testing::*}; + use bevy_ecs::prelude::World; + + #[test] + fn test_any_count() { + let mut context = TestingContext::minimal_plugins(); + + let workflow = context.spawn_io_workflow(|scope, builder| { + let buffer = builder.create_buffer(BufferSettings::keep_all()); + let push_multiple_times = builder + .commands() + .spawn_service(push_multiple_times_into_buffer.into_blocking_service()); + let count = builder + .commands() + .spawn_service(get_buffer_count.into_blocking_service()); + + scope + .input + .chain(builder) + .with_access(buffer) + .then(push_multiple_times) + .then(count) + .connect(scope.terminate); + }); + + let mut promise = context.command(|commands| commands.request(1, workflow).take_response()); + + context.run_with_conditions(&mut promise, Duration::from_secs(2)); + let count = promise.take().available().unwrap(); + assert_eq!(count, 5); + assert!(context.no_unhandled_errors()); + } + + fn push_multiple_times_into_buffer( + In((value, key)): In<(usize, BufferKey)>, + mut access: BufferAccessMut, + ) -> AnyBufferKey { + let mut buffer = access.get_mut(&key).unwrap(); + for _ in 0..5 { + buffer.push(value); + } + + key.into() + } + + fn get_buffer_count(In(key): In, world: &mut World) -> usize { + world.any_buffer_view(&key).unwrap().len() + } + + #[test] + fn test_modify_any_message() { + let mut context = TestingContext::minimal_plugins(); + + let workflow = context.spawn_io_workflow(|scope, builder| { + let buffer = builder.create_buffer(BufferSettings::keep_all()); + let push_multiple_times = builder + .commands() + .spawn_service(push_multiple_times_into_buffer.into_blocking_service()); + let modify_content = builder + .commands() + .spawn_service(modify_buffer_content.into_blocking_service()); + let drain_content = builder + .commands() + .spawn_service(pull_each_buffer_item.into_blocking_service()); + + scope + .input + .chain(builder) + .with_access(buffer) + .then(push_multiple_times) + .then(modify_content) + .then(drain_content) + .connect(scope.terminate); + }); + + let mut promise = context.command(|commands| commands.request(3, workflow).take_response()); + + context.run_with_conditions(&mut promise, Duration::from_secs(2)); + let values = promise.take().available().unwrap(); + assert_eq!(values, vec![0, 3, 6, 9, 12]); + assert!(context.no_unhandled_errors()); + } + + fn modify_buffer_content(In(key): In, world: &mut World) -> AnyBufferKey { + world + .any_buffer_mut(&key, |mut access| { + for i in 0..access.len() { + access.get_mut(i).map(|value| { + *value.downcast_mut::().unwrap() *= i; + }); + } + }) + .unwrap(); + + key + } + + fn pull_each_buffer_item(In(key): In, world: &mut World) -> Vec { + world + .any_buffer_mut(&key, |mut access| { + let mut values = Vec::new(); + while let Some(value) = access.pull() { + values.push(*value.downcast::().unwrap()); + } + values + }) + .unwrap() + } + + #[test] + fn test_drain_any_message() { + let mut context = TestingContext::minimal_plugins(); + + let workflow = context.spawn_io_workflow(|scope, builder| { + let buffer = builder.create_buffer(BufferSettings::keep_all()); + let push_multiple_times = builder + .commands() + .spawn_service(push_multiple_times_into_buffer.into_blocking_service()); + let modify_content = builder + .commands() + .spawn_service(modify_buffer_content.into_blocking_service()); + let drain_content = builder + .commands() + .spawn_service(drain_buffer_contents.into_blocking_service()); + + scope + .input + .chain(builder) + .with_access(buffer) + .then(push_multiple_times) + .then(modify_content) + .then(drain_content) + .connect(scope.terminate); + }); + + let mut promise = context.command(|commands| commands.request(3, workflow).take_response()); + + context.run_with_conditions(&mut promise, Duration::from_secs(2)); + let values = promise.take().available().unwrap(); + assert_eq!(values, vec![0, 3, 6, 9, 12]); + assert!(context.no_unhandled_errors()); + } + + fn drain_buffer_contents(In(key): In, world: &mut World) -> Vec { + world + .any_buffer_mut(&key, |mut access| { + access + .drain(..) + .map(|value| *value.downcast::().unwrap()) + .collect() + }) + .unwrap() + } + + #[test] + fn double_any_messages() { + let mut context = TestingContext::minimal_plugins(); + + let workflow = + context.spawn_io_workflow(|scope: Scope<(u32, i32, f32), (u32, i32, f32)>, builder| { + let buffer_u32: AnyBuffer = builder + .create_buffer::(BufferSettings::default()) + .into(); + let buffer_i32: AnyBuffer = builder + .create_buffer::(BufferSettings::default()) + .into(); + let buffer_f32: AnyBuffer = builder + .create_buffer::(BufferSettings::default()) + .into(); + + let (input_u32, input_i32, input_f32) = scope.input.chain(builder).unzip(); + input_u32.chain(builder).map_block(|v| 2 * v).connect( + buffer_u32 + .downcast_for_message::() + .unwrap() + .input_slot(), + ); + + input_i32.chain(builder).map_block(|v| 2 * v).connect( + buffer_i32 + .downcast_for_message::() + .unwrap() + .input_slot(), + ); + + input_f32.chain(builder).map_block(|v| 2.0 * v).connect( + buffer_f32 + .downcast_for_message::() + .unwrap() + .input_slot(), + ); + + (buffer_u32, buffer_i32, buffer_f32) + .join(builder) + .map_block(|(value_u32, value_i32, value_f32)| { + ( + *value_u32.downcast::().unwrap(), + *value_i32.downcast::().unwrap(), + *value_f32.downcast::().unwrap(), + ) + }) + .connect(scope.terminate); + }); + + let mut promise = context.command(|commands| { + commands + .request((1u32, 2i32, 3f32), workflow) + .take_response() + }); + + context.run_with_conditions(&mut promise, Duration::from_secs(2)); + let (v_u32, v_i32, v_f32) = promise.take().available().unwrap(); + assert_eq!(v_u32, 2); + assert_eq!(v_i32, 4); + assert_eq!(v_f32, 6.0); + assert!(context.no_unhandled_errors()); + } +} diff --git a/src/buffer/buffer_access_lifecycle.rs b/src/buffer/buffer_access_lifecycle.rs index d368a484..b7596fac 100644 --- a/src/buffer/buffer_access_lifecycle.rs +++ b/src/buffer/buffer_access_lifecycle.rs @@ -21,7 +21,7 @@ use tokio::sync::mpsc::UnboundedSender as TokioSender; use std::sync::Arc; -use crate::{emit_disposal, ChannelItem, Disposal, OperationRoster}; +use crate::{emit_disposal, BufferKeyBuilder, ChannelItem, Disposal, OperationRoster}; /// This is used as a field inside of [`crate::BufferKey`] which keeps track of /// when a key that was sent out into the world gets fully dropped from use. We @@ -29,7 +29,7 @@ use crate::{emit_disposal, ChannelItem, Disposal, OperationRoster}; /// we would be needlessly doing a reachability check every time the key gets /// cloned. #[derive(Clone)] -pub(crate) struct BufferAccessLifecycle { +pub struct BufferAccessLifecycle { scope: Entity, accessor: Entity, session: Entity, @@ -87,3 +87,29 @@ impl Drop for BufferAccessLifecycle { } } } + +/// This trait is implemented by [`crate::BufferKey`]-like structs so their +/// lifecycles can be managed. +pub trait BufferKeyLifecycle { + /// What kind of buffer this key can unlock. + type TargetBuffer; + + /// Create a new key of this type. + fn create_key(buffer: &Self::TargetBuffer, builder: &BufferKeyBuilder) -> Self; + + /// Check if the key is currently in use. + fn is_in_use(&self) -> bool; + + /// Create a deep clone of the key. The usage tracking of the clone will + /// be unrelated to the usage tracking of the original. + /// + /// We do a deep clone of the key when distributing it to decouple the + /// lifecycle of the keys that we send out from the key that's held by the + /// accessor node. + // + /// The key instance held by the accessor node will never be dropped until + /// the session is cleaned up, so the keys that we send out into the workflow + /// need to have their own independent lifecycles or else we won't detect + /// when the workflow has dropped them. + fn deep_clone(&self) -> Self; +} diff --git a/src/buffer/buffer_key_builder.rs b/src/buffer/buffer_key_builder.rs index 02e4664d..e1866e2e 100644 --- a/src/buffer/buffer_key_builder.rs +++ b/src/buffer/buffer_key_builder.rs @@ -19,7 +19,7 @@ use bevy_ecs::prelude::Entity; use std::sync::Arc; -use crate::{BufferAccessLifecycle, BufferKey, ChannelSender}; +use crate::{BufferAccessLifecycle, BufferKeyTag, ChannelSender}; pub struct BufferKeyBuilder { scope: Entity, @@ -29,8 +29,9 @@ pub struct BufferKeyBuilder { } impl BufferKeyBuilder { - pub(crate) fn build(&self, buffer: Entity) -> BufferKey { - BufferKey { + /// Make a [`BufferKeyTag`] that can be given to a [`crate::BufferKey`]-like struct. + pub fn make_tag(&self, buffer: Entity) -> BufferKeyTag { + BufferKeyTag { buffer, session: self.session, accessor: self.accessor, @@ -44,7 +45,6 @@ impl BufferKeyBuilder { tracker.clone(), )) }), - _ignore: Default::default(), } } diff --git a/src/buffer/buffer_map.rs b/src/buffer/buffer_map.rs new file mode 100644 index 00000000..04fe4ea7 --- /dev/null +++ b/src/buffer/buffer_map.rs @@ -0,0 +1,894 @@ +/* + * Copyright (C) 2025 Open Source Robotics Foundation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * +*/ + +use std::{borrow::Cow, collections::HashMap}; + +use thiserror::Error as ThisError; + +use smallvec::SmallVec; + +use bevy_ecs::prelude::{Entity, World}; + +use crate::{ + add_listener_to_source, Accessing, AnyBuffer, AnyBufferKey, AnyMessageBox, AsAnyBuffer, Buffer, + BufferKeyBuilder, BufferKeyLifecycle, Bufferable, Buffering, Builder, Chain, Gate, GateState, + Joining, Node, OperationError, OperationResult, OperationRoster, +}; + +pub use bevy_impulse_derive::{Accessor, Joined}; + +/// Uniquely identify a buffer within a buffer map, either by name or by an +/// index value. +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum BufferIdentifier<'a> { + /// Identify a buffer by name + Name(Cow<'a, str>), + /// Identify a buffer by an index value + Index(usize), +} + +impl BufferIdentifier<'static> { + /// Clone a name to use as an identifier. + pub fn clone_name(name: &str) -> Self { + BufferIdentifier::Name(Cow::Owned(name.to_owned())) + } + + /// Borrow a string literal name to use as an identifier. + pub fn literal_name(name: &'static str) -> Self { + BufferIdentifier::Name(Cow::Borrowed(name)) + } + + /// Use an index as an identifier. + pub fn index(index: usize) -> Self { + BufferIdentifier::Index(index) + } +} + +impl<'a> From<&'a str> for BufferIdentifier<'a> { + fn from(value: &'a str) -> Self { + BufferIdentifier::Name(Cow::Borrowed(value)) + } +} + +impl<'a> From for BufferIdentifier<'a> { + fn from(value: usize) -> Self { + BufferIdentifier::Index(value) + } +} + +pub type BufferMap = HashMap, AnyBuffer>; + +/// Extension trait that makes it more convenient to insert buffers into a [`BufferMap`]. +pub trait AddBufferToMap { + /// Convenience function for inserting items into a [`BufferMap`]. This + /// automatically takes care of converting the types. + fn insert_buffer>, B: AsAnyBuffer>( + &mut self, + identifier: I, + buffer: B, + ); +} + +impl AddBufferToMap for BufferMap { + fn insert_buffer>, B: AsAnyBuffer>( + &mut self, + identifier: I, + buffer: B, + ) { + self.insert(identifier.into(), buffer.as_any_buffer()); + } +} + +/// This error is used when the buffers provided for an input are not compatible +/// with the layout. +#[derive(ThisError, Debug, Clone, Default)] +#[error("the incoming buffer map is incompatible with the layout")] +pub struct IncompatibleLayout { + /// Identities of buffers that were missing from the incoming buffer map. + pub missing_buffers: Vec>, + /// Identities of buffers in the incoming buffer map which cannot exist in + /// the target layout. + pub forbidden_buffers: Vec>, + /// Buffers whose expected type did not match the received type. + pub incompatible_buffers: Vec, +} + +impl IncompatibleLayout { + /// Convert this into an error if it has any contents inside. + pub fn as_result(self) -> Result<(), Self> { + if !self.missing_buffers.is_empty() { + return Err(self); + } + + if !self.incompatible_buffers.is_empty() { + return Err(self); + } + + Ok(()) + } + + /// Check whether the buffer associated with the identifier is compatible with + /// the required buffer type. You can pass in a `&static str` or a `usize` + /// directly as the identifier. + /// + /// ``` + /// # use bevy_impulse::prelude::*; + /// + /// let buffer_map = BufferMap::default(); + /// let mut compatibility = IncompatibleLayout::default(); + /// let buffer = compatibility.require_buffer_for_identifier::>("some_field", &buffer_map); + /// assert!(buffer.is_err()); + /// assert!(compatibility.as_result().is_err()); + /// + /// let mut compatibility = IncompatibleLayout::default(); + /// let buffer = compatibility.require_buffer_for_identifier::>(10, &buffer_map); + /// assert!(buffer.is_err()); + /// assert!(compatibility.as_result().is_err()); + /// ``` + pub fn require_buffer_for_identifier( + &mut self, + identifier: impl Into>, + buffers: &BufferMap, + ) -> Result { + let identifier = identifier.into(); + if let Some(buffer) = buffers.get(&identifier) { + if let Some(buffer) = buffer.downcast_buffer::() { + return Ok(buffer); + } else { + self.incompatible_buffers.push(BufferIncompatibility { + identifier, + expected: std::any::type_name::(), + received: buffer.message_type_name(), + }); + } + } else { + self.missing_buffers.push(identifier); + } + + Err(()) + } + + /// Same as [`Self::require_buffer_for_identifier`], but can be used with + /// temporary borrows of a string slice. The string slice will be cloned if + /// an error message needs to be produced. + pub fn require_buffer_for_borrowed_name( + &mut self, + expected_name: &str, + buffers: &BufferMap, + ) -> Result { + let identifier = BufferIdentifier::Name(Cow::Borrowed(expected_name)); + if let Some(buffer) = buffers.get(&identifier) { + if let Some(buffer) = buffer.downcast_buffer::() { + return Ok(buffer); + } else { + self.incompatible_buffers.push(BufferIncompatibility { + identifier: BufferIdentifier::Name(Cow::Owned(expected_name.to_owned())), + expected: std::any::type_name::(), + received: buffer.message_type_name(), + }); + } + } else { + self.missing_buffers + .push(BufferIdentifier::Name(Cow::Owned(expected_name.to_owned()))); + } + + Err(()) + } +} + +/// Difference between the expected and received types of a named buffer. +#[derive(Debug, Clone)] +pub struct BufferIncompatibility { + /// Name of the expected buffer + pub identifier: BufferIdentifier<'static>, + /// The type that was expected for this buffer + pub expected: &'static str, + /// The type that was received for this buffer + pub received: &'static str, + // TODO(@mxgrey): Replace TypeId with TypeInfo +} + +/// This trait can be implemented on structs that represent a layout of buffers. +/// You do not normally have to implement this yourself. Instead you should +/// `#[derive(Joined)]` on a struct that you want a join operation to +/// produce. +pub trait BufferMapLayout: Sized + Clone + 'static + Send + Sync { + /// Try to convert a generic [`BufferMap`] into this specific layout. + fn try_from_buffer_map(buffers: &BufferMap) -> Result; +} + +/// This trait helps auto-generated buffer map structs to implement the Buffering +/// trait. +pub trait BufferMapStruct: Sized + Clone + 'static + Send + Sync { + /// Produce a list of the buffers that exist in this layout. Implementing + /// this function alone is sufficient to implement the entire [`Buffering`] trait. + fn buffer_list(&self) -> SmallVec<[AnyBuffer; 8]>; +} + +impl Bufferable for T { + type BufferType = Self; + + fn into_buffer(self, _: &mut Builder) -> Self::BufferType { + self + } +} + +impl Buffering for T { + fn verify_scope(&self, scope: Entity) { + for buffer in self.buffer_list() { + assert_eq!(buffer.scope(), scope); + } + } + + fn buffered_count(&self, session: Entity, world: &World) -> Result { + let mut min_count = None; + + for buffer in self.buffer_list() { + let count = buffer.buffered_count(session, world)?; + min_count = if min_count.is_some_and(|m| m < count) { + min_count + } else { + Some(count) + }; + } + + Ok(min_count.unwrap_or(0)) + } + + fn ensure_active_session(&self, session: Entity, world: &mut World) -> OperationResult { + for buffer in self.buffer_list() { + buffer.ensure_active_session(session, world)?; + } + + Ok(()) + } + + fn add_listener(&self, listener: Entity, world: &mut World) -> OperationResult { + for buffer in self.buffer_list() { + add_listener_to_source(buffer.id(), listener, world)?; + } + Ok(()) + } + + fn gate_action( + &self, + session: Entity, + action: Gate, + world: &mut World, + roster: &mut OperationRoster, + ) -> OperationResult { + for buffer in self.buffer_list() { + GateState::apply(buffer.id(), session, action, world, roster)?; + } + Ok(()) + } + + fn as_input(&self) -> SmallVec<[Entity; 8]> { + let mut inputs = SmallVec::new(); + for buffer in self.buffer_list() { + inputs.push(buffer.id()); + } + inputs + } +} + +/// This trait can be implemented for structs that are created by joining together +/// values from a collection of buffers. This allows [`join`][1] to produce arbitrary +/// structs. Structs with this trait can be produced by [`try_join`][2]. +/// +/// Each field in this struct needs to have the trait bounds `'static + Send + Sync`. +/// +/// This does not generally need to be implemented explicitly. Instead you should +/// use `#[derive(Joined)]`: +/// +/// ``` +/// use bevy_impulse::prelude::*; +/// +/// #[derive(Joined)] +/// struct SomeValues { +/// integer: i64, +/// string: String, +/// } +/// ``` +/// +/// The above example would allow you to join a value from an `i64` buffer with +/// a value from a `String` buffer. You can have as many fields in the struct +/// as you'd like. +/// +/// This macro will generate a struct of buffers to match the fields of the +/// struct that it's applied to. The name of that struct is anonymous by default +/// since you don't generally need to use it directly, but if you want to give +/// it a name you can use #[joined(buffers_struct_name = ...)]`: +/// +/// ``` +/// # use bevy_impulse::prelude::*; +/// +/// #[derive(Joined)] +/// #[joined(buffers_struct_name = SomeBuffers)] +/// struct SomeValues { +/// integer: i64, +/// string: String, +/// } +/// ``` +/// +/// By default each field of the generated buffers struct will have a type of +/// [`Buffer`], but you can override this using `#[joined(buffer = ...)]` +/// to specify a special buffer type. For example if your `Joined` struct +/// contains an [`AnyMessageBox`] then by default the macro will use `Buffer`, +/// but you probably really want it to have an [`AnyBuffer`]: +/// +/// ``` +/// # use bevy_impulse::prelude::*; +/// +/// #[derive(Joined)] +/// struct SomeValues { +/// integer: i64, +/// string: String, +/// #[joined(buffer = AnyBuffer)] +/// any: AnyMessageBox, +/// } +/// ``` +/// +/// The above method also works for joining a `JsonMessage` field from a `JsonBuffer`. +/// +/// [1]: crate::Builder::join +/// [2]: crate::Builder::try_join +pub trait Joined: 'static + Send + Sync + Sized { + /// This associated type must represent a buffer map layout that implements + /// the [`Joining`] trait. The message type yielded by [`Joining`] for this + /// associated type must match the [`Joined`] type. + type Buffers: 'static + BufferMapLayout + Joining + Send + Sync; + + /// Used by [`Builder::try_join`] + fn try_join_from<'w, 's, 'a, 'b>( + buffers: &BufferMap, + builder: &'b mut Builder<'w, 's, 'a>, + ) -> Result, IncompatibleLayout> { + let buffers: Self::Buffers = Self::Buffers::try_from_buffer_map(buffers)?; + Ok(buffers.join(builder)) + } +} + +/// Trait to describe a set of buffer keys. This allows [listen][1] and [access][2] +/// to work for arbitrary structs of buffer keys. Structs with this trait can be +/// produced by [`try_listen`][3] and [`try_create_buffer_access`][4]. +/// +/// Each field in the struct must be some kind of buffer key. +/// +/// This does not generally need to be implemented explicitly. Instead you should +/// define a struct where all fields are buffer keys and then apply +/// `#[derive(Accessor)]` to it, e.g.: +/// +/// ``` +/// use bevy_impulse::prelude::*; +/// +/// #[derive(Clone, Accessor)] +/// struct SomeKeys { +/// integer: BufferKey, +/// string: BufferKey, +/// any: AnyBufferKey, +/// } +/// ``` +/// +/// The macro will generate a struct of buffers to match the keys. The name of +/// that struct is anonymous by default since you don't generally need to use it +/// directly, but if you want to give it a name you can use `#[key(buffers_struct_name = ...)]`: +/// +/// ``` +/// # use bevy_impulse::prelude::*; +/// +/// #[derive(Clone, Accessor)] +/// #[key(buffers_struct_name = SomeBuffers)] +/// struct SomeKeys { +/// integer: BufferKey, +/// string: BufferKey, +/// any: AnyBufferKey, +/// } +/// ``` +/// +/// [1]: crate::Builder::listen +/// [2]: crate::Builder::create_buffer_access +/// [3]: crate::Builder::try_listen +/// [4]: crate::Builder::try_create_buffer_access +pub trait Accessor: 'static + Send + Sync + Sized + Clone { + type Buffers: 'static + BufferMapLayout + Accessing + Send + Sync; + + fn try_listen_from<'w, 's, 'a, 'b>( + buffers: &BufferMap, + builder: &'b mut Builder<'w, 's, 'a>, + ) -> Result, IncompatibleLayout> { + let buffers: Self::Buffers = Self::Buffers::try_from_buffer_map(buffers)?; + Ok(buffers.listen(builder)) + } + + fn try_buffer_access( + buffers: &BufferMap, + builder: &mut Builder, + ) -> Result, IncompatibleLayout> { + let buffers: Self::Buffers = Self::Buffers::try_from_buffer_map(buffers)?; + Ok(buffers.access(builder)) + } +} + +impl BufferMapLayout for BufferMap { + fn try_from_buffer_map(buffers: &BufferMap) -> Result { + Ok(buffers.clone()) + } +} + +impl BufferMapStruct for BufferMap { + fn buffer_list(&self) -> SmallVec<[AnyBuffer; 8]> { + self.values().cloned().collect() + } +} + +impl Joining for BufferMap { + type Item = HashMap, AnyMessageBox>; + + fn pull(&self, session: Entity, world: &mut World) -> Result { + let mut value = HashMap::new(); + for (name, buffer) in self.iter() { + value.insert(name.clone(), buffer.pull(session, world)?); + } + + Ok(value) + } +} + +impl Joined for HashMap, AnyMessageBox> { + type Buffers = BufferMap; +} + +impl Accessing for BufferMap { + type Key = HashMap, AnyBufferKey>; + + fn create_key(&self, builder: &BufferKeyBuilder) -> Self::Key { + let mut keys = HashMap::new(); + for (name, buffer) in self.iter() { + let key = AnyBufferKey { + tag: builder.make_tag(buffer.id()), + interface: buffer.interface, + }; + keys.insert(name.clone(), key); + } + keys + } + + fn add_accessor(&self, accessor: Entity, world: &mut World) -> OperationResult { + for buffer in self.values() { + buffer.add_accessor(accessor, world)?; + } + Ok(()) + } + + fn deep_clone_key(key: &Self::Key) -> Self::Key { + let mut cloned_key = HashMap::new(); + for (name, key) in key.iter() { + cloned_key.insert(name.clone(), key.deep_clone()); + } + cloned_key + } + + fn is_key_in_use(key: &Self::Key) -> bool { + for k in key.values() { + if k.is_in_use() { + return true; + } + } + + return false; + } +} + +impl Joined for Vec { + type Buffers = Vec>; +} + +impl BufferMapLayout for Vec { + fn try_from_buffer_map(buffers: &BufferMap) -> Result { + let mut downcast_buffers = Vec::new(); + let mut compatibility = IncompatibleLayout::default(); + for i in 0..buffers.len() { + if let Ok(downcast) = compatibility.require_buffer_for_identifier::(i, buffers) { + downcast_buffers.push(downcast); + } + } + + compatibility.as_result()?; + Ok(downcast_buffers) + } +} + +impl Joined for SmallVec<[T; N]> { + type Buffers = SmallVec<[Buffer; N]>; +} + +impl BufferMapLayout + for SmallVec<[B; N]> +{ + fn try_from_buffer_map(buffers: &BufferMap) -> Result { + let mut downcast_buffers = SmallVec::new(); + let mut compatibility = IncompatibleLayout::default(); + for i in 0..buffers.len() { + if let Ok(downcast) = compatibility.require_buffer_for_identifier::(i, buffers) { + downcast_buffers.push(downcast); + } + } + + compatibility.as_result()?; + Ok(downcast_buffers) + } +} + +#[cfg(test)] +mod tests { + use crate::{prelude::*, testing::*, AddBufferToMap, BufferMap}; + + #[derive(Joined)] + struct TestJoinedValue { + integer: i64, + float: f64, + string: String, + generic: T, + #[joined(buffer = AnyBuffer)] + any: AnyMessageBox, + } + + #[test] + fn test_try_join() { + let mut context = TestingContext::minimal_plugins(); + + let workflow = context.spawn_io_workflow(|scope, builder| { + let buffer_i64 = builder.create_buffer(BufferSettings::default()); + let buffer_f64 = builder.create_buffer(BufferSettings::default()); + let buffer_string = builder.create_buffer(BufferSettings::default()); + let buffer_generic = builder.create_buffer(BufferSettings::default()); + let buffer_any = builder.create_buffer(BufferSettings::default()); + + let mut buffers = BufferMap::default(); + buffers.insert_buffer("integer", buffer_i64); + buffers.insert_buffer("float", buffer_f64); + buffers.insert_buffer("string", buffer_string); + buffers.insert_buffer("generic", buffer_generic); + buffers.insert_buffer("any", buffer_any); + + scope.input.chain(builder).fork_unzip(( + |chain: Chain<_>| chain.connect(buffer_i64.input_slot()), + |chain: Chain<_>| chain.connect(buffer_f64.input_slot()), + |chain: Chain<_>| chain.connect(buffer_string.input_slot()), + |chain: Chain<_>| chain.connect(buffer_generic.input_slot()), + |chain: Chain<_>| chain.connect(buffer_any.input_slot()), + )); + + builder.try_join(&buffers).unwrap().connect(scope.terminate); + }); + + let mut promise = context.command(|commands| { + commands + .request( + (5_i64, 3.14_f64, "hello".to_string(), "world", 42_i64), + workflow, + ) + .take_response() + }); + + context.run_with_conditions(&mut promise, Duration::from_secs(2)); + let value: TestJoinedValue<&'static str> = promise.take().available().unwrap(); + assert_eq!(value.integer, 5); + assert_eq!(value.float, 3.14); + assert_eq!(value.string, "hello"); + assert_eq!(value.generic, "world"); + assert_eq!(*value.any.downcast::().unwrap(), 42); + assert!(context.no_unhandled_errors()); + } + + #[test] + fn test_joined_value() { + let mut context = TestingContext::minimal_plugins(); + + let workflow = context.spawn_io_workflow(|scope, builder| { + let buffer_i64 = builder.create_buffer(BufferSettings::default()); + let buffer_f64 = builder.create_buffer(BufferSettings::default()); + let buffer_string = builder.create_buffer(BufferSettings::default()); + let buffer_generic = builder.create_buffer(BufferSettings::default()); + let buffer_any = builder.create_buffer::(BufferSettings::default()); + + scope.input.chain(builder).fork_unzip(( + |chain: Chain<_>| chain.connect(buffer_i64.input_slot()), + |chain: Chain<_>| chain.connect(buffer_f64.input_slot()), + |chain: Chain<_>| chain.connect(buffer_string.input_slot()), + |chain: Chain<_>| chain.connect(buffer_generic.input_slot()), + |chain: Chain<_>| chain.connect(buffer_any.input_slot()), + )); + + let buffers = TestJoinedValue::select_buffers( + buffer_i64, + buffer_f64, + buffer_string, + buffer_generic, + buffer_any.into(), + ); + + builder.join(buffers).connect(scope.terminate); + }); + + let mut promise = context.command(|commands| { + commands + .request( + (5_i64, 3.14_f64, "hello".to_string(), "world", 42_i64), + workflow, + ) + .take_response() + }); + + context.run_with_conditions(&mut promise, Duration::from_secs(2)); + let value: TestJoinedValue<&'static str> = promise.take().available().unwrap(); + assert_eq!(value.integer, 5); + assert_eq!(value.float, 3.14); + assert_eq!(value.string, "hello"); + assert_eq!(value.generic, "world"); + assert_eq!(*value.any.downcast::().unwrap(), 42); + assert!(context.no_unhandled_errors()); + } + + #[derive(Clone, Joined)] + #[joined(buffers_struct_name = FooBuffers)] + struct TestDeriveWithConfig {} + + #[test] + fn test_derive_with_config() { + // a compile test to check that the name of the generated struct is correct + fn _check_buffer_struct_name(_: FooBuffers) {} + } + + struct MultiGenericValue { + t: T, + u: U, + } + + #[derive(Joined)] + #[joined(buffers_struct_name = MultiGenericBuffers)] + struct JoinedMultiGenericValue { + #[joined(buffer = Buffer>)] + a: MultiGenericValue, + b: String, + } + + #[test] + fn test_multi_generic_joined_value() { + let mut context = TestingContext::minimal_plugins(); + + let workflow = context.spawn_io_workflow( + |scope: Scope<(i32, String), JoinedMultiGenericValue>, builder| { + let multi_generic_buffers = MultiGenericBuffers:: { + a: builder.create_buffer(BufferSettings::default()), + b: builder.create_buffer(BufferSettings::default()), + }; + + let copy = multi_generic_buffers; + + scope + .input + .chain(builder) + .map_block(|(integer, string)| { + ( + MultiGenericValue { + t: integer, + u: string.clone(), + }, + string, + ) + }) + .fork_unzip(( + |a: Chain<_>| a.connect(multi_generic_buffers.a.input_slot()), + |b: Chain<_>| b.connect(multi_generic_buffers.b.input_slot()), + )); + + multi_generic_buffers.join(builder).connect(scope.terminate); + copy.join(builder).connect(scope.terminate); + }, + ); + + let mut promise = context.command(|commands| { + commands + .request((5, "hello".to_string()), workflow) + .take_response() + }); + + context.run_with_conditions(&mut promise, Duration::from_secs(2)); + let value = promise.take().available().unwrap(); + assert_eq!(value.a.t, 5); + assert_eq!(value.a.u, "hello"); + assert_eq!(value.b, "hello"); + assert!(context.no_unhandled_errors()); + } + + /// We create this struct just to verify that it is able to compile despite + /// NonCopyBuffer not being copyable. + #[derive(Joined)] + #[allow(unused)] + struct JoinedValueForNonCopyBuffer { + #[joined(buffer = NonCopyBuffer, noncopy_buffer)] + _a: String, + _b: u32, + } + + #[derive(Clone, Accessor)] + #[key(buffers_struct_name = TestKeysBuffers)] + struct TestKeys { + integer: BufferKey, + float: BufferKey, + string: BufferKey, + generic: BufferKey, + any: AnyBufferKey, + } + #[test] + fn test_listen() { + let mut context = TestingContext::minimal_plugins(); + + let workflow = context.spawn_io_workflow(|scope, builder| { + let buffer_any = builder.create_buffer::(BufferSettings::default()); + + let buffers = TestKeys::select_buffers( + builder.create_buffer(BufferSettings::default()), + builder.create_buffer(BufferSettings::default()), + builder.create_buffer(BufferSettings::default()), + builder.create_buffer(BufferSettings::default()), + buffer_any.as_any_buffer(), + ); + + scope.input.chain(builder).fork_unzip(( + |chain: Chain<_>| chain.connect(buffers.integer.input_slot()), + |chain: Chain<_>| chain.connect(buffers.float.input_slot()), + |chain: Chain<_>| chain.connect(buffers.string.input_slot()), + |chain: Chain<_>| chain.connect(buffers.generic.input_slot()), + |chain: Chain<_>| chain.connect(buffer_any.input_slot()), + )); + + builder + .listen(buffers) + .then(join_via_listen.into_blocking_callback()) + .dispose_on_none() + .connect(scope.terminate); + }); + + let mut promise = context.command(|commands| { + commands + .request( + (5_i64, 3.14_f64, "hello".to_string(), "world", 42_i64), + workflow, + ) + .take_response() + }); + + context.run_with_conditions(&mut promise, Duration::from_secs(2)); + let value: TestJoinedValue<&'static str> = promise.take().available().unwrap(); + assert_eq!(value.integer, 5); + assert_eq!(value.float, 3.14); + assert_eq!(value.string, "hello"); + assert_eq!(value.generic, "world"); + assert_eq!(*value.any.downcast::().unwrap(), 42); + assert!(context.no_unhandled_errors()); + } + + #[test] + fn test_try_listen() { + let mut context = TestingContext::minimal_plugins(); + + let workflow = context.spawn_io_workflow(|scope, builder| { + let buffer_i64 = builder.create_buffer::(BufferSettings::default()); + let buffer_f64 = builder.create_buffer::(BufferSettings::default()); + let buffer_string = builder.create_buffer::(BufferSettings::default()); + let buffer_generic = builder.create_buffer::<&'static str>(BufferSettings::default()); + let buffer_any = builder.create_buffer::(BufferSettings::default()); + + scope.input.chain(builder).fork_unzip(( + |chain: Chain<_>| chain.connect(buffer_i64.input_slot()), + |chain: Chain<_>| chain.connect(buffer_f64.input_slot()), + |chain: Chain<_>| chain.connect(buffer_string.input_slot()), + |chain: Chain<_>| chain.connect(buffer_generic.input_slot()), + |chain: Chain<_>| chain.connect(buffer_any.input_slot()), + )); + + let mut buffer_map = BufferMap::new(); + buffer_map.insert_buffer("integer", buffer_i64); + buffer_map.insert_buffer("float", buffer_f64); + buffer_map.insert_buffer("string", buffer_string); + buffer_map.insert_buffer("generic", buffer_generic); + buffer_map.insert_buffer("any", buffer_any); + + builder + .try_listen(&buffer_map) + .unwrap() + .then(join_via_listen.into_blocking_callback()) + .dispose_on_none() + .connect(scope.terminate); + }); + + let mut promise = context.command(|commands| { + commands + .request( + (5_i64, 3.14_f64, "hello".to_string(), "world", 42_i64), + workflow, + ) + .take_response() + }); + + context.run_with_conditions(&mut promise, Duration::from_secs(2)); + let value: TestJoinedValue<&'static str> = promise.take().available().unwrap(); + assert_eq!(value.integer, 5); + assert_eq!(value.float, 3.14); + assert_eq!(value.string, "hello"); + assert_eq!(value.generic, "world"); + assert_eq!(*value.any.downcast::().unwrap(), 42); + assert!(context.no_unhandled_errors()); + } + + /// This macro is a manual implementation of the join operation that uses + /// the buffer listening mechanism. There isn't any reason to reimplement + /// join here except so we can test that listening is working correctly for + /// Accessor. + fn join_via_listen( + In(keys): In>, + world: &mut World, + ) -> Option> { + if world.buffer_view(&keys.integer).ok()?.is_empty() { + return None; + } + if world.buffer_view(&keys.float).ok()?.is_empty() { + return None; + } + if world.buffer_view(&keys.string).ok()?.is_empty() { + return None; + } + if world.buffer_view(&keys.generic).ok()?.is_empty() { + return None; + } + if world.any_buffer_view(&keys.any).ok()?.is_empty() { + return None; + } + + let integer = world + .buffer_mut(&keys.integer, |mut buffer| buffer.pull()) + .unwrap() + .unwrap(); + let float = world + .buffer_mut(&keys.float, |mut buffer| buffer.pull()) + .unwrap() + .unwrap(); + let string = world + .buffer_mut(&keys.string, |mut buffer| buffer.pull()) + .unwrap() + .unwrap(); + let generic = world + .buffer_mut(&keys.generic, |mut buffer| buffer.pull()) + .unwrap() + .unwrap(); + let any = world + .any_buffer_mut(&keys.any, |mut buffer| buffer.pull()) + .unwrap() + .unwrap(); + + Some(TestJoinedValue { + integer, + float, + string, + generic, + any, + }) + } +} diff --git a/src/buffer/buffer_storage.rs b/src/buffer/buffer_storage.rs index c7465415..e3d9bc88 100644 --- a/src/buffer/buffer_storage.rs +++ b/src/buffer/buffer_storage.rs @@ -44,23 +44,18 @@ pub(crate) struct BufferStorage { } impl BufferStorage { - pub(crate) fn force_push(&mut self, session: Entity, value: T) -> Option { - Self::impl_push( - self.reverse_queues.entry(session).or_default(), - self.settings.retention(), - value, - ) + pub(crate) fn count(&self, session: Entity) -> usize { + self.reverse_queues + .get(&session) + .map(|q| q.len()) + .unwrap_or(0) } - pub(crate) fn push(&mut self, session: Entity, value: T) -> Option { - let Some(reverse_queue) = self.reverse_queues.get_mut(&session) else { - return Some(value); - }; - - Self::impl_push(reverse_queue, self.settings.retention(), value) + pub(crate) fn active_sessions(&self) -> SmallVec<[Entity; 16]> { + self.reverse_queues.keys().copied().collect() } - pub(crate) fn impl_push( + fn impl_push( reverse_queue: &mut SmallVec<[T; 16]>, retention: RetentionPolicy, value: T, @@ -92,6 +87,22 @@ impl BufferStorage { replaced } + pub(crate) fn force_push(&mut self, session: Entity, value: T) -> Option { + Self::impl_push( + self.reverse_queues.entry(session).or_default(), + self.settings.retention(), + value, + ) + } + + pub(crate) fn push(&mut self, session: Entity, value: T) -> Option { + let Some(reverse_queue) = self.reverse_queues.get_mut(&session) else { + return Some(value); + }; + + Self::impl_push(reverse_queue, self.settings.retention(), value) + } + pub(crate) fn push_as_oldest(&mut self, session: Entity, value: T) -> Option { let Some(reverse_queue) = self.reverse_queues.get_mut(&session) else { return Some(value); @@ -147,20 +158,8 @@ impl BufferStorage { self.reverse_queues.remove(&session); } - pub(crate) fn count(&self, session: Entity) -> usize { - self.reverse_queues - .get(&session) - .map(|q| q.len()) - .unwrap_or(0) - } - - pub(crate) fn iter(&self, session: Entity) -> IterBufferView<'_, T> - where - T: 'static + Send + Sync, - { - IterBufferView { - iter: self.reverse_queues.get(&session).map(|q| q.iter().rev()), - } + pub(crate) fn ensure_session(&mut self, session: Entity) { + self.reverse_queues.entry(session).or_default(); } pub(crate) fn iter_mut(&mut self, session: Entity) -> IterBufferMut<'_, T> @@ -175,6 +174,28 @@ impl BufferStorage { } } + pub(crate) fn oldest_mut(&mut self, session: Entity) -> Option<&mut T> { + self.reverse_queues + .get_mut(&session) + .and_then(|q| q.last_mut()) + } + + pub(crate) fn newest_mut(&mut self, session: Entity) -> Option<&mut T> { + self.reverse_queues + .get_mut(&session) + .and_then(|q| q.first_mut()) + } + + pub(crate) fn get_mut(&mut self, session: Entity, index: usize) -> Option<&mut T> { + let reverse_queue = self.reverse_queues.get_mut(&session)?; + let len = reverse_queue.len(); + if len <= index { + return None; + } + + reverse_queue.get_mut(len - index - 1) + } + pub(crate) fn drain(&mut self, session: Entity, range: R) -> DrainBuffer<'_, T> where T: 'static + Send + Sync, @@ -188,6 +209,15 @@ impl BufferStorage { } } + pub(crate) fn iter(&self, session: Entity) -> IterBufferView<'_, T> + where + T: 'static + Send + Sync, + { + IterBufferView { + iter: self.reverse_queues.get(&session).map(|q| q.iter().rev()), + } + } + pub(crate) fn oldest(&self, session: Entity) -> Option<&T> { self.reverse_queues.get(&session).and_then(|q| q.last()) } @@ -199,43 +229,13 @@ impl BufferStorage { pub(crate) fn get(&self, session: Entity, index: usize) -> Option<&T> { let reverse_queue = self.reverse_queues.get(&session)?; let len = reverse_queue.len(); - if len >= index { + if len <= index { return None; } reverse_queue.get(len - index - 1) } - pub(crate) fn oldest_mut(&mut self, session: Entity) -> Option<&mut T> { - self.reverse_queues - .get_mut(&session) - .and_then(|q| q.last_mut()) - } - - pub(crate) fn newest_mut(&mut self, session: Entity) -> Option<&mut T> { - self.reverse_queues - .get_mut(&session) - .and_then(|q| q.first_mut()) - } - - pub(crate) fn get_mut(&mut self, session: Entity, index: usize) -> Option<&mut T> { - let reverse_queue = self.reverse_queues.get_mut(&session)?; - let len = reverse_queue.len(); - if len >= index { - return None; - } - - reverse_queue.get_mut(len - index - 1) - } - - pub(crate) fn active_sessions(&self) -> SmallVec<[Entity; 16]> { - self.reverse_queues.keys().copied().collect() - } - - pub(crate) fn ensure_session(&mut self, session: Entity) { - self.reverse_queues.entry(session).or_default(); - } - pub(crate) fn new(settings: BufferSettings) -> Self { Self { settings, diff --git a/src/buffer/bufferable.rs b/src/buffer/bufferable.rs index 17f8b367..daf55738 100644 --- a/src/buffer/bufferable.rs +++ b/src/buffer/bufferable.rs @@ -19,143 +19,25 @@ use bevy_utils::all_tuples; use smallvec::SmallVec; use crate::{ - AddOperation, Buffer, BufferSettings, Buffered, Builder, Chain, CleanupWorkflowConditions, - CloneFromBuffer, Join, Listen, Output, Scope, ScopeSettings, UnusedTarget, + Accessing, AddOperation, Buffer, BufferSettings, Buffering, Builder, Chain, CloneFromBuffer, + Join, Joining, Output, UnusedTarget, }; -pub type BufferKeys = <::BufferType as Buffered>::Key; -pub type BufferItem = <::BufferType as Buffered>::Item; +pub type BufferKeys = <::BufferType as Accessing>::Key; +pub type JoinedItem = <::BufferType as Joining>::Item; pub trait Bufferable { - type BufferType: Buffered; + type BufferType: Buffering; /// Convert these bufferable workflow elements into buffers if they are not /// buffers already. fn into_buffer(self, builder: &mut Builder) -> Self::BufferType; - - /// Join these bufferable workflow elements. Each time every buffer contains - /// at least one element, this will pull the oldest element from each buffer - /// and join them into a tuple that gets sent to the target. - /// - /// If you need a more general way to get access to one or more buffers, - /// use [`listen`](Self::listen) instead. - fn join<'w, 's, 'a, 'b>( - self, - builder: &'b mut Builder<'w, 's, 'a>, - ) -> Chain<'w, 's, 'a, 'b, BufferItem> - where - Self: Sized, - Self::BufferType: 'static + Send + Sync, - BufferItem: 'static + Send + Sync, - { - let scope = builder.scope(); - let buffers = self.into_buffer(builder); - buffers.verify_scope(scope); - - let join = builder.commands.spawn(()).id(); - let target = builder.commands.spawn(UnusedTarget).id(); - builder.commands.add(AddOperation::new( - Some(scope), - join, - Join::new(buffers, target), - )); - - Output::new(scope, target).chain(builder) - } - - /// Create an operation that will output buffer access keys each time any - /// one of the buffers is modified. This can be used to create a node in a - /// workflow that wakes up every time one or more buffers change, and then - /// operates on those buffers. - /// - /// For an operation that simply joins the contents of two or more outputs - /// or buffers, use [`join`](Self::join) instead. - fn listen<'w, 's, 'a, 'b>( - self, - builder: &'b mut Builder<'w, 's, 'a>, - ) -> Chain<'w, 's, 'a, 'b, BufferKeys> - where - Self: Sized, - Self::BufferType: 'static + Send + Sync, - BufferKeys: 'static + Send + Sync, - { - let scope = builder.scope(); - let buffers = self.into_buffer(builder); - buffers.verify_scope(scope); - - let listen = builder.commands.spawn(()).id(); - let target = builder.commands.spawn(UnusedTarget).id(); - builder.commands.add(AddOperation::new( - Some(scope), - listen, - Listen::new(buffers, target), - )); - - Output::new(scope, target).chain(builder) - } - - /// Alternative way to call [`Builder::on_cleanup`]. - fn on_cleanup( - self, - builder: &mut Builder, - build: impl FnOnce(Scope, (), ()>, &mut Builder) -> Settings, - ) where - Self: Sized, - Self::BufferType: 'static + Send + Sync, - BufferKeys: 'static + Send + Sync, - Settings: Into, - { - builder.on_cleanup(self, build) - } - - /// Alternative way to call [`Builder::on_cancel`]. - fn on_cancel( - self, - builder: &mut Builder, - build: impl FnOnce(Scope, (), ()>, &mut Builder) -> Settings, - ) where - Self: Sized, - Self::BufferType: 'static + Send + Sync, - BufferKeys: 'static + Send + Sync, - Settings: Into, - { - builder.on_cancel(self, build) - } - - /// Alternative way to call [`Builder::on_terminate`]. - fn on_terminate( - self, - builder: &mut Builder, - build: impl FnOnce(Scope, (), ()>, &mut Builder) -> Settings, - ) where - Self: Sized, - Self::BufferType: 'static + Send + Sync, - BufferKeys: 'static + Send + Sync, - Settings: Into, - { - builder.on_terminate(self, build) - } - - /// Alternative way to call [`Builder::on_cleanup_if`]. - fn on_cleanup_if( - self, - builder: &mut Builder, - conditions: CleanupWorkflowConditions, - build: impl FnOnce(Scope, (), ()>, &mut Builder) -> Settings, - ) where - Self: Sized, - Self::BufferType: 'static + Send + Sync, - BufferKeys: 'static + Send + Sync, - Settings: Into, - { - builder.on_cleanup_if(conditions, self, build) - } } impl Bufferable for Buffer { type BufferType = Self; fn into_buffer(self, builder: &mut Builder) -> Self::BufferType { - assert_eq!(self.scope, builder.scope()); + assert_eq!(self.scope(), builder.scope()); self } } @@ -163,7 +45,7 @@ impl Bufferable for Buffer { impl Bufferable for CloneFromBuffer { type BufferType = Self; fn into_buffer(self, builder: &mut Builder) -> Self::BufferType { - assert_eq!(self.scope, builder.scope()); + assert_eq!(self.scope(), builder.scope()); self } } @@ -178,6 +60,70 @@ impl Bufferable for Output { } } +pub trait Joinable: Bufferable { + type Item: 'static + Send + Sync; + + fn join<'w, 's, 'a, 'b>( + self, + builder: &'b mut Builder<'w, 's, 'a>, + ) -> Chain<'w, 's, 'a, 'b, Self::Item>; +} + +/// This trait is used to create join operations that pull exactly one value +/// from multiple buffers or outputs simultaneously. +impl Joinable for B +where + B: Bufferable, + B::BufferType: Joining, +{ + type Item = JoinedItem; + + /// Join these bufferable workflow elements. Each time every buffer contains + /// at least one element, this will pull the oldest element from each buffer + /// and join them into a tuple that gets sent to the target. + /// + /// If you need a more general way to get access to one or more buffers, + /// use [`listen`](Accessible::listen) instead. + fn join<'w, 's, 'a, 'b>( + self, + builder: &'b mut Builder<'w, 's, 'a>, + ) -> Chain<'w, 's, 'a, 'b, Self::Item> { + self.into_buffer(builder).join(builder) + } +} + +/// This trait is used to create operations that access buffers or outputs. +pub trait Accessible: Bufferable { + type Keys: 'static + Send + Sync; + + /// Create an operation that will output buffer access keys each time any + /// one of the buffers is modified. This can be used to create a node in a + /// workflow that wakes up every time one or more buffers change, and then + /// operates on those buffers. + /// + /// For an operation that simply joins the contents of two or more outputs + /// or buffers, use [`join`](Joinable::join) instead. + fn listen<'w, 's, 'a, 'b>( + self, + builder: &'b mut Builder<'w, 's, 'a>, + ) -> Chain<'w, 's, 'a, 'b, Self::Keys>; +} + +impl Accessible for B +where + B: Bufferable, + B::BufferType: Accessing, +{ + type Keys = BufferKeys; + + fn listen<'w, 's, 'a, 'b>( + self, + builder: &'b mut Builder<'w, 's, 'a>, + ) -> Chain<'w, 's, 'a, 'b, Self::Keys> { + self.into_buffer(builder).listen(builder) + } +} + macro_rules! impl_bufferable_for_tuple { ($($T:ident),*) => { #[allow(non_snake_case)] @@ -206,8 +152,15 @@ impl Bufferable for [T; N] { } } +impl Bufferable for Vec { + type BufferType = Vec; + fn into_buffer(self, builder: &mut Builder) -> Self::BufferType { + self.into_iter().map(|b| b.into_buffer(builder)).collect() + } +} + pub trait IterBufferable { - type BufferElement: Buffered; + type BufferElement: Buffering + Joining; /// Convert an iterable collection of bufferable workflow elements into /// buffers if they are not buffers already. @@ -224,11 +177,11 @@ pub trait IterBufferable { fn join_vec<'w, 's, 'a, 'b, const N: usize>( self, builder: &'b mut Builder<'w, 's, 'a>, - ) -> Chain<'w, 's, 'a, 'b, SmallVec<[::Item; N]>> + ) -> Chain<'w, 's, 'a, 'b, SmallVec<[::Item; N]>> where Self: Sized, Self::BufferElement: 'static + Send + Sync, - ::Item: 'static + Send + Sync, + ::Item: 'static + Send + Sync, { let buffers = self.into_buffer_vec::(builder); let join = builder.commands.spawn(()).id(); @@ -247,6 +200,7 @@ impl IterBufferable for T where T: IntoIterator, T::Item: Bufferable, + ::BufferType: Joining, { type BufferElement = ::BufferType; diff --git a/src/buffer/buffered.rs b/src/buffer/buffering.rs similarity index 54% rename from src/buffer/buffered.rs rename to src/buffer/buffering.rs index 084acb13..81eea5f1 100644 --- a/src/buffer/buffered.rs +++ b/src/buffer/buffering.rs @@ -16,24 +16,24 @@ */ use bevy_ecs::prelude::{Entity, World}; +use bevy_hierarchy::BuildChildren; use bevy_utils::all_tuples; use smallvec::SmallVec; use crate::{ - Buffer, BufferAccessors, BufferKey, BufferKeyBuilder, BufferStorage, CloneFromBuffer, - ForkTargetStorage, Gate, GateState, InspectBuffer, ManageBuffer, OperationError, - OperationResult, OperationRoster, OrBroken, SingleInputStorage, + AddOperation, BeginCleanupWorkflow, Buffer, BufferAccessors, BufferKey, BufferKeyBuilder, + BufferKeyLifecycle, BufferStorage, Builder, Chain, CleanupWorkflowConditions, CloneFromBuffer, + ForkTargetStorage, Gate, GateState, InputSlot, InspectBuffer, Join, Listen, ManageBuffer, Node, + OperateBufferAccess, OperationError, OperationResult, OperationRoster, OrBroken, Output, Scope, + ScopeSettings, SingleInputStorage, UnusedTarget, }; -pub trait Buffered: Clone { +pub trait Buffering: 'static + Send + Sync + Clone { fn verify_scope(&self, scope: Entity); fn buffered_count(&self, session: Entity, world: &World) -> Result; - type Item; - fn pull(&self, session: Entity, world: &mut World) -> Result; - fn add_listener(&self, listener: Entity, world: &mut World) -> OperationResult; fn gate_action( @@ -46,56 +46,177 @@ pub trait Buffered: Clone { fn as_input(&self) -> SmallVec<[Entity; 8]>; - type Key: Clone; - fn add_accessor(&self, accessor: Entity, world: &mut World) -> OperationResult; + fn ensure_active_session(&self, session: Entity, world: &mut World) -> OperationResult; +} - fn create_key(&self, builder: &BufferKeyBuilder) -> Self::Key; +pub trait Joining: Buffering { + type Item: 'static + Send + Sync; + fn pull(&self, session: Entity, world: &mut World) -> Result; - fn ensure_active_session(&self, session: Entity, world: &mut World) -> OperationResult; + /// Join these bufferable workflow elements. Each time every buffer contains + /// at least one element, this will pull the oldest element from each buffer + /// and join them into a tuple that gets sent to the target. + /// + /// If you need a more general way to get access to one or more buffers, + /// use [`listen`](Accessing::listen) instead. + fn join<'w, 's, 'a, 'b>( + self, + builder: &'b mut Builder<'w, 's, 'a>, + ) -> Chain<'w, 's, 'a, 'b, Self::Item> { + let scope = builder.scope(); + self.verify_scope(scope); + + let join = builder.commands.spawn(()).id(); + let target = builder.commands.spawn(UnusedTarget).id(); + builder.commands.add(AddOperation::new( + Some(scope), + join, + Join::new(self, target), + )); + + Output::new(scope, target).chain(builder) + } +} +pub trait Accessing: Buffering { + type Key: 'static + Send + Sync + Clone; + fn add_accessor(&self, accessor: Entity, world: &mut World) -> OperationResult; + fn create_key(&self, builder: &BufferKeyBuilder) -> Self::Key; fn deep_clone_key(key: &Self::Key) -> Self::Key; - fn is_key_in_use(key: &Self::Key) -> bool; + + /// Create an operation that will output buffer access keys each time any + /// one of the buffers is modified. This can be used to create a node in a + /// workflow that wakes up every time one or more buffers change, and then + /// operates on those buffers. + /// + /// For an operation that simply joins the contents of two or more outputs + /// or buffers, use [`join`](Joining::join) instead. + fn listen<'w, 's, 'a, 'b>( + self, + builder: &'b mut Builder<'w, 's, 'a>, + ) -> Chain<'w, 's, 'a, 'b, Self::Key> { + let scope = builder.scope(); + self.verify_scope(scope); + + let listen = builder.commands.spawn(()).id(); + let target = builder.commands.spawn(UnusedTarget).id(); + builder.commands.add(AddOperation::new( + Some(scope), + listen, + Listen::new(self, target), + )); + + Output::new(scope, target).chain(builder) + } + + fn access(self, builder: &mut Builder) -> Node { + let source = builder.commands.spawn(()).id(); + let target = builder.commands.spawn(UnusedTarget).id(); + builder.commands.add(AddOperation::new( + Some(builder.scope), + source, + OperateBufferAccess::::new(self, target), + )); + + Node { + input: InputSlot::new(builder.scope, source), + output: Output::new(builder.scope, target), + streams: (), + } + } + + /// Alternative way to call [`Builder::on_cleanup`]. + fn on_cleanup( + self, + builder: &mut Builder, + build: impl FnOnce(Scope, &mut Builder) -> Settings, + ) where + Settings: Into, + { + self.on_cleanup_if( + builder, + CleanupWorkflowConditions::always_if(true, true), + build, + ) + } + + /// Alternative way to call [`Builder::on_cancel`]. + fn on_cancel( + self, + builder: &mut Builder, + build: impl FnOnce(Scope, &mut Builder) -> Settings, + ) where + Settings: Into, + { + self.on_cleanup_if( + builder, + CleanupWorkflowConditions::always_if(false, true), + build, + ) + } + + /// Alternative way to call [`Builder::on_terminate`]. + fn on_terminate( + self, + builder: &mut Builder, + build: impl FnOnce(Scope, &mut Builder) -> Settings, + ) where + Settings: Into, + { + self.on_cleanup_if( + builder, + CleanupWorkflowConditions::always_if(true, false), + build, + ) + } + + /// Alternative way to call [`Builder::on_cleanup_if`]. + fn on_cleanup_if( + self, + builder: &mut Builder, + conditions: CleanupWorkflowConditions, + build: impl FnOnce(Scope, &mut Builder) -> Settings, + ) where + Settings: Into, + { + let cancelling_scope_id = builder.commands.spawn(()).id(); + let _ = builder.create_scope_impl::( + cancelling_scope_id, + builder.finish_scope_cancel, + build, + ); + + let begin_cancel = builder.commands.spawn(()).set_parent(builder.scope).id(); + self.verify_scope(builder.scope); + builder.commands.add(AddOperation::new( + None, + begin_cancel, + BeginCleanupWorkflow::::new( + builder.scope, + self, + cancelling_scope_id, + conditions.run_on_terminate, + conditions.run_on_cancel, + ), + )); + } } -impl Buffered for Buffer { +impl Buffering for Buffer { fn verify_scope(&self, scope: Entity) { - assert_eq!(scope, self.scope); + assert_eq!(scope, self.scope()); } fn buffered_count(&self, session: Entity, world: &World) -> Result { world - .get_entity(self.source) + .get_entity(self.id()) .or_broken()? .buffered_count::(session) } - type Item = T; - fn pull(&self, session: Entity, world: &mut World) -> Result { - world - .get_entity_mut(self.source) - .or_broken()? - .pull_from_buffer::(session) - } - fn add_listener(&self, listener: Entity, world: &mut World) -> OperationResult { - let mut targets = world - .get_mut::(self.source) - .or_broken()?; - if !targets.0.contains(&listener) { - targets.0.push(listener); - } - - if let Some(mut input_storage) = world.get_mut::(listener) { - input_storage.add(self.source); - } else { - world - .get_entity_mut(listener) - .or_broken()? - .insert(SingleInputStorage::new(self.source)); - } - - Ok(()) + add_listener_to_source(self.id(), listener, world) } fn gate_action( @@ -105,35 +226,46 @@ impl Buffered for Buffer { world: &mut World, roster: &mut OperationRoster, ) -> OperationResult { - GateState::apply(self.source, session, action, world, roster) + GateState::apply(self.id(), session, action, world, roster) } fn as_input(&self) -> SmallVec<[Entity; 8]> { - SmallVec::from_iter([self.source]) + SmallVec::from_iter([self.id()]) } - type Key = BufferKey; - fn add_accessor(&self, accessor: Entity, world: &mut World) -> OperationResult { - let mut accessors = world.get_mut::(self.source).or_broken()?; - - accessors.0.push(accessor); - accessors.0.sort(); - accessors.0.dedup(); + fn ensure_active_session(&self, session: Entity, world: &mut World) -> OperationResult { + world + .get_mut::>(self.id()) + .or_broken()? + .ensure_session(session); Ok(()) } +} - fn create_key(&self, builder: &BufferKeyBuilder) -> Self::Key { - builder.build(self.source) +impl Joining for Buffer { + type Item = T; + fn pull(&self, session: Entity, world: &mut World) -> Result { + world + .get_entity_mut(self.id()) + .or_broken()? + .pull_from_buffer::(session) } +} - fn ensure_active_session(&self, session: Entity, world: &mut World) -> OperationResult { +impl Accessing for Buffer { + type Key = BufferKey; + fn add_accessor(&self, accessor: Entity, world: &mut World) -> OperationResult { world - .get_mut::>(self.source) + .get_mut::(self.id()) .or_broken()? - .ensure_session(session); + .add_accessor(accessor); Ok(()) } + fn create_key(&self, builder: &BufferKeyBuilder) -> Self::Key { + Self::Key::create_key(&self, builder) + } + fn deep_clone_key(key: &Self::Key) -> Self::Key { key.deep_clone() } @@ -143,44 +275,20 @@ impl Buffered for Buffer { } } -impl Buffered for CloneFromBuffer { +impl Buffering for CloneFromBuffer { fn verify_scope(&self, scope: Entity) { - assert_eq!(scope, self.scope); + assert_eq!(scope, self.scope()); } fn buffered_count(&self, session: Entity, world: &World) -> Result { world - .get_entity(self.source) + .get_entity(self.id()) .or_broken()? .buffered_count::(session) } - type Item = T; - fn pull(&self, session: Entity, world: &mut World) -> Result { - world - .get_entity(self.source) - .or_broken()? - .try_clone_from_buffer(session) - .and_then(|r| r.or_broken()) - } - fn add_listener(&self, listener: Entity, world: &mut World) -> OperationResult { - let mut targets = world - .get_mut::(self.source) - .or_broken()?; - if !targets.0.contains(&listener) { - targets.0.push(listener); - } - - if let Some(mut input_storage) = world.get_mut::(listener) { - input_storage.add(self.source); - } else { - world - .get_entity_mut(listener) - .or_broken()? - .insert(SingleInputStorage::new(self.source)); - } - Ok(()) + add_listener_to_source(self.id(), listener, world) } fn gate_action( @@ -190,35 +298,46 @@ impl Buffered for CloneFromBuffer { world: &mut World, roster: &mut OperationRoster, ) -> OperationResult { - GateState::apply(self.source, session, action, world, roster) + GateState::apply(self.id(), session, action, world, roster) } fn as_input(&self) -> SmallVec<[Entity; 8]> { - SmallVec::from_iter([self.source]) + SmallVec::from_iter([self.id()]) } - type Key = BufferKey; - fn add_accessor(&self, accessor: Entity, world: &mut World) -> OperationResult { - let mut accessors = world.get_mut::(self.source).or_broken()?; - - accessors.0.push(accessor); - accessors.0.sort(); - accessors.0.dedup(); - Ok(()) + fn ensure_active_session(&self, session: Entity, world: &mut World) -> OperationResult { + world + .get_entity_mut(self.id()) + .or_broken()? + .ensure_session::(session) } +} - fn create_key(&self, builder: &BufferKeyBuilder) -> Self::Key { - builder.build(self.source) +impl Joining for CloneFromBuffer { + type Item = T; + fn pull(&self, session: Entity, world: &mut World) -> Result { + world + .get_entity(self.id()) + .or_broken()? + .try_clone_from_buffer(session) + .and_then(|r| r.or_broken()) } +} - fn ensure_active_session(&self, session: Entity, world: &mut World) -> OperationResult { +impl Accessing for CloneFromBuffer { + type Key = BufferKey; + fn add_accessor(&self, accessor: Entity, world: &mut World) -> OperationResult { world - .get_mut::>(self.source) + .get_mut::(self.id()) .or_broken()? - .ensure_session(session); + .add_accessor(accessor); Ok(()) } + fn create_key(&self, builder: &BufferKeyBuilder) -> Self::Key { + Self::Key::create_key(&(*self).into(), builder) + } + fn deep_clone_key(key: &Self::Key) -> Self::Key { key.deep_clone() } @@ -231,7 +350,7 @@ impl Buffered for CloneFromBuffer { macro_rules! impl_buffered_for_tuple { ($(($T:ident, $K:ident)),*) => { #[allow(non_snake_case)] - impl<$($T: Buffered),*> Buffered for ($($T,)*) + impl<$($T: Buffering),*> Buffering for ($($T,)*) { fn verify_scope(&self, scope: Entity) { let ($($T,)*) = self; @@ -253,18 +372,6 @@ macro_rules! impl_buffered_for_tuple { ].iter().copied().min().unwrap_or(0)) } - type Item = ($($T::Item),*); - fn pull( - &self, - session: Entity, - world: &mut World, - ) -> Result { - let ($($T,)*) = self; - Ok(($( - $T.pull(session, world)?, - )*)) - } - fn add_listener( &self, listener: Entity, @@ -300,6 +407,38 @@ macro_rules! impl_buffered_for_tuple { inputs } + fn ensure_active_session( + &self, + session: Entity, + world: &mut World, + ) -> OperationResult { + let ($($T,)*) = self; + $( + $T.ensure_active_session(session, world)?; + )* + Ok(()) + } + } + + #[allow(non_snake_case)] + impl<$($T: Joining),*> Joining for ($($T,)*) + { + type Item = ($($T::Item),*); + fn pull( + &self, + session: Entity, + world: &mut World, + ) -> Result { + let ($($T,)*) = self; + Ok(($( + $T.pull(session, world)?, + )*)) + } + } + + #[allow(non_snake_case)] + impl<$($T: Accessing),*> Accessing for ($($T,)*) + { type Key = ($($T::Key), *); fn add_accessor( &self, @@ -323,18 +462,6 @@ macro_rules! impl_buffered_for_tuple { )*) } - fn ensure_active_session( - &self, - session: Entity, - world: &mut World, - ) -> OperationResult { - let ($($T,)*) = self; - $( - $T.ensure_active_session(session, world)?; - )* - Ok(()) - } - fn deep_clone_key(key: &Self::Key) -> Self::Key { let ($($K,)*) = key; ($( @@ -352,11 +479,11 @@ macro_rules! impl_buffered_for_tuple { } } -// Implements the `Buffered` trait for all tuples between size 2 and 12 -// (inclusive) made of types that implement `Buffered` +// Implements the `Buffering` trait for all tuples between size 2 and 12 +// (inclusive) made of types that implement `Buffering` all_tuples!(impl_buffered_for_tuple, 2, 12, T, K); -impl Buffered for [T; N] { +impl Buffering for [T; N] { fn verify_scope(&self, scope: Entity) { for buffer in self.iter() { buffer.verify_scope(scope); @@ -375,15 +502,6 @@ impl Buffered for [T; N] { Ok(min_count.unwrap_or(0)) } - // TODO(@mxgrey) We may be able to use [T::Item; N] here instead of SmallVec - // when try_map is stabilized: https://github.com/rust-lang/rust/issues/79711 - type Item = SmallVec<[T::Item; N]>; - fn pull(&self, session: Entity, world: &mut World) -> Result { - self.iter() - .map(|buffer| buffer.pull(session, world)) - .collect() - } - fn add_listener(&self, listener: Entity, world: &mut World) -> OperationResult { for buffer in self { buffer.add_listener(listener, world)?; @@ -408,6 +526,27 @@ impl Buffered for [T; N] { self.iter().flat_map(|buffer| buffer.as_input()).collect() } + fn ensure_active_session(&self, session: Entity, world: &mut World) -> OperationResult { + for buffer in self { + buffer.ensure_active_session(session, world)?; + } + + Ok(()) + } +} + +impl Joining for [T; N] { + // TODO(@mxgrey) We may be able to use [T::Item; N] here instead of SmallVec + // when try_map is stabilized: https://github.com/rust-lang/rust/issues/79711 + type Item = SmallVec<[T::Item; N]>; + fn pull(&self, session: Entity, world: &mut World) -> Result { + self.iter() + .map(|buffer| buffer.pull(session, world)) + .collect() + } +} + +impl Accessing for [T; N] { type Key = SmallVec<[T::Key; N]>; fn add_accessor(&self, accessor: Entity, world: &mut World) -> OperationResult { for buffer in self { @@ -424,14 +563,6 @@ impl Buffered for [T; N] { keys } - fn ensure_active_session(&self, session: Entity, world: &mut World) -> OperationResult { - for buffer in self { - buffer.ensure_active_session(session, world)?; - } - - Ok(()) - } - fn deep_clone_key(key: &Self::Key) -> Self::Key { let mut keys = SmallVec::new(); for k in key { @@ -451,7 +582,7 @@ impl Buffered for [T; N] { } } -impl Buffered for SmallVec<[T; N]> { +impl Buffering for SmallVec<[T; N]> { fn verify_scope(&self, scope: Entity) { for buffer in self.iter() { buffer.verify_scope(scope); @@ -470,13 +601,6 @@ impl Buffered for SmallVec<[T; N]> { Ok(min_count.unwrap_or(0)) } - type Item = SmallVec<[T::Item; N]>; - fn pull(&self, session: Entity, world: &mut World) -> Result { - self.iter() - .map(|buffer| buffer.pull(session, world)) - .collect() - } - fn add_listener(&self, listener: Entity, world: &mut World) -> OperationResult { for buffer in self { buffer.add_listener(listener, world)?; @@ -501,6 +625,25 @@ impl Buffered for SmallVec<[T; N]> { self.iter().flat_map(|buffer| buffer.as_input()).collect() } + fn ensure_active_session(&self, session: Entity, world: &mut World) -> OperationResult { + for buffer in self { + buffer.ensure_active_session(session, world)?; + } + + Ok(()) + } +} + +impl Joining for SmallVec<[T; N]> { + type Item = SmallVec<[T::Item; N]>; + fn pull(&self, session: Entity, world: &mut World) -> Result { + self.iter() + .map(|buffer| buffer.pull(session, world)) + .collect() + } +} + +impl Accessing for SmallVec<[T; N]> { type Key = SmallVec<[T::Key; N]>; fn add_accessor(&self, accessor: Entity, world: &mut World) -> OperationResult { for buffer in self { @@ -517,6 +660,68 @@ impl Buffered for SmallVec<[T; N]> { keys } + fn deep_clone_key(key: &Self::Key) -> Self::Key { + let mut keys = SmallVec::new(); + for k in key { + keys.push(T::deep_clone_key(k)); + } + keys + } + + fn is_key_in_use(key: &Self::Key) -> bool { + for k in key { + if T::is_key_in_use(k) { + return true; + } + } + + false + } +} + +impl Buffering for Vec { + fn verify_scope(&self, scope: Entity) { + for buffer in self { + buffer.verify_scope(scope); + } + } + + fn buffered_count(&self, session: Entity, world: &World) -> Result { + let mut min_count = None; + for buffer in self { + let count = buffer.buffered_count(session, world)?; + if !min_count.is_some_and(|min| min < count) { + min_count = Some(count); + } + } + + Ok(min_count.unwrap_or(0)) + } + + fn add_listener(&self, listener: Entity, world: &mut World) -> OperationResult { + for buffer in self { + buffer.add_listener(listener, world)?; + } + Ok(()) + } + + fn gate_action( + &self, + session: Entity, + action: Gate, + world: &mut World, + roster: &mut OperationRoster, + ) -> OperationResult { + for buffer in self { + buffer.gate_action(session, action, world, roster)?; + } + Ok(()) + } + + fn as_input(&self) -> SmallVec<[Entity; 8]> { + self.iter().flat_map(|buffer| buffer.as_input()).collect() + } + fn ensure_active_session(&self, session: Entity, world: &mut World) -> OperationResult { for buffer in self { buffer.ensure_active_session(session, world)?; @@ -524,18 +729,45 @@ impl Buffered for SmallVec<[T; N]> { Ok(()) } +} + +impl Joining for Vec { + type Item = Vec; + fn pull(&self, session: Entity, world: &mut World) -> Result { + self.iter() + .map(|buffer| buffer.pull(session, world)) + .collect() + } +} + +impl Accessing for Vec { + type Key = Vec; + fn add_accessor(&self, accessor: Entity, world: &mut World) -> OperationResult { + for buffer in self { + buffer.add_accessor(accessor, world)?; + } + Ok(()) + } + + fn create_key(&self, builder: &BufferKeyBuilder) -> Self::Key { + let mut keys = Vec::new(); + for buffer in self { + keys.push(buffer.create_key(builder)); + } + keys + } fn deep_clone_key(key: &Self::Key) -> Self::Key { - let mut keys = SmallVec::new(); + let mut keys = Vec::new(); for k in key { - keys.push(T::deep_clone_key(k)); + keys.push(B::deep_clone_key(k)); } keys } fn is_key_in_use(key: &Self::Key) -> bool { for k in key { - if T::is_key_in_use(k) { + if B::is_key_in_use(k) { return true; } } @@ -543,3 +775,25 @@ impl Buffered for SmallVec<[T; N]> { false } } + +pub(crate) fn add_listener_to_source( + source: Entity, + listener: Entity, + world: &mut World, +) -> OperationResult { + let mut targets = world.get_mut::(source).or_broken()?; + if !targets.0.contains(&listener) { + targets.0.push(listener); + } + + if let Some(mut input_storage) = world.get_mut::(listener) { + input_storage.add(source); + } else { + world + .get_entity_mut(listener) + .or_broken()? + .insert(SingleInputStorage::new(source)); + } + + Ok(()) +} diff --git a/src/buffer/json_buffer.rs b/src/buffer/json_buffer.rs new file mode 100644 index 00000000..a3eba1ff --- /dev/null +++ b/src/buffer/json_buffer.rs @@ -0,0 +1,1598 @@ +/* + * Copyright (C) 2025 Open Source Robotics Foundation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * +*/ + +// TODO(@mxgrey): Add module-level documentation describing how to use JsonBuffer + +use std::{ + any::TypeId, + collections::HashMap, + ops::RangeBounds, + sync::{Mutex, OnceLock}, +}; + +use bevy_ecs::{ + prelude::{Commands, Entity, EntityRef, EntityWorldMut, Mut, World}, + system::SystemState, +}; + +use serde::{de::DeserializeOwned, Serialize}; + +pub use serde_json::Value as JsonMessage; + +use smallvec::SmallVec; + +use crate::{ + add_listener_to_source, Accessing, AnyBuffer, AnyBufferAccessInterface, AnyBufferKey, AnyRange, + AsAnyBuffer, Buffer, BufferAccessMut, BufferAccessors, BufferError, BufferIdentifier, + BufferKey, BufferKeyBuilder, BufferKeyLifecycle, BufferKeyTag, BufferLocation, BufferMap, + BufferMapLayout, BufferMapStruct, BufferStorage, Bufferable, Buffering, Builder, DrainBuffer, + Gate, GateState, IncompatibleLayout, InspectBuffer, Joined, Joining, ManageBuffer, + NotifyBufferUpdate, OperationError, OperationResult, OrBroken, +}; + +/// A [`Buffer`] whose message type has been anonymized, but which is known to +/// support serialization and deserialization. Joining this buffer type will +/// yield a [`JsonMessage`]. +#[derive(Clone, Copy, Debug)] +pub struct JsonBuffer { + location: BufferLocation, + interface: &'static (dyn JsonBufferAccessInterface + Send + Sync), +} + +impl JsonBuffer { + /// Downcast this into a concerete [`Buffer`] for the specific message type. + /// + /// To downcast this into a specialized kind of buffer, use [`Self::downcast_buffer`] instead. + pub fn downcast_for_message(&self) -> Option> { + if TypeId::of::() == self.interface.any_access_interface().message_type_id() { + Some(Buffer { + location: self.location, + _ignore: Default::default(), + }) + } else { + None + } + } + + /// Downcast this into a different specialized buffer representation. + pub fn downcast_buffer(&self) -> Option { + self.as_any_buffer().downcast_buffer::() + } + + /// Register the ability to cast into [`JsonBuffer`] and [`JsonBufferKey`] + /// for buffers containing messages of type `T`. This only needs to be done + /// once in the entire lifespan of a program. + /// + /// Note that this will take effect automatically any time you create an + /// instance of [`JsonBuffer`] or [`JsonBufferKey`] for a buffer with + /// messages of type `T`. + pub fn register_for() + where + T: 'static + Serialize + DeserializeOwned + Send + Sync, + { + // We just need to ensure that this function gets called so that the + // downcast callback gets registered. Nothing more needs to be done. + JsonBufferAccessImpl::::get_interface(); + } + + /// Get the entity ID of the buffer. + pub fn id(&self) -> Entity { + self.location.source + } + + /// Get the ID of the workflow that the buffer is associated with. + pub fn scope(&self) -> Entity { + self.location.scope + } + + /// Get general information about the buffer. + pub fn location(&self) -> BufferLocation { + self.location + } +} + +impl From> for JsonBuffer { + fn from(value: Buffer) -> Self { + Self { + location: value.location, + interface: JsonBufferAccessImpl::::get_interface(), + } + } +} + +impl From for AnyBuffer { + fn from(value: JsonBuffer) -> Self { + Self { + location: value.location, + interface: value.interface.any_access_interface(), + } + } +} + +impl AsAnyBuffer for JsonBuffer { + fn as_any_buffer(&self) -> AnyBuffer { + (*self).into() + } +} + +/// Similar to a [`BufferKey`] except it can be used for any buffer that supports +/// serialization and deserialization without knowing the buffer's specific +/// message type at compile time. +/// +/// This can key be used with a [`World`][1] to directly view or manipulate the +/// contents of a buffer through the [`JsonBufferWorldAccess`] interface. +/// +/// [1]: bevy_ecs::prelude::World +#[derive(Clone)] +pub struct JsonBufferKey { + tag: BufferKeyTag, + interface: &'static (dyn JsonBufferAccessInterface + Send + Sync), +} + +impl JsonBufferKey { + /// Downcast this into a concrete [`BufferKey`] for the specified message type. + /// + /// To downcast to a specialized kind of key, use [`Self::downcast_buffer_key`] instead. + pub fn downcast_for_message(self) -> Option> { + self.as_any_buffer_key().downcast_for_message() + } + + pub fn downcast_buffer_key(self) -> Option { + self.as_any_buffer_key().downcast_buffer_key() + } + + /// Cast this into an [`AnyBufferKey`] + pub fn as_any_buffer_key(self) -> AnyBufferKey { + self.into() + } +} + +impl BufferKeyLifecycle for JsonBufferKey { + type TargetBuffer = JsonBuffer; + + fn create_key(buffer: &Self::TargetBuffer, builder: &BufferKeyBuilder) -> Self { + Self { + tag: builder.make_tag(buffer.id()), + interface: buffer.interface, + } + } + + fn is_in_use(&self) -> bool { + self.tag.is_in_use() + } + + fn deep_clone(&self) -> Self { + Self { + tag: self.tag.deep_clone(), + interface: self.interface, + } + } +} + +impl std::fmt::Debug for JsonBufferKey { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("JsonBufferKey") + .field( + "message_type_name", + &self.interface.any_access_interface().message_type_name(), + ) + .field("tag", &self.tag) + .finish() + } +} + +impl From> for JsonBufferKey { + fn from(value: BufferKey) -> Self { + let interface = JsonBufferAccessImpl::::get_interface(); + JsonBufferKey { + tag: value.tag, + interface, + } + } +} + +impl From for AnyBufferKey { + fn from(value: JsonBufferKey) -> Self { + AnyBufferKey { + tag: value.tag, + interface: value.interface.any_access_interface(), + } + } +} + +/// Similar to [`BufferView`][crate::BufferView], but this can be unlocked with +/// a [`JsonBufferKey`], so it can work for any buffer whose message types +/// support serialization and deserialization. +pub struct JsonBufferView<'a> { + storage: Box, + gate: &'a GateState, + session: Entity, +} + +impl<'a> JsonBufferView<'a> { + /// Get a serialized copy of the oldest message in the buffer. + pub fn oldest(&self) -> JsonMessageViewResult { + self.storage.json_oldest(self.session) + } + + /// Get a serialized copy of the newest message in the buffer. + pub fn newest(&self) -> JsonMessageViewResult { + self.storage.json_newest(self.session) + } + + /// Get a serialized copy of a message in the buffer. + pub fn get(&self, index: usize) -> JsonMessageViewResult { + self.storage.json_get(self.session, index) + } + + /// Get how many messages are in this buffer. + pub fn len(&self) -> usize { + self.storage.json_count(self.session) + } + + /// Check if the buffer is empty. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Check whether the gate of this buffer is open or closed. + pub fn gate(&self) -> Gate { + self.gate + .map + .get(&self.session) + .copied() + .unwrap_or(Gate::Open) + } +} + +/// Similar to [`BufferMut`][crate::BufferMut], but this can be unlocked with a +/// [`JsonBufferKey`], so it can work for any buffer whose message types support +/// serialization and deserialization. +pub struct JsonBufferMut<'w, 's, 'a> { + storage: Box, + gate: Mut<'a, GateState>, + buffer: Entity, + session: Entity, + accessor: Option, + commands: &'a mut Commands<'w, 's>, + modified: bool, +} + +impl<'w, 's, 'a> JsonBufferMut<'w, 's, 'a> { + /// Same as [BufferMut::allow_closed_loops][1]. + /// + /// [1]: crate::BufferMut::allow_closed_loops + pub fn allow_closed_loops(mut self) -> Self { + self.accessor = None; + self + } + + /// Get a serialized copy of the oldest message in the buffer. + pub fn oldest(&self) -> JsonMessageViewResult { + self.storage.json_oldest(self.session) + } + + /// Get a serialized copy of the newest message in the buffer. + pub fn newest(&self) -> JsonMessageViewResult { + self.storage.json_newest(self.session) + } + + /// Get a serialized copy of a message in the buffer. + pub fn get(&self, index: usize) -> JsonMessageViewResult { + self.storage.json_get(self.session, index) + } + + /// Get how many messages are in this buffer. + pub fn len(&self) -> usize { + self.storage.json_count(self.session) + } + + /// Check if the buffer is empty. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Check whether the gate of this buffer is open or closed. + pub fn gate(&self) -> Gate { + self.gate + .map + .get(&self.session) + .copied() + .unwrap_or(Gate::Open) + } + + /// Modify the oldest message in the buffer. + pub fn oldest_mut(&mut self) -> Option> { + self.storage + .json_oldest_mut(self.session, &mut self.modified) + } + + /// Modify the newest message in the buffer. + pub fn newest_mut(&mut self) -> Option> { + self.storage + .json_newest_mut(self.session, &mut self.modified) + } + + /// Modify a message in the buffer. + pub fn get_mut(&mut self, index: usize) -> Option> { + self.storage + .json_get_mut(self.session, index, &mut self.modified) + } + + /// Drain a range of messages out of the buffer. + pub fn drain>(&mut self, range: R) -> DrainJsonBuffer<'_> { + self.modified = true; + DrainJsonBuffer { + interface: self.storage.json_drain(self.session, AnyRange::new(range)), + } + } + + /// Pull the oldest message from the buffer as a JSON value. Unlike + /// [`Self::oldest`] this will remove the message from the buffer. + pub fn pull(&mut self) -> JsonMessageViewResult { + self.modified = true; + self.storage.json_pull(self.session) + } + + /// Pull the oldest message from the buffer and attempt to deserialize it + /// into the target type. + pub fn pull_as(&mut self) -> Result, serde_json::Error> { + self.pull()?.map(|m| serde_json::from_value(m)).transpose() + } + + /// Pull the newest message from the buffer as a JSON value. Unlike + /// [`Self::newest`] this will remove the message from the buffer. + pub fn pull_newest(&mut self) -> JsonMessageViewResult { + self.modified = true; + self.storage.json_pull_newest(self.session) + } + + /// Pull the newest message from the buffer and attempt to deserialize it + /// into the target type. + pub fn pull_newest_as(&mut self) -> Result, serde_json::Error> { + self.pull_newest()? + .map(|m| serde_json::from_value(m)) + .transpose() + } + + /// Attempt to push a new value into the buffer. + /// + /// If the input value is compatible with the message type of the buffer, + /// this will return [`Ok`]. If the buffer is at its limit before a successful + /// push, this will return the value that needed to be removed. + /// + /// If the input value does not match the message type of the buffer, this + /// will return [`Err`]. This may also return [`Err`] if the message coming + /// out of the buffer failed to serialize. + // TODO(@mxgrey): Consider having an error type that differentiates the + // various possible error modes. + pub fn push( + &mut self, + value: T, + ) -> Result, serde_json::Error> { + let message = serde_json::to_value(&value)?; + self.modified = true; + self.storage.json_push(self.session, message) + } + + /// Same as [`Self::push`] but no serialization step is needed for the incoming + /// message. + pub fn push_json( + &mut self, + message: JsonMessage, + ) -> Result, serde_json::Error> { + self.modified = true; + self.storage.json_push(self.session, message) + } + + /// Same as [`Self::push`] but the message will be interpreted as the oldest + /// message in the buffer. + pub fn push_as_oldest( + &mut self, + value: T, + ) -> Result, serde_json::Error> { + let message = serde_json::to_value(&value)?; + self.modified = true; + self.storage.json_push_as_oldest(self.session, message) + } + + /// Same as [`Self::push_as_oldest`] but no serialization step is needed for + /// the incoming message. + pub fn push_json_as_oldest( + &mut self, + message: JsonMessage, + ) -> Result, serde_json::Error> { + self.modified = true; + self.storage.json_push_as_oldest(self.session, message) + } + + /// Tell the buffer [`Gate`] to open. + pub fn open_gate(&mut self) { + if let Some(gate) = self.gate.map.get_mut(&self.session) { + if *gate != Gate::Open { + *gate = Gate::Open; + self.modified = true; + } + } + } + + /// Tell the buffer [`Gate`] to close. + pub fn close_gate(&mut self) { + if let Some(gate) = self.gate.map.get_mut(&self.session) { + *gate = Gate::Closed; + // There is no need to to indicate that a modification happened + // because listeners do not get notified about gates closing. + } + } + + /// Perform an action on the gate of the buffer. + pub fn gate_action(&mut self, action: Gate) { + match action { + Gate::Open => self.open_gate(), + Gate::Closed => self.close_gate(), + } + } + + /// Trigger the listeners for this buffer to wake up even if nothing in the + /// buffer has changed. This could be used for timers or timeout elements + /// in a workflow. + pub fn pulse(&mut self) { + self.modified = true; + } +} + +impl<'w, 's, 'a> Drop for JsonBufferMut<'w, 's, 'a> { + fn drop(&mut self) { + if self.modified { + self.commands.add(NotifyBufferUpdate::new( + self.buffer, + self.session, + self.accessor, + )); + } + } +} + +pub trait JsonBufferWorldAccess { + /// Call this to get read-only access to any buffer whose message type is + /// serializable and deserializable. + /// + /// For technical reasons this requires direct [`World`] access, but you can + /// do other read-only queries on the world while holding onto the + /// [`JsonBufferView`]. + fn json_buffer_view(&self, key: &JsonBufferKey) -> Result, BufferError>; + + /// Call this to get mutable access to any buffer whose message type is + /// serializable and deserializable. + /// + /// Pass in a callback that will receive a [`JsonBufferMut`], allowing it to + /// view and modify the contents of the buffer. + fn json_buffer_mut( + &mut self, + key: &JsonBufferKey, + f: impl FnOnce(JsonBufferMut) -> U, + ) -> Result; +} + +impl JsonBufferWorldAccess for World { + fn json_buffer_view(&self, key: &JsonBufferKey) -> Result, BufferError> { + key.interface.create_json_buffer_view(key, self) + } + + fn json_buffer_mut( + &mut self, + key: &JsonBufferKey, + f: impl FnOnce(JsonBufferMut) -> U, + ) -> Result { + let interface = key.interface; + let mut state = interface.create_json_buffer_access_mut_state(self); + let mut access = state.get_json_buffer_access_mut(self); + let buffer_mut = access.as_json_buffer_mut(key)?; + Ok(f(buffer_mut)) + } +} + +/// View or modify a buffer message in terms of JSON values. +pub struct JsonMut<'a> { + interface: &'a mut dyn JsonMutInterface, + modified: &'a mut bool, +} + +impl<'a> JsonMut<'a> { + /// Serialize the message within the buffer into JSON. + /// + /// This new [`JsonMessage`] will be a duplicate of the data of the message + /// inside the buffer, effectively meaning this function clones the data. + pub fn serialize(&self) -> Result { + self.interface.serialize() + } + + /// This will first serialize the message within the buffer into JSON and + /// then attempt to deserialize it into the target type. + /// + /// The target type does not need to match the message type inside the buffer, + /// as long as the target type can be deserialized from a serialized value + /// of the buffer's message type. + /// + /// The returned value will duplicate the data of the message inside the + /// buffer, effectively meaning this function clones the data. + pub fn deserialize_into(&self) -> Result { + serde_json::from_value::(self.serialize()?) + } + + /// Replace the underlying message with new data, and receive its original + /// data as JSON. + #[must_use = "if you are going to discard the returned message, use insert instead"] + pub fn replace(&mut self, message: JsonMessage) -> JsonMessageReplaceResult { + *self.modified = true; + self.interface.replace(message) + } + + /// Insert new data into the underyling message. This is the same as replace + /// except it is more efficient if you don't care about the original data, + /// because it will discard the original data instead of serializing it. + pub fn insert(&mut self, message: JsonMessage) -> Result<(), serde_json::Error> { + *self.modified = true; + self.interface.insert(message) + } + + /// Modify the data of the underlying message. This is equivalent to calling + /// [`Self::serialize`], modifying the value, and then calling [`Self::insert`]. + /// The benefit of this function is that you do not need to remember to + /// insert after you have finished your modifications. + pub fn modify(&mut self, f: impl FnOnce(&mut JsonMessage)) -> Result<(), serde_json::Error> { + let mut message = self.serialize()?; + f(&mut message); + self.insert(message) + } +} + +/// The return type for functions that give a JSON view of a message in a buffer. +/// If an error occurs while attempting to serialize the message, this will return +/// [`Err`]. +/// +/// If this returns [`Ok`] then [`None`] means there was no message available at +/// the requested location while [`Some`] will contain a serialized copy of the +/// message. +pub type JsonMessageViewResult = Result, serde_json::Error>; + +/// The return type for functions that push a new message into a buffer. If an +/// error occurs while deserializing the message into the buffer's message type +/// then this will return [`Err`]. +/// +/// If this returns [`Ok`] then [`None`] means the new message was added and all +/// prior messages have been retained in the buffer. [`Some`] will contain an +/// old message which has now been removed from the buffer. +pub type JsonMessagePushResult = Result, serde_json::Error>; + +/// The return type for functions that replace (swap out) one message with +/// another. If an error occurs while serializing or deserializing either +/// message to/from the buffer's message type then this will return [`Err`]. +/// +/// If this returns [`Ok`] then the message was successfully replaced, and the +/// value inside [`Ok`] is the message that was previously in the buffer. +pub type JsonMessageReplaceResult = Result; + +trait JsonBufferViewing { + fn json_count(&self, session: Entity) -> usize; + fn json_oldest<'a>(&'a self, session: Entity) -> JsonMessageViewResult; + fn json_newest<'a>(&'a self, session: Entity) -> JsonMessageViewResult; + fn json_get<'a>(&'a self, session: Entity, index: usize) -> JsonMessageViewResult; +} + +trait JsonBufferManagement: JsonBufferViewing { + fn json_push(&mut self, session: Entity, value: JsonMessage) -> JsonMessagePushResult; + fn json_push_as_oldest(&mut self, session: Entity, value: JsonMessage) + -> JsonMessagePushResult; + fn json_pull(&mut self, session: Entity) -> JsonMessageViewResult; + fn json_pull_newest(&mut self, session: Entity) -> JsonMessageViewResult; + fn json_oldest_mut<'a>( + &'a mut self, + session: Entity, + modified: &'a mut bool, + ) -> Option>; + fn json_newest_mut<'a>( + &'a mut self, + session: Entity, + modified: &'a mut bool, + ) -> Option>; + fn json_get_mut<'a>( + &'a mut self, + session: Entity, + index: usize, + modified: &'a mut bool, + ) -> Option>; + fn json_drain<'a>( + &'a mut self, + session: Entity, + range: AnyRange, + ) -> Box; +} + +impl JsonBufferViewing for &'_ BufferStorage +where + T: 'static + Send + Sync + Serialize + DeserializeOwned, +{ + fn json_count(&self, session: Entity) -> usize { + self.count(session) + } + + fn json_oldest<'a>(&'a self, session: Entity) -> JsonMessageViewResult { + self.oldest(session).map(serde_json::to_value).transpose() + } + + fn json_newest<'a>(&'a self, session: Entity) -> JsonMessageViewResult { + self.newest(session).map(serde_json::to_value).transpose() + } + + fn json_get<'a>(&'a self, session: Entity, index: usize) -> JsonMessageViewResult { + self.get(session, index) + .map(serde_json::to_value) + .transpose() + } +} + +impl JsonBufferViewing for Mut<'_, BufferStorage> +where + T: 'static + Send + Sync + Serialize + DeserializeOwned, +{ + fn json_count(&self, session: Entity) -> usize { + self.count(session) + } + + fn json_oldest<'a>(&'a self, session: Entity) -> JsonMessageViewResult { + self.oldest(session).map(serde_json::to_value).transpose() + } + + fn json_newest<'a>(&'a self, session: Entity) -> JsonMessageViewResult { + self.newest(session).map(serde_json::to_value).transpose() + } + + fn json_get<'a>(&'a self, session: Entity, index: usize) -> JsonMessageViewResult { + self.get(session, index) + .map(serde_json::to_value) + .transpose() + } +} + +impl JsonBufferManagement for Mut<'_, BufferStorage> +where + T: 'static + Send + Sync + Serialize + DeserializeOwned, +{ + fn json_push(&mut self, session: Entity, value: JsonMessage) -> JsonMessagePushResult { + let value: T = serde_json::from_value(value)?; + self.push(session, value) + .map(serde_json::to_value) + .transpose() + } + + fn json_push_as_oldest( + &mut self, + session: Entity, + value: JsonMessage, + ) -> JsonMessagePushResult { + let value: T = serde_json::from_value(value)?; + self.push(session, value) + .map(serde_json::to_value) + .transpose() + } + + fn json_pull(&mut self, session: Entity) -> JsonMessageViewResult { + self.pull(session).map(serde_json::to_value).transpose() + } + + fn json_pull_newest(&mut self, session: Entity) -> JsonMessageViewResult { + self.pull_newest(session) + .map(serde_json::to_value) + .transpose() + } + + fn json_oldest_mut<'a>( + &'a mut self, + session: Entity, + modified: &'a mut bool, + ) -> Option> { + self.oldest_mut(session).map(|interface| JsonMut { + interface, + modified, + }) + } + + fn json_newest_mut<'a>( + &'a mut self, + session: Entity, + modified: &'a mut bool, + ) -> Option> { + self.newest_mut(session).map(|interface| JsonMut { + interface, + modified, + }) + } + + fn json_get_mut<'a>( + &'a mut self, + session: Entity, + index: usize, + modified: &'a mut bool, + ) -> Option> { + self.get_mut(session, index).map(|interface| JsonMut { + interface, + modified, + }) + } + + fn json_drain<'a>( + &'a mut self, + session: Entity, + range: AnyRange, + ) -> Box { + Box::new(self.drain(session, range)) + } +} + +trait JsonMutInterface { + /// Serialize the underlying message into JSON + fn serialize(&self) -> Result; + /// Replace the underlying message with new data, and receive its original + /// data as JSON + fn replace(&mut self, message: JsonMessage) -> JsonMessageReplaceResult; + /// Insert new data into the underyling message. This is the same as replace + /// except it is more efficient if you don't care about the original data, + /// because it will discard the original data instead of serializing it. + fn insert(&mut self, message: JsonMessage) -> Result<(), serde_json::Error>; +} + +impl JsonMutInterface for T { + fn serialize(&self) -> Result { + serde_json::to_value(self) + } + + fn replace(&mut self, message: JsonMessage) -> JsonMessageReplaceResult { + let new_message: T = serde_json::from_value(message)?; + let old_message = serde_json::to_value(&self)?; + *self = new_message; + Ok(old_message) + } + + fn insert(&mut self, message: JsonMessage) -> Result<(), serde_json::Error> { + let new_message: T = serde_json::from_value(message)?; + *self = new_message; + Ok(()) + } +} + +trait JsonBufferAccessInterface { + fn any_access_interface(&self) -> &'static (dyn AnyBufferAccessInterface + Send + Sync); + + fn buffered_count( + &self, + buffer_ref: &EntityRef, + session: Entity, + ) -> Result; + + fn ensure_session(&self, buffer_mut: &mut EntityWorldMut, session: Entity) -> OperationResult; + + fn pull( + &self, + buffer_mut: &mut EntityWorldMut, + session: Entity, + ) -> Result; + + fn create_json_buffer_view<'a>( + &self, + key: &JsonBufferKey, + world: &'a World, + ) -> Result, BufferError>; + + fn create_json_buffer_access_mut_state( + &self, + world: &mut World, + ) -> Box; +} + +impl<'a> std::fmt::Debug for &'a (dyn JsonBufferAccessInterface + Send + Sync) { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Message Properties") + .field("type", &self.any_access_interface().message_type_name()) + .finish() + } +} + +struct JsonBufferAccessImpl(std::marker::PhantomData); + +impl JsonBufferAccessImpl { + pub(crate) fn get_interface() -> &'static (dyn JsonBufferAccessInterface + Send + Sync) { + // Create and cache the json buffer access interface + static INTERFACE_MAP: OnceLock< + Mutex>, + > = OnceLock::new(); + let interfaces = INTERFACE_MAP.get_or_init(|| Mutex::default()); + + let mut interfaces_mut = interfaces.lock().unwrap(); + *interfaces_mut.entry(TypeId::of::()).or_insert_with(|| { + // Register downcasting for JsonBuffer and JsonBufferKey the + // first time that we retrieve an interface for this type. + let any_interface = AnyBuffer::interface_for::(); + any_interface.register_buffer_downcast( + TypeId::of::(), + Box::new(|location| { + Box::new(JsonBuffer { + location, + interface: Self::get_interface(), + }) + }), + ); + + any_interface.register_key_downcast( + TypeId::of::(), + Box::new(|tag| { + Box::new(JsonBufferKey { + tag, + interface: Self::get_interface(), + }) + }), + ); + + // SAFETY: This will leak memory exactly once per type, so the leakage is bounded. + // Leaking this allows the interface to be shared freely across all instances. + Box::leak(Box::new(JsonBufferAccessImpl::(Default::default()))) + }) + } +} + +impl JsonBufferAccessInterface + for JsonBufferAccessImpl +{ + fn any_access_interface(&self) -> &'static (dyn AnyBufferAccessInterface + Send + Sync) { + AnyBuffer::interface_for::() + } + + fn buffered_count( + &self, + buffer_ref: &EntityRef, + session: Entity, + ) -> Result { + buffer_ref.buffered_count::(session) + } + + fn ensure_session(&self, buffer_mut: &mut EntityWorldMut, session: Entity) -> OperationResult { + buffer_mut.ensure_session::(session) + } + + fn pull( + &self, + buffer_mut: &mut EntityWorldMut, + session: Entity, + ) -> Result { + let value = buffer_mut.pull_from_buffer::(session)?; + serde_json::to_value(value).or_broken() + } + + fn create_json_buffer_view<'a>( + &self, + key: &JsonBufferKey, + world: &'a World, + ) -> Result, BufferError> { + let buffer_ref = world + .get_entity(key.tag.buffer) + .ok_or(BufferError::BufferMissing)?; + let storage = buffer_ref + .get::>() + .ok_or(BufferError::BufferMissing)?; + let gate = buffer_ref + .get::() + .ok_or(BufferError::BufferMissing)?; + Ok(JsonBufferView { + storage: Box::new(storage), + gate, + session: key.tag.session, + }) + } + + fn create_json_buffer_access_mut_state( + &self, + world: &mut World, + ) -> Box { + Box::new(SystemState::>::new(world)) + } +} + +trait JsonBufferAccessMutState { + fn get_json_buffer_access_mut<'s, 'w: 's>( + &'s mut self, + world: &'w mut World, + ) -> Box + 's>; +} + +impl JsonBufferAccessMutState for SystemState> +where + T: 'static + Send + Sync + Serialize + DeserializeOwned, +{ + fn get_json_buffer_access_mut<'s, 'w: 's>( + &'s mut self, + world: &'w mut World, + ) -> Box + 's> { + Box::new(self.get_mut(world)) + } +} + +trait JsonBufferAccessMut<'w, 's> { + fn as_json_buffer_mut<'a>( + &'a mut self, + key: &JsonBufferKey, + ) -> Result, BufferError>; +} + +impl<'w, 's, T> JsonBufferAccessMut<'w, 's> for BufferAccessMut<'w, 's, T> +where + T: 'static + Send + Sync + Serialize + DeserializeOwned, +{ + fn as_json_buffer_mut<'a>( + &'a mut self, + key: &JsonBufferKey, + ) -> Result, BufferError> { + let BufferAccessMut { query, commands } = self; + let (storage, gate) = query + .get_mut(key.tag.buffer) + .map_err(|_| BufferError::BufferMissing)?; + Ok(JsonBufferMut { + storage: Box::new(storage), + gate, + buffer: key.tag.buffer, + session: key.tag.session, + accessor: Some(key.tag.accessor), + commands, + modified: false, + }) + } +} + +pub struct DrainJsonBuffer<'a> { + interface: Box, +} + +impl<'a> Iterator for DrainJsonBuffer<'a> { + type Item = Result; + + fn next(&mut self) -> Option { + self.interface.json_next() + } +} + +trait DrainJsonBufferInterface { + fn json_next(&mut self) -> Option>; +} + +impl DrainJsonBufferInterface for DrainBuffer<'_, T> { + fn json_next(&mut self) -> Option> { + self.next().map(serde_json::to_value) + } +} + +impl Bufferable for JsonBuffer { + type BufferType = Self; + fn into_buffer(self, builder: &mut Builder) -> Self::BufferType { + assert_eq!(self.scope(), builder.scope()); + self + } +} + +impl Buffering for JsonBuffer { + fn verify_scope(&self, scope: Entity) { + assert_eq!(scope, self.scope()); + } + + fn buffered_count(&self, session: Entity, world: &World) -> Result { + let buffer_ref = world.get_entity(self.id()).or_broken()?; + self.interface.buffered_count(&buffer_ref, session) + } + + fn add_listener(&self, listener: Entity, world: &mut World) -> OperationResult { + add_listener_to_source(self.id(), listener, world) + } + + fn gate_action( + &self, + session: Entity, + action: Gate, + world: &mut World, + roster: &mut crate::OperationRoster, + ) -> OperationResult { + GateState::apply(self.id(), session, action, world, roster) + } + + fn as_input(&self) -> smallvec::SmallVec<[Entity; 8]> { + SmallVec::from_iter([self.id()]) + } + + fn ensure_active_session(&self, session: Entity, world: &mut World) -> OperationResult { + let mut buffer_mut = world.get_entity_mut(self.id()).or_broken()?; + self.interface.ensure_session(&mut buffer_mut, session) + } +} + +impl Joining for JsonBuffer { + type Item = JsonMessage; + fn pull(&self, session: Entity, world: &mut World) -> Result { + let mut buffer_mut = world.get_entity_mut(self.id()).or_broken()?; + self.interface.pull(&mut buffer_mut, session) + } +} + +impl Accessing for JsonBuffer { + type Key = JsonBufferKey; + fn add_accessor(&self, accessor: Entity, world: &mut World) -> OperationResult { + world + .get_mut::(self.id()) + .or_broken()? + .add_accessor(accessor); + Ok(()) + } + + fn create_key(&self, builder: &BufferKeyBuilder) -> Self::Key { + JsonBufferKey { + tag: builder.make_tag(self.id()), + interface: self.interface, + } + } + + fn deep_clone_key(key: &Self::Key) -> Self::Key { + key.deep_clone() + } + + fn is_key_in_use(key: &Self::Key) -> bool { + key.is_in_use() + } +} + +impl Joined for serde_json::Map { + type Buffers = HashMap; +} + +impl BufferMapLayout for HashMap { + fn try_from_buffer_map(buffers: &BufferMap) -> Result { + let mut downcast_buffers = HashMap::new(); + let mut compatibility = IncompatibleLayout::default(); + for name in buffers.keys() { + match name { + BufferIdentifier::Name(name) => { + if let Ok(downcast) = + compatibility.require_buffer_for_borrowed_name::(&name, buffers) + { + downcast_buffers.insert(name.clone().into_owned(), downcast); + } + } + BufferIdentifier::Index(index) => { + compatibility + .forbidden_buffers + .push(BufferIdentifier::Index(*index)); + } + } + } + + compatibility.as_result()?; + Ok(downcast_buffers) + } +} + +impl BufferMapStruct for HashMap { + fn buffer_list(&self) -> SmallVec<[AnyBuffer; 8]> { + self.values().map(|b| b.as_any_buffer()).collect() + } +} + +impl Joining for HashMap { + type Item = serde_json::Map; + fn pull(&self, session: Entity, world: &mut World) -> Result { + self.iter() + .map(|(key, value)| value.pull(session, world).map(|v| (key.clone(), v))) + .collect() + } +} + +#[cfg(test)] +mod tests { + use crate::{prelude::*, testing::*, AddBufferToMap}; + use bevy_ecs::prelude::World; + use serde::{Deserialize, Serialize}; + + #[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)] + struct TestMessage { + v_i32: i32, + v_u32: u32, + v_string: String, + } + + impl TestMessage { + fn new() -> Self { + Self { + v_i32: 1, + v_u32: 2, + v_string: "hello".to_string(), + } + } + } + + #[test] + fn test_json_count() { + let mut context = TestingContext::minimal_plugins(); + + let workflow = context.spawn_io_workflow(|scope, builder| { + let buffer = builder.create_buffer(BufferSettings::keep_all()); + let push_multiple_times = builder + .commands() + .spawn_service(push_multiple_times_into_buffer.into_blocking_service()); + let count = builder + .commands() + .spawn_service(get_buffer_count.into_blocking_service()); + + scope + .input + .chain(builder) + .with_access(buffer) + .then(push_multiple_times) + .then(count) + .connect(scope.terminate); + }); + + let msg = TestMessage::new(); + let mut promise = + context.command(|commands| commands.request(msg, workflow).take_response()); + + context.run_with_conditions(&mut promise, Duration::from_secs(2)); + let count = promise.take().available().unwrap(); + assert_eq!(count, 5); + assert!(context.no_unhandled_errors()); + } + + fn push_multiple_times_into_buffer( + In((value, key)): In<(TestMessage, BufferKey)>, + mut access: BufferAccessMut, + ) -> JsonBufferKey { + let mut buffer = access.get_mut(&key).unwrap(); + for _ in 0..5 { + buffer.push(value.clone()); + } + + key.into() + } + + fn get_buffer_count(In(key): In, world: &mut World) -> usize { + world.json_buffer_view(&key).unwrap().len() + } + + #[test] + fn test_modify_json_message() { + let mut context = TestingContext::minimal_plugins(); + + let workflow = context.spawn_io_workflow(|scope, builder| { + let buffer = builder.create_buffer(BufferSettings::keep_all()); + let push_multiple_times = builder + .commands() + .spawn_service(push_multiple_times_into_buffer.into_blocking_service()); + let modify_content = builder + .commands() + .spawn_service(modify_buffer_content.into_blocking_service()); + let drain_content = builder + .commands() + .spawn_service(pull_each_buffer_item.into_blocking_service()); + + scope + .input + .chain(builder) + .with_access(buffer) + .then(push_multiple_times) + .then(modify_content) + .then(drain_content) + .connect(scope.terminate); + }); + + let msg = TestMessage::new(); + let mut promise = + context.command(|commands| commands.request(msg, workflow).take_response()); + + context.run_with_conditions(&mut promise, Duration::from_secs(2)); + let values = promise.take().available().unwrap(); + assert_eq!(values.len(), 5); + for i in 0..values.len() { + let v_i32 = values[i].get("v_i32").unwrap().as_i64().unwrap(); + assert_eq!(v_i32, i as i64); + } + assert!(context.no_unhandled_errors()); + } + + fn modify_buffer_content(In(key): In, world: &mut World) -> JsonBufferKey { + world + .json_buffer_mut(&key, |mut access| { + for i in 0..access.len() { + access + .get_mut(i) + .unwrap() + .modify(|value| { + let v_i32 = value.get_mut("v_i32").unwrap(); + let modified_v_i32 = i as i64 * v_i32.as_i64().unwrap(); + *v_i32 = modified_v_i32.into(); + }) + .unwrap(); + } + }) + .unwrap(); + + key + } + + fn pull_each_buffer_item(In(key): In, world: &mut World) -> Vec { + world + .json_buffer_mut(&key, |mut access| { + let mut values = Vec::new(); + while let Ok(Some(value)) = access.pull() { + values.push(value); + } + values + }) + .unwrap() + } + + #[test] + fn test_drain_json_message() { + let mut context = TestingContext::minimal_plugins(); + + let workflow = context.spawn_io_workflow(|scope, builder| { + let buffer = builder.create_buffer(BufferSettings::keep_all()); + let push_multiple_times = builder + .commands() + .spawn_service(push_multiple_times_into_buffer.into_blocking_service()); + let modify_content = builder + .commands() + .spawn_service(modify_buffer_content.into_blocking_service()); + let drain_content = builder + .commands() + .spawn_service(drain_buffer_contents.into_blocking_service()); + + scope + .input + .chain(builder) + .with_access(buffer) + .then(push_multiple_times) + .then(modify_content) + .then(drain_content) + .connect(scope.terminate); + }); + + let msg = TestMessage::new(); + let mut promise = + context.command(|commands| commands.request(msg, workflow).take_response()); + + context.run_with_conditions(&mut promise, Duration::from_secs(2)); + let values = promise.take().available().unwrap(); + assert_eq!(values.len(), 5); + for i in 0..values.len() { + let v_i32 = values[i].get("v_i32").unwrap().as_i64().unwrap(); + assert_eq!(v_i32, i as i64); + } + assert!(context.no_unhandled_errors()); + } + + fn drain_buffer_contents(In(key): In, world: &mut World) -> Vec { + world + .json_buffer_mut(&key, |mut access| { + access.drain(..).collect::, _>>() + }) + .unwrap() + .unwrap() + } + + #[test] + fn double_json_messages() { + let mut context = TestingContext::minimal_plugins(); + + let workflow = context.spawn_io_workflow(|scope, builder| { + let buffer_double_u32: JsonBuffer = builder + .create_buffer::(BufferSettings::default()) + .into(); + let buffer_double_i32: JsonBuffer = builder + .create_buffer::(BufferSettings::default()) + .into(); + let buffer_double_string: JsonBuffer = builder + .create_buffer::(BufferSettings::default()) + .into(); + + scope.input.chain(builder).fork_clone(( + |chain: Chain<_>| { + chain + .map_block(|mut msg: TestMessage| { + msg.v_u32 *= 2; + msg + }) + .connect( + buffer_double_u32 + .downcast_for_message::() + .unwrap() + .input_slot(), + ) + }, + |chain: Chain<_>| { + chain + .map_block(|mut msg: TestMessage| { + msg.v_i32 *= 2; + msg + }) + .connect( + buffer_double_i32 + .downcast_for_message::() + .unwrap() + .input_slot(), + ) + }, + |chain: Chain<_>| { + chain + .map_block(|mut msg: TestMessage| { + msg.v_string = msg.v_string.clone() + &msg.v_string; + msg + }) + .connect( + buffer_double_string + .downcast_for_message::() + .unwrap() + .input_slot(), + ) + }, + )); + + (buffer_double_u32, buffer_double_i32, buffer_double_string) + .join(builder) + .connect(scope.terminate); + }); + + let msg = TestMessage::new(); + let mut promise = + context.command(|commands| commands.request(msg, workflow).take_response()); + + context.run_with_conditions(&mut promise, Duration::from_secs(2)); + let (double_u32, double_i32, double_string) = promise.take().available().unwrap(); + assert_eq!(4, double_u32.get("v_u32").unwrap().as_i64().unwrap()); + assert_eq!(2, double_i32.get("v_i32").unwrap().as_i64().unwrap()); + assert_eq!( + "hellohello", + double_string.get("v_string").unwrap().as_str().unwrap() + ); + assert!(context.no_unhandled_errors()); + } + + #[test] + fn test_buffer_downcast() { + let mut context = TestingContext::minimal_plugins(); + + let workflow = context.spawn_io_workflow(|scope, builder| { + // We just need to test that these buffers can be downcast without + // a panic occurring. + JsonBuffer::register_for::(); + let buffer = builder.create_buffer::(BufferSettings::keep_all()); + let any_buffer: AnyBuffer = buffer.into(); + let json_buffer: JsonBuffer = any_buffer.downcast_buffer().unwrap(); + let _original_from_any: Buffer = + any_buffer.downcast_for_message().unwrap(); + let _original_from_json: Buffer = + json_buffer.downcast_for_message().unwrap(); + + scope + .input + .chain(builder) + .with_access(buffer) + .map_block(|(data, key)| { + let any_key: AnyBufferKey = key.clone().into(); + let json_key: JsonBufferKey = any_key.clone().downcast_buffer_key().unwrap(); + let _original_from_any: BufferKey = + any_key.downcast_for_message().unwrap(); + let _original_from_json: BufferKey = + json_key.downcast_for_message().unwrap(); + + data + }) + .connect(scope.terminate); + }); + + let mut promise = context.command(|commands| commands.request(1, workflow).take_response()); + + context.run_with_conditions(&mut promise, Duration::from_secs(2)); + let response = promise.take().available().unwrap(); + assert_eq!(1, response); + assert!(context.no_unhandled_errors()); + } + + #[derive(Clone, Joined)] + #[joined(buffers_struct_name = TestJoinedValueJsonBuffers)] + struct TestJoinedValueJson { + integer: i64, + float: f64, + #[joined(buffer = JsonBuffer)] + json: JsonMessage, + } + + #[test] + fn test_try_join_json() { + let mut context = TestingContext::minimal_plugins(); + + let workflow = context.spawn_io_workflow(|scope, builder| { + JsonBuffer::register_for::(); + + let buffer_i64 = builder.create_buffer(BufferSettings::default()); + let buffer_f64 = builder.create_buffer(BufferSettings::default()); + let buffer_json = builder.create_buffer(BufferSettings::default()); + + let mut buffers = BufferMap::default(); + buffers.insert_buffer("integer", buffer_i64); + buffers.insert_buffer("float", buffer_f64); + buffers.insert_buffer("json", buffer_json); + + scope.input.chain(builder).fork_unzip(( + |chain: Chain<_>| chain.connect(buffer_i64.input_slot()), + |chain: Chain<_>| chain.connect(buffer_f64.input_slot()), + |chain: Chain<_>| chain.connect(buffer_json.input_slot()), + )); + + builder.try_join(&buffers).unwrap().connect(scope.terminate); + }); + + let mut promise = context.command(|commands| { + commands + .request((5_i64, 3.14_f64, TestMessage::new()), workflow) + .take_response() + }); + + context.run_with_conditions(&mut promise, Duration::from_secs(2)); + let value: TestJoinedValueJson = promise.take().available().unwrap(); + assert_eq!(value.integer, 5); + assert_eq!(value.float, 3.14); + let deserialized_json: TestMessage = serde_json::from_value(value.json).unwrap(); + let expected_json = TestMessage::new(); + assert_eq!(deserialized_json, expected_json); + } + + #[test] + fn test_joined_value_json() { + let mut context = TestingContext::minimal_plugins(); + + let workflow = context.spawn_io_workflow(|scope, builder| { + JsonBuffer::register_for::(); + + let json_buffer = builder.create_buffer::(BufferSettings::default()); + let buffers = TestJoinedValueJsonBuffers { + integer: builder.create_buffer(BufferSettings::default()), + float: builder.create_buffer(BufferSettings::default()), + json: json_buffer.into(), + }; + + scope.input.chain(builder).fork_unzip(( + |chain: Chain<_>| chain.connect(buffers.integer.input_slot()), + |chain: Chain<_>| chain.connect(buffers.float.input_slot()), + |chain: Chain<_>| chain.connect(json_buffer.input_slot()), + )); + + builder.join(buffers).connect(scope.terminate); + }); + + let mut promise = context.command(|commands| { + commands + .request((5_i64, 3.14_f64, TestMessage::new()), workflow) + .take_response() + }); + + context.run_with_conditions(&mut promise, Duration::from_secs(2)); + let value: TestJoinedValueJson = promise.take().available().unwrap(); + assert_eq!(value.integer, 5); + assert_eq!(value.float, 3.14); + let deserialized_json: TestMessage = serde_json::from_value(value.json).unwrap(); + let expected_json = TestMessage::new(); + assert_eq!(deserialized_json, expected_json); + } + + #[test] + fn test_select_buffers_json() { + let mut context = TestingContext::minimal_plugins(); + + let workflow = context.spawn_io_workflow(|scope, builder| { + let buffer_integer = builder.create_buffer::(BufferSettings::default()); + let buffer_float = builder.create_buffer::(BufferSettings::default()); + let buffer_json = + JsonBuffer::from(builder.create_buffer::(BufferSettings::default())); + + let buffers = + TestJoinedValueJson::select_buffers(buffer_integer, buffer_float, buffer_json); + + scope.input.chain(builder).fork_unzip(( + |chain: Chain<_>| chain.connect(buffers.integer.input_slot()), + |chain: Chain<_>| chain.connect(buffers.float.input_slot()), + |chain: Chain<_>| { + chain.connect(buffers.json.downcast_for_message().unwrap().input_slot()) + }, + )); + + builder.join(buffers).connect(scope.terminate); + }); + + let mut promise = context.command(|commands| { + commands + .request((5_i64, 3.14_f64, TestMessage::new()), workflow) + .take_response() + }); + + context.run_with_conditions(&mut promise, Duration::from_secs(2)); + let value: TestJoinedValueJson = promise.take().available().unwrap(); + assert_eq!(value.integer, 5); + assert_eq!(value.float, 3.14); + let deserialized_json: TestMessage = serde_json::from_value(value.json).unwrap(); + let expected_json = TestMessage::new(); + assert_eq!(deserialized_json, expected_json); + } + + #[test] + fn test_join_json_buffer_vec() { + let mut context = TestingContext::minimal_plugins(); + + let workflow = context.spawn_io_workflow(|scope, builder| { + let buffer_u32 = builder.create_buffer::(BufferSettings::default()); + let buffer_i32 = builder.create_buffer::(BufferSettings::default()); + let buffer_string = builder.create_buffer::(BufferSettings::default()); + let buffer_msg = builder.create_buffer::(BufferSettings::default()); + let buffers: Vec = vec![ + buffer_i32.into(), + buffer_u32.into(), + buffer_string.into(), + buffer_msg.into(), + ]; + + scope + .input + .chain(builder) + .map_block(|msg: TestMessage| (msg.v_u32, msg.v_i32, msg.v_string.clone(), msg)) + .fork_unzip(( + |chain: Chain<_>| chain.connect(buffer_u32.input_slot()), + |chain: Chain<_>| chain.connect(buffer_i32.input_slot()), + |chain: Chain<_>| chain.connect(buffer_string.input_slot()), + |chain: Chain<_>| chain.connect(buffer_msg.input_slot()), + )); + + builder.join(buffers).connect(scope.terminate); + }); + + let mut promise = context.command(|commands| { + commands + .request(TestMessage::new(), workflow) + .take_response() + }); + + context.run_with_conditions(&mut promise, Duration::from_secs(2)); + let values = promise.take().available().unwrap(); + assert_eq!(values.len(), 4); + assert_eq!(values[0], serde_json::Value::Number(1.into())); + assert_eq!(values[1], serde_json::Value::Number(2.into())); + assert_eq!(values[2], serde_json::Value::String("hello".to_string())); + assert_eq!(values[3], serde_json::to_value(TestMessage::new()).unwrap()); + } + + // We define this struct just to make sure the Accessor macro successfully + // compiles with JsonBufferKey. + #[derive(Clone, Accessor)] + #[allow(unused)] + struct TestJsonKeyMap { + integer: BufferKey, + string: BufferKey, + json: JsonBufferKey, + any: AnyBufferKey, + } +} diff --git a/src/buffer/manage_buffer.rs b/src/buffer/manage_buffer.rs index f0a6705a..af9cf65b 100644 --- a/src/buffer/manage_buffer.rs +++ b/src/buffer/manage_buffer.rs @@ -89,6 +89,8 @@ pub trait ManageBuffer { ) -> Result, OperationError>; fn clear_buffer(&mut self, session: Entity) -> OperationResult; + + fn ensure_session(&mut self, session: Entity) -> OperationResult; } impl<'w> ManageBuffer for EntityWorldMut<'w> { @@ -114,4 +116,11 @@ impl<'w> ManageBuffer for EntityWorldMut<'w> { .clear_session(session); Ok(()) } + + fn ensure_session(&mut self, session: Entity) -> OperationResult { + self.get_mut::>() + .or_broken()? + .ensure_session(session); + Ok(()) + } } diff --git a/src/builder.rs b/src/builder.rs index 4476d060..45727868 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -16,20 +16,19 @@ */ use bevy_ecs::prelude::{Commands, Entity}; -use bevy_hierarchy::prelude::BuildChildren; use std::future::Future; use smallvec::SmallVec; use crate::{ - AddOperation, AsMap, BeginCleanupWorkflow, Buffer, BufferItem, BufferKeys, BufferSettings, - Bufferable, Buffered, Chain, Collect, ForkClone, ForkCloneOutput, ForkTargetStorage, Gate, - GateRequest, Injection, InputSlot, IntoAsyncMap, IntoBlockingMap, Node, OperateBuffer, - OperateBufferAccess, OperateDynamicGate, OperateScope, OperateSplit, OperateStaticGate, Output, - Provider, RequestOfMap, ResponseOfMap, Scope, ScopeEndpoints, ScopeSettings, - ScopeSettingsStorage, Sendish, Service, SplitOutputs, Splittable, StreamPack, StreamTargetMap, - StreamsOfMap, Trim, TrimBranch, UnusedTarget, + Accessible, Accessing, Accessor, AddOperation, AsMap, Buffer, BufferKeys, BufferLocation, + BufferMap, BufferSettings, Bufferable, Buffering, Chain, Collect, ForkClone, ForkCloneOutput, + ForkTargetStorage, Gate, GateRequest, IncompatibleLayout, Injection, InputSlot, IntoAsyncMap, + IntoBlockingMap, Joinable, Joined, Node, OperateBuffer, OperateDynamicGate, OperateScope, + OperateSplit, OperateStaticGate, Output, Provider, RequestOfMap, ResponseOfMap, Scope, + ScopeEndpoints, ScopeSettings, ScopeSettingsStorage, Sendish, Service, SplitOutputs, + Splittable, StreamPack, StreamTargetMap, StreamsOfMap, Trim, TrimBranch, UnusedTarget, }; pub(crate) mod connect; @@ -165,8 +164,10 @@ impl<'w, 's, 'a> Builder<'w, 's, 'a> { )); Buffer { - scope: self.scope, - source, + location: BufferLocation { + scope: self.scope(), + source, + }, _ignore: Default::default(), } } @@ -230,27 +231,32 @@ impl<'w, 's, 'a> Builder<'w, 's, 'a> { ) } - /// Alternative way of calling [`Bufferable::join`] - pub fn join<'b, B: Bufferable>(&'b mut self, buffers: B) -> Chain<'w, 's, 'a, 'b, BufferItem> - where - B::BufferType: 'static + Send + Sync, - BufferItem: 'static + Send + Sync, - { + /// Alternative way of calling [`Joinable::join`] + pub fn join<'b, B: Joinable>(&'b mut self, buffers: B) -> Chain<'w, 's, 'a, 'b, B::Item> { buffers.join(self) } - /// Alternative way of calling [`Bufferable::listen`]. - pub fn listen<'b, B: Bufferable>( + /// Try joining a map of buffers into a single value. + pub fn try_join<'b, J: Joined>( &'b mut self, - buffers: B, - ) -> Chain<'w, 's, 'a, 'b, BufferKeys> - where - B::BufferType: 'static + Send + Sync, - BufferKeys: 'static + Send + Sync, - { + buffers: &BufferMap, + ) -> Result, IncompatibleLayout> { + J::try_join_from(buffers, self) + } + + /// Alternative way of calling [`Accessible::listen`]. + pub fn listen<'b, B: Accessible>(&'b mut self, buffers: B) -> Chain<'w, 's, 'a, 'b, B::Keys> { buffers.listen(self) } + /// Try listening to a map of buffers. + pub fn try_listen<'b, Keys: Accessor>( + &'b mut self, + buffers: &BufferMap, + ) -> Result, IncompatibleLayout> { + Keys::try_listen_from(buffers, self) + } + /// Create a node that combines its inputs with access to some buffers. You /// must specify one ore more buffers to access. FOr multiple buffers, /// combine then into a tuple or an [`Iterator`]. Tuples of buffers can be @@ -258,27 +264,29 @@ impl<'w, 's, 'a> Builder<'w, 's, 'a> { /// /// Other [outputs](Output) can also be passed in as buffers. These outputs /// will be transformed into a buffer with default buffer settings. - pub fn create_buffer_access(&mut self, buffers: B) -> Node)> + pub fn create_buffer_access( + &mut self, + buffers: B, + ) -> Node)> where + B::BufferType: Accessing, T: 'static + Send + Sync, - B: Bufferable, - B::BufferType: 'static + Send + Sync, - BufferKeys: 'static + Send + Sync, { let buffers = buffers.into_buffer(self); - let source = self.commands.spawn(()).id(); - let target = self.commands.spawn(UnusedTarget).id(); - self.commands.add(AddOperation::new( - Some(self.scope), - source, - OperateBufferAccess::::new(buffers, target), - )); + buffers.access(self) + } - Node { - input: InputSlot::new(self.scope, source), - output: Output::new(self.scope, target), - streams: (), - } + /// Try to create access to some buffers. Same as [`Self::create_buffer_access`] + /// except it will return an error if the buffers in the [`BufferMap`] are not + /// compatible with the keys that are being asked for. + pub fn try_create_buffer_access( + &mut self, + buffers: &BufferMap, + ) -> Result, IncompatibleLayout> + where + T: 'static + Send + Sync, + { + Keys::try_buffer_access(buffers, self) } /// Collect incoming workflow threads into a container. @@ -385,15 +393,10 @@ impl<'w, 's, 'a> Builder<'w, 's, 'a> { build: impl FnOnce(Scope, (), ()>, &mut Builder) -> Settings, ) where B: Bufferable, - B::BufferType: 'static + Send + Sync, - BufferKeys: 'static + Send + Sync, + B::BufferType: Accessing, Settings: Into, { - self.on_cleanup_if( - CleanupWorkflowConditions::always_if(true, true), - from_buffers, - build, - ) + from_buffers.into_buffer(self).on_cleanup(self, build); } /// Define a cleanup workflow that only gets run if the scope was cancelled. @@ -415,15 +418,10 @@ impl<'w, 's, 'a> Builder<'w, 's, 'a> { build: impl FnOnce(Scope, (), ()>, &mut Builder) -> Settings, ) where B: Bufferable, - B::BufferType: 'static + Send + Sync, - BufferKeys: 'static + Send + Sync, + B::BufferType: Accessing, Settings: Into, { - self.on_cleanup_if( - CleanupWorkflowConditions::always_if(false, true), - from_buffers, - build, - ) + from_buffers.into_buffer(self).on_cancel(self, build); } /// Define a cleanup workflow that only gets run if the scope was successfully @@ -439,15 +437,10 @@ impl<'w, 's, 'a> Builder<'w, 's, 'a> { build: impl FnOnce(Scope, (), ()>, &mut Builder) -> Settings, ) where B: Bufferable, - B::BufferType: 'static + Send + Sync, - BufferKeys: 'static + Send + Sync, + B::BufferType: Accessing, Settings: Into, { - self.on_cleanup_if( - CleanupWorkflowConditions::always_if(true, false), - from_buffers, - build, - ) + from_buffers.into_buffer(self).on_terminate(self, build); } /// Define a sub-workflow that will be run when this workflow is being cleaned @@ -460,31 +453,12 @@ impl<'w, 's, 'a> Builder<'w, 's, 'a> { build: impl FnOnce(Scope, (), ()>, &mut Builder) -> Settings, ) where B: Bufferable, - B::BufferType: 'static + Send + Sync, - BufferKeys: 'static + Send + Sync, + B::BufferType: Accessing, Settings: Into, { - let cancelling_scope_id = self.commands.spawn(()).id(); - let _ = self.create_scope_impl::, (), (), Settings>( - cancelling_scope_id, - self.finish_scope_cancel, - build, - ); - - let begin_cancel = self.commands.spawn(()).set_parent(self.scope).id(); - let buffers = from_buffers.into_buffer(self); - buffers.verify_scope(self.scope); - self.commands.add(AddOperation::new( - None, - begin_cancel, - BeginCleanupWorkflow::::new( - self.scope, - buffers, - cancelling_scope_id, - conditions.run_on_terminate, - conditions.run_on_cancel, - ), - )); + from_buffers + .into_buffer(self) + .on_cleanup_if(self, conditions, build); } /// Create a node that trims (cancels) other nodes in the workflow when it @@ -525,7 +499,6 @@ impl<'w, 's, 'a> Builder<'w, 's, 'a> { pub fn create_gate(&mut self, buffers: B) -> Node, T> where B: Bufferable, - B::BufferType: 'static + Send + Sync, T: 'static + Send + Sync, { let buffers = buffers.into_buffer(self); @@ -555,7 +528,6 @@ impl<'w, 's, 'a> Builder<'w, 's, 'a> { pub fn create_gate_action(&mut self, action: Gate, buffers: B) -> Node where B: Bufferable, - B::BufferType: 'static + Send + Sync, T: 'static + Send + Sync, { let buffers = buffers.into_buffer(self); @@ -582,7 +554,6 @@ impl<'w, 's, 'a> Builder<'w, 's, 'a> { pub fn create_gate_open(&mut self, buffers: B) -> Node where B: Bufferable, - B::BufferType: 'static + Send + Sync, T: 'static + Send + Sync, { self.create_gate_action(Gate::Open, buffers) @@ -594,7 +565,6 @@ impl<'w, 's, 'a> Builder<'w, 's, 'a> { pub fn create_gate_close(&mut self, buffers: B) -> Node where B: Bufferable, - B::BufferType: 'static + Send + Sync, T: 'static + Send + Sync, { self.create_gate_action(Gate::Closed, buffers) @@ -696,8 +666,8 @@ impl<'w, 's, 'a> Builder<'w, 's, 'a> { /// later without breaking API. #[derive(Clone)] pub struct CleanupWorkflowConditions { - run_on_terminate: bool, - run_on_cancel: bool, + pub(crate) run_on_terminate: bool, + pub(crate) run_on_cancel: bool, } impl CleanupWorkflowConditions { diff --git a/src/chain.rs b/src/chain.rs index 7fa5ffa5..fbdee9f6 100644 --- a/src/chain.rs +++ b/src/chain.rs @@ -24,12 +24,12 @@ use smallvec::SmallVec; use std::error::Error; use crate::{ - make_option_branching, make_result_branching, AddOperation, AsMap, Buffer, BufferKey, - BufferKeys, Bufferable, Buffered, Builder, Collect, CreateCancelFilter, CreateDisposalFilter, - ForkTargetStorage, Gate, GateRequest, InputSlot, IntoAsyncMap, IntoBlockingCallback, - IntoBlockingMap, Node, Noop, OperateBufferAccess, OperateDynamicGate, OperateSplit, - OperateStaticGate, Output, ProvideOnce, Provider, Scope, ScopeSettings, Sendish, Service, - Spread, StreamOf, StreamPack, StreamTargetMap, Trim, TrimBranch, UnusedTarget, + make_option_branching, make_result_branching, Accessing, AddOperation, AsMap, Buffer, + BufferKey, BufferKeys, Bufferable, Buffering, Builder, Collect, CreateCancelFilter, + CreateDisposalFilter, ForkTargetStorage, Gate, GateRequest, InputSlot, IntoAsyncMap, + IntoBlockingCallback, IntoBlockingMap, Node, Noop, OperateBufferAccess, OperateDynamicGate, + OperateSplit, OperateStaticGate, Output, ProvideOnce, Provider, Scope, ScopeSettings, Sendish, + Service, Spread, StreamOf, StreamPack, StreamTargetMap, Trim, TrimBranch, UnusedTarget, }; pub mod fork_clone_builder; @@ -298,12 +298,11 @@ impl<'w, 's, 'a, 'b, T: 'static + Send + Sync> Chain<'w, 's, 'a, 'b, T> { /// will be transformed into a buffer with default buffer settings. /// /// To obtain a set of buffer keys each time a buffer is modified, use - /// [`listen`](crate::Bufferable::listen). + /// [`listen`](crate::Accessible::listen). pub fn with_access(self, buffers: B) -> Chain<'w, 's, 'a, 'b, (T, BufferKeys)> where B: Bufferable, - B::BufferType: 'static + Send + Sync, - BufferKeys: 'static + Send + Sync, + B::BufferType: Accessing, { let buffers = buffers.into_buffer(self.builder); buffers.verify_scope(self.builder.scope); @@ -324,8 +323,7 @@ impl<'w, 's, 'a, 'b, T: 'static + Send + Sync> Chain<'w, 's, 'a, 'b, T> { pub fn then_access(self, buffers: B) -> Chain<'w, 's, 'a, 'b, BufferKeys> where B: Bufferable, - B::BufferType: 'static + Send + Sync, - BufferKeys: 'static + Send + Sync, + B::BufferType: Accessing, { self.with_access(buffers).map_block(|(_, key)| key) } @@ -393,7 +391,7 @@ impl<'w, 's, 'a, 'b, T: 'static + Send + Sync> Chain<'w, 's, 'a, 'b, T> { /// The return values of the individual chain builders will be zipped into /// one tuple return value by this function. If all of the builders return /// [`Output`] then you can easily continue chaining more operations using - /// [`join`](crate::Bufferable::join), or destructure them into individual + /// [`join`](crate::Joinable::join), or destructure them into individual /// outputs that you can continue to build with. pub fn fork_clone>(self, build: Build) -> Build::Outputs where @@ -546,7 +544,7 @@ impl<'w, 's, 'a, 'b, T: 'static + Send + Sync> Chain<'w, 's, 'a, 'b, T> { /// If the buffer is broken (e.g. its operation has been despawned) the /// workflow will be cancelled. pub fn then_push(self, buffer: Buffer) -> Chain<'w, 's, 'a, 'b, ()> { - assert_eq!(self.scope(), buffer.scope); + assert_eq!(self.scope(), buffer.scope()); self.with_access(buffer) .then(push_into_buffer.into_blocking_callback()) .cancel_on_err() @@ -557,7 +555,6 @@ impl<'w, 's, 'a, 'b, T: 'static + Send + Sync> Chain<'w, 's, 'a, 'b, T> { pub fn then_gate_action(self, action: Gate, buffers: B) -> Chain<'w, 's, 'a, 'b, T> where B: Bufferable, - B::BufferType: 'static + Send + Sync, { let buffers = buffers.into_buffer(self.builder); buffers.verify_scope(self.builder.scope); @@ -578,7 +575,6 @@ impl<'w, 's, 'a, 'b, T: 'static + Send + Sync> Chain<'w, 's, 'a, 'b, T> { pub fn then_gate_open(self, buffers: B) -> Chain<'w, 's, 'a, 'b, T> where B: Bufferable, - B::BufferType: 'static + Send + Sync, { self.then_gate_action(Gate::Open, buffers) } @@ -588,7 +584,6 @@ impl<'w, 's, 'a, 'b, T: 'static + Send + Sync> Chain<'w, 's, 'a, 'b, T> { pub fn then_gate_close(self, buffers: B) -> Chain<'w, 's, 'a, 'b, T> where B: Bufferable, - B::BufferType: 'static + Send + Sync, { self.then_gate_action(Gate::Closed, buffers) } diff --git a/src/diagram.rs b/src/diagram.rs index f6b70203..b4822a36 100644 --- a/src/diagram.rs +++ b/src/diagram.rs @@ -375,37 +375,7 @@ pub struct Diagram { } impl Diagram { - /// Spawns a workflow from this diagram. - /// - /// # Examples - /// - /// ``` - /// use bevy_impulse::{Diagram, DiagramError, NodeBuilderOptions, DiagramElementRegistry, RunCommandsOnWorldExt}; - /// - /// let mut app = bevy_app::App::new(); - /// let mut registry = DiagramElementRegistry::new(); - /// registry.register_node_builder(NodeBuilderOptions::new("echo".to_string()), |builder, _config: ()| { - /// builder.create_map_block(|msg: String| msg) - /// }); - /// - /// let json_str = r#" - /// { - /// "version": "0.1.0", - /// "start": "echo", - /// "ops": { - /// "echo": { - /// "type": "node", - /// "builder": "echo", - /// "next": { "builtin": "terminate" } - /// } - /// } - /// } - /// "#; - /// - /// let diagram = Diagram::from_json_str(json_str)?; - /// let workflow = app.world.command(|cmds| diagram.spawn_io_workflow(cmds, ®istry))?; - /// # Ok::<_, DiagramError>(()) - /// ``` + /// Implementation for [Self::spawn_io_workflow]. // TODO(koonpeng): Support streams other than `()` #43. /* pub */ fn spawn_workflow( @@ -447,7 +417,37 @@ impl Diagram { Ok(w) } - /// Wrapper to [spawn_workflow::<()>](Self::spawn_workflow). + /// Spawns a workflow from this diagram. + /// + /// # Examples + /// + /// ``` + /// use bevy_impulse::{Diagram, DiagramError, NodeBuilderOptions, DiagramElementRegistry, RunCommandsOnWorldExt}; + /// + /// let mut app = bevy_app::App::new(); + /// let mut registry = DiagramElementRegistry::new(); + /// registry.register_node_builder(NodeBuilderOptions::new("echo".to_string()), |builder, _config: ()| { + /// builder.create_map_block(|msg: String| msg) + /// }); + /// + /// let json_str = r#" + /// { + /// "version": "0.1.0", + /// "start": "echo", + /// "ops": { + /// "echo": { + /// "type": "node", + /// "builder": "echo", + /// "next": { "builtin": "terminate" } + /// } + /// } + /// } + /// "#; + /// + /// let diagram = Diagram::from_json_str(json_str)?; + /// let workflow = app.world.command(|cmds| diagram.spawn_io_workflow(cmds, ®istry))?; + /// # Ok::<_, DiagramError>(()) + /// ``` pub fn spawn_io_workflow( &self, cmds: &mut Commands, diff --git a/src/diagram/registration.rs b/src/diagram/registration.rs index 66afb59f..d0ce66e9 100644 --- a/src/diagram/registration.rs +++ b/src/diagram/registration.rs @@ -985,7 +985,7 @@ impl DiagramElementRegistry { /// Register a node builder with all the common operations (deserialize the /// request, serialize the response, and clone the response) enabled. /// - /// You will receive a [`RegistrationBuilder`] which you can then use to + /// You will receive a [`NodeRegistrationBuilder`] which you can then use to /// enable more operations around your node, such as fork result, split, /// or unzip. The data types of your node need to be suitable for those /// operations or else the compiler will not allow you to enable them. diff --git a/src/gate.rs b/src/gate.rs index 8d34c757..03797337 100644 --- a/src/gate.rs +++ b/src/gate.rs @@ -23,14 +23,14 @@ pub enum Gate { /// receive a wakeup immediately when a gate switches from closed to open, /// even if none of the data inside the buffer has changed. /// - /// [1]: crate::Bufferable::join + /// [1]: crate::Joinable::join Open, /// Close the buffer gate so that listeners (including [join][1] operations) /// will not be woken up when the data in the buffer gets modified. This /// effectively blocks the workflow nodes that are downstream of the buffer. /// Data will build up in the buffer according to its [`BufferSettings`][2]. /// - /// [1]: crate::Bufferable::join + /// [1]: crate::Joinable::join /// [2]: crate::BufferSettings Closed, } diff --git a/src/lib.rs b/src/lib.rs index d7e27ab8..ceab07f2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -72,6 +72,8 @@ pub use async_execution::Sendish; pub mod buffer; pub use buffer::*; +pub mod re_exports; + pub mod builder; pub use builder::*; @@ -148,6 +150,8 @@ pub use trim::*; use bevy_app::prelude::{App, Plugin, Update}; use bevy_ecs::prelude::{Entity, In}; +extern crate self as bevy_impulse; + /// Use `BlockingService` to indicate that your system is a blocking [`Service`]. /// /// A blocking service will have exclusive world access while it runs, which @@ -336,8 +340,10 @@ impl Plugin for ImpulsePlugin { pub mod prelude { pub use crate::{ buffer::{ - Buffer, BufferAccess, BufferAccessMut, BufferKey, BufferSettings, Bufferable, Buffered, - IterBufferable, RetentionPolicy, + Accessible, Accessor, AnyBuffer, AnyBufferKey, AnyBufferMut, AnyBufferWorldAccess, + AnyMessageBox, AsAnyBuffer, Buffer, BufferAccess, BufferAccessMut, BufferKey, + BufferMap, BufferMapLayout, BufferSettings, BufferWorldAccess, Bufferable, Buffering, + IncompatibleLayout, IterBufferable, Joinable, Joined, RetentionPolicy, }, builder::Builder, callback::{AsCallback, Callback, IntoAsyncCallback, IntoBlockingCallback}, @@ -362,4 +368,9 @@ pub mod prelude { BlockingCallback, BlockingCallbackInput, BlockingMap, BlockingService, BlockingServiceInput, ContinuousQuery, ContinuousService, ContinuousServiceInput, }; + + #[cfg(feature = "diagram")] + pub use crate::buffer::{ + JsonBuffer, JsonBufferKey, JsonBufferMut, JsonBufferWorldAccess, JsonMessage, + }; } diff --git a/src/operation/cleanup.rs b/src/operation/cleanup.rs index f4231825..8bac864e 100644 --- a/src/operation/cleanup.rs +++ b/src/operation/cleanup.rs @@ -16,7 +16,7 @@ */ use crate::{ - BufferAccessStorage, Buffered, ManageDisposal, ManageInput, MiscellaneousFailure, + Accessing, BufferAccessStorage, ManageDisposal, ManageInput, MiscellaneousFailure, OperationError, OperationResult, OperationRoster, OrBroken, ScopeStorage, UnhandledErrors, }; @@ -98,7 +98,7 @@ impl<'a> OperationCleanup<'a> { pub fn cleanup_buffer_access(&mut self) -> OperationResult where - B: Buffered + 'static + Send + Sync, + B: Accessing + 'static + Send + Sync, B::Key: 'static + Send + Sync, { let scope = self diff --git a/src/operation/join.rs b/src/operation/join.rs index 91d5f3e9..314e64b8 100644 --- a/src/operation/join.rs +++ b/src/operation/join.rs @@ -18,7 +18,7 @@ use bevy_ecs::prelude::{Component, Entity}; use crate::{ - Buffered, FunnelInputStorage, Input, InputBundle, ManageInput, Operation, OperationCleanup, + FunnelInputStorage, Input, InputBundle, Joining, ManageInput, Operation, OperationCleanup, OperationError, OperationReachability, OperationRequest, OperationResult, OperationSetup, OrBroken, ReachabilityResult, SingleInputStorage, SingleTargetStorage, }; @@ -37,7 +37,7 @@ impl Join { #[derive(Component)] struct BufferStorage(Buffers); -impl Operation for Join +impl Operation for Join where Buffers::Item: 'static + Send + Sync, { diff --git a/src/operation/listen.rs b/src/operation/listen.rs index 3a15ca82..378fe9f5 100644 --- a/src/operation/listen.rs +++ b/src/operation/listen.rs @@ -18,7 +18,7 @@ use bevy_ecs::prelude::Entity; use crate::{ - buffer_key_usage, get_access_keys, BufferAccessStorage, BufferKeyUsage, Buffered, + buffer_key_usage, get_access_keys, Accessing, BufferAccessStorage, BufferKeyUsage, FunnelInputStorage, Input, InputBundle, ManageInput, Operation, OperationCleanup, OperationReachability, OperationRequest, OperationResult, OperationSetup, OrBroken, ReachabilityResult, SingleInputStorage, SingleTargetStorage, @@ -37,7 +37,7 @@ impl Listen { impl Operation for Listen where - B: Buffered + 'static + Send + Sync, + B: Accessing + 'static + Send + Sync, B::Key: 'static + Send + Sync, { fn setup(self, OperationSetup { source, world }: OperationSetup) -> OperationResult { diff --git a/src/operation/operate_buffer_access.rs b/src/operation/operate_buffer_access.rs index b2936f01..07208410 100644 --- a/src/operation/operate_buffer_access.rs +++ b/src/operation/operate_buffer_access.rs @@ -25,7 +25,7 @@ use std::{ use smallvec::SmallVec; use crate::{ - BufferKeyBuilder, Buffered, ChannelQueue, Input, InputBundle, ManageInput, Operation, + Accessing, BufferKeyBuilder, ChannelQueue, Input, InputBundle, ManageInput, Operation, OperationCleanup, OperationError, OperationReachability, OperationRequest, OperationResult, OperationSetup, OrBroken, ReachabilityResult, ScopeStorage, SingleInputStorage, SingleTargetStorage, @@ -34,7 +34,7 @@ use crate::{ pub(crate) struct OperateBufferAccess where T: 'static + Send + Sync, - B: Buffered, + B: Accessing, { buffers: B, target: Entity, @@ -44,7 +44,7 @@ where impl OperateBufferAccess where T: 'static + Send + Sync, - B: Buffered, + B: Accessing, { pub(crate) fn new(buffers: B, target: Entity) -> Self { Self { @@ -59,12 +59,12 @@ where pub struct BufferKeyUsage(pub(crate) fn(Entity, Entity, &World) -> ReachabilityResult); #[derive(Component)] -pub(crate) struct BufferAccessStorage { +pub(crate) struct BufferAccessStorage { pub(crate) buffers: B, pub(crate) keys: HashMap, } -impl BufferAccessStorage { +impl BufferAccessStorage { pub(crate) fn new(buffers: B) -> Self { Self { buffers, @@ -76,7 +76,7 @@ impl BufferAccessStorage { impl Operation for OperateBufferAccess where T: 'static + Send + Sync, - B: Buffered + 'static + Send + Sync, + B: Accessing + 'static + Send + Sync, B::Key: 'static + Send + Sync, { fn setup(self, OperationSetup { source, world }: OperationSetup) -> OperationResult { @@ -138,7 +138,7 @@ pub(crate) fn get_access_keys( world: &mut World, ) -> Result where - B: Buffered + 'static + Send + Sync, + B: Accessing + 'static + Send + Sync, B::Key: 'static + Send + Sync, { let scope = world.get::(source).or_broken()?.get(); @@ -180,7 +180,7 @@ pub(crate) fn buffer_key_usage( world: &World, ) -> ReachabilityResult where - B: Buffered + 'static + Send + Sync, + B: Accessing + 'static + Send + Sync, B::Key: 'static + Send + Sync, { let key = world @@ -206,6 +206,12 @@ where pub(crate) struct BufferAccessors(pub(crate) SmallVec<[Entity; 8]>); impl BufferAccessors { + pub(crate) fn add_accessor(&mut self, accessor: Entity) { + self.0.push(accessor); + self.0.sort(); + self.0.dedup(); + } + pub(crate) fn is_reachable(r: &mut OperationReachability) -> ReachabilityResult { let Some(accessors) = r.world.get::(r.source) else { return Ok(false); diff --git a/src/operation/operate_gate.rs b/src/operation/operate_gate.rs index 9a677e37..0a45d449 100644 --- a/src/operation/operate_gate.rs +++ b/src/operation/operate_gate.rs @@ -18,7 +18,7 @@ use bevy_ecs::prelude::{Component, Entity}; use crate::{ - emit_disposal, Buffered, Disposal, Gate, GateRequest, Input, InputBundle, ManageInput, + emit_disposal, Buffering, Disposal, Gate, GateRequest, Input, InputBundle, ManageInput, Operation, OperationCleanup, OperationReachability, OperationRequest, OperationResult, OperationSetup, OrBroken, ReachabilityResult, SingleInputStorage, SingleTargetStorage, }; @@ -48,7 +48,7 @@ impl OperateDynamicGate { impl Operation for OperateDynamicGate where T: 'static + Send + Sync, - B: Buffered + 'static + Send + Sync, + B: Buffering + 'static + Send + Sync, { fn setup(self, OperationSetup { source, world }: OperationSetup) -> OperationResult { world @@ -144,7 +144,7 @@ impl OperateStaticGate { impl Operation for OperateStaticGate where - B: Buffered + 'static + Send + Sync, + B: Buffering + 'static + Send + Sync, T: 'static + Send + Sync, { fn setup(self, OperationSetup { source, world }: OperationSetup) -> OperationResult { diff --git a/src/operation/scope.rs b/src/operation/scope.rs index fe18821f..f757653b 100644 --- a/src/operation/scope.rs +++ b/src/operation/scope.rs @@ -16,15 +16,14 @@ */ use crate::{ - check_reachability, execute_operation, is_downstream_of, AddOperation, Blocker, - BufferKeyBuilder, Buffered, Cancel, Cancellable, Cancellation, Cleanup, CleanupContents, - ClearBufferFn, CollectMarker, DisposalListener, DisposalUpdate, FinalizeCleanup, - FinalizeCleanupRequest, Input, InputBundle, InspectDisposals, ManageCancellation, ManageInput, - Operation, OperationCancel, OperationCleanup, OperationError, OperationReachability, - OperationRequest, OperationResult, OperationRoster, OperationSetup, OrBroken, - ReachabilityResult, ScopeSettings, SingleInputStorage, SingleTargetStorage, Stream, StreamPack, - StreamRequest, StreamTargetMap, StreamTargetStorage, UnhandledErrors, Unreachability, - UnusedTarget, + check_reachability, execute_operation, is_downstream_of, Accessing, AddOperation, Blocker, + BufferKeyBuilder, Cancel, Cancellable, Cancellation, Cleanup, CleanupContents, ClearBufferFn, + CollectMarker, DisposalListener, DisposalUpdate, FinalizeCleanup, FinalizeCleanupRequest, + Input, InputBundle, InspectDisposals, ManageCancellation, ManageInput, Operation, + OperationCancel, OperationCleanup, OperationError, OperationReachability, OperationRequest, + OperationResult, OperationRoster, OperationSetup, OrBroken, ReachabilityResult, ScopeSettings, + SingleInputStorage, SingleTargetStorage, Stream, StreamPack, StreamRequest, StreamTargetMap, + StreamTargetStorage, UnhandledErrors, Unreachability, UnusedTarget, }; use backtrace::Backtrace; @@ -1125,7 +1124,7 @@ impl BeginCleanupWorkflow { impl Operation for BeginCleanupWorkflow where - B: Buffered + 'static + Send + Sync, + B: Accessing + 'static + Send + Sync, B::Key: 'static + Send + Sync, { fn setup(self, OperationSetup { source, world }: OperationSetup) -> OperationResult { diff --git a/src/re_exports.rs b/src/re_exports.rs new file mode 100644 index 00000000..84f22076 --- /dev/null +++ b/src/re_exports.rs @@ -0,0 +1,21 @@ +/* + * Copyright (C) 2025 Open Source Robotics Foundation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * +*/ + +//! This module contains symbols that are being re-exported so they can be used +//! by bevy_impulse_derive. + +pub use bevy_ecs::prelude::{Entity, World}; diff --git a/src/testing.rs b/src/testing.rs index 83d21dd9..f7ba4664 100644 --- a/src/testing.rs +++ b/src/testing.rs @@ -19,7 +19,7 @@ use bevy_app::ScheduleRunnerPlugin; pub use bevy_app::{App, Update}; use bevy_core::{FrameCountPlugin, TaskPoolPlugin, TypeRegistrationPlugin}; pub use bevy_ecs::{ - prelude::{Commands, Component, Entity, In, Local, Query, ResMut, Resource}, + prelude::{Commands, Component, Entity, In, Local, Query, ResMut, Resource, World}, system::{CommandQueue, IntoSystem}, }; use bevy_time::TimePlugin; @@ -32,10 +32,12 @@ pub use std::time::{Duration, Instant}; use smallvec::SmallVec; use crate::{ - flush_impulses, AddContinuousServicesExt, AsyncServiceInput, BlockingMap, BlockingServiceInput, - Builder, ContinuousQuery, ContinuousQueueView, ContinuousService, FlushParameters, - GetBufferedSessionsFn, Promise, RunCommandsOnWorldExt, Scope, Service, SpawnWorkflowExt, - StreamOf, StreamPack, UnhandledErrors, WorkflowSettings, + flush_impulses, Accessing, AddContinuousServicesExt, AnyBuffer, AsAnyBuffer, AsyncServiceInput, + BlockingMap, BlockingServiceInput, Buffer, BufferKey, BufferKeyLifecycle, Bufferable, + Buffering, Builder, ContinuousQuery, ContinuousQueueView, ContinuousService, FlushParameters, + GetBufferedSessionsFn, Joining, OperationError, OperationResult, OperationRoster, Promise, + RunCommandsOnWorldExt, Scope, Service, SpawnWorkflowExt, StreamOf, StreamPack, UnhandledErrors, + WorkflowSettings, }; pub struct TestingContext { @@ -478,3 +480,104 @@ pub struct TestComponent; pub struct Integer { pub value: i32, } + +/// This is an ordinary buffer newtype whose only purpose is to test the +/// #[joined(noncopy_buffer)] feature. We intentionally do not implement +/// the Copy trait for it. +pub struct NonCopyBuffer { + inner: Buffer, +} + +impl NonCopyBuffer { + pub fn register_downcast() { + let any_interface = AnyBuffer::interface_for::(); + any_interface.register_buffer_downcast( + std::any::TypeId::of::>(), + Box::new(|location| { + Box::new(NonCopyBuffer:: { + inner: Buffer { + location, + _ignore: Default::default(), + }, + }) + }), + ); + } +} + +impl Clone for NonCopyBuffer { + fn clone(&self) -> Self { + Self { inner: self.inner } + } +} + +impl AsAnyBuffer for NonCopyBuffer { + fn as_any_buffer(&self) -> AnyBuffer { + self.inner.as_any_buffer() + } +} + +impl Bufferable for NonCopyBuffer { + type BufferType = Self; + fn into_buffer(self, _builder: &mut Builder) -> Self::BufferType { + self + } +} + +impl Buffering for NonCopyBuffer { + fn add_listener(&self, listener: Entity, world: &mut World) -> OperationResult { + self.inner.add_listener(listener, world) + } + + fn as_input(&self) -> smallvec::SmallVec<[Entity; 8]> { + self.inner.as_input() + } + + fn buffered_count(&self, session: Entity, world: &World) -> Result { + self.inner.buffered_count(session, world) + } + + fn ensure_active_session(&self, session: Entity, world: &mut World) -> OperationResult { + self.inner.ensure_active_session(session, world) + } + + fn gate_action( + &self, + session: Entity, + action: crate::Gate, + world: &mut World, + roster: &mut OperationRoster, + ) -> OperationResult { + self.inner.gate_action(session, action, world, roster) + } + + fn verify_scope(&self, scope: Entity) { + self.inner.verify_scope(scope); + } +} + +impl Joining for NonCopyBuffer { + type Item = T; + fn pull(&self, session: Entity, world: &mut World) -> Result { + self.inner.pull(session, world) + } +} + +impl Accessing for NonCopyBuffer { + type Key = BufferKey; + fn add_accessor(&self, accessor: Entity, world: &mut World) -> OperationResult { + self.inner.add_accessor(accessor, world) + } + + fn create_key(&self, builder: &crate::BufferKeyBuilder) -> Self::Key { + self.inner.create_key(builder) + } + + fn deep_clone_key(key: &Self::Key) -> Self::Key { + key.deep_clone() + } + + fn is_key_in_use(key: &Self::Key) -> bool { + key.is_in_use() + } +} From 7f3de4ff1bf0d73d91021204ddf927d85c9afa62 Mon Sep 17 00:00:00 2001 From: Teo Koon Peng Date: Thu, 27 Mar 2025 10:02:42 +0800 Subject: [PATCH 14/20] Support buffers and loops in diagrams (#51) (#63) Signed-off-by: Michael X. Grey Signed-off-by: Teo Koon Peng Co-authored-by: Michael X. Grey --- diagram.schema.json | 263 +++- macros/src/buffer.rs | 2 +- src/buffer.rs | 14 +- src/buffer/any_buffer.rs | 26 +- src/buffer/buffer_map.rs | 30 + src/buffer/json_buffer.rs | 138 +- src/builder.rs | 123 +- src/cancel.rs | 26 + src/chain.rs | 61 +- src/diagram.rs | 799 +++++++++-- src/diagram/buffer_schema.rs | 1008 ++++++++++++++ src/diagram/fork_clone.rs | 134 -- src/diagram/fork_clone_schema.rs | 192 +++ src/diagram/fork_result.rs | 133 -- src/diagram/fork_result_schema.rs | 179 +++ src/diagram/impls.rs | 25 - src/diagram/join.rs | 306 ----- src/diagram/join_schema.rs | 535 ++++++++ src/diagram/node_schema.rs | 51 + src/diagram/registration.rs | 1199 +++++++++++------ src/diagram/serialization.rs | 395 +++++- .../{split_serialized.rs => split_schema.rs} | 168 ++- src/diagram/supported.rs | 32 + src/diagram/testing.rs | 48 +- .../{transform.rs => transform_schema.rs} | 122 +- src/diagram/type_info.rs | 52 + src/diagram/{unzip.rs => unzip_schema.rs} | 150 ++- src/diagram/workflow_builder.rs | 1094 +++++++++------ src/disposal.rs | 12 +- src/node.rs | 18 + src/operation.rs | 12 + src/operation/operate_cancel.rs | 121 ++ 32 files changed, 5530 insertions(+), 1938 deletions(-) create mode 100644 src/diagram/buffer_schema.rs delete mode 100644 src/diagram/fork_clone.rs create mode 100644 src/diagram/fork_clone_schema.rs delete mode 100644 src/diagram/fork_result.rs create mode 100644 src/diagram/fork_result_schema.rs delete mode 100644 src/diagram/impls.rs delete mode 100644 src/diagram/join.rs create mode 100644 src/diagram/join_schema.rs create mode 100644 src/diagram/node_schema.rs rename src/diagram/{split_serialized.rs => split_schema.rs} (84%) create mode 100644 src/diagram/supported.rs rename src/diagram/{transform.rs => transform_schema.rs} (55%) create mode 100644 src/diagram/type_info.rs rename src/diagram/{unzip.rs => unzip_schema.rs} (60%) create mode 100644 src/operation/operate_cancel.rs diff --git a/diagram.schema.json b/diagram.schema.json index 8a8c9a9b..ae9b7796 100644 --- a/diagram.schema.json +++ b/diagram.schema.json @@ -8,14 +8,27 @@ "version" ], "properties": { + "on_implicit_error": { + "description": "To simplify diagram definitions, the diagram workflow builder will sometimes insert implicit operations into the workflow, such as implicit serializing and deserializing. These implicit operations may be fallible.\n\nThis field indicates how a failed implicit operation should be handled. If left unspecified, an implicit error will cause the entire workflow to be cancelled.", + "default": null, + "anyOf": [ + { + "$ref": "#/definitions/NextOperation" + }, + { + "type": "null" + } + ] + }, "ops": { + "description": "Operations that define the workflow", "type": "object", "additionalProperties": { "$ref": "#/definitions/DiagramOperation" } }, "start": { - "description": "Signifies the start of a workflow.", + "description": "Indicates where the workflow should start running.", "allOf": [ { "$ref": "#/definitions/NextOperation" @@ -28,16 +41,41 @@ } }, "definitions": { - "BuiltinSource": { - "type": "string", - "enum": [ - "start" + "BufferInputs": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "object", + "additionalProperties": { + "type": "string" + } + }, + { + "type": "array", + "items": { + "type": "string" + } + } ] }, + "BufferSettings": { + "description": "Settings to describe the behavior of a buffer.", + "type": "object", + "required": [ + "retention" + ], + "properties": { + "retention": { + "$ref": "#/definitions/RetentionPolicy" + } + } + }, "BuiltinTarget": { "oneOf": [ { - "description": "Use the output to terminate the workflow. This will be the return value of the workflow.", + "description": "Use the output to terminate the current scope. The value passed into this operation will be the return value of the scope.", "type": "string", "enum": [ "terminate" @@ -49,13 +87,20 @@ "enum": [ "dispose" ] + }, + { + "description": "When triggered, cancel the current scope. If this is an inner scope of a workflow then the parent scope will see a disposal happen. If this is the root scope of a workflow then the whole workflow will cancel.", + "type": "string", + "enum": [ + "cancel" + ] } ] }, "DiagramOperation": { "oneOf": [ { - "description": "Connect the request to a registered node.\n\n``` # bevy_impulse::Diagram::from_json_str(r#\" { \"version\": \"0.1.0\", \"start\": \"node_op\", \"ops\": { \"node_op\": { \"type\": \"node\", \"builder\": \"my_node_builder\", \"next\": { \"builtin\": \"terminate\" } } } } # \"#)?; # Ok::<_, serde_json::Error>(())", + "description": "Create an operation that that takes an input message and produces an output message.\n\nThe behavior is determined by the choice of node `builder` and optioanlly the `config` that you provide. Each type of node builder has its own schema for the config.\n\nThe output message will be sent to the operation specified by `next`.\n\nTODO(@mxgrey): [Support stream outputs](https://github.com/open-rmf/bevy_impulse/issues/43)\n\n# Examples ``` # bevy_impulse::Diagram::from_json_str(r#\" { \"version\": \"0.1.0\", \"start\": \"cutting_board\", \"ops\": { \"cutting_board\": { \"type\": \"node\", \"builder\": \"chop\", \"config\": \"diced\", \"next\": \"bowl\" }, \"bowl\": { \"type\": \"node\", \"builder\": \"stir\", \"next\": \"oven\" }, \"oven\": { \"type\": \"node\", \"builder\": \"bake\", \"config\": { \"temperature\": 200, \"duration\": 120 }, \"next\": { \"builtin\": \"terminate\" } } } } # \"#)?; # Ok::<_, serde_json::Error>(())", "type": "object", "required": [ "builder", @@ -81,7 +126,7 @@ } }, { - "description": "If the request is cloneable, clone it into multiple responses.\n\n# Examples ``` # bevy_impulse::Diagram::from_json_str(r#\" { \"version\": \"0.1.0\", \"start\": \"fork_clone\", \"ops\": { \"fork_clone\": { \"type\": \"fork_clone\", \"next\": [\"terminate\"] } } } # \"#)?; # Ok::<_, serde_json::Error>(())", + "description": "If the request is cloneable, clone it into multiple responses that can each be sent to a different operation. The `next` property is an array.\n\nThis creates multiple simultaneous branches of execution within the workflow. Usually when you have multiple branches you will either * race - connect all branches to `terminate` and the first branch to finish \"wins\" the race and gets to the be output * join - connect each branch into a buffer and then use the `join` operation to reunite them * collect - TODO(@mxgrey): [add the collect operation](https://github.com/open-rmf/bevy_impulse/issues/59)\n\n# Examples ``` # bevy_impulse::Diagram::from_json_str(r#\" { \"version\": \"0.1.0\", \"start\": \"begin_race\", \"ops\": { \"begin_race\": { \"type\": \"fork_clone\", \"next\": [ \"ferrari\", \"mustang\" ] }, \"ferrari\": { \"type\": \"node\", \"builder\": \"drive\", \"config\": \"ferrari\", \"next\": { \"builtin\": \"terminate\" } }, \"mustang\": { \"type\": \"node\", \"builder\": \"drive\", \"config\": \"mustang\", \"next\": { \"builtin\": \"terminate\" } } } } # \"#)?; # Ok::<_, serde_json::Error>(())", "type": "object", "required": [ "next", @@ -103,7 +148,7 @@ } }, { - "description": "If the request is a tuple of (T1, T2, T3, ...), unzip it into multiple responses of T1, T2, T3, ...\n\n# Examples ``` # bevy_impulse::Diagram::from_json_str(r#\" { \"version\": \"0.1.0\", \"start\": \"unzip\", \"ops\": { \"unzip\": { \"type\": \"unzip\", \"next\": [{ \"builtin\": \"terminate\" }] } } } # \"#)?; # Ok::<_, serde_json::Error>(())", + "description": "If the input message is a tuple of (T1, T2, T3, ...), unzip it into multiple output messages of T1, T2, T3, ...\n\nEach output message may have a different type and can be sent to a different operation. This creates multiple simultaneous branches of execution within the workflow. See [`DiagramOperation::ForkClone`] for more information on parallel branches.\n\n# Examples ``` # bevy_impulse::Diagram::from_json_str(r#\" { \"version\": \"0.1.0\", \"start\": \"name_phone_address\", \"ops\": { \"name_phone_address\": { \"type\": \"unzip\", \"next\": [ \"process_name\", \"process_phone_number\", \"process_address\" ] }, \"process_name\": { \"type\": \"node\", \"builder\": \"process_name\", \"next\": \"name_processed\" }, \"process_phone_number\": { \"type\": \"node\", \"builder\": \"process_phone_number\", \"next\": \"phone_number_processed\" }, \"process_address\": { \"type\": \"node\", \"builder\": \"process_address\", \"next\": \"address_processed\" }, \"name_processed\": { \"type\": \"buffer\" }, \"phone_number_processed\": { \"type\": \"buffer\" }, \"address_processed\": { \"type\": \"buffer\" }, \"finished\": { \"type\": \"join\", \"buffers\": [ \"name_processed\", \"phone_number_processed\", \"address_processed\" ], \"next\": { \"builtin\": \"terminate\" } } } } # \"#)?; # Ok::<_, serde_json::Error>(())", "type": "object", "required": [ "next", @@ -125,7 +170,7 @@ } }, { - "description": "If the request is a `Result<_, _>`, branch it to `Ok` and `Err`.\n\n# Examples ``` # bevy_impulse::Diagram::from_json_str(r#\" { \"version\": \"0.1.0\", \"start\": \"fork_result\", \"ops\": { \"fork_result\": { \"type\": \"fork_result\", \"ok\": { \"builtin\": \"terminate\" }, \"err\": { \"builtin\": \"dispose\" } } } } # \"#)?; # Ok::<_, serde_json::Error>(())", + "description": "If the request is a [`Result`], send the output message down an `ok` branch or down an `err` branch depending on whether the result has an [`Ok`] or [`Err`] value. The `ok` branch will receive a `T` while the `err` branch will receive an `E`.\n\nOnly one branch will be activated by each input message that enters the operation.\n\n# Examples ``` # bevy_impulse::Diagram::from_json_str(r#\" { \"version\": \"0.1.0\", \"start\": \"fork_result\", \"ops\": { \"fork_result\": { \"type\": \"fork_result\", \"ok\": { \"builtin\": \"terminate\" }, \"err\": { \"builtin\": \"dispose\" } } } } # \"#)?; # Ok::<_, serde_json::Error>(())", "type": "object", "required": [ "err", @@ -148,7 +193,7 @@ } }, { - "description": "If the request is a list-like or map-like object, split it into multiple responses. Note that the split output is a tuple of `(KeyOrIndex, Value)`, nodes receiving a split output should have request of that type instead of just the value type.\n\n# Examples ``` # bevy_impulse::Diagram::from_json_str(r#\" { \"version\": \"0.1.0\", \"start\": \"split\", \"ops\": { \"split\": { \"type\": \"split\", \"index\": [{ \"builtin\": \"terminate\" }] } } } # \"#)?; # Ok::<_, serde_json::Error>(()) ```", + "description": "If the input message is a list-like or map-like object, split it into multiple output messages.\n\nNote that the type of output message from the split depends on how the input message implements the [`Splittable`][1] trait. In many cases this will be a tuple of `(key, value)`.\n\nThere are three ways to specify where the split output messages should go, and all can be used at the same time: * `sequential` - For array-like collections, send the \"first\" element of the collection to the first operation listed in the `sequential` array. The \"second\" element of the collection goes to the second operation listed in the `sequential` array. And so on for all elements in the collection. If one of the elements in the collection is mentioned in the `keyed` set, then the sequence will pass over it as if the element does not exist at all. * `keyed` - For map-like collections, send the split element associated with the specified key to its associated output. * `remaining` - Any elements that are were not captured by `sequential` or by `keyed` will be sent to this.\n\n[1]: crate::Splittable\n\n# Examples\n\nSuppose I am an animal rescuer sorting through a new collection of animals that need recuing. My home has space for three exotic animals plus any number of dogs and cats.\n\nI have a custom `SpeciesCollection` data structure that implements [`Splittable`][1] by allowing you to key on the type of animal.\n\nIn the workflow below, we send all cats and dogs to `home`, and we also send the first three non-dog and non-cat species to `home`. All remaining animals go to the zoo.\n\n``` # bevy_impulse::Diagram::from_json_str(r#\" { \"version\": \"0.1.0\", \"start\": \"select_animals\", \"ops\": { \"select_animals\": { \"type\": \"split\", \"sequential\": [ \"home\", \"home\", \"home\" ], \"keyed\": { \"cat\": \"home\", \"dog\": \"home\" }, \"remaining\": \"zoo\" } } } # \"#)?; # Ok::<_, serde_json::Error>(()) ```\n\nIf we input `[\"frog\", \"cat\", \"bear\", \"beaver\", \"dog\", \"rabbit\", \"dog\", \"monkey\"]` then `frog`, `bear`, and `beaver` will be sent to `home` since those are the first three animals that are not `dog` or `cat`, and we will also send one `cat` and two `dog` home. `rabbit` and `monkey` will be sent to the zoo.", "type": "object", "required": [ "type" @@ -187,28 +232,29 @@ } }, { - "description": "Wait for an item to be emitted from each of the inputs, then combined the oldest of each into an array.\n\n# Examples ``` # bevy_impulse::Diagram::from_json_str(r#\" { \"version\": \"0.1.0\", \"start\": \"split\", \"ops\": { \"split\": { \"type\": \"split\", \"index\": [\"op1\", \"op2\"] }, \"op1\": { \"type\": \"node\", \"builder\": \"foo\", \"next\": \"join\" }, \"op2\": { \"type\": \"node\", \"builder\": \"bar\", \"next\": \"join\" }, \"join\": { \"type\": \"join\", \"inputs\": [\"op1\", \"op2\"], \"next\": { \"builtin\": \"terminate\" } } } } # \"#)?; # Ok::<_, serde_json::Error>(()) ```", + "description": "Wait for exactly one item to be available in each buffer listed in `buffers`, then join each of those items into a single output message that gets sent to `next`.\n\nIf the `next` operation is not a `node` type (e.g. `fork_clone`) then you must specify a `target_node` so that the diagram knows what data structure to join the values into.\n\nThe output message type must be registered as joinable at compile time. If you want to join into a dynamic data structure then you should use [`DiagramOperation::SerializedJoin`] instead.\n\n# Examples ``` # bevy_impulse::Diagram::from_json_str(r#\" { \"version\": \"0.1.0\", \"start\": \"fork_measuring\", \"ops\": { \"fork_measuring\": { \"type\": \"fork_clone\", \"next\": [\"localize\", \"imu\"] }, \"localize\": { \"type\": \"node\", \"builder\": \"localize\", \"next\": \"estimated_position\" }, \"imu\": { \"type\": \"node\", \"builder\": \"imu\", \"config\": \"velocity\", \"next\": \"estimated_velocity\" }, \"estimated_position\": { \"type\": \"buffer\" }, \"estimated_velocity\": { \"type\": \"buffer\" }, \"gather_state\": { \"type\": \"join\", \"buffers\": { \"position\": \"estimate_position\", \"velocity\": \"estimate_velocity\" }, \"next\": \"report_state\" }, \"report_state\": { \"type\": \"node\", \"builder\": \"publish_state\", \"next\": { \"builtin\": \"terminate\" } } } } # \"#)?; # Ok::<_, serde_json::Error>(()) ```", "type": "object", "required": [ - "inputs", + "buffers", "next", "type" ], "properties": { - "inputs": { - "description": "Controls the order of the resulting join. Each item must be an operation id of one of the incoming outputs.", - "type": "array", - "items": { - "$ref": "#/definitions/SourceOperation" - } + "buffers": { + "description": "Map of buffer keys and buffers.", + "allOf": [ + { + "$ref": "#/definitions/BufferInputs" + } + ] }, "next": { "$ref": "#/definitions/NextOperation" }, - "no_serialize": { - "description": "Do not serialize before performing the join. If true, joins can only be done on outputs of the same type.", + "target_node": { + "description": "The id of an operation that this operation is for. The id must be a `node` operation. Optional if `next` is a node operation.", "type": [ - "boolean", + "string", "null" ] }, @@ -221,7 +267,35 @@ } }, { - "description": "If the request is serializable, transform it by running it through a [CEL](https://cel.dev/) program. The context includes a \"request\" variable which contains the request.\n\n# Examples ``` # bevy_impulse::Diagram::from_json_str(r#\" { \"version\": \"0.1.0\", \"start\": \"transform\", \"ops\": { \"transform\": { \"type\": \"transform\", \"cel\": \"request.name\", \"next\": { \"builtin\": \"terminate\" } } } } # \"#)?; # Ok::<_, serde_json::Error>(()) ```\n\nNote that due to how `serde_json` performs serialization, positive integers are always serialized as unsigned. In CEL, You can't do an operation between unsigned and signed so it is recommended to always perform explicit casts.\n\n# Examples ``` # bevy_impulse::Diagram::from_json_str(r#\" { \"version\": \"0.1.0\", \"start\": \"transform\", \"ops\": { \"transform\": { \"type\": \"transform\", \"cel\": \"int(request.score) * 3\", \"next\": { \"builtin\": \"terminate\" } } } } # \"#)?; # Ok::<_, serde_json::Error>(()) ```", + "description": "Same as [`DiagramOperation::Join`] but all input messages must be serializable, and the output message will always be [`serde_json::Value`].\n\nIf you use an array for `buffers` then the output message will be a [`serde_json::Value::Array`]. If you use a map for `buffers` then the output message will be a [`serde_json::Value::Object`].\n\nUnlike [`DiagramOperation::Join`], the `target_node` property does not exist for this schema.", + "type": "object", + "required": [ + "buffers", + "next", + "type" + ], + "properties": { + "buffers": { + "description": "Map of buffer keys and buffers.", + "allOf": [ + { + "$ref": "#/definitions/BufferInputs" + } + ] + }, + "next": { + "$ref": "#/definitions/NextOperation" + }, + "type": { + "type": "string", + "enum": [ + "serialized_join" + ] + } + } + }, + { + "description": "If the request is serializable, transform it by running it through a [CEL](https://cel.dev/) program. The context includes a \"request\" variable which contains the input message.\n\n# Examples ``` # bevy_impulse::Diagram::from_json_str(r#\" { \"version\": \"0.1.0\", \"start\": \"transform\", \"ops\": { \"transform\": { \"type\": \"transform\", \"cel\": \"request.name\", \"next\": { \"builtin\": \"terminate\" } } } } # \"#)?; # Ok::<_, serde_json::Error>(()) ```\n\nNote that due to how `serde_json` performs serialization, positive integers are always serialized as unsigned. In CEL, You can't do an operation between unsigned and signed so it is recommended to always perform explicit casts.\n\n# Examples ``` # bevy_impulse::Diagram::from_json_str(r#\" { \"version\": \"0.1.0\", \"start\": \"transform\", \"ops\": { \"transform\": { \"type\": \"transform\", \"cel\": \"int(request.score) * 3\", \"next\": { \"builtin\": \"terminate\" } } } } # \"#)?; # Ok::<_, serde_json::Error>(()) ```", "type": "object", "required": [ "cel", @@ -235,6 +309,18 @@ "next": { "$ref": "#/definitions/NextOperation" }, + "on_error": { + "description": "Specify what happens if an error occurs during the transformation. If you specify a target for on_error, then an error message will be sent to that target. You can set this to `{ \"builtin\": \"dispose\" }` to simply ignore errors.\n\nIf left unspecified, a failure will be treated like an implicit operation failure and behave according to `on_implicit_error`.", + "default": null, + "anyOf": [ + { + "$ref": "#/definitions/NextOperation" + }, + { + "type": "null" + } + ] + }, "type": { "type": "string", "enum": [ @@ -244,16 +330,105 @@ } }, { - "description": "Drop the request, equivalent to a no-op.", + "description": "Create a [`Buffer`][1] which can be used to store and pull data within a scope.\n\nBy default the [`BufferSettings`][2] will keep the single last message pushed to the buffer. You can change that with the optional `settings` property.\n\nUse the `\"serialize\": true` option to serialize the messages into [`JsonMessage`] before they are inserted into the buffer. This allows any serializable message type to be pushed into the buffer. If left unspecified, the buffer will store the specific data type that gets pushed into it. If the buffer inputs are not being serialized, then all incoming messages being pushed into the buffer must have the same type.\n\n[1]: crate::Buffer [2]: crate::BufferSettings\n\n# Examples ``` # bevy_impulse::Diagram::from_json_str(r#\" { \"version\": \"0.1.0\", \"start\": \"fork_clone\", \"ops\": { \"fork_clone\": { \"type\": \"fork_clone\", \"next\": [\"num_output\", \"string_output\", \"all_num_buffer\", \"serialized_num_buffer\"] }, \"num_output\": { \"type\": \"node\", \"builder\": \"num_output\", \"next\": \"buffer_access\" }, \"string_output\": { \"type\": \"node\", \"builder\": \"string_output\", \"next\": \"string_buffer\" }, \"string_buffer\": { \"type\": \"buffer\", \"settings\": { \"retention\": { \"keep_last\": 10 } } }, \"all_num_buffer\": { \"type\": \"buffer\", \"settings\": { \"retention\": \"keep_all\" } }, \"serialized_num_buffer\": { \"type\": \"buffer\", \"serialize\": true }, \"buffer_access\": { \"type\": \"buffer_access\", \"buffers\": [\"string_buffer\"], \"target_node\": \"with_buffer_access\", \"next\": \"with_buffer_access\" }, \"with_buffer_access\": { \"type\": \"node\", \"builder\": \"with_buffer_access\", \"next\": { \"builtin\": \"terminate\" } } } } # \"#)?; # Ok::<_, serde_json::Error>(()) ```", "type": "object", "required": [ "type" ], "properties": { + "serialize": { + "description": "If true, messages will be serialized before sending into the buffer.", + "type": [ + "boolean", + "null" + ] + }, + "settings": { + "default": { + "retention": { + "keep_last": 1 + } + }, + "allOf": [ + { + "$ref": "#/definitions/BufferSettings" + } + ] + }, + "type": { + "type": "string", + "enum": [ + "buffer" + ] + } + } + }, + { + "description": "Zip a message together with access to one or more buffers.\n\nThe receiving node must have an input type of `(Message, Keys)` where `Keys` implements the [`Accessor`][1] trait.\n\n[1]: crate::Accessor\n\n# Examples ``` # bevy_impulse::Diagram::from_json_str(r#\" { \"version\": \"0.1.0\", \"start\": \"fork_clone\", \"ops\": { \"fork_clone\": { \"type\": \"fork_clone\", \"next\": [\"num_output\", \"string_output\"] }, \"num_output\": { \"type\": \"node\", \"builder\": \"num_output\", \"next\": \"buffer_access\" }, \"string_output\": { \"type\": \"node\", \"builder\": \"string_output\", \"next\": \"string_buffer\" }, \"string_buffer\": { \"type\": \"buffer\" }, \"buffer_access\": { \"type\": \"buffer_access\", \"buffers\": [\"string_buffer\"], \"target_node\": \"with_buffer_access\", \"next\": \"with_buffer_access\" }, \"with_buffer_access\": { \"type\": \"node\", \"builder\": \"with_buffer_access\", \"next\": { \"builtin\": \"terminate\" } } } } # \"#)?; # Ok::<_, serde_json::Error>(())", + "type": "object", + "required": [ + "buffers", + "next", + "type" + ], + "properties": { + "buffers": { + "description": "Map of buffer keys and buffers.", + "allOf": [ + { + "$ref": "#/definitions/BufferInputs" + } + ] + }, + "next": { + "$ref": "#/definitions/NextOperation" + }, + "target_node": { + "description": "The id of an operation that this operation is for. The id must be a `node` operation. Optional if `next` is a node operation.", + "type": [ + "string", + "null" + ] + }, + "type": { + "type": "string", + "enum": [ + "buffer_access" + ] + } + } + }, + { + "description": "Listen on a buffer.\n\n# Examples ``` # bevy_impulse::Diagram::from_json_str(r#\" { \"version\": \"0.1.0\", \"start\": \"num_output\", \"ops\": { \"buffer\": { \"type\": \"buffer\" }, \"num_output\": { \"type\": \"node\", \"builder\": \"num_output\", \"next\": \"buffer\" }, \"listen\": { \"type\": \"listen\", \"buffers\": [\"buffer\"], \"target_node\": \"listen_buffer\", \"next\": \"listen_buffer\" }, \"listen_buffer\": { \"type\": \"node\", \"builder\": \"listen_buffer\", \"next\": { \"builtin\": \"terminate\" } } } } # \"#)?; # Ok::<_, serde_json::Error>(())", + "type": "object", + "required": [ + "buffers", + "next", + "type" + ], + "properties": { + "buffers": { + "description": "Map of buffer keys and buffers.", + "allOf": [ + { + "$ref": "#/definitions/BufferInputs" + } + ] + }, + "next": { + "$ref": "#/definitions/NextOperation" + }, + "target_node": { + "description": "The id of an operation that this operation is for. The id must be a `node` operation. Optional if `next` is a node operation.", + "type": [ + "string", + "null" + ] + }, "type": { "type": "string", "enum": [ - "dispose" + "listen" ] } } @@ -278,21 +453,45 @@ } ] }, - "SourceOperation": { - "anyOf": [ + "RetentionPolicy": { + "description": "Describe how data within a buffer gets retained. Most mechanisms that pull data from a buffer will remove the oldest item in the buffer, so this policy is for dealing with situations where items are being stored faster than they are being pulled.\n\nThe default value is KeepLast(1).", + "oneOf": [ { - "type": "string" + "description": "Keep the last N items that were stored into the buffer. Once the limit is reached, the oldest item will be removed any time a new item arrives.", + "type": "object", + "required": [ + "keep_last" + ], + "properties": { + "keep_last": { + "type": "integer", + "format": "uint", + "minimum": 0.0 + } + }, + "additionalProperties": false }, { + "description": "Keep the first N items that are stored into the buffer. Once the limit is reached, any new item that arrives will be discarded.", "type": "object", "required": [ - "builtin" + "keep_first" ], "properties": { - "builtin": { - "$ref": "#/definitions/BuiltinSource" + "keep_first": { + "type": "integer", + "format": "uint", + "minimum": 0.0 } - } + }, + "additionalProperties": false + }, + { + "description": "Do not limit how many items can be stored in the buffer.", + "type": "string", + "enum": [ + "keep_all" + ] } ] } diff --git a/macros/src/buffer.rs b/macros/src/buffer.rs index b231da09..26af4c74 100644 --- a/macros/src/buffer.rs +++ b/macros/src/buffer.rs @@ -401,7 +401,7 @@ fn impl_buffer_map_layout( } impl #impl_generics ::bevy_impulse::BufferMapStruct for #struct_ident #ty_generics #where_clause { - fn buffer_list(&self) -> ::smallvec::SmallVec<[AnyBuffer; 8]> { + fn buffer_list(&self) -> ::smallvec::SmallVec<[::bevy_impulse::AnyBuffer; 8]> { use smallvec::smallvec; smallvec![#( ::bevy_impulse::AsAnyBuffer::as_any_buffer(&self.#field_ident), diff --git a/src/buffer.rs b/src/buffer.rs index 1d33c28d..1c962681 100644 --- a/src/buffer.rs +++ b/src/buffer.rs @@ -171,7 +171,12 @@ impl From> for Buffer { } /// Settings to describe the behavior of a buffer. -#[derive(Default, Clone, Copy)] +#[cfg_attr( + feature = "diagram", + derive(serde::Serialize, serde::Deserialize, schemars::JsonSchema), + serde(rename_all = "snake_case") +)] +#[derive(Default, Clone, Copy, Debug)] pub struct BufferSettings { retention: RetentionPolicy, } @@ -214,7 +219,12 @@ impl BufferSettings { /// are being pulled. /// /// The default value is KeepLast(1). -#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +#[cfg_attr( + feature = "diagram", + derive(serde::Serialize, serde::Deserialize, schemars::JsonSchema), + serde(rename_all = "snake_case") +)] +#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug)] pub enum RetentionPolicy { /// Keep the last N items that were stored into the buffer. Once the limit /// is reached, the oldest item will be removed any time a new item arrives. diff --git a/src/buffer/any_buffer.rs b/src/buffer/any_buffer.rs index efde9907..ebfd8943 100644 --- a/src/buffer/any_buffer.rs +++ b/src/buffer/any_buffer.rs @@ -34,10 +34,11 @@ use thiserror::Error as ThisError; use smallvec::SmallVec; use crate::{ - add_listener_to_source, Accessing, Buffer, BufferAccessMut, BufferAccessors, BufferError, - BufferKey, BufferKeyBuilder, BufferKeyLifecycle, BufferKeyTag, BufferLocation, BufferStorage, - Bufferable, Buffering, Builder, DrainBuffer, Gate, GateState, InspectBuffer, Joining, - ManageBuffer, NotifyBufferUpdate, OperationError, OperationResult, OperationRoster, OrBroken, + add_listener_to_source, Accessing, Accessor, Buffer, BufferAccessMut, BufferAccessors, + BufferError, BufferKey, BufferKeyBuilder, BufferKeyLifecycle, BufferKeyTag, BufferLocation, + BufferMap, BufferMapLayout, BufferStorage, Bufferable, Buffering, Builder, DrainBuffer, Gate, + GateState, IncompatibleLayout, InspectBuffer, Joining, ManageBuffer, NotifyBufferUpdate, + OperationError, OperationResult, OperationRoster, OrBroken, }; /// A [`Buffer`] whose message type has been anonymized. Joining with this buffer @@ -243,6 +244,23 @@ impl From> for AnyBufferKey { } } +impl Accessor for AnyBufferKey { + type Buffers = AnyBuffer; +} + +impl BufferMapLayout for AnyBuffer { + fn try_from_buffer_map(buffers: &BufferMap) -> Result { + let mut compatibility = IncompatibleLayout::default(); + + if let Ok(any_buffer) = compatibility.require_buffer_for_identifier::(0, buffers) + { + return Ok(any_buffer); + } + + Err(compatibility) + } +} + /// Similar to [`BufferView`][crate::BufferView], but this can be unlocked with /// an [`AnyBufferKey`], so it can work for any buffer whose message types /// support serialization and deserialization. diff --git a/src/buffer/buffer_map.rs b/src/buffer/buffer_map.rs index 04fe4ea7..0204488b 100644 --- a/src/buffer/buffer_map.rs +++ b/src/buffer/buffer_map.rs @@ -31,6 +31,8 @@ use crate::{ pub use bevy_impulse_derive::{Accessor, Joined}; +use super::BufferKey; + /// Uniquely identify a buffer within a buffer map, either by name or by an /// index value. #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] @@ -424,6 +426,20 @@ pub trait Accessor: 'static + Send + Sync + Sized + Clone { } } +impl Accessor for BufferKey +where + T: Send + Sync + 'static, +{ + type Buffers = Buffer; +} + +impl Accessor for Vec> +where + T: Send + Sync + 'static, +{ + type Buffers = Vec>; +} + impl BufferMapLayout for BufferMap { fn try_from_buffer_map(buffers: &BufferMap) -> Result { Ok(buffers.clone()) @@ -498,6 +514,20 @@ impl Joined for Vec { type Buffers = Vec>; } +impl BufferMapLayout for Buffer { + fn try_from_buffer_map(buffers: &BufferMap) -> Result { + let mut compatibility = IncompatibleLayout::default(); + + if let Ok(downcast_buffer) = + compatibility.require_buffer_for_identifier::>(0, buffers) + { + return Ok(downcast_buffer); + } + + Err(compatibility) + } +} + impl BufferMapLayout for Vec { fn try_from_buffer_map(buffers: &BufferMap) -> Result { let mut downcast_buffers = Vec::new(); diff --git a/src/buffer/json_buffer.rs b/src/buffer/json_buffer.rs index a3eba1ff..68cb1c95 100644 --- a/src/buffer/json_buffer.rs +++ b/src/buffer/json_buffer.rs @@ -36,8 +36,8 @@ pub use serde_json::Value as JsonMessage; use smallvec::SmallVec; use crate::{ - add_listener_to_source, Accessing, AnyBuffer, AnyBufferAccessInterface, AnyBufferKey, AnyRange, - AsAnyBuffer, Buffer, BufferAccessMut, BufferAccessors, BufferError, BufferIdentifier, + add_listener_to_source, Accessing, Accessor, AnyBuffer, AnyBufferAccessInterface, AnyBufferKey, + AnyRange, AsAnyBuffer, Buffer, BufferAccessMut, BufferAccessors, BufferError, BufferIdentifier, BufferKey, BufferKeyBuilder, BufferKeyLifecycle, BufferKeyTag, BufferLocation, BufferMap, BufferMapLayout, BufferMapStruct, BufferStorage, Bufferable, Buffering, Builder, DrainBuffer, Gate, GateState, IncompatibleLayout, InspectBuffer, Joined, Joining, ManageBuffer, @@ -820,10 +820,20 @@ impl JsonBufferAccessIm static INTERFACE_MAP: OnceLock< Mutex>, > = OnceLock::new(); - let interfaces = INTERFACE_MAP.get_or_init(|| Mutex::default()); + let interfaces = INTERFACE_MAP.get_or_init(|| { + let mut interfaces = HashMap::new(); + register_basic_types(&mut interfaces); + Mutex::new(interfaces) + }); let mut interfaces_mut = interfaces.lock().unwrap(); - *interfaces_mut.entry(TypeId::of::()).or_insert_with(|| { + Self::get_or_register_type(&mut *interfaces_mut) + } + + fn get_or_register_type( + interfaces: &mut HashMap, + ) -> &'static (dyn JsonBufferAccessInterface + Send + Sync) { + *interfaces.entry(TypeId::of::()).or_insert_with(|| { // Register downcasting for JsonBuffer and JsonBufferKey the // first time that we retrieve an interface for this type. let any_interface = AnyBuffer::interface_for::(); @@ -854,6 +864,29 @@ impl JsonBufferAccessIm } } +fn register_basic_types( + interfaces: &mut HashMap, +) { + JsonBufferAccessImpl::::get_or_register_type(interfaces); + JsonBufferAccessImpl::::get_or_register_type(interfaces); + JsonBufferAccessImpl::>::get_or_register_type(interfaces); + JsonBufferAccessImpl::::get_or_register_type(interfaces); + JsonBufferAccessImpl::::get_or_register_type(interfaces); + JsonBufferAccessImpl::::get_or_register_type(interfaces); + JsonBufferAccessImpl::::get_or_register_type(interfaces); + JsonBufferAccessImpl::::get_or_register_type(interfaces); + JsonBufferAccessImpl::::get_or_register_type(interfaces); + JsonBufferAccessImpl::::get_or_register_type(interfaces); + JsonBufferAccessImpl::::get_or_register_type(interfaces); + JsonBufferAccessImpl::::get_or_register_type(interfaces); + JsonBufferAccessImpl::::get_or_register_type(interfaces); + JsonBufferAccessImpl::::get_or_register_type(interfaces); + JsonBufferAccessImpl::::get_or_register_type(interfaces); + JsonBufferAccessImpl::::get_or_register_type(interfaces); + JsonBufferAccessImpl::::get_or_register_type(interfaces); + JsonBufferAccessImpl::<()>::get_or_register_type(interfaces); +} + impl JsonBufferAccessInterface for JsonBufferAccessImpl { @@ -1059,6 +1092,24 @@ impl Accessing for JsonBuffer { } } +impl Accessor for JsonBufferKey { + type Buffers = JsonBuffer; +} + +impl BufferMapLayout for JsonBuffer { + fn try_from_buffer_map(buffers: &BufferMap) -> Result { + let mut compatibility = IncompatibleLayout::default(); + + if let Ok(downcast_buffer) = + compatibility.require_buffer_for_identifier::(0, buffers) + { + return Ok(downcast_buffer); + } + + Err(compatibility) + } +} + impl Joined for serde_json::Map { type Buffers = HashMap; } @@ -1067,8 +1118,8 @@ impl BufferMapLayout for HashMap { fn try_from_buffer_map(buffers: &BufferMap) -> Result { let mut downcast_buffers = HashMap::new(); let mut compatibility = IncompatibleLayout::default(); - for name in buffers.keys() { - match name { + for identifier in buffers.keys() { + match identifier { BufferIdentifier::Name(name) => { if let Ok(downcast) = compatibility.require_buffer_for_borrowed_name::(&name, buffers) @@ -1104,6 +1155,81 @@ impl Joining for HashMap { } } +impl Joined for JsonMessage { + type Buffers = HashMap, JsonBuffer>; +} + +impl BufferMapLayout for HashMap, JsonBuffer> { + fn try_from_buffer_map(buffers: &BufferMap) -> Result { + let mut downcast_buffers = HashMap::new(); + let mut compatibility = IncompatibleLayout::default(); + for identifier in buffers.keys() { + if let Ok(downcast) = compatibility + .require_buffer_for_identifier::(identifier.clone(), buffers) + { + downcast_buffers.insert(identifier.clone(), downcast); + } + } + + compatibility.as_result()?; + Ok(downcast_buffers) + } +} + +impl BufferMapStruct for HashMap, JsonBuffer> { + fn buffer_list(&self) -> SmallVec<[AnyBuffer; 8]> { + self.values().map(|b| b.as_any_buffer()).collect() + } +} + +impl Joining for HashMap, JsonBuffer> { + type Item = JsonMessage; + fn pull(&self, session: Entity, world: &mut World) -> Result { + let mut object = serde_json::Map::::new(); + let mut array = Vec::::new(); + + for (identifier, buffer) in self.iter() { + match identifier { + BufferIdentifier::Index(index) => { + if *index >= array.len() { + // Ensure we have enough items in the array to reach the + // specified index. + array.resize(*index + 1, JsonMessage::Null); + } + + array[*index] = buffer.pull(session, world)?; + } + BufferIdentifier::Name(name) => { + object.insert(name.as_ref().to_owned(), buffer.pull(session, world)?); + } + } + } + + let value = if !object.is_empty() && !array.is_empty() { + // There are keyed buffers as well as arrayed buffers, so we need to + // organize them into two different fields in a json object. + JsonMessage::Object(serde_json::Map::from_iter([ + ("array".to_owned(), JsonMessage::Array(array)), + ("object".to_owned(), JsonMessage::Object(object)), + ])) + } else if !object.is_empty() { + // There are only object entries, so we will join them into a + // top-level object. + JsonMessage::Object(object) + } else if !array.is_empty() { + // There are only array entries, so we will join them into a + // top-level array. + JsonMessage::Array(array) + } else { + // There are no entries at all. This shouldn't happen, but we will + // handle it by returning a null value. + JsonMessage::Null + }; + + Ok(value) + } +} + #[cfg(test)] mod tests { use crate::{prelude::*, testing::*, AddBufferToMap}; diff --git a/src/builder.rs b/src/builder.rs index 45727868..a4009aab 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -22,13 +22,15 @@ use std::future::Future; use smallvec::SmallVec; use crate::{ - Accessible, Accessing, Accessor, AddOperation, AsMap, Buffer, BufferKeys, BufferLocation, - BufferMap, BufferSettings, Bufferable, Buffering, Chain, Collect, ForkClone, ForkCloneOutput, + make_option_branching, make_result_branching, Accessible, Accessing, Accessor, AddOperation, + AsMap, Buffer, BufferKeys, BufferLocation, BufferMap, BufferSettings, Bufferable, Buffering, + Chain, Collect, ForkClone, ForkCloneOutput, ForkOptionOutput, ForkResultOutput, ForkTargetStorage, Gate, GateRequest, IncompatibleLayout, Injection, InputSlot, IntoAsyncMap, - IntoBlockingMap, Joinable, Joined, Node, OperateBuffer, OperateDynamicGate, OperateScope, - OperateSplit, OperateStaticGate, Output, Provider, RequestOfMap, ResponseOfMap, Scope, - ScopeEndpoints, ScopeSettings, ScopeSettingsStorage, Sendish, Service, SplitOutputs, - Splittable, StreamPack, StreamTargetMap, StreamsOfMap, Trim, TrimBranch, UnusedTarget, + IntoBlockingMap, Joinable, Joined, Node, OperateBuffer, OperateCancel, OperateDynamicGate, + OperateQuietCancel, OperateScope, OperateSplit, OperateStaticGate, Output, Provider, + RequestOfMap, ResponseOfMap, Scope, ScopeEndpoints, ScopeSettings, ScopeSettingsStorage, + Sendish, Service, SplitOutputs, Splittable, StreamPack, StreamTargetMap, StreamsOfMap, Trim, + TrimBranch, UnusedTarget, Unzippable, }; pub(crate) mod connect; @@ -213,8 +215,8 @@ impl<'w, 's, 'a> Builder<'w, 's, 'a> { self.create_scope::(build) } - /// Create a node that clones its inputs and sends them off to any number of - /// targets. + /// Create an operation that clones its inputs and sends them off to any + /// number of targets. pub fn create_fork_clone(&mut self) -> (InputSlot, ForkCloneOutput) where T: Clone + 'static + Send + Sync, @@ -231,6 +233,75 @@ impl<'w, 's, 'a> Builder<'w, 's, 'a> { ) } + /// Create an operation that unzips its inputs and sends each element off to + /// a different output. + pub fn create_unzip(&mut self) -> (InputSlot, T::Unzipped) + where + T: Unzippable + 'static + Send + Sync, + { + let source = self.commands.spawn(()).id(); + ( + InputSlot::new(self.scope, source), + T::unzip_output(Output::::new(self.scope, source), self), + ) + } + + /// Create an operation that creates a fork for a [`Result`] input. The value + /// inside the [`Result`] will be unpacked and sent down a different branch + /// depending on whether it was in the [`Ok`] or [`Err`] variant. + pub fn create_fork_result(&mut self) -> (InputSlot>, ForkResultOutput) + where + T: 'static + Send + Sync, + E: 'static + Send + Sync, + { + let source = self.commands.spawn(()).id(); + let target_ok = self.commands.spawn(UnusedTarget).id(); + let target_err = self.commands.spawn(UnusedTarget).id(); + + self.commands.add(AddOperation::new( + Some(self.scope), + source, + make_result_branching::(ForkTargetStorage::from_iter([target_ok, target_err])), + )); + + ( + InputSlot::new(self.scope, source), + ForkResultOutput { + ok: Output::new(self.scope, target_ok), + err: Output::new(self.scope, target_err), + }, + ) + } + + /// Create an operation that creates a fork for an [`Option`] input. The value + /// inside the [`Option`] will be unpacked and sent down a different branch + /// depending on whether it was in the [`Some`] or [`None`] variant. + /// + /// For the [`None`] variant a unit `()` output will be sent, also called + /// a trigger. + pub fn create_fork_option(&mut self) -> (InputSlot>, ForkOptionOutput) + where + T: 'static + Send + Sync, + { + let source = self.commands.spawn(()).id(); + let target_some = self.commands.spawn(UnusedTarget).id(); + let target_none = self.commands.spawn(UnusedTarget).id(); + + self.commands.add(AddOperation::new( + Some(self.scope), + source, + make_option_branching::(ForkTargetStorage::from_iter([target_some, target_none])), + )); + + ( + InputSlot::new(self.scope, source), + ForkOptionOutput { + some: Output::new(self.scope, target_some), + none: Output::new(self.scope, target_none), + }, + ) + } + /// Alternative way of calling [`Joinable::join`] pub fn join<'b, B: Joinable>(&'b mut self, buffers: B) -> Chain<'w, 's, 'a, 'b, B::Item> { buffers.join(self) @@ -368,6 +439,42 @@ impl<'w, 's, 'a> Builder<'w, 's, 'a> { ) } + /// Create an input slot that will cancel the current scope when it gets + /// triggered. This can be used on types that implement [`ToString`]. + /// + /// If you need to cancel for a type that does not implement [`ToString`] + /// then convert it to a trigger `()` and then connect it to + /// [`Self::create_quiet_cancel`]. + pub fn create_cancel(&mut self) -> InputSlot + where + T: 'static + Send + Sync + ToString, + { + let source = self.commands.spawn(()).id(); + self.commands.add(AddOperation::new( + Some(self.scope), + source, + OperateCancel::::new(), + )); + + InputSlot::new(self.scope, source) + } + + /// Create an input slot that will cancel that current scope when it gets + /// triggered. + /// + /// If you want the cancellation message to include information about the + /// input value that triggered it, use [`Self::create_cancel`]. + pub fn create_quiet_cancel(&mut self) -> InputSlot<()> { + let source = self.commands.spawn(()).id(); + self.commands.add(AddOperation::new( + Some(self.scope), + source, + OperateQuietCancel, + )); + + InputSlot::new(self.scope, source) + } + /// This method allows you to define a cleanup workflow that branches off of /// this scope that will activate during the scope's cleanup phase. The /// input to the cleanup workflow will be a key to access to one or more diff --git a/src/cancel.rs b/src/cancel.rs index 632b8ac3..b255b545 100644 --- a/src/cancel.rs +++ b/src/cancel.rs @@ -67,6 +67,14 @@ impl Cancellation { .into() } + pub fn triggered(cancelled_at_node: Entity, value: Option) -> Self { + TriggeredCancellation { + cancelled_at_node, + value, + } + .into() + } + pub fn supplanted( supplanted_at_node: Entity, supplanted_by_node: Entity, @@ -119,6 +127,9 @@ pub enum CancellationCause { /// A filtering node has triggered a cancellation. Filtered(Filtered), + /// The workflow triggered its own cancellation. + Triggered(TriggeredCancellation), + /// Some workflows will queue up requests to deliver them one at a time. /// Depending on the label of the incoming requests, a new request might /// supplant an earlier one, causing the earlier request to be cancelled. @@ -163,6 +174,21 @@ pub enum CancellationCause { Broken(Broken), } +/// A variant of [`CancellationCause`] +#[derive(Debug)] +pub struct TriggeredCancellation { + /// The cancellation node that was triggered. + pub cancelled_at_node: Entity, + /// The value that triggered the cancellation, if one was provided. + pub value: Option, +} + +impl From for CancellationCause { + fn from(value: TriggeredCancellation) -> Self { + CancellationCause::Triggered(value) + } +} + impl From for CancellationCause { fn from(value: Filtered) -> Self { CancellationCause::Filtered(value) diff --git a/src/chain.rs b/src/chain.rs index fbdee9f6..ccc8f7a4 100644 --- a/src/chain.rs +++ b/src/chain.rs @@ -27,9 +27,10 @@ use crate::{ make_option_branching, make_result_branching, Accessing, AddOperation, AsMap, Buffer, BufferKey, BufferKeys, Bufferable, Buffering, Builder, Collect, CreateCancelFilter, CreateDisposalFilter, ForkTargetStorage, Gate, GateRequest, InputSlot, IntoAsyncMap, - IntoBlockingCallback, IntoBlockingMap, Node, Noop, OperateBufferAccess, OperateDynamicGate, - OperateSplit, OperateStaticGate, Output, ProvideOnce, Provider, Scope, ScopeSettings, Sendish, - Service, Spread, StreamOf, StreamPack, StreamTargetMap, Trim, TrimBranch, UnusedTarget, + IntoBlockingCallback, IntoBlockingMap, Node, Noop, OperateBufferAccess, OperateCancel, + OperateDynamicGate, OperateQuietCancel, OperateSplit, OperateStaticGate, Output, ProvideOnce, + Provider, Scope, ScopeSettings, Sendish, Service, Spread, StreamOf, StreamPack, + StreamTargetMap, Trim, TrimBranch, UnusedTarget, }; pub mod fork_clone_builder; @@ -349,6 +350,24 @@ impl<'w, 's, 'a, 'b, T: 'static + Send + Sync> Chain<'w, 's, 'a, 'b, T> { self.then(filter_provider).cancel_on_none() } + /// When the chain reaches this point, cancel the workflow and include + /// information about the value that triggered the cancellation. The input + /// type must implement [`ToString`]. + /// + /// If you want to trigger a cancellation with a type that does not + /// implement [`ToString`] then use [`Self::trigger`] and then + /// [`Self::then_quiet_cancel`]. + pub fn then_cancel(self) + where + T: ToString, + { + self.builder.commands.add(AddOperation::new( + Some(self.scope()), + self.target, + OperateCancel::::new(), + )); + } + /// Same as [`Chain::cancellation_filter`] but the chain will be disposed /// instead of cancelled, so the workflow may continue if the termination /// node can still be reached. @@ -1151,6 +1170,20 @@ where } } +impl<'w, 's, 'a, 'b> Chain<'w, 's, 'a, 'b, ()> { + /// When the chain reaches this point, cancel the workflow. + /// + /// If you want to include information about the value that triggered the + /// cancellation, use [`Self::then_cancel`]. + pub fn then_quiet_cancel(self) { + self.builder.commands.add(AddOperation::new( + Some(self.scope()), + self.target, + OperateQuietCancel, + )); + } +} + #[cfg(test)] mod tests { use crate::{prelude::*, testing::*}; @@ -1585,4 +1618,26 @@ mod tests { } assert!(context.no_unhandled_errors()); } + + #[test] + fn test_unused_branch() { + let mut context = TestingContext::minimal_plugins(); + + let workflow = + context.spawn_io_workflow(|scope: Scope>, i64>, builder| { + scope + .input + .chain(builder) + .spread() + .fork_result(|ok| ok.connect(scope.terminate), |err| err.unused()); + }); + + let test_set = vec![Err(()), Err(()), Ok(5), Err(()), Ok(10)]; + let mut promise = + context.command(|commands| commands.request(test_set, workflow).take_response()); + + context.run_with_conditions(&mut promise, Duration::from_secs(2)); + assert!(context.no_unhandled_errors()); + assert_eq!(promise.take().available().unwrap(), 5); + } } diff --git a/src/diagram.rs b/src/diagram.rs index b4822a36..1c85aa21 100644 --- a/src/diagram.rs +++ b/src/diagram.rs @@ -1,32 +1,58 @@ -mod fork_clone; -mod fork_result; -mod impls; -mod join; +/* + * Copyright (C) 2025 Open Source Robotics Foundation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * +*/ + +mod buffer_schema; +mod fork_clone_schema; +mod fork_result_schema; +mod join_schema; +mod node_schema; mod registration; mod serialization; -mod split_serialized; -mod transform; -mod unzip; +mod split_schema; +mod supported; +mod transform_schema; +mod type_info; +mod unzip_schema; mod workflow_builder; use bevy_ecs::system::Commands; -use fork_clone::ForkCloneOp; -use fork_result::ForkResultOp; -use join::JoinOp; -pub use join::JoinOutput; +use buffer_schema::{BufferAccessSchema, BufferSchema, ListenSchema}; +use fork_clone_schema::{DynForkClone, ForkCloneSchema, PerformForkClone}; +use fork_result_schema::{DynForkResult, ForkResultSchema}; +pub use join_schema::JoinOutput; +use join_schema::{JoinSchema, SerializedJoinSchema}; +pub use node_schema::NodeSchema; pub use registration::*; pub use serialization::*; -pub use split_serialized::*; +pub use split_schema::*; use tracing::debug; -use transform::{TransformError, TransformOp}; -use unzip::UnzipOp; -use workflow_builder::create_workflow; +use transform_schema::{TransformError, TransformSchema}; +use type_info::TypeInfo; +use unzip_schema::UnzipSchema; +use workflow_builder::{create_workflow, BuildDiagramOperation, BuildStatus, DiagramContext}; // ---------- -use std::{collections::HashMap, fmt::Display, io::Read}; +use std::{borrow::Cow, collections::HashMap, fmt::Display, io::Read}; -use crate::{Builder, Scope, Service, SpawnWorkflowExt, SplitConnectionError, StreamPack}; +use crate::{ + Builder, IncompatibleLayout, JsonMessage, Scope, Service, SpawnWorkflowExt, + SplitConnectionError, StreamPack, +}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -53,6 +79,24 @@ impl Display for NextOperation { } } +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +#[serde(rename_all = "snake_case", untagged)] +pub enum BufferInputs { + Single(OperationId), + Dict(HashMap), + Array(Vec), +} + +impl BufferInputs { + pub fn is_empty(&self) -> bool { + match self { + Self::Single(_) => false, + Self::Dict(d) => d.is_empty(), + Self::Array(a) => a.is_empty(), + } + } +} + #[derive( Debug, Clone, @@ -69,12 +113,17 @@ impl Display for NextOperation { #[serde(rename_all = "snake_case")] #[strum(serialize_all = "snake_case")] pub enum BuiltinTarget { - /// Use the output to terminate the workflow. This will be the return value - /// of the workflow. + /// Use the output to terminate the current scope. The value passed into + /// this operation will be the return value of the scope. Terminate, /// Dispose of the output. Dispose, + + /// When triggered, cancel the current scope. If this is an inner scope of a + /// workflow then the parent scope will see a disposal happen. If this is + /// the root scope of a workflow then the whole workflow will cancel. + Cancel, } #[derive( @@ -122,79 +171,162 @@ pub enum BuiltinSource { #[derive(Debug, Serialize, Deserialize, JsonSchema)] #[serde(rename_all = "snake_case")] -pub struct TerminateOp {} +pub struct TerminateSchema {} -#[derive(Debug, Serialize, Deserialize, JsonSchema)] -#[serde(rename_all = "snake_case")] -pub struct NodeOp { - builder: BuilderId, - #[serde(default)] - config: serde_json::Value, - next: NextOperation, -} - -#[derive(Debug, JsonSchema, Serialize, Deserialize)] +#[derive(Clone, strum::Display, Debug, JsonSchema, Serialize, Deserialize)] #[serde(rename_all = "snake_case", tag = "type")] +#[strum(serialize_all = "snake_case")] pub enum DiagramOperation { - /// Connect the request to a registered node. + /// Create an operation that that takes an input message and produces an + /// output message. /// + /// The behavior is determined by the choice of node `builder` and + /// optioanlly the `config` that you provide. Each type of node builder has + /// its own schema for the config. + /// + /// The output message will be sent to the operation specified by `next`. + /// + /// TODO(@mxgrey): [Support stream outputs](https://github.com/open-rmf/bevy_impulse/issues/43) + /// + /// # Examples /// ``` /// # bevy_impulse::Diagram::from_json_str(r#" /// { /// "version": "0.1.0", - /// "start": "node_op", + /// "start": "cutting_board", /// "ops": { - /// "node_op": { + /// "cutting_board": { /// "type": "node", - /// "builder": "my_node_builder", + /// "builder": "chop", + /// "config": "diced", + /// "next": "bowl" + /// }, + /// "bowl": { + /// "type": "node", + /// "builder": "stir", + /// "next": "oven" + /// }, + /// "oven": { + /// "type": "node", + /// "builder": "bake", + /// "config": { + /// "temperature": 200, + /// "duration": 120 + /// }, /// "next": { "builtin": "terminate" } /// } /// } /// } /// # "#)?; /// # Ok::<_, serde_json::Error>(()) - Node(NodeOp), + Node(NodeSchema), - /// If the request is cloneable, clone it into multiple responses. + /// If the request is cloneable, clone it into multiple responses that can + /// each be sent to a different operation. The `next` property is an array. + /// + /// This creates multiple simultaneous branches of execution within the + /// workflow. Usually when you have multiple branches you will either + /// * race - connect all branches to `terminate` and the first branch to + /// finish "wins" the race and gets to the be output + /// * join - connect each branch into a buffer and then use the `join` + /// operation to reunite them + /// * collect - TODO(@mxgrey): [add the collect operation](https://github.com/open-rmf/bevy_impulse/issues/59) /// /// # Examples /// ``` /// # bevy_impulse::Diagram::from_json_str(r#" /// { /// "version": "0.1.0", - /// "start": "fork_clone", + /// "start": "begin_race", /// "ops": { - /// "fork_clone": { + /// "begin_race": { /// "type": "fork_clone", - /// "next": ["terminate"] + /// "next": [ + /// "ferrari", + /// "mustang" + /// ] + /// }, + /// "ferrari": { + /// "type": "node", + /// "builder": "drive", + /// "config": "ferrari", + /// "next": { "builtin": "terminate" } + /// }, + /// "mustang": { + /// "type": "node", + /// "builder": "drive", + /// "config": "mustang", + /// "next": { "builtin": "terminate" } /// } /// } /// } /// # "#)?; /// # Ok::<_, serde_json::Error>(()) - ForkClone(ForkCloneOp), + ForkClone(ForkCloneSchema), - /// If the request is a tuple of (T1, T2, T3, ...), unzip it into multiple responses - /// of T1, T2, T3, ... + /// If the input message is a tuple of (T1, T2, T3, ...), unzip it into + /// multiple output messages of T1, T2, T3, ... + /// + /// Each output message may have a different type and can be sent to a + /// different operation. This creates multiple simultaneous branches of + /// execution within the workflow. See [`DiagramOperation::ForkClone`] for + /// more information on parallel branches. /// /// # Examples /// ``` /// # bevy_impulse::Diagram::from_json_str(r#" /// { /// "version": "0.1.0", - /// "start": "unzip", + /// "start": "name_phone_address", /// "ops": { - /// "unzip": { + /// "name_phone_address": { /// "type": "unzip", - /// "next": [{ "builtin": "terminate" }] + /// "next": [ + /// "process_name", + /// "process_phone_number", + /// "process_address" + /// ] + /// }, + /// "process_name": { + /// "type": "node", + /// "builder": "process_name", + /// "next": "name_processed" + /// }, + /// "process_phone_number": { + /// "type": "node", + /// "builder": "process_phone_number", + /// "next": "phone_number_processed" + /// }, + /// "process_address": { + /// "type": "node", + /// "builder": "process_address", + /// "next": "address_processed" + /// }, + /// "name_processed": { "type": "buffer" }, + /// "phone_number_processed": { "type": "buffer" }, + /// "address_processed": { "type": "buffer" }, + /// "finished": { + /// "type": "join", + /// "buffers": [ + /// "name_processed", + /// "phone_number_processed", + /// "address_processed" + /// ], + /// "next": { "builtin": "terminate" } /// } /// } /// } /// # "#)?; /// # Ok::<_, serde_json::Error>(()) - Unzip(UnzipOp), + Unzip(UnzipSchema), - /// If the request is a `Result<_, _>`, branch it to `Ok` and `Err`. + /// If the request is a [`Result`], send the output message down an + /// `ok` branch or down an `err` branch depending on whether the result has + /// an [`Ok`] or [`Err`] value. The `ok` branch will receive a `T` while the + /// `err` branch will receive an `E`. + /// + /// Only one branch will be activated by each input message that enters the + /// operation. /// /// # Examples /// ``` @@ -212,57 +344,122 @@ pub enum DiagramOperation { /// } /// # "#)?; /// # Ok::<_, serde_json::Error>(()) - ForkResult(ForkResultOp), + ForkResult(ForkResultSchema), - /// If the request is a list-like or map-like object, split it into multiple responses. - /// Note that the split output is a tuple of `(KeyOrIndex, Value)`, nodes receiving a split - /// output should have request of that type instead of just the value type. + /// If the input message is a list-like or map-like object, split it into + /// multiple output messages. + /// + /// Note that the type of output message from the split depends on how the + /// input message implements the [`Splittable`][1] trait. In many cases this + /// will be a tuple of `(key, value)`. + /// + /// There are three ways to specify where the split output messages should + /// go, and all can be used at the same time: + /// * `sequential` - For array-like collections, send the "first" element of + /// the collection to the first operation listed in the `sequential` array. + /// The "second" element of the collection goes to the second operation + /// listed in the `sequential` array. And so on for all elements in the + /// collection. If one of the elements in the collection is mentioned in + /// the `keyed` set, then the sequence will pass over it as if the element + /// does not exist at all. + /// * `keyed` - For map-like collections, send the split element associated + /// with the specified key to its associated output. + /// * `remaining` - Any elements that are were not captured by `sequential` + /// or by `keyed` will be sent to this. + /// + /// [1]: crate::Splittable /// /// # Examples + /// + /// Suppose I am an animal rescuer sorting through a new collection of + /// animals that need recuing. My home has space for three exotic animals + /// plus any number of dogs and cats. + /// + /// I have a custom `SpeciesCollection` data structure that implements + /// [`Splittable`][1] by allowing you to key on the type of animal. + /// + /// In the workflow below, we send all cats and dogs to `home`, and we also + /// send the first three non-dog and non-cat species to `home`. All + /// remaining animals go to the zoo. + /// /// ``` /// # bevy_impulse::Diagram::from_json_str(r#" /// { /// "version": "0.1.0", - /// "start": "split", + /// "start": "select_animals", /// "ops": { - /// "split": { + /// "select_animals": { /// "type": "split", - /// "index": [{ "builtin": "terminate" }] + /// "sequential": [ + /// "home", + /// "home", + /// "home" + /// ], + /// "keyed": { + /// "cat": "home", + /// "dog": "home" + /// }, + /// "remaining": "zoo" /// } /// } /// } /// # "#)?; /// # Ok::<_, serde_json::Error>(()) /// ``` - Split(SplitOp), - - /// Wait for an item to be emitted from each of the inputs, then combined the - /// oldest of each into an array. + /// + /// If we input `["frog", "cat", "bear", "beaver", "dog", "rabbit", "dog", "monkey"]` + /// then `frog`, `bear`, and `beaver` will be sent to `home` since those are + /// the first three animals that are not `dog` or `cat`, and we will also + /// send one `cat` and two `dog` home. `rabbit` and `monkey` will be sent to the zoo. + Split(SplitSchema), + + /// Wait for exactly one item to be available in each buffer listed in + /// `buffers`, then join each of those items into a single output message + /// that gets sent to `next`. + /// + /// If the `next` operation is not a `node` type (e.g. `fork_clone`) then + /// you must specify a `target_node` so that the diagram knows what data + /// structure to join the values into. + /// + /// The output message type must be registered as joinable at compile time. + /// If you want to join into a dynamic data structure then you should use + /// [`DiagramOperation::SerializedJoin`] instead. /// /// # Examples /// ``` /// # bevy_impulse::Diagram::from_json_str(r#" /// { /// "version": "0.1.0", - /// "start": "split", + /// "start": "fork_measuring", /// "ops": { - /// "split": { - /// "type": "split", - /// "index": ["op1", "op2"] + /// "fork_measuring": { + /// "type": "fork_clone", + /// "next": ["localize", "imu"] /// }, - /// "op1": { + /// "localize": { /// "type": "node", - /// "builder": "foo", - /// "next": "join" + /// "builder": "localize", + /// "next": "estimated_position" /// }, - /// "op2": { + /// "imu": { /// "type": "node", - /// "builder": "bar", - /// "next": "join" + /// "builder": "imu", + /// "config": "velocity", + /// "next": "estimated_velocity" /// }, - /// "join": { + /// "estimated_position": { "type": "buffer" }, + /// "estimated_velocity": { "type": "buffer" }, + /// "gather_state": { /// "type": "join", - /// "inputs": ["op1", "op2"], + /// "buffers": { + /// "position": "estimate_position", + /// "velocity": "estimate_velocity" + /// }, + /// "next": "report_state" + /// }, + /// "report_state": { + /// "type": "node", + /// "builder": "publish_state", /// "next": { "builtin": "terminate" } /// } /// } @@ -270,10 +467,21 @@ pub enum DiagramOperation { /// # "#)?; /// # Ok::<_, serde_json::Error>(()) /// ``` - Join(JoinOp), + Join(JoinSchema), + + /// Same as [`DiagramOperation::Join`] but all input messages must be + /// serializable, and the output message will always be [`serde_json::Value`]. + /// + /// If you use an array for `buffers` then the output message will be a + /// [`serde_json::Value::Array`]. If you use a map for `buffers` then the + /// output message will be a [`serde_json::Value::Object`]. + /// + /// Unlike [`DiagramOperation::Join`], the `target_node` property does not + /// exist for this schema. + SerializedJoin(SerializedJoinSchema), /// If the request is serializable, transform it by running it through a [CEL](https://cel.dev/) program. - /// The context includes a "request" variable which contains the request. + /// The context includes a "request" variable which contains the input message. /// /// # Examples /// ``` @@ -314,15 +522,185 @@ pub enum DiagramOperation { /// # "#)?; /// # Ok::<_, serde_json::Error>(()) /// ``` - Transform(TransformOp), + Transform(TransformSchema), - /// Drop the request, equivalent to a no-op. - Dispose, + /// Create a [`Buffer`][1] which can be used to store and pull data within + /// a scope. + /// + /// By default the [`BufferSettings`][2] will keep the single last message + /// pushed to the buffer. You can change that with the optional `settings` + /// property. + /// + /// Use the `"serialize": true` option to serialize the messages into + /// [`JsonMessage`] before they are inserted into the buffer. This + /// allows any serializable message type to be pushed into the buffer. If + /// left unspecified, the buffer will store the specific data type that gets + /// pushed into it. If the buffer inputs are not being serialized, then all + /// incoming messages being pushed into the buffer must have the same type. + /// + /// [1]: crate::Buffer + /// [2]: crate::BufferSettings + /// + /// # Examples + /// ``` + /// # bevy_impulse::Diagram::from_json_str(r#" + /// { + /// "version": "0.1.0", + /// "start": "fork_clone", + /// "ops": { + /// "fork_clone": { + /// "type": "fork_clone", + /// "next": ["num_output", "string_output", "all_num_buffer", "serialized_num_buffer"] + /// }, + /// "num_output": { + /// "type": "node", + /// "builder": "num_output", + /// "next": "buffer_access" + /// }, + /// "string_output": { + /// "type": "node", + /// "builder": "string_output", + /// "next": "string_buffer" + /// }, + /// "string_buffer": { + /// "type": "buffer", + /// "settings": { + /// "retention": { "keep_last": 10 } + /// } + /// }, + /// "all_num_buffer": { + /// "type": "buffer", + /// "settings": { + /// "retention": "keep_all" + /// } + /// }, + /// "serialized_num_buffer": { + /// "type": "buffer", + /// "serialize": true + /// }, + /// "buffer_access": { + /// "type": "buffer_access", + /// "buffers": ["string_buffer"], + /// "target_node": "with_buffer_access", + /// "next": "with_buffer_access" + /// }, + /// "with_buffer_access": { + /// "type": "node", + /// "builder": "with_buffer_access", + /// "next": { "builtin": "terminate" } + /// } + /// } + /// } + /// # "#)?; + /// # Ok::<_, serde_json::Error>(()) + /// ``` + Buffer(BufferSchema), + + /// Zip a message together with access to one or more buffers. + /// + /// The receiving node must have an input type of `(Message, Keys)` + /// where `Keys` implements the [`Accessor`][1] trait. + /// + /// [1]: crate::Accessor + /// + /// # Examples + /// ``` + /// # bevy_impulse::Diagram::from_json_str(r#" + /// { + /// "version": "0.1.0", + /// "start": "fork_clone", + /// "ops": { + /// "fork_clone": { + /// "type": "fork_clone", + /// "next": ["num_output", "string_output"] + /// }, + /// "num_output": { + /// "type": "node", + /// "builder": "num_output", + /// "next": "buffer_access" + /// }, + /// "string_output": { + /// "type": "node", + /// "builder": "string_output", + /// "next": "string_buffer" + /// }, + /// "string_buffer": { + /// "type": "buffer" + /// }, + /// "buffer_access": { + /// "type": "buffer_access", + /// "buffers": ["string_buffer"], + /// "target_node": "with_buffer_access", + /// "next": "with_buffer_access" + /// }, + /// "with_buffer_access": { + /// "type": "node", + /// "builder": "with_buffer_access", + /// "next": { "builtin": "terminate" } + /// } + /// } + /// } + /// # "#)?; + /// # Ok::<_, serde_json::Error>(()) + BufferAccess(BufferAccessSchema), + + /// Listen on a buffer. + /// + /// # Examples + /// ``` + /// # bevy_impulse::Diagram::from_json_str(r#" + /// { + /// "version": "0.1.0", + /// "start": "num_output", + /// "ops": { + /// "buffer": { + /// "type": "buffer" + /// }, + /// "num_output": { + /// "type": "node", + /// "builder": "num_output", + /// "next": "buffer" + /// }, + /// "listen": { + /// "type": "listen", + /// "buffers": ["buffer"], + /// "target_node": "listen_buffer", + /// "next": "listen_buffer" + /// }, + /// "listen_buffer": { + /// "type": "node", + /// "builder": "listen_buffer", + /// "next": { "builtin": "terminate" } + /// } + /// } + /// } + /// # "#)?; + /// # Ok::<_, serde_json::Error>(()) + Listen(ListenSchema), } -type DiagramStart = serde_json::Value; -type DiagramTerminate = serde_json::Value; -type DiagramScope = Scope; +impl BuildDiagramOperation for DiagramOperation { + fn build_diagram_operation( + &self, + id: &OperationId, + builder: &mut Builder, + ctx: &mut DiagramContext, + ) -> Result { + match self { + Self::Buffer(op) => op.build_diagram_operation(id, builder, ctx), + Self::BufferAccess(op) => op.build_diagram_operation(id, builder, ctx), + Self::ForkClone(op) => op.build_diagram_operation(id, builder, ctx), + Self::ForkResult(op) => op.build_diagram_operation(id, builder, ctx), + Self::Join(op) => op.build_diagram_operation(id, builder, ctx), + Self::Listen(op) => op.build_diagram_operation(id, builder, ctx), + Self::Node(op) => op.build_diagram_operation(id, builder, ctx), + Self::SerializedJoin(op) => op.build_diagram_operation(id, builder, ctx), + Self::Split(op) => op.build_diagram_operation(id, builder, ctx), + Self::Transform(op) => op.build_diagram_operation(id, builder, ctx), + Self::Unzip(op) => op.build_diagram_operation(id, builder, ctx), + } + } +} /// Returns the schema for [`String`] fn schema_with_string(gen: &mut schemars::gen::SchemaGenerator) -> schemars::schema::Schema { @@ -368,22 +746,65 @@ pub struct Diagram { #[schemars(schema_with = "schema_with_string")] version: semver::Version, - /// Signifies the start of a workflow. + /// Indicates where the workflow should start running. start: NextOperation, + /// To simplify diagram definitions, the diagram workflow builder will + /// sometimes insert implicit operations into the workflow, such as implicit + /// serializing and deserializing. These implicit operations may be fallible. + /// + /// This field indicates how a failed implicit operation should be handled. + /// If left unspecified, an implicit error will cause the entire workflow to + /// be cancelled. + #[serde(default)] + on_implicit_error: Option, + + /// Operations that define the workflow ops: HashMap, } impl Diagram { - /// Implementation for [Self::spawn_io_workflow]. + /// Spawns a workflow from this diagram. + /// + /// # Examples + /// + /// ``` + /// use bevy_impulse::*; + /// + /// let mut app = bevy_app::App::new(); + /// let mut registry = DiagramElementRegistry::new(); + /// registry.register_node_builder(NodeBuilderOptions::new("echo".to_string()), |builder, _config: ()| { + /// builder.create_map_block(|msg: String| msg) + /// }); + /// + /// let json_str = r#" + /// { + /// "version": "0.1.0", + /// "start": "echo", + /// "ops": { + /// "echo": { + /// "type": "node", + /// "builder": "echo", + /// "next": { "builtin": "terminate" } + /// } + /// } + /// } + /// "#; + /// + /// let diagram = Diagram::from_json_str(json_str)?; + /// let workflow = app.world.command(|cmds| diagram.spawn_io_workflow::(cmds, ®istry))?; + /// # Ok::<_, Box>(()) + /// ``` // TODO(koonpeng): Support streams other than `()` #43. /* pub */ - fn spawn_workflow( + fn spawn_workflow( &self, cmds: &mut Commands, registry: &DiagramElementRegistry, - ) -> Result, DiagramError> + ) -> Result, DiagramError> where + Request: 'static + Send + Sync, + Response: 'static + Send + Sync, Streams: StreamPack, { let mut err: Option = None; @@ -400,15 +821,17 @@ impl Diagram { }; } - let w = cmds.spawn_workflow(|scope: DiagramScope, builder: &mut Builder| { - debug!( - "spawn workflow, scope input: {:?}, terminate: {:?}", - scope.input.id(), - scope.terminate.id() - ); + let w = cmds.spawn_workflow( + |scope: Scope, builder: &mut Builder| { + debug!( + "spawn workflow, scope input: {:?}, terminate: {:?}", + scope.input.id(), + scope.terminate.id() + ); - unwrap_or_return!(create_workflow(scope, builder, registry, self)); - }); + unwrap_or_return!(create_workflow(scope, builder, registry, self)); + }, + ); if let Some(err) = err { return Err(err); @@ -422,7 +845,7 @@ impl Diagram { /// # Examples /// /// ``` - /// use bevy_impulse::{Diagram, DiagramError, NodeBuilderOptions, DiagramElementRegistry, RunCommandsOnWorldExt}; + /// use bevy_impulse::*; /// /// let mut app = bevy_app::App::new(); /// let mut registry = DiagramElementRegistry::new(); @@ -445,15 +868,19 @@ impl Diagram { /// "#; /// /// let diagram = Diagram::from_json_str(json_str)?; - /// let workflow = app.world.command(|cmds| diagram.spawn_io_workflow(cmds, ®istry))?; - /// # Ok::<_, DiagramError>(()) + /// let workflow = app.world.command(|cmds| diagram.spawn_io_workflow::(cmds, ®istry))?; + /// # Ok::<_, Box>(()) /// ``` - pub fn spawn_io_workflow( + pub fn spawn_io_workflow( &self, cmds: &mut Commands, registry: &DiagramElementRegistry, - ) -> Result, DiagramError> { - self.spawn_workflow::<()>(cmds, registry) + ) -> Result, DiagramError> + where + Request: 'static + Send + Sync, + Response: 'static + Send + Sync, + { + self.spawn_workflow::(cmds, registry) } pub fn from_json(value: serde_json::Value) -> Result { @@ -470,53 +897,128 @@ impl Diagram { { serde_json::from_reader(r) } + + fn get_op(&self, op_id: &OperationId) -> Result<&DiagramOperation, DiagramErrorCode> { + self.ops + .get(op_id) + .ok_or_else(|| DiagramErrorCode::OperationNotFound(op_id.clone())) + } } #[derive(thiserror::Error, Debug)] -pub enum DiagramError { +#[error("{context} {code}")] +pub struct DiagramError { + pub context: DiagramErrorContext, + + #[source] + pub code: DiagramErrorCode, +} + +impl DiagramError { + pub fn in_operation(op_id: OperationId, code: DiagramErrorCode) -> Self { + Self { + context: DiagramErrorContext { op_id: Some(op_id) }, + code, + } + } +} + +#[derive(Debug)] +pub struct DiagramErrorContext { + op_id: Option, +} + +impl Display for DiagramErrorContext { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if let Some(op_id) = &self.op_id { + write!(f, "in operation [{}],", op_id)?; + } + Ok(()) + } +} + +#[derive(thiserror::Error, Debug)] +pub enum DiagramErrorCode { #[error("node builder [{0}] is not registered")] BuilderNotFound(BuilderId), #[error("operation [{0}] not found")] OperationNotFound(OperationId), - #[error("output type does not match input type")] - TypeMismatch, + #[error("type mismatch, source {source_type}, target {target_type}")] + TypeMismatch { + source_type: TypeInfo, + target_type: TypeInfo, + }, + + #[error("Operation [{0}] attempted to instantiate multiple inputs.")] + MultipleInputsCreated(OperationId), - #[error("missing start or terminate")] + #[error("Operation [{0}] attempted to instantiate multiple buffers.")] + MultipleBuffersCreated(OperationId), + + #[error("Missing a connection to start or terminate. A workflow cannot run with a valid connection to each.")] MissingStartOrTerminate, - #[error("cannot connect to start")] - CannotConnectStart, + #[error("Serialization was disabled for the target message type.")] + NotSerializable(TypeInfo), - #[error("request or response cannot be serialized or deserialized")] - NotSerializable, + #[error("Deserialization was disabled for the target message type.")] + NotDeserializable(TypeInfo), - #[error("response cannot be cloned")] + #[error("Cloning was disabled for the target message type.")] NotCloneable, - #[error("the number of unzip slots in response does not match the number of inputs")] + #[error("The number of unzip slots in response does not match the number of inputs.")] NotUnzippable, - #[error( - "node must be registered with \"with_fork_result()\" to be able to perform fork result" - )] + #[error("The number of elements in the unzip expected by the diagram [{expected}] is different from the real number [{actual}]")] + UnzipMismatch { + expected: usize, + actual: usize, + elements: Vec, + }, + + #[error("Call .with_fork_result() on your node to be able to fork its Result-type output.")] CannotForkResult, - #[error("response cannot be split")] + #[error("Response cannot be split. Make sure to use .with_split() when building the node.")] NotSplittable, - #[error("responses cannot be joined")] + #[error( + "Message cannot be joined. Make sure to use .with_join() when building the target node." + )] NotJoinable, - #[error("empty join is not allowed")] + #[error("Empty join is not allowed.")] EmptyJoin, + #[error("Target type cannot be determined from [next] and [target_node] is not provided.")] + UnknownTarget, + + #[error("There was an attempt to access an unknown operation: [{0}]")] + UnknownOperation(NextOperation), + #[error(transparent)] CannotTransform(#[from] TransformError), - #[error("an interconnect like fork_clone cannot connect to another interconnect")] - BadInterconnectChain, + #[error("box/unbox operation for the message is not registered")] + CannotBoxOrUnbox, + + #[error("Buffer access was not enabled for a node connected to a buffer access operation. Make sure to use .with_buffer_access() when building the node.")] + CannotBufferAccess, + + #[error("cannot listen on these buffers to produce a request of [{0}]")] + CannotListen(TypeInfo), + + #[error(transparent)] + IncompatibleBuffers(#[from] IncompatibleLayout), + + #[error("one or more operation is missing inputs")] + IncompleteDiagram, + + #[error("operation type only accept single input")] + OnlySingleInput, #[error(transparent)] JsonError(#[from] serde_json::Error), @@ -524,18 +1026,26 @@ pub enum DiagramError { #[error(transparent)] ConnectionError(#[from] SplitConnectionError), - /// Use this only for errors that *should* never happen because of some preconditions. - /// If this error ever comes up, then it likely means that there is some logical flaws - /// in the algorithm. - #[error("an unknown error occurred while building the diagram, {0}")] - UnknownError(String), + #[error("a type being used in the diagram was not registered {0}")] + UnregisteredType(TypeInfo), + + #[error("The build of the workflow came to a halt, reasons:\n{reasons:?}")] + BuildHalted { + /// Reasons that operations were unable to make progress building + reasons: HashMap>, + }, + + #[error("The workflow building process has had an excessive number of iterations. This may indicate an implementation bug or an extraordinarily complex diagram.")] + ExcessiveIterations, } -#[macro_export] -macro_rules! unknown_diagram_error { - () => { - DiagramError::UnknownError(format!("{}:{}", file!(), line!())) - }; +impl From for DiagramError { + fn from(code: DiagramErrorCode) -> Self { + DiagramError { + context: DiagramErrorContext { op_id: None }, + code, + } + } } #[cfg(test)] @@ -570,6 +1080,7 @@ mod tests { let err = fixture .spawn_and_run(&diagram, serde_json::Value::from(4)) .unwrap_err(); + assert!(fixture.context.no_unhandled_errors()); assert!(matches!( *err.downcast_ref::().unwrap().cause, CancellationCause::Unreachable(_) @@ -593,8 +1104,12 @@ mod tests { })) .unwrap(); - let err = fixture.spawn_io_workflow(&diagram).unwrap_err(); - assert!(matches!(err, DiagramError::NotSerializable), "{:?}", err); + let err = fixture.spawn_json_io_workflow(&diagram).unwrap_err(); + assert!( + matches!(err.code, DiagramErrorCode::TypeMismatch { .. }), + "{:?}", + err + ); } #[test] @@ -614,8 +1129,12 @@ mod tests { })) .unwrap(); - let err = fixture.spawn_io_workflow(&diagram).unwrap_err(); - assert!(matches!(err, DiagramError::NotSerializable), "{:?}", err); + let err = fixture.spawn_json_io_workflow(&diagram).unwrap_err(); + assert!( + matches!(err.code, DiagramErrorCode::NotSerializable(_)), + "{:?}", + err + ); } #[test] @@ -640,8 +1159,18 @@ mod tests { })) .unwrap(); - let err = fixture.spawn_io_workflow(&diagram).unwrap_err(); - assert!(matches!(err, DiagramError::TypeMismatch), "{:?}", err); + let err = fixture.spawn_json_io_workflow(&diagram).unwrap_err(); + assert!( + matches!( + err.code, + DiagramErrorCode::TypeMismatch { + target_type: _, + source_type: _ + } + ), + "{:?}", + err + ); } #[test] @@ -669,6 +1198,7 @@ mod tests { let err = fixture .spawn_and_run(&diagram, serde_json::Value::from(4)) .unwrap_err(); + assert!(fixture.context.no_unhandled_errors()); assert!(matches!( *err.downcast_ref::().unwrap().cause, CancellationCause::Unreachable(_) @@ -704,6 +1234,7 @@ mod tests { let result = fixture .spawn_and_run(&diagram, serde_json::Value::from(4)) .unwrap(); + assert!(fixture.context.no_unhandled_errors()); assert_eq!(result, 36); } @@ -721,6 +1252,7 @@ mod tests { let result = fixture .spawn_and_run(&diagram, serde_json::Value::from(4)) .unwrap(); + assert!(fixture.context.no_unhandled_errors()); assert_eq!(result, 4); } @@ -749,6 +1281,7 @@ mod tests { serde_json::Value::from(4), ) .unwrap(); + assert!(fixture.context.no_unhandled_errors()); assert_eq!(result, 28); } @@ -769,7 +1302,10 @@ mod tests { }, "unzip": { "type": "unzip", - "next": ["transform"], + "next": [ + "transform", + { "builtin": "dispose" }, + ], }, "transform": { "type": "transform", @@ -783,6 +1319,7 @@ mod tests { let result = fixture .spawn_and_run(&diagram, serde_json::Value::from(4)) .unwrap(); + assert!(fixture.context.no_unhandled_errors()); assert_eq!(result, 777); } } diff --git a/src/diagram/buffer_schema.rs b/src/diagram/buffer_schema.rs new file mode 100644 index 00000000..79e7ec74 --- /dev/null +++ b/src/diagram/buffer_schema.rs @@ -0,0 +1,1008 @@ +/* + * Copyright (C) 2025 Open Source Robotics Foundation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * +*/ + +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; + +use crate::{Accessor, BufferSettings, Builder, JsonMessage}; + +use super::{ + type_info::TypeInfo, BufferInputs, BuildDiagramOperation, BuildStatus, DiagramContext, + DiagramErrorCode, NextOperation, OperationId, +}; + +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +pub struct BufferSchema { + #[serde(default)] + pub(super) settings: BufferSettings, + + /// If true, messages will be serialized before sending into the buffer. + pub(super) serialize: Option, +} + +impl BuildDiagramOperation for BufferSchema { + fn build_diagram_operation( + &self, + id: &OperationId, + builder: &mut Builder, + ctx: &mut DiagramContext, + ) -> Result { + let message_info = if self.serialize.is_some_and(|v| v) { + TypeInfo::of::() + } else { + let Some(inferred_type) = ctx.infer_input_type_into_target(id) else { + // There are no outputs ready for this target, so we can't do + // anything yet. The builder should try again later. + + // TODO(@mxgrey): We should allow users to explicitly specify the + // message type for the buffer. When they do, we won't need to wait + // for an input. + return Ok(BuildStatus::defer("waiting for an input")); + }; + + *inferred_type + }; + + let buffer = + ctx.registry + .messages + .create_buffer(&message_info, self.settings.clone(), builder)?; + ctx.set_buffer_for_operation(id, buffer)?; + Ok(BuildStatus::Finished) + } +} + +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +#[serde(rename_all = "snake_case")] +pub struct BufferAccessSchema { + pub(super) next: NextOperation, + + /// Map of buffer keys and buffers. + pub(super) buffers: BufferInputs, + + /// The id of an operation that this operation is for. The id must be a `node` operation. Optional if `next` is a node operation. + pub(super) target_node: Option, +} + +impl BuildDiagramOperation for BufferAccessSchema { + fn build_diagram_operation( + &self, + id: &OperationId, + builder: &mut Builder, + ctx: &mut DiagramContext, + ) -> Result { + let buffer_map = match ctx.create_buffer_map(&self.buffers) { + Ok(buffer_map) => buffer_map, + Err(reason) => return Ok(BuildStatus::defer(reason)), + }; + + let target_type = ctx.get_node_request_type(self.target_node.as_ref(), &self.next)?; + let node = ctx + .registry + .messages + .with_buffer_access(&target_type, &buffer_map, builder)?; + ctx.set_input_for_target(id, node.input)?; + ctx.add_output_into_target(self.next.clone(), node.output); + Ok(BuildStatus::Finished) + } +} + +pub trait BufferAccessRequest { + type Message: Send + Sync + 'static; + type BufferKeys: Accessor; +} + +impl BufferAccessRequest for (T, B) +where + T: Send + Sync + 'static, + B: Accessor, +{ + type Message = T; + type BufferKeys = B; +} + +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +#[serde(rename_all = "snake_case")] +pub struct ListenSchema { + pub(super) next: NextOperation, + + /// Map of buffer keys and buffers. + pub(super) buffers: BufferInputs, + + /// The id of an operation that this operation is for. The id must be a `node` operation. Optional if `next` is a node operation. + pub(super) target_node: Option, +} + +impl BuildDiagramOperation for ListenSchema { + fn build_diagram_operation( + &self, + _: &OperationId, + builder: &mut Builder, + ctx: &mut DiagramContext, + ) -> Result { + let buffer_map = match ctx.create_buffer_map(&self.buffers) { + Ok(buffer_map) => buffer_map, + Err(reason) => return Ok(BuildStatus::defer(reason)), + }; + + let target_type = ctx.get_node_request_type(self.target_node.as_ref(), &self.next)?; + let output = ctx + .registry + .messages + .listen(&target_type, &buffer_map, builder)?; + ctx.add_output_into_target(self.next.clone(), output); + Ok(BuildStatus::Finished) + } +} + +#[cfg(test)] +mod tests { + use bevy_ecs::{prelude::World, system::In}; + use serde_json::json; + + use crate::{ + diagram::testing::DiagramTestFixture, Accessor, AnyBufferKey, AnyBufferWorldAccess, + BufferAccess, BufferAccessMut, BufferKey, BufferWorldAccess, Diagram, DiagramErrorCode, + IntoBlockingCallback, JsonBufferKey, JsonBufferWorldAccess, JsonMessage, JsonPosition, + Node, NodeBuilderOptions, + }; + + /// create a new [`DiagramTestFixture`] with some extra builders. + fn new_fixture() -> DiagramTestFixture { + let mut fixture = DiagramTestFixture::new(); + + fn num_output(_: serde_json::Value) -> i64 { + 1 + } + + fixture.registry.register_node_builder( + NodeBuilderOptions::new("num_output".to_string()), + |builder, _config: ()| builder.create_map_block(num_output), + ); + + fn string_output(_: serde_json::Value) -> String { + "hello".to_string() + } + + fixture.registry.register_node_builder( + NodeBuilderOptions::new("string_output".to_string()), + |builder, _config: ()| builder.create_map_block(string_output), + ); + + fixture + .registry + .opt_out() + .no_serializing() + .no_deserializing() + .register_node_builder( + NodeBuilderOptions::new("insert_json_buffer_entries".to_owned()), + |builder, config: usize| { + builder.create_node( + (move |In((value, key)): In<(JsonMessage, JsonBufferKey)>, + world: &mut World| { + world + .json_buffer_mut(&key, |mut buffer| { + for _ in 0..config { + buffer.push(value.clone()).unwrap(); + } + }) + .unwrap(); + }) + .into_blocking_callback(), + ) + }, + ) + .with_buffer_access() + .with_common_response(); + + fixture + .registry + .opt_out() + .no_serializing() + .no_deserializing() + .register_node_builder( + NodeBuilderOptions::new("count_json_buffer_entries".to_owned()), + |builder, _config: ()| { + builder.create_node(count_json_buffer_entries.into_blocking_callback()) + }, + ) + .with_buffer_access() + .with_common_response(); + + fixture + .registry + .opt_out() + .no_serializing() + .no_deserializing() + .register_node_builder( + NodeBuilderOptions::new("listen_count_json_buffer_entries".to_owned()), + |builder, _config: ()| { + builder.create_node(listen_count_json_buffer_entries.into_blocking_callback()) + }, + ) + .with_listen() + .with_common_response(); + + fixture + .registry + .opt_out() + .no_serializing() + .no_deserializing() + .register_node_builder( + NodeBuilderOptions::new("count_any_buffer_entries".to_owned()), + |builder, _config: ()| { + builder.create_node(count_any_buffer_entries.into_blocking_callback()) + }, + ) + .with_buffer_access() + .with_common_response(); + + fixture + .registry + .opt_out() + .no_serializing() + .no_deserializing() + .register_node_builder( + NodeBuilderOptions::new("listen_count_any_buffer_entries".to_owned()), + |builder, _config: ()| { + builder.create_node(listen_count_any_buffer_entries.into_blocking_callback()) + }, + ) + .with_listen() + .with_common_response(); + + // TODO(@mxgrey): Replace these with a general deserializing operation + fixture.registry.register_node_builder( + NodeBuilderOptions::new("deserialize_i64"), + |builder, _config: ()| { + builder + .create_map_block(|msg: JsonMessage| msg.as_number().unwrap().as_i64().unwrap()) + }, + ); + + fixture.registry.register_node_builder( + NodeBuilderOptions::new("json_split_i64"), + |builder, _config: ()| { + builder.create_map_block(|(_, msg): (JsonPosition, JsonMessage)| { + msg.as_number().unwrap().as_i64().unwrap() + }) + }, + ); + + fixture.registry.register_node_builder( + NodeBuilderOptions::new("json_split_string"), + |builder, _config: ()| { + builder.create_map_block(|(_, msg): (JsonPosition, JsonMessage)| { + msg.as_str().unwrap().to_owned() + }) + }, + ); + + fixture + } + + fn count_json_buffer_entries( + In(((), key)): In<((), JsonBufferKey)>, + world: &mut World, + ) -> usize { + world.json_buffer_view(&key).unwrap().len() + } + + fn listen_count_json_buffer_entries(In(key): In, world: &mut World) -> usize { + world.json_buffer_view(&key).unwrap().len() + } + + fn count_any_buffer_entries(In(((), key)): In<((), AnyBufferKey)>, world: &mut World) -> usize { + world.any_buffer_view(&key).unwrap().len() + } + + fn listen_count_any_buffer_entries(In(key): In, world: &mut World) -> usize { + world.any_buffer_view(&key).unwrap().len() + } + + #[test] + fn test_buffer_mismatch_type() { + let mut fixture = new_fixture(); + fixture + .registry + .register_node_builder( + NodeBuilderOptions::new("join_i64"), + |builder, _config: ()| builder.create_map_block(|i: Vec| i[0]), + ) + .with_join(); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "start": "string_output", + "ops": { + "string_output": { + "type": "node", + "builder": "string_output", + "next": "buffer", + }, + "buffer": { + "type": "buffer", + }, + "join": { + "type": "join", + "buffers": ["buffer"], + "target_node": "op1", + "next": "op1", + }, + "op1": { + "type": "node", + "builder": "join_i64", + "next": { "builtin": "terminate" }, + }, + }, + })) + .unwrap(); + + let err = fixture.spawn_json_io_workflow(&diagram).unwrap_err(); + assert!( + matches!(err.code, DiagramErrorCode::IncompatibleBuffers(_)), + "{:#?}", + err + ); + } + + #[test] + fn test_buffer_multiple_inputs() { + let mut fixture = new_fixture(); + fixture + .registry + .opt_out() + .no_serializing() + .no_deserializing() + .register_node_builder( + NodeBuilderOptions::new("wait_2_strings"), + |builder, _config: ()| { + let n = builder.create_node( + (|In(req): In>>, access: BufferAccess| { + if access.get(&req[0]).unwrap().len() < 2 { + None + } else { + Some("hello world".to_string()) + } + }) + .into_blocking_callback(), + ); + let output = n.output.chain(builder).dispose_on_none().output(); + Node::>, String> { + input: n.input, + output, + streams: n.streams, + } + }, + ) + .with_listen(); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "start": "fork_clone", + "ops": { + "fork_clone": { + "type": "fork_clone", + "next": ["string_output", "string_output2"], + }, + "string_output": { + "type": "node", + "builder": "string_output", + "next": "buffer", + }, + "string_output2": { + "type": "node", + "builder": "string_output", + "next": "buffer", + }, + "buffer": { + "type": "buffer", + "settings": { + "retention": "keep_all", + }, + }, + "listen": { + "type": "listen", + "buffers": ["buffer"], + "target_node": "wait_2_strings", + "next": "wait_2_strings", + }, + "wait_2_strings": { + "type": "node", + "builder": "wait_2_strings", + "next": { "builtin": "terminate" }, + }, + }, + })) + .unwrap(); + + let result = fixture + .spawn_and_run(&diagram, serde_json::Value::Null) + .unwrap(); + assert!(fixture.context.no_unhandled_errors()); + assert_eq!(result, "hello world"); + } + + #[test] + fn test_buffer_access() { + let mut fixture = new_fixture(); + + fixture + .registry + .opt_out() + .no_serializing() + .no_deserializing() + .register_node_builder( + NodeBuilderOptions::new("with_buffer_access"), + |builder, _config: ()| { + builder.create_map_block(|req: (i64, Vec>)| req.0) + }, + ) + .with_buffer_access(); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "start": "fork_clone", + "ops": { + "fork_clone": { + "type": "fork_clone", + "next": ["num_output", "string_output"], + }, + "num_output": { + "type": "node", + "builder": "num_output", + "next": "buffer_access", + }, + "string_output": { + "type": "node", + "builder": "string_output", + "next": "string_buffer", + }, + "string_buffer": { + "type": "buffer", + }, + "buffer_access": { + "type": "buffer_access", + "buffers": ["string_buffer"], + "target_node": "with_buffer_access", + "next": "with_buffer_access", + }, + "with_buffer_access": { + "type": "node", + "builder": "with_buffer_access", + "next": { "builtin": "terminate" }, + }, + }, + })) + .unwrap(); + + let result = fixture + .spawn_and_run(&diagram, serde_json::Value::Null) + .unwrap(); + assert!(fixture.context.no_unhandled_errors()); + assert_eq!(result, 1); + } + + #[test] + fn test_json_buffer_access() { + let mut fixture = new_fixture(); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "start": "fork_input", + "ops": { + "fork_input": { + "type": "fork_clone", + "next": [ + "json_buffer", + "insert_access", + ], + }, + "json_buffer": { + "type": "buffer", + "settings": { "retention": "keep_all" }, + }, + "insert_access": { + "type": "buffer_access", + "buffers": ["json_buffer"], + "next": "insert", + }, + "insert": { + "type": "node", + "builder": "insert_json_buffer_entries", + "config": 10, + "next": "count_access", + }, + "count_access": { + "type": "buffer_access", + "buffers": "json_buffer", + "next": "count", + }, + "count": { + "type": "node", + "builder": "count_json_buffer_entries", + "next": { "builtin": "terminate" }, + }, + } + })) + .unwrap(); + + let result = fixture + .spawn_and_run(&diagram, serde_json::Value::String("hello".to_owned())) + .unwrap(); + assert!(fixture.context.no_unhandled_errors()); + assert_eq!(result, 11); + } + + #[test] + fn test_any_buffer_access() { + let mut fixture = new_fixture(); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "start": "fork_input", + "ops": { + "fork_input": { + "type": "fork_clone", + "next": [ + "json_buffer", + "insert_access", + ], + }, + "json_buffer": { + "type": "buffer", + "settings": { "retention": "keep_all" }, + }, + "insert_access": { + "type": "buffer_access", + "buffers": ["json_buffer"], + "next": "insert", + }, + "insert": { + "type": "node", + "builder": "insert_json_buffer_entries", + "config": 10, + "next": "count_access", + }, + "count_access": { + "type": "buffer_access", + "buffers": "json_buffer", + "next": "count", + }, + "count": { + "type": "node", + "builder": "count_any_buffer_entries", + "next": { "builtin": "terminate" }, + }, + } + })) + .unwrap(); + + let result = fixture + .spawn_and_run(&diagram, serde_json::Value::String("hello".to_owned())) + .unwrap(); + assert!(fixture.context.no_unhandled_errors()); + assert_eq!(result, 11); + } + + #[test] + fn test_generic_listen() { + let mut fixture = new_fixture(); + + fn count_generic_buffer( + In(key): In>, + mut access: BufferAccessMut, + ) -> i64 { + access.get_mut(&key).unwrap().pull().unwrap() + } + + fixture + .registry + .opt_out() + .no_serializing() + .no_deserializing() + .register_node_builder( + NodeBuilderOptions::new("pull_generic_buffer"), + |builder, _config: ()| { + builder.create_node(count_generic_buffer.into_blocking_callback()) + }, + ) + .with_listen() + .with_common_response(); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "start": "deserialize", + "ops": { + "deserialize": { + "type": "node", + "builder": "deserialize_i64", + "next": "buffer", + }, + "buffer": { "type": "buffer" }, + "listen": { + "type": "listen", + "buffers": "buffer", + "next": "count", + }, + "count": { + "type": "node", + "builder": "pull_generic_buffer", + "next": { "builtin": "terminate" }, + }, + }, + })) + .unwrap(); + + let result = fixture + .spawn_and_run(&diagram, JsonMessage::Number(5_i64.into())) + .unwrap(); + assert!(fixture.context.no_unhandled_errors()); + assert_eq!(result, 5_i64); + } + + #[test] + fn test_vec_listen() { + let mut fixture = new_fixture(); + + fn listen_buffer(In(request): In>>, access: BufferAccess) -> usize { + access.get(&request[0]).unwrap().len() + } + + fixture + .registry + .opt_out() + .no_serializing() + .no_deserializing() + .register_node_builder( + NodeBuilderOptions::new("listen_buffer"), + |builder, _config: ()| -> Node>, usize, ()> { + builder.create_node(listen_buffer.into_blocking_callback()) + }, + ) + .with_listen() + .with_common_response(); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "start": "num_output", + "ops": { + "buffer": { + "type": "buffer", + }, + "num_output": { + "type": "node", + "builder": "num_output", + "next": "buffer", + }, + "listen": { + "type": "listen", + "buffers": ["buffer"], + "target_node": "listen_buffer", + "next": "listen_buffer", + }, + "listen_buffer": { + "type": "node", + "builder": "listen_buffer", + "next": { "builtin": "terminate" }, + }, + }, + })) + .unwrap(); + + let result = fixture.spawn_and_run(&diagram, JsonMessage::Null).unwrap(); + assert!(fixture.context.no_unhandled_errors()); + assert_eq!(result, 1); + } + + #[test] + fn test_json_buffer_listen() { + let mut fixture = new_fixture(); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "start": "buffer", + "ops": { + "buffer": { "type": "buffer" }, + "listen": { + "type": "listen", + "buffers": "buffer", + "next": "count", + }, + "count": { + "type": "node", + "builder": "listen_count_json_buffer_entries", + "next": { "builtin": "terminate" }, + }, + }, + })) + .unwrap(); + + let result = fixture.spawn_and_run(&diagram, JsonMessage::Null).unwrap(); + assert!(fixture.context.no_unhandled_errors()); + assert_eq!(result, 1); + } + + #[test] + fn test_any_buffer_listen() { + let mut fixture = new_fixture(); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "start": "buffer", + "ops": { + "buffer": { "type": "buffer" }, + "listen": { + "type": "listen", + "buffers": "buffer", + "next": "count", + }, + "count": { + "type": "node", + "builder": "listen_count_any_buffer_entries", + "next": { "builtin": "terminate" }, + }, + }, + })) + .unwrap(); + + let result = fixture.spawn_and_run(&diagram, JsonMessage::Null).unwrap(); + assert!(fixture.context.no_unhandled_errors()); + assert_eq!(result, 1); + } + + #[derive(Accessor, Clone)] + struct TestAccessor { + integer: BufferKey, + string: BufferKey, + json: JsonBufferKey, + any: AnyBufferKey, + } + + #[test] + fn test_struct_accessor_access() { + let mut fixture = new_fixture(); + + let input = JsonMessage::Object(serde_json::Map::from_iter([ + ("integer".to_owned(), JsonMessage::Number(5_i64.into())), + ("string".to_owned(), JsonMessage::String("hello".to_owned())), + ])); + + let expected = input.clone(); + + // TODO(@mxgrey): Replace this with a builtin trigger operation + fixture + .registry + .register_node_builder(NodeBuilderOptions::new("trigger"), |builder, _: ()| { + builder.create_map_block(|_: JsonMessage| ()) + }); + + fixture + .registry + .opt_out() + .no_serializing() + .no_deserializing() + .register_node_builder( + NodeBuilderOptions::new("check_for_all"), + move |builder, _config: ()| { + let expected = expected.clone(); + builder.create_node( + (move |In((_, keys)): In<((), TestAccessor)>, world: &mut World| { + wait_for_all(keys, world, &expected) + }) + .into_blocking_callback(), + ) + }, + ) + .with_buffer_access() + .with_fork_result() + .with_common_response(); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "start": "fork", + "ops": { + "fork": { + "type": "fork_clone", + "next": [ + "split", + "json_buffer", + "any_buffer", + "trigger", + ], + }, + "split": { + "type": "split", + "keyed": { + "integer": "push_integer", + "string": "push_string", + }, + }, + "push_integer": { + "type": "node", + "builder": "json_split_i64", + "next": "integer_buffer", + }, + "push_string": { + "type": "node", + "builder": "json_split_string", + "next": "string_buffer", + }, + "integer_buffer": { "type": "buffer" }, + "string_buffer": { "type": "buffer" }, + "json_buffer": { "type": "buffer" }, + "any_buffer": { "type": "buffer" }, + "trigger": { + "type": "node", + "builder": "trigger", + "next": "access", + }, + "access": { + "type": "buffer_access", + "buffers": { + "integer": "integer_buffer", + "string": "string_buffer", + "json": "json_buffer", + "any": "any_buffer", + }, + "next": "check_for_all" + }, + "check_for_all": { + "type": "node", + "builder": "check_for_all", + "next": "filter", + }, + "filter": { + "type": "fork_result", + "ok": { "builtin": "terminate" }, + "err": "access", + }, + }, + })) + .unwrap(); + + let result = fixture.spawn_and_run(&diagram, input).unwrap(); + assert!(fixture.context.no_unhandled_errors()); + assert_eq!(result, JsonMessage::Null); + } + + #[test] + fn test_struct_accessor_listen() { + let mut fixture = new_fixture(); + + let input = JsonMessage::Object(serde_json::Map::from_iter([ + ("integer".to_owned(), JsonMessage::Number(5_i64.into())), + ("string".to_owned(), JsonMessage::String("hello".to_owned())), + ])); + + let expected = input.clone(); + + fixture + .registry + .opt_out() + .no_serializing() + .no_deserializing() + .register_node_builder( + NodeBuilderOptions::new("listen_for_all"), + move |builder, _config: ()| { + let expected = expected.clone(); + builder.create_node( + (move |In(keys): In, world: &mut World| { + wait_for_all(keys, world, &expected) + }) + .into_blocking_callback(), + ) + }, + ) + .with_listen() + .with_fork_result() + .with_common_response(); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "start": "fork", + "ops": { + "fork": { + "type": "fork_clone", + "next": [ + "split", + "json_buffer", + "any_buffer", + ], + }, + "split": { + "type": "split", + "keyed": { + "integer": "push_integer", + "string": "push_string", + }, + }, + "push_integer": { + "type": "node", + "builder": "json_split_i64", + "next": "integer_buffer", + }, + "push_string": { + "type": "node", + "builder": "json_split_string", + "next": "string_buffer", + }, + "integer_buffer": { "type": "buffer" }, + "string_buffer": { "type": "buffer" }, + "json_buffer": { "type": "buffer" }, + "any_buffer": { "type": "buffer" }, + "listen": { + "type": "listen", + "buffers": { + "integer": "integer_buffer", + "string": "string_buffer", + "json": "json_buffer", + "any": "any_buffer", + }, + "next": "listen_for_all" + }, + "listen_for_all": { + "type": "node", + "builder": "listen_for_all", + "next": "filter", + }, + "filter": { + "type": "fork_result", + "ok": { "builtin": "terminate" }, + "err": { "builtin": "dispose" }, + }, + }, + })) + .unwrap(); + + let result = fixture.spawn_and_run(&diagram, input).unwrap(); + assert!(fixture.context.no_unhandled_errors()); + assert_eq!(result, JsonMessage::Null); + } + + fn wait_for_all( + keys: TestAccessor, + world: &mut World, + expected: &JsonMessage, + ) -> Result<(), ()> { + if let Some(integer) = world.buffer_view(&keys.integer).unwrap().newest() { + assert_eq!(*integer, 5); + } else { + return Err(()); + } + + if let Some(string) = world.buffer_view(&keys.string).unwrap().newest() { + assert_eq!(string, "hello"); + } else { + return Err(()); + } + + if let Ok(Some(json)) = world.json_buffer_view(&keys.json).unwrap().newest() { + assert_eq!(&json, expected); + } else { + return Err(()); + } + + if let Some(any) = world.any_buffer_view(&keys.any).unwrap().newest() { + assert_eq!(any.downcast_ref::().unwrap(), expected); + } else { + return Err(()); + } + + Ok(()) + } +} diff --git a/src/diagram/fork_clone.rs b/src/diagram/fork_clone.rs deleted file mode 100644 index d6b8dc34..00000000 --- a/src/diagram/fork_clone.rs +++ /dev/null @@ -1,134 +0,0 @@ -use std::any::TypeId; - -use schemars::JsonSchema; -use serde::{Deserialize, Serialize}; -use tracing::debug; - -use crate::Builder; - -use super::{ - impls::{DefaultImpl, NotSupported}, - DiagramError, DynOutput, NextOperation, -}; - -#[derive(Debug, Serialize, Deserialize, JsonSchema)] -#[serde(rename_all = "snake_case")] -pub struct ForkCloneOp { - pub(super) next: Vec, -} - -pub trait DynForkClone { - const CLONEABLE: bool; - - fn dyn_fork_clone( - builder: &mut Builder, - output: DynOutput, - amount: usize, - ) -> Result, DiagramError>; -} - -impl DynForkClone for NotSupported { - const CLONEABLE: bool = false; - - fn dyn_fork_clone( - _builder: &mut Builder, - _output: DynOutput, - _amount: usize, - ) -> Result, DiagramError> { - Err(DiagramError::NotCloneable) - } -} - -impl DynForkClone for DefaultImpl -where - T: Send + Sync + 'static + Clone, -{ - const CLONEABLE: bool = true; - - fn dyn_fork_clone( - builder: &mut Builder, - output: DynOutput, - amount: usize, - ) -> Result, DiagramError> { - debug!("fork clone: {:?}", output); - assert_eq!(output.type_id, TypeId::of::()); - - let fork_clone = output.into_output::()?.fork_clone(builder); - let outputs = (0..amount) - .map(|_| fork_clone.clone_output(builder).into()) - .collect(); - debug!("forked outputs: {:?}", outputs); - Ok(outputs) - } -} - -#[cfg(test)] -mod tests { - use serde_json::json; - use test_log::test; - - use crate::{diagram::testing::DiagramTestFixture, Diagram}; - - use super::*; - - #[test] - fn test_fork_clone_uncloneable() { - let mut fixture = DiagramTestFixture::new(); - - let diagram = Diagram::from_json(json!({ - "version": "0.1.0", - "start": "op1", - "ops": { - "op1": { - "type": "node", - "builder": "multiply3_uncloneable", - "next": "fork_clone" - }, - "fork_clone": { - "type": "fork_clone", - "next": ["op2"] - }, - "op2": { - "type": "node", - "builder": "multiply3_uncloneable", - "next": { "builtin": "terminate" }, - }, - }, - })) - .unwrap(); - let err = fixture.spawn_io_workflow(&diagram).unwrap_err(); - assert!(matches!(err, DiagramError::NotCloneable), "{:?}", err); - } - - #[test] - fn test_fork_clone() { - let mut fixture = DiagramTestFixture::new(); - - let diagram = Diagram::from_json(json!({ - "version": "0.1.0", - "start": "op1", - "ops": { - "op1": { - "type": "node", - "builder": "multiply3", - "next": "fork_clone" - }, - "fork_clone": { - "type": "fork_clone", - "next": ["op2"] - }, - "op2": { - "type": "node", - "builder": "multiply3", - "next": { "builtin": "terminate" }, - }, - }, - })) - .unwrap(); - - let result = fixture - .spawn_and_run(&diagram, serde_json::Value::from(4)) - .unwrap(); - assert_eq!(result, 36); - } -} diff --git a/src/diagram/fork_clone_schema.rs b/src/diagram/fork_clone_schema.rs new file mode 100644 index 00000000..3dda14d2 --- /dev/null +++ b/src/diagram/fork_clone_schema.rs @@ -0,0 +1,192 @@ +/* + * Copyright (C) 2025 Open Source Robotics Foundation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * +*/ + +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; + +use crate::{Builder, ForkCloneOutput}; + +use super::{ + supported::*, BuildDiagramOperation, BuildStatus, DiagramContext, DiagramErrorCode, + DynInputSlot, DynOutput, NextOperation, OperationId, +}; + +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +#[serde(rename_all = "snake_case")] +pub struct ForkCloneSchema { + pub(super) next: Vec, +} + +impl BuildDiagramOperation for ForkCloneSchema { + fn build_diagram_operation( + &self, + id: &OperationId, + builder: &mut Builder, + ctx: &mut DiagramContext, + ) -> Result { + let Some(inferred_type) = ctx.infer_input_type_into_target(id) else { + // There are no outputs ready for this target, so we can't do + // anything yet. The builder should try again later. + return Ok(BuildStatus::defer("waiting for an input")); + }; + + let fork = ctx.registry.messages.fork_clone(inferred_type, builder)?; + ctx.set_input_for_target(id, fork.input)?; + for target in &self.next { + ctx.add_output_into_target(target.clone(), fork.outputs.clone_output(builder)); + } + + Ok(BuildStatus::Finished) + } +} + +pub trait PerformForkClone { + const CLONEABLE: bool; + + fn perform_fork_clone(builder: &mut Builder) -> Result; +} + +impl PerformForkClone for NotSupported { + const CLONEABLE: bool = false; + + fn perform_fork_clone(_builder: &mut Builder) -> Result { + Err(DiagramErrorCode::NotCloneable) + } +} + +impl PerformForkClone for Supported +where + T: Send + Sync + 'static + Clone, +{ + const CLONEABLE: bool = true; + + fn perform_fork_clone(builder: &mut Builder) -> Result { + let (input, outputs) = builder.create_fork_clone::(); + + Ok(DynForkClone { + input: input.into(), + outputs: DynForkCloneOutput::new(outputs), + }) + } +} + +pub struct DynForkClone { + pub input: DynInputSlot, + pub outputs: DynForkCloneOutput, +} + +pub struct DynForkCloneOutput { + inner: Box, +} + +impl DynForkCloneOutput { + pub fn new(inner: impl DynamicClone + 'static) -> Self { + Self { + inner: Box::new(inner), + } + } + + pub fn clone_output(&self, builder: &mut Builder) -> DynOutput { + self.inner.dyn_clone_output(builder) + } +} + +pub trait DynamicClone { + fn dyn_clone_output(&self, builder: &mut Builder) -> DynOutput; +} + +impl DynamicClone for ForkCloneOutput { + fn dyn_clone_output(&self, builder: &mut Builder) -> DynOutput { + self.clone_output(builder).into() + } +} + +#[cfg(test)] +mod tests { + use serde_json::json; + use test_log::test; + + use crate::{diagram::testing::DiagramTestFixture, Diagram}; + + use super::*; + + #[test] + fn test_fork_clone_uncloneable() { + let mut fixture = DiagramTestFixture::new(); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "start": "op1", + "ops": { + "op1": { + "type": "node", + "builder": "multiply3_uncloneable", + "next": "fork_clone" + }, + "fork_clone": { + "type": "fork_clone", + "next": ["op2"] + }, + "op2": { + "type": "node", + "builder": "multiply3_uncloneable", + "next": { "builtin": "terminate" }, + }, + }, + })) + .unwrap(); + let err = fixture.spawn_json_io_workflow(&diagram).unwrap_err(); + assert!( + matches!(err.code, DiagramErrorCode::NotCloneable), + "{:?}", + err + ); + } + + #[test] + fn test_fork_clone() { + let mut fixture = DiagramTestFixture::new(); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "start": "op1", + "ops": { + "op1": { + "type": "node", + "builder": "multiply3", + "next": "fork_clone" + }, + "fork_clone": { + "type": "fork_clone", + "next": ["op2"] + }, + "op2": { + "type": "node", + "builder": "multiply3", + "next": { "builtin": "terminate" }, + }, + }, + })) + .unwrap(); + + let result = fixture + .spawn_and_run(&diagram, serde_json::Value::from(4)) + .unwrap(); + assert!(fixture.context.no_unhandled_errors()); + assert_eq!(result, 36); + } +} diff --git a/src/diagram/fork_result.rs b/src/diagram/fork_result.rs deleted file mode 100644 index 6decaddf..00000000 --- a/src/diagram/fork_result.rs +++ /dev/null @@ -1,133 +0,0 @@ -use schemars::JsonSchema; -use serde::{Deserialize, Serialize}; -use tracing::debug; - -use crate::Builder; - -use super::{ - impls::{DefaultImpl, NotSupported}, - DiagramError, DynOutput, NextOperation, -}; - -#[derive(Debug, Serialize, Deserialize, JsonSchema)] -#[serde(rename_all = "snake_case")] -pub struct ForkResultOp { - pub(super) ok: NextOperation, - pub(super) err: NextOperation, -} - -pub trait DynForkResult { - const SUPPORTED: bool; - - fn dyn_fork_result( - builder: &mut Builder, - output: DynOutput, - ) -> Result<(DynOutput, DynOutput), DiagramError>; -} - -impl DynForkResult for NotSupported { - const SUPPORTED: bool = false; - - fn dyn_fork_result( - _builder: &mut Builder, - _output: DynOutput, - ) -> Result<(DynOutput, DynOutput), DiagramError> { - Err(DiagramError::CannotForkResult) - } -} - -impl DynForkResult> for DefaultImpl -where - T: Send + Sync + 'static, - E: Send + Sync + 'static, -{ - const SUPPORTED: bool = true; - - fn dyn_fork_result( - builder: &mut Builder, - output: DynOutput, - ) -> Result<(DynOutput, DynOutput), DiagramError> { - debug!("fork result: {:?}", output); - - let chain = output.into_output::>()?.chain(builder); - let outputs = chain.fork_result(|c| c.output().into(), |c| c.output().into()); - debug!("forked outputs: {:?}", outputs); - Ok(outputs) - } -} - -#[cfg(test)] -mod tests { - use serde_json::json; - use test_log::test; - - use crate::{diagram::testing::DiagramTestFixture, Builder, Diagram, NodeBuilderOptions}; - - #[test] - fn test_fork_result() { - let mut fixture = DiagramTestFixture::new(); - - fn check_even(v: i64) -> Result { - if v % 2 == 0 { - Ok("even".to_string()) - } else { - Err("odd".to_string()) - } - } - - fixture - .registry - .register_node_builder( - NodeBuilderOptions::new("check_even".to_string()), - |builder: &mut Builder, _config: ()| builder.create_map_block(&check_even), - ) - .with_fork_result(); - - fn echo(s: String) -> String { - s - } - - fixture.registry.register_node_builder( - NodeBuilderOptions::new("echo".to_string()), - |builder: &mut Builder, _config: ()| builder.create_map_block(&echo), - ); - - let diagram = Diagram::from_json(json!({ - "version": "0.1.0", - "start": "op1", - "ops": { - "op1": { - "type": "node", - "builder": "check_even", - "next": "fork_result", - }, - "fork_result": { - "type": "fork_result", - "ok": "op2", - "err": "op3", - }, - "op2": { - "type": "node", - "builder": "echo", - "next": { "builtin": "terminate" }, - }, - "op3": { - "type": "node", - "builder": "echo", - "next": { "builtin": "terminate" }, - }, - }, - })) - .unwrap(); - - let result = fixture - .spawn_and_run(&diagram, serde_json::Value::from(4)) - .unwrap(); - assert_eq!(result, "even"); - - let result = fixture - .spawn_and_run(&diagram, serde_json::Value::from(3)) - .unwrap(); - assert_eq!(result, "odd"); - } -} diff --git a/src/diagram/fork_result_schema.rs b/src/diagram/fork_result_schema.rs new file mode 100644 index 00000000..3d6c975d --- /dev/null +++ b/src/diagram/fork_result_schema.rs @@ -0,0 +1,179 @@ +/* + * Copyright (C) 2025 Open Source Robotics Foundation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * +*/ + +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; + +use crate::Builder; + +use super::{ + supported::*, type_info::TypeInfo, BuildDiagramOperation, BuildStatus, DiagramContext, + DiagramErrorCode, DynInputSlot, DynOutput, MessageRegistration, MessageRegistry, NextOperation, + OperationId, PerformForkClone, SerializeMessage, +}; + +pub struct DynForkResult { + pub input: DynInputSlot, + pub ok: DynOutput, + pub err: DynOutput, +} + +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +#[serde(rename_all = "snake_case")] +pub struct ForkResultSchema { + pub(super) ok: NextOperation, + pub(super) err: NextOperation, +} + +impl BuildDiagramOperation for ForkResultSchema { + fn build_diagram_operation( + &self, + id: &OperationId, + builder: &mut Builder, + ctx: &mut DiagramContext, + ) -> Result { + let Some(inferred_type) = ctx.infer_input_type_into_target(id) else { + // There are no outputs ready for this target, so we can't do + // anything yet. The builder should try again later. + return Ok(BuildStatus::defer("waiting for an input")); + }; + + let fork = ctx.registry.messages.fork_result(inferred_type, builder)?; + ctx.set_input_for_target(id, fork.input)?; + ctx.add_output_into_target(self.ok.clone(), fork.ok); + ctx.add_output_into_target(self.err.clone(), fork.err); + Ok(BuildStatus::Finished) + } +} + +pub trait RegisterForkResult { + fn on_register(registry: &mut MessageRegistry) -> bool; +} + +impl RegisterForkResult for Supported<(Result, S, C)> +where + T: Send + Sync + 'static, + E: Send + Sync + 'static, + S: SerializeMessage + SerializeMessage, + C: PerformForkClone + PerformForkClone, +{ + fn on_register(registry: &mut MessageRegistry) -> bool { + let ops = &mut registry + .messages + .entry(TypeInfo::of::>()) + .or_insert(MessageRegistration::new::()) + .operations; + if ops.fork_result_impl.is_some() { + return false; + } + + ops.fork_result_impl = Some(|builder| { + let (input, outputs) = builder.create_fork_result::(); + Ok(DynForkResult { + input: input.into(), + ok: outputs.ok.into(), + err: outputs.err.into(), + }) + }); + + registry.register_serialize::(); + registry.register_fork_clone::(); + + registry.register_serialize::(); + registry.register_fork_clone::(); + + true + } +} + +#[cfg(test)] +mod tests { + use serde_json::json; + use test_log::test; + + use crate::{diagram::testing::DiagramTestFixture, Builder, Diagram, NodeBuilderOptions}; + + #[test] + fn test_fork_result() { + let mut fixture = DiagramTestFixture::new(); + + fn check_even(v: i64) -> Result { + if v % 2 == 0 { + Ok("even".to_string()) + } else { + Err("odd".to_string()) + } + } + + fixture + .registry + .register_node_builder( + NodeBuilderOptions::new("check_even".to_string()), + |builder: &mut Builder, _config: ()| builder.create_map_block(&check_even), + ) + .with_fork_result(); + + fn echo(s: String) -> String { + s + } + + fixture.registry.register_node_builder( + NodeBuilderOptions::new("echo".to_string()), + |builder: &mut Builder, _config: ()| builder.create_map_block(&echo), + ); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "start": "op1", + "ops": { + "op1": { + "type": "node", + "builder": "check_even", + "next": "fork_result", + }, + "fork_result": { + "type": "fork_result", + "ok": "op2", + "err": "op3", + }, + "op2": { + "type": "node", + "builder": "echo", + "next": { "builtin": "terminate" }, + }, + "op3": { + "type": "node", + "builder": "echo", + "next": { "builtin": "terminate" }, + }, + }, + })) + .unwrap(); + + let result = fixture + .spawn_and_run(&diagram, serde_json::Value::from(4)) + .unwrap(); + assert!(fixture.context.no_unhandled_errors()); + assert_eq!(result, "even"); + + let result = fixture + .spawn_and_run(&diagram, serde_json::Value::from(3)) + .unwrap(); + assert!(fixture.context.no_unhandled_errors()); + assert_eq!(result, "odd"); + } +} diff --git a/src/diagram/impls.rs b/src/diagram/impls.rs deleted file mode 100644 index 9412fbb5..00000000 --- a/src/diagram/impls.rs +++ /dev/null @@ -1,25 +0,0 @@ -use std::marker::PhantomData; - -/// A struct to provide the default implementation for various operations. -pub struct DefaultImpl; - -/// A struct to provide the default implementation for various operations. -pub struct DefaultImplMarker { - _unused: PhantomData, -} - -impl DefaultImplMarker { - pub(super) fn new() -> Self { - Self { - _unused: Default::default(), - } - } -} - -/// A struct to provide "not supported" implementations for various operations. -pub struct NotSupported; - -/// A struct to provide "not supported" implementations for various operations. -pub struct NotSupportedMarker { - _unused: PhantomData, -} diff --git a/src/diagram/join.rs b/src/diagram/join.rs deleted file mode 100644 index cdba36be..00000000 --- a/src/diagram/join.rs +++ /dev/null @@ -1,306 +0,0 @@ -use schemars::JsonSchema; -use serde::{Deserialize, Serialize}; -use smallvec::SmallVec; -use tracing::debug; - -use crate::{Builder, IterBufferable, Output}; - -use super::{ - DiagramError, DynOutput, MessageRegistry, NextOperation, SerializeMessage, SourceOperation, -}; - -#[derive(Debug, Serialize, Deserialize, JsonSchema)] -#[serde(rename_all = "snake_case")] -pub struct JoinOp { - pub(super) next: NextOperation, - - /// Controls the order of the resulting join. Each item must be an operation id of one of the - /// incoming outputs. - pub(super) inputs: Vec, - - /// Do not serialize before performing the join. If true, joins can only be done - /// on outputs of the same type. - pub(super) no_serialize: Option, -} - -pub(super) fn register_join_impl(registry: &mut MessageRegistry) -where - T: Send + Sync + 'static, - Serializer: SerializeMessage>, -{ - registry.register_join::(Box::new(join_impl::)); -} - -/// Serialize the outputs before joining them, and convert the resulting joined output into a -/// [`serde_json::Value`]. -pub(super) fn serialize_and_join( - builder: &mut Builder, - registry: &MessageRegistry, - outputs: Vec, -) -> Result, DiagramError> { - debug!("serialize and join outputs {:?}", outputs); - - if outputs.is_empty() { - // do not allow empty joins - return Err(DiagramError::EmptyJoin); - } - - let outputs = outputs - .into_iter() - .map(|o| registry.serialize(builder, o)) - .collect::, DiagramError>>()?; - - // we need to convert the joined output to [`serde_json::Value`] in order for it to be - // serializable. - let joined_output = outputs.join_vec::<4>(builder).output(); - let json_output = joined_output - .chain(builder) - .map_block(|o| serde_json::to_value(o)) - .cancel_on_err() - .output(); - Ok(json_output) -} - -fn join_impl(builder: &mut Builder, outputs: Vec) -> Result -where - T: Send + Sync + 'static, -{ - debug!("join outputs {:?}", outputs); - - if outputs.is_empty() { - // do a empty join, in practice, this branch is never ran because [`WorkflowBuilder`] - // should error out if there is an empty join. - return Err(DiagramError::EmptyJoin); - } - - let first_type = outputs[0].type_id; - - let outputs = outputs - .into_iter() - .map(|o| { - if o.type_id != first_type { - Err(DiagramError::TypeMismatch) - } else { - Ok(o.into_output::()?) - } - }) - .collect::, _>>()?; - - // we don't know the number of items at compile time, so we just use a sensible number. - // NOTE: Be sure to update `JoinOutput` if this changes. - Ok(outputs.join_vec::<4>(builder).output().into()) -} - -/// The resulting type of a `join` operation. Nodes receiving a join output must have request -/// of this type. Note that the join output is NOT serializable. If you would like to serialize it, -/// convert it to a `Vec` first. -pub type JoinOutput = SmallVec<[T; 4]>; - -#[cfg(test)] -mod tests { - use serde_json::json; - use test_log::test; - - use super::*; - use crate::{ - diagram::testing::DiagramTestFixture, Diagram, DiagramError, JsonPosition, - NodeBuilderOptions, - }; - - #[test] - fn test_join() { - let mut fixture = DiagramTestFixture::new(); - - fn get_split_value(pair: (JsonPosition, serde_json::Value)) -> serde_json::Value { - pair.1 - } - - fixture.registry.register_node_builder( - NodeBuilderOptions::new("get_split_value".to_string()), - |builder, _config: ()| builder.create_map_block(get_split_value), - ); - - fn serialize_join_output(join_output: JoinOutput) -> serde_json::Value { - serde_json::to_value(join_output).unwrap() - } - - fixture - .registry - .opt_out() - .no_request_deserializing() - .register_node_builder( - NodeBuilderOptions::new("serialize_join_output".to_string()), - |builder, _config: ()| builder.create_map_block(serialize_join_output), - ); - - let diagram = Diagram::from_json(json!({ - "version": "0.1.0", - "start": "split", - "ops": { - "split": { - "type": "split", - "sequential": ["get_split_value1", "get_split_value2"] - }, - "get_split_value1": { - "type": "node", - "builder": "get_split_value", - "next": "op1", - }, - "op1": { - "type": "node", - "builder": "multiply3", - "next": "join", - }, - "get_split_value2": { - "type": "node", - "builder": "get_split_value", - "next": "op2", - }, - "op2": { - "type": "node", - "builder": "multiply3", - "next": "join", - }, - "join": { - "type": "join", - "inputs": ["op1", "op2"], - "next": "serialize_join_output", - "no_serialize": true, - }, - "serialize_join_output": { - "type": "node", - "builder": "serialize_join_output", - "next": { "builtin": "terminate" }, - }, - } - })) - .unwrap(); - - let result = fixture - .spawn_and_run(&diagram, serde_json::Value::from([1, 2])) - .unwrap(); - assert_eq!(result.as_array().unwrap().len(), 2); - assert_eq!(result[0], 3); - assert_eq!(result[1], 6); - } - - /// This test is to ensure that the order of split and join operations are stable. - #[test] - fn test_join_stress() { - for _ in 1..20 { - test_join(); - } - } - - #[test] - fn test_empty_join() { - let mut fixture = DiagramTestFixture::new(); - - fn get_split_value(pair: (JsonPosition, serde_json::Value)) -> serde_json::Value { - pair.1 - } - - fixture.registry.register_node_builder( - NodeBuilderOptions::new("get_split_value".to_string()), - |builder, _config: ()| builder.create_map_block(get_split_value), - ); - - let diagram = Diagram::from_json(json!({ - "version": "0.1.0", - "start": "split", - "ops": { - "split": { - "type": "split", - "sequential": ["get_split_value1", "get_split_value2"] - }, - "get_split_value1": { - "type": "node", - "builder": "get_split_value", - "next": "op1", - }, - "op1": { - "type": "node", - "builder": "multiply3", - "next": { "builtin": "terminate" }, - }, - "get_split_value2": { - "type": "node", - "builder": "get_split_value", - "next": "op2", - }, - "op2": { - "type": "node", - "builder": "multiply3", - "next": { "builtin": "terminate" }, - }, - "join": { - "type": "join", - "inputs": [], - "next": { "builtin": "terminate" }, - "no_serialize": true, - }, - } - })) - .unwrap(); - - let err = fixture.spawn_io_workflow(&diagram).unwrap_err(); - assert!(matches!(err, DiagramError::EmptyJoin)); - } - - #[test] - fn test_serialize_and_join() { - let mut fixture = DiagramTestFixture::new(); - - fn num_output(_: serde_json::Value) -> i64 { - 1 - } - - fixture.registry.register_node_builder( - NodeBuilderOptions::new("num_output".to_string()), - |builder, _config: ()| builder.create_map_block(num_output), - ); - - fn string_output(_: serde_json::Value) -> String { - "hello".to_string() - } - - fixture.registry.register_node_builder( - NodeBuilderOptions::new("string_output".to_string()), - |builder, _config: ()| builder.create_map_block(string_output), - ); - - let diagram = Diagram::from_json(json!({ - "version": "0.1.0", - "start": "fork_clone", - "ops": { - "fork_clone": { - "type": "fork_clone", - "next": ["op1", "op2"] - }, - "op1": { - "type": "node", - "builder": "num_output", - "next": "join", - }, - "op2": { - "type": "node", - "builder": "string_output", - "next": "join", - }, - "join": { - "type": "join", - "inputs": ["op1", "op2"], - "next": { "builtin": "terminate" }, - }, - } - })) - .unwrap(); - - let result = fixture - .spawn_and_run(&diagram, serde_json::Value::Null) - .unwrap(); - assert_eq!(result.as_array().unwrap().len(), 2); - assert_eq!(result[0], 1); - assert_eq!(result[1], "hello"); - } -} diff --git a/src/diagram/join_schema.rs b/src/diagram/join_schema.rs new file mode 100644 index 00000000..d312cd70 --- /dev/null +++ b/src/diagram/join_schema.rs @@ -0,0 +1,535 @@ +/* + * Copyright (C) 2025 Open Source Robotics Foundation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * +*/ + +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use smallvec::SmallVec; + +use crate::{Builder, JsonMessage}; + +use super::{ + BufferInputs, BuildDiagramOperation, BuildStatus, DiagramContext, DiagramErrorCode, + NextOperation, OperationId, +}; + +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +#[serde(rename_all = "snake_case")] +pub struct JoinSchema { + pub(super) next: NextOperation, + + /// Map of buffer keys and buffers. + pub(super) buffers: BufferInputs, + + /// The id of an operation that this operation is for. The id must be a `node` operation. Optional if `next` is a node operation. + pub(super) target_node: Option, +} + +impl BuildDiagramOperation for JoinSchema { + fn build_diagram_operation( + &self, + _: &OperationId, + builder: &mut Builder, + ctx: &mut DiagramContext, + ) -> Result { + if self.buffers.is_empty() { + return Err(DiagramErrorCode::EmptyJoin); + } + + let buffer_map = match ctx.create_buffer_map(&self.buffers) { + Ok(buffer_map) => buffer_map, + Err(reason) => return Ok(BuildStatus::defer(reason)), + }; + + let target_type = ctx.get_node_request_type(self.target_node.as_ref(), &self.next)?; + + let output = ctx + .registry + .messages + .join(&target_type, &buffer_map, builder)?; + ctx.add_output_into_target(self.next.clone(), output); + Ok(BuildStatus::Finished) + } +} + +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +#[serde(rename_all = "snake_case")] +pub struct SerializedJoinSchema { + pub(super) next: NextOperation, + + /// Map of buffer keys and buffers. + pub(super) buffers: BufferInputs, +} + +impl BuildDiagramOperation for SerializedJoinSchema { + fn build_diagram_operation( + &self, + _: &OperationId, + builder: &mut Builder, + ctx: &mut DiagramContext, + ) -> Result { + if self.buffers.is_empty() { + return Err(DiagramErrorCode::EmptyJoin); + } + + let buffer_map = match ctx.create_buffer_map(&self.buffers) { + Ok(buffer_map) => buffer_map, + Err(reason) => return Ok(BuildStatus::defer(reason)), + }; + + let output = builder.try_join::(&buffer_map)?.output(); + ctx.add_output_into_target(self.next.clone(), output.into()); + + Ok(BuildStatus::Finished) + } +} + +/// The resulting type of a `join` operation. Nodes receiving a join output must have request +/// of this type. Note that the join output is NOT serializable. If you would like to serialize it, +/// convert it to a `Vec` first. +pub type JoinOutput = SmallVec<[T; 4]>; + +#[cfg(test)] +mod tests { + use bevy_impulse_derive::Joined; + use serde_json::json; + use test_log::test; + + use super::*; + use crate::{ + diagram::testing::DiagramTestFixture, Diagram, DiagramElementRegistry, DiagramError, + DiagramErrorCode, NodeBuilderOptions, + }; + + fn foo(_: serde_json::Value) -> String { + "foo".to_string() + } + + fn bar(_: serde_json::Value) -> String { + "bar".to_string() + } + + #[derive(Serialize, Deserialize, JsonSchema, Joined)] + struct FooBar { + foo: String, + bar: String, + } + + impl Default for FooBar { + fn default() -> Self { + FooBar { + foo: "foo".to_owned(), + bar: "bar".to_owned(), + } + } + } + + fn foobar(foobar: FooBar) -> String { + format!("{}{}", foobar.foo, foobar.bar) + } + + fn foobar_array(foobar: Vec) -> String { + format!("{}{}", foobar[0], foobar[1]) + } + + fn register_join_nodes(registry: &mut DiagramElementRegistry) { + registry.register_node_builder(NodeBuilderOptions::new("foo"), |builder, _config: ()| { + builder.create_map_block(foo) + }); + registry.register_node_builder(NodeBuilderOptions::new("bar"), |builder, _config: ()| { + builder.create_map_block(bar) + }); + registry + .opt_out() + .no_cloning() + .register_node_builder(NodeBuilderOptions::new("foobar"), |builder, _config: ()| { + builder.create_map_block(foobar) + }) + .with_join(); + registry + .register_node_builder( + NodeBuilderOptions::new("foobar_array"), + |builder, _config: ()| builder.create_map_block(foobar_array), + ) + .with_join(); + registry.opt_out().no_cloning().register_node_builder( + NodeBuilderOptions::new("create_foobar"), + |builder, config: FooBar| { + builder.create_map_block(move |_: JsonMessage| FooBar { + foo: config.foo.clone(), + bar: config.bar.clone(), + }) + }, + ); + } + + #[test] + fn test_join() { + let mut fixture = DiagramTestFixture::new(); + register_join_nodes(&mut fixture.registry); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "start": "fork_clone", + "ops": { + "fork_clone": { + "type": "fork_clone", + "next": ["foo", "bar"], + }, + "foo": { + "type": "node", + "builder": "foo", + "next": "foo_buffer", + }, + "foo_buffer": { + "type": "buffer", + }, + "bar": { + "type": "node", + "builder": "bar", + "next": "bar_buffer", + }, + "bar_buffer": { + "type": "buffer", + }, + "join": { + "type": "join", + "buffers": { + "foo": "foo_buffer", + "bar": "bar_buffer", + }, + "target_node": "foobar", + "next": "foobar", + }, + "foobar": { + "type": "node", + "builder": "foobar", + "next": { "builtin": "terminate" }, + }, + } + })) + .unwrap(); + + let result = fixture + .spawn_and_run(&diagram, serde_json::Value::Null) + .unwrap(); + assert!(fixture.context.no_unhandled_errors()); + assert_eq!(result, "foobar"); + } + + #[test] + /// similar to `test_join`, except the `target_node` field is not provided and the target type is inferred from `next`. + fn test_join_infer_type() { + let mut fixture = DiagramTestFixture::new(); + register_join_nodes(&mut fixture.registry); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "start": "fork_clone", + "ops": { + "fork_clone": { + "type": "fork_clone", + "next": ["foo", "bar"], + }, + "foo": { + "type": "node", + "builder": "foo", + "next": "foo_buffer", + }, + "foo_buffer": { + "type": "buffer", + }, + "bar": { + "type": "node", + "builder": "bar", + "next": "bar_buffer", + }, + "bar_buffer": { + "type": "buffer", + }, + "join": { + "type": "join", + "buffers": { + "foo": "foo_buffer", + "bar": "bar_buffer", + }, + "next": "foobar", + }, + "foobar": { + "type": "node", + "builder": "foobar", + "next": { "builtin": "terminate" }, + }, + } + })) + .unwrap(); + + let result = fixture + .spawn_and_run(&diagram, serde_json::Value::Null) + .unwrap(); + assert!(fixture.context.no_unhandled_errors()); + assert_eq!(result, "foobar"); + } + + #[test] + /// when `target_node` is not given and next is not a node + fn test_join_infer_type_fail() { + let mut fixture = DiagramTestFixture::new(); + register_join_nodes(&mut fixture.registry); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "start": "fork_clone", + "ops": { + "fork_clone": { + "type": "fork_clone", + "next": ["foo", "bar"], + }, + "foo": { + "type": "node", + "builder": "foo", + "next": "foo_buffer", + }, + "foo_buffer": { + "type": "buffer", + }, + "bar": { + "type": "node", + "builder": "bar", + "next": "bar_buffer", + }, + "bar_buffer": { + "type": "buffer", + }, + "join": { + "type": "join", + "buffers": { + "foo": "foo_buffer", + "bar": "bar_buffer", + }, + "next": "fork_clone2", + }, + "fork_clone2": { + "type": "fork_clone", + "next": [{ "builtin": "terminate" }], + }, + } + })) + .unwrap(); + + let result = fixture + .spawn_and_run(&diagram, serde_json::Value::Null) + .unwrap_err(); + assert!(fixture.context.no_unhandled_errors()); + let err_code = &result.downcast_ref::().unwrap().code; + assert!(matches!(err_code, DiagramErrorCode::UnknownTarget,)); + } + + #[test] + fn test_join_buffer_array() { + let mut fixture = DiagramTestFixture::new(); + register_join_nodes(&mut fixture.registry); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "start": "fork_clone", + "ops": { + "fork_clone": { + "type": "fork_clone", + "next": ["foo", "bar"], + }, + "foo": { + "type": "node", + "builder": "foo", + "next": "foo_buffer", + }, + "foo_buffer": { + "type": "buffer", + }, + "bar": { + "type": "node", + "builder": "bar", + "next": "bar_buffer", + }, + "bar_buffer": { + "type": "buffer", + }, + "join": { + "type": "join", + "buffers": ["foo_buffer", "bar_buffer"], + "target_node": "foobar_array", + "next": "foobar_array", + }, + "foobar_array": { + "type": "node", + "builder": "foobar_array", + "next": { "builtin": "terminate" }, + }, + } + })) + .unwrap(); + + let result = fixture + .spawn_and_run(&diagram, serde_json::Value::Null) + .unwrap(); + assert!(fixture.context.no_unhandled_errors()); + assert_eq!(result, "foobar"); + } + + #[test] + fn test_empty_join() { + let mut fixture = DiagramTestFixture::new(); + register_join_nodes(&mut fixture.registry); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "start": "foo", + "ops": { + "foo": { + "type": "node", + "builder": "foo", + "next": { "builtin": "terminate" }, + }, + "join": { + "type": "join", + "buffers": [], + "target_node": "foobar", + "next": "foobar", + }, + "foobar": { + "type": "node", + "builder": "foobar", + "next": { "builtin": "terminate" }, + }, + } + })) + .unwrap(); + + let err = fixture.spawn_json_io_workflow(&diagram).unwrap_err(); + assert!(matches!(err.code, DiagramErrorCode::EmptyJoin)); + } + + #[test] + fn test_serialized_join() { + let mut fixture = DiagramTestFixture::new(); + register_join_nodes(&mut fixture.registry); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "start": "fork_clone", + "ops": { + "fork_clone": { + "type": "fork_clone", + "next": ["foo", "bar"], + }, + "foo": { + "type": "node", + "builder": "foo", + "next": "foo_buffer", + }, + "foo_buffer": { + "type": "buffer", + "serialize": true, + }, + "bar": { + "type": "node", + "builder": "bar", + "next": "bar_buffer", + }, + "bar_buffer": { + "type": "buffer", + "serialize": true, + }, + "serialized_join": { + "type": "serialized_join", + "buffers": { + "foo": "foo_buffer", + "bar": "bar_buffer", + }, + "next": { "builtin": "terminate" }, + }, + } + })) + .unwrap(); + + let result = fixture + .spawn_and_run(&diagram, serde_json::Value::Null) + .unwrap(); + assert!(fixture.context.no_unhandled_errors()); + assert_eq!(result["foo"], "foo"); + assert_eq!(result["bar"], "bar"); + } + + #[test] + fn test_serialized_join_with_unserialized_buffers() { + let mut fixture = DiagramTestFixture::new(); + register_join_nodes(&mut fixture.registry); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "start": "fork_clone", + "ops": { + "fork_clone": { + "type": "fork_clone", + "next": ["create_foobar_1", "create_foobar_2"], + }, + "create_foobar_1": { + "type": "node", + "builder": "create_foobar", + "config": { + "foo": "foo_1", + "bar": "bar_1", + }, + "next": "foobar_buffer_1", + }, + "create_foobar_2": { + "type": "node", + "builder": "create_foobar", + "config": { + "foo": "foo_2", + "bar": "bar_2", + }, + "next": "foobar_buffer_2", + }, + "foobar_buffer_1": { + "type": "buffer", + }, + "foobar_buffer_2": { + "type": "buffer", + }, + "serialized_join": { + "type": "serialized_join", + "buffers": { + "foobar_1": "foobar_buffer_1", + "foobar_2": "foobar_buffer_2", + }, + "next": { "builtin": "terminate" }, + }, + } + })) + .unwrap(); + + let result = fixture + .spawn_and_run(&diagram, serde_json::Value::Null) + .unwrap(); + assert!(fixture.context.no_unhandled_errors()); + let object = result.as_object().unwrap(); + assert_eq!(object["foobar_1"].as_object().unwrap()["foo"], "foo_1"); + assert_eq!(object["foobar_1"].as_object().unwrap()["bar"], "bar_1"); + assert_eq!(object["foobar_2"].as_object().unwrap()["foo"], "foo_2"); + assert_eq!(object["foobar_2"].as_object().unwrap()["bar"], "bar_2"); + } +} diff --git a/src/diagram/node_schema.rs b/src/diagram/node_schema.rs new file mode 100644 index 00000000..c45d6598 --- /dev/null +++ b/src/diagram/node_schema.rs @@ -0,0 +1,51 @@ +/* + * Copyright (C) 2025 Open Source Robotics Foundation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * +*/ + +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; + +use crate::Builder; + +use super::{ + BuildDiagramOperation, BuildStatus, BuilderId, DiagramContext, DiagramErrorCode, NextOperation, + OperationId, +}; + +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +#[serde(rename_all = "snake_case")] +pub struct NodeSchema { + pub(super) builder: BuilderId, + #[serde(default)] + pub(super) config: serde_json::Value, + pub(super) next: NextOperation, +} + +impl BuildDiagramOperation for NodeSchema { + fn build_diagram_operation( + &self, + id: &OperationId, + builder: &mut Builder, + ctx: &mut DiagramContext, + ) -> Result { + let node_registration = ctx.registry.get_node_registration(&self.builder)?; + let node = node_registration.create_node(builder, self.config.clone())?; + + ctx.set_input_for_target(id, node.input.into())?; + ctx.add_output_into_target(self.next.clone(), node.output); + Ok(BuildStatus::Finished) + } +} diff --git a/src/diagram/registration.rs b/src/diagram/registration.rs index d0ce66e9..7859d92f 100644 --- a/src/diagram/registration.rs +++ b/src/diagram/registration.rs @@ -1,5 +1,22 @@ +/* + * Copyright (C) 2025 Open Source Robotics Foundation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * +*/ + use std::{ - any::{type_name, Any, TypeId}, + any::{type_name, Any}, borrow::Borrow, cell::RefCell, collections::HashMap, @@ -7,7 +24,10 @@ use std::{ marker::PhantomData, }; -use crate::{Builder, InputSlot, Node, Output, StreamPack}; +use crate::{ + Accessor, AnyBuffer, AsAnyBuffer, BufferMap, BufferSettings, Builder, Connect, InputSlot, + Joined, JsonBuffer, JsonMessage, Node, Output, StreamPack, +}; use bevy_ecs::entity::Entity; use schemars::{ gen::{SchemaGenerator, SchemaSettings}, @@ -22,16 +42,12 @@ use serde::{ use serde_json::json; use tracing::debug; -use crate::SerializeMessage; - use super::{ - fork_clone::DynForkClone, - fork_result::DynForkResult, - impls::{DefaultImpl, DefaultImplMarker, NotSupported}, - join::register_join_impl, - unzip::DynUnzip, - BuilderId, DefaultDeserializer, DefaultSerializer, DeserializeMessage, DiagramError, DynSplit, - DynSplitOutputs, DynType, OpaqueMessageDeserializer, OpaqueMessageSerializer, SplitOp, + buffer_schema::BufferAccessRequest, fork_clone_schema::PerformForkClone, + fork_result_schema::RegisterForkResult, register_json, supported::*, type_info::TypeInfo, + unzip_schema::PerformUnzip, BuilderId, DeserializeMessage, DiagramErrorCode, DynForkClone, + DynForkResult, DynSplit, DynType, JsonRegistration, RegisterJson, RegisterSplit, + SerializeMessage, SplitSchema, TransformError, }; /// A type erased [`crate::InputSlot`] @@ -39,17 +55,21 @@ use super::{ pub struct DynInputSlot { scope: Entity, source: Entity, - pub(super) type_id: TypeId, + type_info: TypeInfo, } impl DynInputSlot { - pub(super) fn scope(&self) -> Entity { + pub fn scope(&self) -> Entity { self.scope } - pub(super) fn id(&self) -> Entity { + pub fn id(&self) -> Entity { self.source } + + pub fn message_info(&self) -> &TypeInfo { + &self.type_info + } } impl From> for DynInputSlot { @@ -57,57 +77,115 @@ impl From> for DynInputSlot { Self { scope: input.scope(), source: input.id(), - type_id: TypeId::of::(), + type_info: TypeInfo::of::(), + } + } +} + +impl From for DynInputSlot { + fn from(buffer: AnyBuffer) -> Self { + let any_interface = buffer.get_interface(); + Self { + scope: buffer.scope(), + source: buffer.id(), + type_info: TypeInfo { + type_id: any_interface.message_type_id(), + type_name: any_interface.message_type_name(), + }, } } } -#[derive(Debug)] /// A type erased [`crate::Output`] pub struct DynOutput { scope: Entity, target: Entity, - pub(super) type_id: TypeId, + message_info: TypeInfo, } impl DynOutput { - pub(super) fn into_output(self) -> Result, DiagramError> + pub fn new(scope: Entity, target: Entity, message_info: TypeInfo) -> Self { + Self { + scope, + target, + message_info, + } + } + + pub fn message_info(&self) -> &TypeInfo { + &self.message_info + } + + pub fn into_output(self) -> Result, DiagramErrorCode> where T: Send + Sync + 'static + Any, { - if self.type_id != TypeId::of::() { - Err(DiagramError::TypeMismatch) + if self.message_info != TypeInfo::of::() { + Err(DiagramErrorCode::TypeMismatch { + source_type: self.message_info, + target_type: TypeInfo::of::(), + }) } else { Ok(Output::::new(self.scope, self.target)) } } - pub(super) fn scope(&self) -> Entity { + pub fn scope(&self) -> Entity { self.scope } - pub(super) fn id(&self) -> Entity { + pub fn id(&self) -> Entity { self.target } + + /// Connect a [`DynOutput`] to a [`DynInputSlot`]. + pub fn connect_to( + self, + input: &DynInputSlot, + builder: &mut Builder, + ) -> Result<(), DiagramErrorCode> { + if self.message_info() != input.message_info() { + return Err(DiagramErrorCode::TypeMismatch { + source_type: *self.message_info(), + target_type: *input.message_info(), + }); + } + + builder.commands().add(Connect { + original_target: self.id(), + new_target: input.id(), + }); + + Ok(()) + } +} + +impl Debug for DynOutput { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("DynOutput") + .field("scope", &self.scope) + .field("target", &self.target) + .field("type_info", &self.message_info) + .finish() + } } impl From> for DynOutput where - T: Send + Sync + 'static, + T: Send + Sync + 'static + Any, { fn from(output: Output) -> Self { Self { scope: output.scope(), target: output.id(), - type_id: TypeId::of::(), + message_info: TypeInfo::of::(), } } } - /// A type erased [`bevy_impulse::Node`] -pub(super) struct DynNode { - pub(super) input: DynInputSlot, - pub(super) output: DynOutput, +pub struct DynNode { + pub input: DynInputSlot, + pub output: DynOutput, } impl DynNode { @@ -141,10 +219,8 @@ where pub struct NodeRegistration { pub(super) id: BuilderId, pub(super) name: String, - /// type name of the request - pub(super) request: &'static str, - /// type name of the response - pub(super) response: &'static str, + pub(super) request: TypeInfo, + pub(super) response: TypeInfo, pub(super) config_schema: Schema, /// Creates an instance of the registered node. @@ -157,7 +233,7 @@ impl NodeRegistration { &self, builder: &mut Builder, config: serde_json::Value, - ) -> Result { + ) -> Result { let n = (self.create_node_impl.borrow_mut())(builder, config)?; debug!( "created node of {}, output: {:?}, input: {:?}", @@ -168,23 +244,18 @@ impl NodeRegistration { } type CreateNodeFn = - RefCell Result>>; -type DeserializeFn = - Box) -> Result>; -type SerializeFn = - Box Result, DiagramError>>; -type ForkCloneFn = - Box Result, DiagramError>>; -type ForkResultFn = - Box Result<(DynOutput, DynOutput), DiagramError>>; -type SplitFn = Box< - dyn for<'a> Fn( - &mut Builder, - DynOutput, - &'a SplitOp, - ) -> Result, DiagramError>, ->; -type JoinFn = Box) -> Result>; + RefCell Result>>; +type DeserializeFn = fn(&mut Builder) -> Result; +type SerializeFn = fn(&mut Builder) -> Result; +type ForkCloneFn = fn(&mut Builder) -> Result; +type ForkResultFn = fn(&mut Builder) -> Result; +type SplitFn = fn(&SplitSchema, &mut Builder) -> Result; +type JoinFn = fn(&BufferMap, &mut Builder) -> Result; +type BufferAccessFn = fn(&BufferMap, &mut Builder) -> Result; +type ListenFn = fn(&BufferMap, &mut Builder) -> Result; +type CreateBufferFn = fn(BufferSettings, &mut Builder) -> AnyBuffer; +type CreateTriggerFn = fn(&mut Builder) -> DynNode; +type ToStringFn = fn(&mut Builder) -> DynNode; #[must_use] pub struct CommonOperations<'a, Deserialize, Serialize, Cloneable> { @@ -203,7 +274,7 @@ impl<'a, DeserializeImpl, SerializeImpl, Cloneable> /// * `name` - Friendly name for the builder, this is only used for display purposes. /// * `f` - The node builder to register. pub fn register_node_builder( - self, + mut self, options: NodeBuilderOptions, mut f: impl FnMut(&mut Builder, Config) -> Node + 'static, ) -> NodeRegistrationBuilder<'a, Request, Response, Streams> @@ -213,24 +284,22 @@ impl<'a, DeserializeImpl, SerializeImpl, Cloneable> Response: Send + Sync + 'static, Streams: StreamPack, DeserializeImpl: DeserializeMessage, + DeserializeImpl: DeserializeMessage, + SerializeImpl: SerializeMessage, SerializeImpl: SerializeMessage, - Cloneable: DynForkClone, + Cloneable: PerformForkClone, + Cloneable: PerformForkClone, + JsonRegistration: RegisterJson, + JsonRegistration: RegisterJson, { - self.registry - .messages - .register_deserialize::(); - self.registry - .messages - .register_serialize::(); - self.registry - .messages - .register_fork_clone::(); + self.impl_register_message::(); + self.impl_register_message::(); let registration = NodeRegistration { id: options.id.clone(), name: options.name.unwrap_or(options.id.clone()), - request: type_name::(), - response: type_name::(), + request: TypeInfo::of::(), + response: TypeInfo::of::(), config_schema: self .registry .messages @@ -244,16 +313,29 @@ impl<'a, DeserializeImpl, SerializeImpl, Cloneable> }; self.registry.nodes.insert(options.id.clone(), registration); - NodeRegistrationBuilder::::new(&mut self.registry.messages) + NodeRegistrationBuilder::::new(self.registry) } /// Register a message with the specified common operations. - pub fn register_message(self) -> MessageRegistrationBuilder<'a, Message> + pub fn register_message(mut self) -> MessageRegistrationBuilder<'a, Message> + where + Message: Send + Sync + 'static, + DeserializeImpl: DeserializeMessage, + SerializeImpl: SerializeMessage, + Cloneable: PerformForkClone, + JsonRegistration: RegisterJson, + { + self.impl_register_message(); + MessageRegistrationBuilder::::new(&mut self.registry.messages) + } + + fn impl_register_message(&mut self) where Message: Send + Sync + 'static, DeserializeImpl: DeserializeMessage, - SerializeImpl: SerializeMessage + SerializeMessage>, - Cloneable: DynForkClone, + SerializeImpl: SerializeMessage, + Cloneable: PerformForkClone, + JsonRegistration: RegisterJson, { self.registry .messages @@ -264,38 +346,46 @@ impl<'a, DeserializeImpl, SerializeImpl, Cloneable> self.registry .messages .register_fork_clone::(); - register_join_impl::(&mut self.registry.messages); - MessageRegistrationBuilder::::new(&mut self.registry.messages) + register_json::(); } - /// Opt out of deserializing the request of the node. Use this to build a - /// node whose request type is not deserializable. - pub fn no_request_deserializing( - self, - ) -> CommonOperations<'a, OpaqueMessageDeserializer, SerializeImpl, Cloneable> { + /// Opt out of deserializing the input and output messages of the node. + /// + /// If you want to enable deserializing for only the input or only the output + /// then use [`DiagramElementRegistry::register_message`] on the message type + /// directly. + /// + /// Note that [`JsonBuffer`] is only enabled for message types that enable + /// both serializing AND deserializing. + pub fn no_deserializing(self) -> CommonOperations<'a, NotSupported, SerializeImpl, Cloneable> { CommonOperations { registry: self.registry, _ignore: Default::default(), } } - /// Opt out of serializing the response of the node. Use this to build a - /// node whose response type is not serializable. - pub fn no_response_serializing( - self, - ) -> CommonOperations<'a, DeserializeImpl, OpaqueMessageSerializer, Cloneable> { + /// Opt out of serializing the input and output messages of the node. + /// + /// If you want to enable serialization for only the input or only the output + /// then use [`DiagramElementRegistry::register_message`] on the message type + /// directly. + /// + /// Note that [`JsonBuffer`] is only enabled for message types that enable + /// both serializing AND deserializing. + pub fn no_serializing(self) -> CommonOperations<'a, DeserializeImpl, NotSupported, Cloneable> { CommonOperations { registry: self.registry, _ignore: Default::default(), } } - /// Opt out of cloning the response of the node. Use this to build a node - /// whose response type is not cloneable. - pub fn no_response_cloning( - self, - ) -> CommonOperations<'a, DeserializeImpl, SerializeImpl, NotSupported> { + /// Opt out of cloning the input and output messages of the node. + /// + /// If you want to enable cloning for only the input or only the output + /// then use [`DiagramElementRegistry::register_message`] on the message type + /// directly. + pub fn no_cloning(self) -> CommonOperations<'a, DeserializeImpl, SerializeImpl, NotSupported> { CommonOperations { registry: self.registry, _ignore: Default::default(), @@ -310,7 +400,7 @@ pub struct MessageRegistrationBuilder<'a, Message> { impl<'a, Message> MessageRegistrationBuilder<'a, Message> where - Message: Any, + Message: Send + Sync + 'static + Any, { fn new(registry: &'a mut MessageRegistry) -> Self { Self { @@ -319,91 +409,208 @@ where } } - /// Mark the node as having a unzippable response. This is required in order for the node + /// Mark the message as having a unzippable response. This is required in order for the node /// to be able to be connected to a "Unzip" operation. pub fn with_unzip(&mut self) -> &mut Self where - DefaultImplMarker<(Message, DefaultSerializer)>: DynUnzip, + Supported<(Message, Supported, Supported)>: PerformUnzip, { - self.data.register_unzip::(); + self.data.register_unzip::(); self } - /// Mark the node as having an unzippable response whose elements are not serializable. + /// Mark the message as having an unzippable response whose elements are not serializable. pub fn with_unzip_minimal(&mut self) -> &mut Self where - DefaultImplMarker<(Message, NotSupported)>: DynUnzip, + Supported<(Message, NotSupported, NotSupported)>: PerformUnzip, { - self.data.register_unzip::(); + self.data + .register_unzip::(); self } - /// Mark the node as having a [`Result<_, _>`] response. This is required in order for the node + /// Mark the message as having a [`Result<_, _>`] response. This is required in order for the node /// to be able to be connected to a "Fork Result" operation. pub fn with_fork_result(&mut self) -> &mut Self where - DefaultImpl: DynForkResult, + Supported<(Message, Supported, Supported)>: RegisterForkResult, { - self.data.register_fork_result::(); + self.data + .register_fork_result::>(); self } - /// Mark the node as having a splittable response. This is required in order + /// Same as `Self::with_fork_result` but it will not register serialization + /// or cloning for the [`Ok`] or [`Err`] variants of the message. + pub fn with_fork_result_minimal(&mut self) -> &mut Self + where + Supported<(Message, NotSupported, NotSupported)>: RegisterForkResult, + { + self.data + .register_fork_result::>(); + self + } + + /// Mark the message as having a splittable response. This is required in order /// for the node to be able to be connected to a "Split" operation. pub fn with_split(&mut self) -> &mut Self where - DefaultImpl: DynSplit, + Supported<(Message, Supported, Supported)>: RegisterSplit, { - self.data - .register_split::(); + self.data.register_split::(); self } - /// Mark the node as having a splittable response but the items from the split + /// Mark the message as having a splittable response but the items from the split /// are unserializable. pub fn with_split_minimal(&mut self) -> &mut Self where - DefaultImpl: DynSplit, + Supported<(Message, NotSupported, NotSupported)>: RegisterSplit, { self.data - .register_split::(); + .register_split::(); + self + } + + /// Mark the message as being joinable. + pub fn with_join(&mut self) -> &mut Self + where + Message: Joined, + { + self.data.register_join::(); + self + } + + /// Mark the message as being a buffer access. + pub fn with_buffer_access(&mut self) -> &mut Self + where + Message: BufferAccessRequest, + { + self.data.register_buffer_access::(); + self + } + + /// Mark the message as being listenable. + pub fn with_listen(&mut self) -> &mut Self + where + Message: Accessor, + { + self.data.register_listen::(); + self + } + + pub fn with_to_string(&mut self) -> &mut Self + where + Message: ToString, + { + self.data.register_to_string::(); self } } pub struct NodeRegistrationBuilder<'a, Request, Response, Streams> { - registry: &'a mut MessageRegistry, + registry: &'a mut DiagramElementRegistry, _ignore: PhantomData<(Request, Response, Streams)>, } impl<'a, Request, Response, Streams> NodeRegistrationBuilder<'a, Request, Response, Streams> where - Request: Any, - Response: Any, + Request: Send + Sync + 'static + Any, + Response: Send + Sync + 'static + Any, { - fn new(registry: &'a mut MessageRegistry) -> Self { + fn new(registry: &'a mut DiagramElementRegistry) -> Self { Self { registry, _ignore: Default::default(), } } + /// If you opted out of any common operations in order to accommodate your + /// response type, you can enable all common operations for your response + /// type using this. + pub fn with_common_request(&mut self) -> &mut Self + where + Request: DynType + DeserializeOwned + Serialize + Clone, + { + self.registry.register_message::(); + self + } + + /// If you opted out of cloning, you can enable it specifically for the + /// input message with this. + pub fn with_clone_request(&mut self) -> &mut Self + where + Request: Clone, + { + self.registry + .messages + .register_fork_clone::(); + self + } + + /// If you opted out of deserialization, you can enable it specifically for + /// the input message with this. + pub fn with_deserialize_request(&mut self) -> &mut Self + where + Request: DeserializeOwned + DynType, + { + self.registry + .messages + .register_deserialize::(); + self + } + + /// If you opted out of any common operations in order to accommodate your + /// request type, you can enable all common operations for your response + /// type using this. + pub fn with_common_response(&mut self) -> &mut Self + where + Response: DynType + DeserializeOwned + Serialize + Clone, + { + self.registry.register_message::(); + self + } + + /// If you opted out of cloning, you can enable it specifically for the + /// output message with this. + pub fn with_clone_response(&mut self) -> &mut Self + where + Response: Clone, + { + self.registry + .messages + .register_fork_clone::(); + self + } + + /// If you opted out of serialization, you can enable it specifically for + /// the output message with this. + pub fn with_serialize_response(&mut self) -> &mut Self + where + Response: Serialize + DynType, + { + self.registry + .messages + .register_serialize::(); + self + } + /// Mark the node as having a unzippable response. This is required in order for the node /// to be able to be connected to a "Unzip" operation. pub fn with_unzip(&mut self) -> &mut Self where - DefaultImplMarker<(Response, DefaultSerializer)>: DynUnzip, + Supported<(Response, Supported, Supported)>: PerformUnzip, { - MessageRegistrationBuilder::new(self.registry).with_unzip(); + MessageRegistrationBuilder::new(&mut self.registry.messages).with_unzip(); self } /// Mark the node as having an unzippable response whose elements are not serializable. pub fn with_unzip_unserializable(&mut self) -> &mut Self where - DefaultImplMarker<(Response, NotSupported)>: DynUnzip, + Supported<(Response, NotSupported, NotSupported)>: PerformUnzip, { - MessageRegistrationBuilder::new(self.registry).with_unzip_minimal(); + MessageRegistrationBuilder::new(&mut self.registry.messages).with_unzip_minimal(); self } @@ -411,9 +618,19 @@ where /// to be able to be connected to a "Fork Result" operation. pub fn with_fork_result(&mut self) -> &mut Self where - DefaultImpl: DynForkResult, + Supported<(Response, Supported, Supported)>: RegisterForkResult, { - MessageRegistrationBuilder::new(self.registry).with_fork_result(); + MessageRegistrationBuilder::new(&mut self.registry.messages).with_fork_result(); + self + } + + /// Same as `Self::with_fork_result` but it will not register serialization + /// or cloning for the [`Ok`] or [`Err`] variants of the message. + pub fn with_fork_result_minimal(&mut self) -> &mut Self + where + Supported<(Response, NotSupported, NotSupported)>: RegisterForkResult, + { + MessageRegistrationBuilder::new(&mut self.registry.messages).with_fork_result_minimal(); self } @@ -421,9 +638,9 @@ where /// for the node to be able to be connected to a "Split" operation. pub fn with_split(&mut self) -> &mut Self where - DefaultImpl: DynSplit, + Supported<(Response, Supported, Supported)>: RegisterSplit, { - MessageRegistrationBuilder::new(self.registry).with_split(); + MessageRegistrationBuilder::new(&mut self.registry.messages).with_split(); self } @@ -431,9 +648,52 @@ where /// are unserializable. pub fn with_split_unserializable(&mut self) -> &mut Self where - DefaultImpl: DynSplit, + Supported<(Response, NotSupported, NotSupported)>: RegisterSplit, + { + MessageRegistrationBuilder::new(&mut self.registry.messages).with_split_minimal(); + self + } + + /// Mark the node as having a joinable request. + pub fn with_join(&mut self) -> &mut Self + where + Request: Joined, + { + self.registry.messages.register_join::(); + self + } + + /// Mark the node as having a buffer access request. + pub fn with_buffer_access(&mut self) -> &mut Self + where + Request: BufferAccessRequest, + { + self.registry.messages.register_buffer_access::(); + self + } + + /// Mark the node as having a listen request. + pub fn with_listen(&mut self) -> &mut Self + where + Request: Accessor, + { + self.registry.messages.register_listen::(); + self + } + + pub fn with_request_to_string(&mut self) -> &mut Self + where + Request: ToString, + { + self.registry.messages.register_to_string::(); + self + } + + pub fn with_response_to_string(&mut self) -> &mut Self + where + Response: ToString, { - MessageRegistrationBuilder::new(self.registry).with_split_minimal(); + self.registry.messages.register_to_string::(); self } } @@ -454,107 +714,42 @@ pub struct DiagramElementRegistry { pub(super) messages: MessageRegistry, } -#[derive(Default)] pub(super) struct MessageOperation { - deserialize_impl: Option, - serialize_impl: Option, - fork_clone_impl: Option, - unzip_impl: Option>, - fork_result_impl: Option, - split_impl: Option, - join_impl: Option, + pub(super) deserialize_impl: Option, + pub(super) serialize_impl: Option, + pub(super) fork_clone_impl: Option, + pub(super) unzip_impl: Option>, + pub(super) fork_result_impl: Option, + pub(super) split_impl: Option, + pub(super) join_impl: Option, + pub(super) buffer_access_impl: Option, + pub(super) listen_impl: Option, + pub(super) to_string_impl: Option, + pub(super) create_buffer_impl: CreateBufferFn, + pub(super) create_trigger_impl: CreateTriggerFn, } impl MessageOperation { - fn new() -> Self { - Self::default() - } - - /// Try to deserialize `output` into `input_type`. If `output` is not `serde_json::Value`, this does nothing. - pub(super) fn deserialize( - &self, - builder: &mut Builder, - output: DynOutput, - ) -> Result { - let f = self - .deserialize_impl - .as_ref() - .ok_or(DiagramError::NotSerializable)?; - f(builder, output.into_output()?) - } - - pub(super) fn serialize( - &self, - builder: &mut Builder, - output: DynOutput, - ) -> Result, DiagramError> { - let f = self - .serialize_impl - .as_ref() - .ok_or(DiagramError::NotSerializable)?; - f(builder, output) - } - - pub(super) fn fork_clone( - &self, - builder: &mut Builder, - output: DynOutput, - amount: usize, - ) -> Result, DiagramError> { - let f = self - .fork_clone_impl - .as_ref() - .ok_or(DiagramError::NotCloneable)?; - f(builder, output, amount) - } - - pub(super) fn unzip( - &self, - builder: &mut Builder, - output: DynOutput, - ) -> Result, DiagramError> { - let unzip_impl = &self - .unzip_impl - .as_ref() - .ok_or(DiagramError::NotUnzippable)?; - unzip_impl.dyn_unzip(builder, output) - } - - pub(super) fn fork_result( - &self, - builder: &mut Builder, - output: DynOutput, - ) -> Result<(DynOutput, DynOutput), DiagramError> { - let f = self - .fork_result_impl - .as_ref() - .ok_or(DiagramError::CannotForkResult)?; - f(builder, output) - } - - pub(super) fn split<'a>( - &self, - builder: &mut Builder, - output: DynOutput, - split_op: &'a SplitOp, - ) -> Result, DiagramError> { - let f = self - .split_impl - .as_ref() - .ok_or(DiagramError::NotSplittable)?; - f(builder, output, split_op) - } - - pub(super) fn join( - &self, - builder: &mut Builder, - outputs: OutputIter, - ) -> Result + fn new() -> Self where - OutputIter: IntoIterator, + T: Send + Sync + 'static + Any, { - let f = self.join_impl.as_ref().ok_or(DiagramError::NotJoinable)?; - f(builder, outputs.into_iter().collect()) + Self { + deserialize_impl: None, + serialize_impl: None, + fork_clone_impl: None, + unzip_impl: None, + fork_result_impl: None, + split_impl: None, + join_impl: None, + buffer_access_impl: None, + listen_impl: None, + to_string_impl: None, + create_buffer_impl: |settings, builder| { + builder.create_buffer::(settings).as_any_buffer() + }, + create_trigger_impl: |builder| builder.create_map_block(|_: T| ()).into(), + } } } @@ -590,17 +785,20 @@ impl Serialize for MessageOperation { } pub struct MessageRegistration { - type_name: &'static str, - schema: Option, - operations: MessageOperation, + pub(super) type_name: &'static str, + pub(super) schema: Option, + pub(super) operations: MessageOperation, } impl MessageRegistration { - fn new() -> Self { + pub(super) fn new() -> Self + where + T: Send + Sync + 'static + Any, + { Self { type_name: type_name::(), schema: None, - operations: MessageOperation::new(), + operations: MessageOperation::new::(), } } } @@ -620,190 +818,171 @@ impl Serialize for MessageRegistration { #[derive(Serialize)] pub struct MessageRegistry { #[serde(serialize_with = "MessageRegistry::serialize_messages")] - messages: HashMap, + pub messages: HashMap, #[serde( rename = "schemas", serialize_with = "MessageRegistry::serialize_schemas" )] - schema_generator: SchemaGenerator, + pub schema_generator: SchemaGenerator, } impl MessageRegistry { + fn new() -> Self { + let mut settings = SchemaSettings::default(); + settings.definitions_path = "#/schemas/".to_string(); + + Self { + schema_generator: SchemaGenerator::new(settings), + messages: HashMap::from([( + TypeInfo::of::(), + MessageRegistration::new::(), + )]), + } + } + fn get(&self) -> Option<&MessageRegistration> where T: Any, { - self.messages.get(&TypeId::of::()) + self.messages.get(&TypeInfo::of::()) } - pub(super) fn deserialize( + pub fn deserialize( &self, - target_type: &TypeId, + target_type: &TypeInfo, builder: &mut Builder, - output: DynOutput, - ) -> Result { - if output.type_id != TypeId::of::() || &output.type_id == target_type { - Ok(output) - } else if let Some(reg) = self.messages.get(target_type) { - reg.operations.deserialize(builder, output) - } else { - Err(DiagramError::NotSerializable) - } + ) -> Result { + self.try_deserialize(target_type, builder)? + .ok_or_else(|| DiagramErrorCode::NotDeserializable(*target_type)) + } + + pub fn try_deserialize( + &self, + target_type: &TypeInfo, + builder: &mut Builder, + ) -> Result, DiagramErrorCode> { + self.messages + .get(target_type) + .and_then(|reg| reg.operations.deserialize_impl.as_ref()) + .map(|deserialize| deserialize(builder)) + .transpose() } /// Register a deserialize function if not already registered, returns true if the new /// function is registered. - pub(super) fn register_deserialize(&mut self) -> bool + pub fn register_deserialize(&mut self) where T: Send + Sync + 'static + Any, Deserializer: DeserializeMessage, { - let reg = self - .messages - .entry(TypeId::of::()) - .or_insert(MessageRegistration::new::()); - let ops = &mut reg.operations; - if !Deserializer::deserializable() || ops.deserialize_impl.is_some() { - return false; - } + Deserializer::register_deserialize(&mut self.messages, &mut self.schema_generator); + } - debug!( - "register deserialize for type: {}, with deserializer: {}", - std::any::type_name::(), - std::any::type_name::() - ); - ops.deserialize_impl = Some(Box::new(|builder, output| { - debug!("deserialize output: {:?}", output); - let receiver = - builder.create_map_block(|json: serde_json::Value| Deserializer::from_json(json)); - builder.connect(output, receiver.input); - let deserialized_output = receiver - .output - .chain(builder) - .cancel_on_err() - .output() - .into(); - debug!("deserialized output: {:?}", deserialized_output); - Ok(deserialized_output) - })); - - reg.schema = Deserializer::json_schema(&mut self.schema_generator); + pub fn serialize( + &self, + incoming_type: &TypeInfo, + builder: &mut Builder, + ) -> Result { + self.try_serialize(incoming_type, builder)? + .ok_or_else(|| DiagramErrorCode::NotSerializable(*incoming_type)) + } - true + pub fn try_serialize( + &self, + incoming_type: &TypeInfo, + builder: &mut Builder, + ) -> Result, DiagramErrorCode> { + self.messages + .get(incoming_type) + .and_then(|reg| reg.operations.serialize_impl.as_ref()) + .map(|serialize| serialize(builder)) + .transpose() } - pub(super) fn serialize( + pub fn try_to_string( &self, + incoming_type: &TypeInfo, builder: &mut Builder, - output: DynOutput, - ) -> Result, DiagramError> { - if output.type_id == TypeId::of::() { - output.into_output() - } else if let Some(reg) = self.messages.get(&output.type_id) { - reg.operations.serialize(builder, output) - } else { - Err(DiagramError::NotSerializable) - } + ) -> Result, DiagramErrorCode> { + let ops = &self + .messages + .get(incoming_type) + .ok_or_else(|| DiagramErrorCode::UnregisteredType(*incoming_type))? + .operations; + + Ok(ops.to_string_impl.map(|f| f(builder))) } /// Register a serialize function if not already registered, returns true if the new /// function is registered. - pub(super) fn register_serialize(&mut self) -> bool + pub fn register_serialize(&mut self) where T: Send + Sync + 'static + Any, Serializer: SerializeMessage, { - let reg = &mut self - .messages - .entry(TypeId::of::()) - .or_insert(MessageRegistration::new::()); - let ops = &mut reg.operations; - if !Serializer::serializable() || ops.serialize_impl.is_some() { - return false; - } - - debug!( - "register serialize for type: {}, with serializer: {}", - std::any::type_name::(), - std::any::type_name::() - ); - ops.serialize_impl = Some(Box::new(|builder, output| { - debug!("serialize output: {:?}", output); - let n = builder.create_map_block(|resp: T| Serializer::to_json(&resp)); - builder.connect(output.into_output()?, n.input); - let serialized_output = n.output.chain(builder).cancel_on_err().output(); - debug!("serialized output: {:?}", serialized_output); - Ok(serialized_output) - })); - - reg.schema = Serializer::json_schema(&mut self.schema_generator); - - true + Serializer::register_serialize(&mut self.messages, &mut self.schema_generator) } - pub(super) fn fork_clone( + pub fn fork_clone( &self, + message_info: &TypeInfo, builder: &mut Builder, - output: DynOutput, - amount: usize, - ) -> Result, DiagramError> { - if let Some(reg) = self.messages.get(&output.type_id) { - reg.operations.fork_clone(builder, output, amount) - } else { - Err(DiagramError::NotCloneable) - } + ) -> Result { + self.messages + .get(message_info) + .and_then(|reg| reg.operations.fork_clone_impl.as_ref()) + .ok_or(DiagramErrorCode::NotCloneable) + .and_then(|f| f(builder)) } /// Register a fork_clone function if not already registered, returns true if the new /// function is registered. - pub(super) fn register_fork_clone(&mut self) -> bool + pub fn register_fork_clone(&mut self) -> bool where - T: Any, - F: DynForkClone, + T: Send + Sync + 'static + Any, + F: PerformForkClone, { let ops = &mut self .messages - .entry(TypeId::of::()) + .entry(TypeInfo::of::()) .or_insert(MessageRegistration::new::()) .operations; if !F::CLONEABLE || ops.fork_clone_impl.is_some() { return false; } - ops.fork_clone_impl = Some(Box::new(|builder, output, amount| { - F::dyn_fork_clone(builder, output, amount) - })); + ops.fork_clone_impl = Some(|builder| F::perform_fork_clone(builder)); true } - pub(super) fn unzip( - &self, - builder: &mut Builder, - output: DynOutput, - ) -> Result, DiagramError> { - if let Some(reg) = self.messages.get(&output.type_id) { - reg.operations.unzip(builder, output) - } else { - Err(DiagramError::NotUnzippable) - } + pub fn unzip<'a>( + &'a self, + message_info: &TypeInfo, + ) -> Result<&'a dyn PerformUnzip, DiagramErrorCode> { + self.messages + .get(message_info) + .and_then(|reg| reg.operations.unzip_impl.as_ref()) + .map(|unzip| -> &'a (dyn PerformUnzip) { unzip.as_ref() }) + .ok_or(DiagramErrorCode::NotUnzippable) } /// Register a unzip function if not already registered, returns true if the new /// function is registered. - pub(super) fn register_unzip(&mut self) -> bool + pub(super) fn register_unzip(&mut self) -> bool where - T: Any, + T: Send + Sync + 'static + Any, Serializer: 'static, - DefaultImplMarker<(T, Serializer)>: DynUnzip, + Cloneable: 'static, + Supported<(T, Serializer, Cloneable)>: PerformUnzip, { - let unzip_impl = DefaultImplMarker::<(T, Serializer)>::new(); + let unzip_impl = Supported::<(T, Serializer, Cloneable)>::new(); unzip_impl.on_register(self); let ops = &mut self .messages - .entry(TypeId::of::()) + .entry(TypeInfo::of::()) .or_insert(MessageRegistration::new::()) .operations; if ops.unzip_impl.is_some() { @@ -814,143 +993,215 @@ impl MessageRegistry { true } - pub(super) fn fork_result( + pub fn fork_result( &self, + message_info: &TypeInfo, builder: &mut Builder, - output: DynOutput, - ) -> Result<(DynOutput, DynOutput), DiagramError> { - if let Some(reg) = self.messages.get(&output.type_id) { - reg.operations.fork_result(builder, output) - } else { - Err(DiagramError::CannotForkResult) - } + ) -> Result { + self.messages + .get(message_info) + .and_then(|reg| reg.operations.fork_result_impl.as_ref()) + .ok_or(DiagramErrorCode::CannotForkResult) + .and_then(|f| f(builder)) } /// Register a fork_result function if not already registered, returns true if the new /// function is registered. - pub(super) fn register_fork_result(&mut self) -> bool + pub(super) fn register_fork_result(&mut self) -> bool where - T: Any, - F: DynForkResult, + R: RegisterForkResult, + { + R::on_register(self) + } + + pub fn split( + &self, + message_info: &TypeInfo, + split_op: &SplitSchema, + builder: &mut Builder, + ) -> Result { + self.messages + .get(message_info) + .and_then(|reg| reg.operations.split_impl.as_ref()) + .ok_or(DiagramErrorCode::NotSplittable) + .and_then(|f| f(split_op, builder)) + } + + /// Register a split function if not already registered. + pub(super) fn register_split(&mut self) + where + T: Send + Sync + 'static + Any, + Supported<(T, S, C)>: RegisterSplit, + { + Supported::<(T, S, C)>::on_register(self); + } + + pub fn create_buffer( + &self, + message_info: &TypeInfo, + settings: BufferSettings, + builder: &mut Builder, + ) -> Result { + let f = self + .messages + .get(message_info) + .ok_or_else(|| DiagramErrorCode::UnregisteredType(*message_info))? + .operations + .create_buffer_impl; + + Ok(f(settings, builder)) + } + + pub fn trigger( + &self, + message_info: &TypeInfo, + builder: &mut Builder, + ) -> Result { + self.messages + .get(message_info) + .map(|reg| (reg.operations.create_trigger_impl)(builder)) + .ok_or_else(|| DiagramErrorCode::UnregisteredType(*message_info)) + } + + pub fn join( + &self, + joinable: &TypeInfo, + buffers: &BufferMap, + builder: &mut Builder, + ) -> Result { + self.messages + .get(joinable) + .and_then(|reg| reg.operations.join_impl.as_ref()) + .ok_or_else(|| DiagramErrorCode::NotJoinable) + .and_then(|f| f(buffers, builder)) + } + + /// Register a join function if not already registered, returns true if the new + /// function is registered. + pub(super) fn register_join(&mut self) -> bool + where + T: Send + Sync + 'static + Any + Joined, { let ops = &mut self .messages - .entry(TypeId::of::()) + .entry(TypeInfo::of::()) .or_insert(MessageRegistration::new::()) .operations; - if ops.fork_result_impl.is_some() { + if ops.join_impl.is_some() { return false; } - ops.fork_result_impl = Some(Box::new(|builder, output| { - F::dyn_fork_result(builder, output) - })); + ops.join_impl = + Some(|buffers, builder| Ok(builder.try_join::(buffers)?.output().into())); true } - pub(super) fn split<'b>( + pub fn with_buffer_access( &self, + target_type: &TypeInfo, + buffers: &BufferMap, builder: &mut Builder, - output: DynOutput, - split_op: &'b SplitOp, - ) -> Result, DiagramError> { - if let Some(reg) = self.messages.get(&output.type_id) { - reg.operations.split(builder, output, split_op) - } else { - Err(DiagramError::NotSplittable) - } + ) -> Result { + self.messages + .get(target_type) + .and_then(|reg| reg.operations.buffer_access_impl.as_ref()) + .ok_or(DiagramErrorCode::CannotBufferAccess) + .and_then(|f| f(buffers, builder)) } - /// Register a split function if not already registered, returns true if the new - /// function is registered. - pub(super) fn register_split(&mut self) -> bool + pub(super) fn register_buffer_access(&mut self) -> bool where - T: Any, - F: DynSplit, + T: Send + Sync + 'static + BufferAccessRequest, { let ops = &mut self .messages - .entry(TypeId::of::()) + .entry(TypeInfo::of::()) .or_insert(MessageRegistration::new::()) .operations; - if ops.split_impl.is_some() { + if ops.buffer_access_impl.is_some() { return false; } - ops.split_impl = Some(Box::new(|builder, output, split_op| { - F::dyn_split(builder, output, split_op) - })); - F::on_register(self); + ops.buffer_access_impl = Some(|buffers, builder| { + let buffer_access = + builder.try_create_buffer_access::(buffers)?; + Ok(buffer_access.into()) + }); true } - pub(super) fn join( + pub fn listen( &self, + target_type: &TypeInfo, + buffers: &BufferMap, builder: &mut Builder, - outputs: OutputIter, - ) -> Result - where - OutputIter: IntoIterator, - { - let mut i = outputs.into_iter().peekable(); - let output_type_id = if let Some(o) = i.peek() { - Some(o.type_id.clone()) - } else { - None - }; - - if let Some(output_type_id) = output_type_id { - if let Some(reg) = self.messages.get(&output_type_id) { - reg.operations.join(builder, i) - } else { - Err(DiagramError::NotJoinable) - } - } else { - Err(DiagramError::NotJoinable) - } + ) -> Result { + self.messages + .get(target_type) + .and_then(|reg| reg.operations.listen_impl.as_ref()) + .ok_or_else(|| DiagramErrorCode::CannotListen(*target_type)) + .and_then(|f| f(buffers, builder)) } - /// Register a join function if not already registered, returns true if the new - /// function is registered. - pub(super) fn register_join(&mut self, f: JoinFn) -> bool + pub(super) fn register_listen(&mut self) -> bool where - T: Any, + T: Send + Sync + 'static + Any + Accessor, { let ops = &mut self .messages - .entry(TypeId::of::()) + .entry(TypeInfo::of::()) .or_insert(MessageRegistration::new::()) .operations; - if ops.join_impl.is_some() { + if ops.listen_impl.is_some() { return false; } - ops.join_impl = Some(f); + ops.listen_impl = + Some(|buffers, builder| Ok(builder.try_listen::(buffers)?.output().into())); true } + pub(super) fn register_to_string(&mut self) + where + T: 'static + Send + Sync + ToString, + { + let ops = &mut self + .messages + .entry(TypeInfo::of::()) + .or_insert(MessageRegistration::new::()) + .operations; + + ops.to_string_impl = + Some(|builder| builder.create_map_block(|msg: T| msg.to_string()).into()); + } + fn serialize_messages( - v: &HashMap, + v: &HashMap, serializer: S, ) -> Result where S: serde::Serializer, { let mut s = serializer.serialize_map(Some(v.len()))?; - for msg in v.values() { + for (type_id, reg) in v { + // hide builtin registrations + if type_id == &TypeInfo::of::() { + continue; + } + // should we use short name? It makes the serialized json more readable at the cost // of greatly increased chance of key conflicts. // let short_name = { - // if let Some(start) = msg.type_name.rfind(":") { - // &msg.type_name[start + 1..] + // if let Some(start) = reg.type_name.rfind(":") { + // ®.type_name[start + 1..] // } else { - // msg.type_name + // reg.type_name // } // }; - s.serialize_entry(msg.type_name, msg)?; + s.serialize_entry(reg.type_name, reg)?; } s.end() } @@ -965,15 +1216,17 @@ impl MessageRegistry { impl Default for DiagramElementRegistry { fn default() -> Self { - let mut settings = SchemaSettings::default(); - settings.definitions_path = "#/schemas/".to_string(); - DiagramElementRegistry { + // Ensure buffer downcasting is automatically registered for all basic + // serializable types. + JsonBuffer::register_for::<()>(); + + let mut registry = DiagramElementRegistry { nodes: Default::default(), - messages: MessageRegistry { - schema_generator: SchemaGenerator::new(settings), - messages: HashMap::new(), - }, - } + messages: MessageRegistry::new(), + }; + + registry.register_builtin_messages(); + registry } } @@ -982,6 +1235,17 @@ impl DiagramElementRegistry { Self::default() } + /// Create a new registry that does not automatically register any of the + /// builtin types. Only advanced users who know what they are doing should + /// use this. + pub fn blank() -> Self { + JsonBuffer::register_for::<()>(); + DiagramElementRegistry { + nodes: Default::default(), + messages: MessageRegistry::new(), + } + } + /// Register a node builder with all the common operations (deserialize the /// request, serialize the response, and clone the response) enabled. /// @@ -1012,35 +1276,53 @@ impl DiagramElementRegistry { ) -> NodeRegistrationBuilder where Config: JsonSchema + DeserializeOwned, - Request: Send + Sync + 'static + DynType + DeserializeOwned, - Response: Send + Sync + 'static + DynType + Serialize + Clone, + Request: Send + Sync + 'static + DynType + Serialize + DeserializeOwned + Clone, + Response: Send + Sync + 'static + DynType + Serialize + DeserializeOwned + Clone, { self.opt_out().register_node_builder(options, builder) } + /// Register a single message for general use between nodes. This will + /// include all common operations for the message (deserialize, serialize, + /// and clone). + /// + /// You will receive a [`MessageRegistrationBuilder`] which you can then use + /// to enable more operations for the message, such as forking, splitting, + /// unzipping, and joining. The message type needs to be suitable for each + /// operation that you register for it or else the compiler will not allow + /// you to enable them. + /// + /// Use [`Self::opt_out`] to opt out of specified common operations before + /// beginning to register the message. This allows you to register message + /// types that do not support one or more of the common operations. + pub fn register_message(&mut self) -> MessageRegistrationBuilder + where + Message: Send + Sync + 'static + DynType + DeserializeOwned + Serialize + Clone, + { + self.opt_out().register_message() + } + /// In some cases the common operations of deserialization, serialization, - /// and cloning cannot be performed for the request or response of a node. + /// and cloning cannot be performed for the input or output message of a node. /// When that happens you can still register your node builder by calling /// this function and explicitly disabling the common operations that your /// node cannot support. /// - /// - /// In order for the request to be deserializable, it must implement [`schemars::JsonSchema`] and [`serde::de::DeserializeOwned`]. - /// In order for the response to be serializable, it must implement [`schemars::JsonSchema`] and [`serde::Serialize`]. + /// In order for a message type to support all the common operations, it + /// must implement [`schemars::JsonSchema`], [`serde::de::DeserializeOwned`], + /// [`serde::Serialize`], and [`Clone`]. /// /// ``` /// use schemars::JsonSchema; /// use serde::{Deserialize, Serialize}; /// - /// #[derive(JsonSchema, Deserialize)] - /// struct DeserializableRequest {} - /// - /// #[derive(JsonSchema, Serialize)] - /// struct SerializableResponse {} + /// #[derive(JsonSchema, Serialize, Deserialize, Clone)] + /// struct MyCommonMessage {} /// ``` /// - /// If your node have a request or response that is not serializable, there is still - /// a way to register it. + /// If your node has an input or output message that is missing one of these + /// traits, you can still register it by opting out of the relevant common + /// operation(s): /// /// ``` /// use bevy_impulse::{NodeBuilderOptions, DiagramElementRegistry}; @@ -1052,9 +1334,9 @@ impl DiagramElementRegistry { /// let mut registry = DiagramElementRegistry::new(); /// registry /// .opt_out() - /// .no_request_deserializing() - /// .no_response_serializing() - /// .no_response_cloning() + /// .no_deserializing() + /// .no_serializing() + /// .no_cloning() /// .register_node_builder( /// NodeBuilderOptions::new("echo"), /// |builder, _config: ()| { @@ -1066,23 +1348,21 @@ impl DiagramElementRegistry { /// Note that nodes registered without deserialization cannot be connected /// to the workflow start, and nodes registered without serialization cannot /// be connected to the workflow termination. - pub fn opt_out( - &mut self, - ) -> CommonOperations { + pub fn opt_out(&mut self) -> CommonOperations { CommonOperations { registry: self, _ignore: Default::default(), } } - pub fn get_node_registration(&self, id: &Q) -> Result<&NodeRegistration, DiagramError> + pub fn get_node_registration(&self, id: &Q) -> Result<&NodeRegistration, DiagramErrorCode> where Q: Borrow + ?Sized, { let k = id.borrow(); self.nodes .get(k) - .ok_or(DiagramError::BuilderNotFound(k.to_string())) + .ok_or(DiagramErrorCode::BuilderNotFound(k.to_string())) } pub fn get_message_registration(&self) -> Option<&MessageRegistration> @@ -1091,6 +1371,39 @@ impl DiagramElementRegistry { { self.messages.get::() } + + /// Register useful messages that are known to the bevy impulse library. + /// This will be run automatically when you create using [`Self::default()`] + /// or [`Self::new()`]. + pub fn register_builtin_messages(&mut self) { + self.register_message::() + .with_join() + .with_split(); + + self.opt_out() + .no_serializing() + .no_deserializing() + .no_cloning() + .register_message::() + .with_to_string(); + + self.register_message::(); + self.register_message::(); + self.register_message::(); + self.register_message::(); + self.register_message::(); + self.register_message::(); + self.register_message::(); + self.register_message::(); + self.register_message::(); + self.register_message::(); + self.register_message::(); + self.register_message::(); + self.register_message::(); + self.register_message::(); + self.register_message::(); + self.register_message::<()>(); + } } #[non_exhaustive] @@ -1194,7 +1507,7 @@ mod tests { let tuple_resp = |_: ()| -> (i64,) { (1,) }; registry .opt_out() - .no_response_cloning() + .no_cloning() .register_node_builder( NodeBuilderOptions::new("multiply3_uncloneable").with_name("Test Name"), move |builder: &mut Builder, _config: ()| builder.create_map_block(tuple_resp), @@ -1280,12 +1593,14 @@ mod tests { let mut registry = DiagramElementRegistry::new(); registry .opt_out() - .no_request_deserializing() - .no_response_cloning() + .no_serializing() + .no_deserializing() + .no_cloning() .register_node_builder( NodeBuilderOptions::new("opaque_request_map").with_name("Test Name"), move |builder, _config: ()| builder.create_map_block(opaque_request_map), - ); + ) + .with_serialize_response(); assert!(registry.get_node_registration("opaque_request_map").is_ok()); let req_ops = ®istry .messages @@ -1299,14 +1614,16 @@ mod tests { let opaque_response_map = |_: ()| NonSerializableRequest {}; registry .opt_out() - .no_response_serializing() - .no_response_cloning() + .no_serializing() + .no_deserializing() + .no_cloning() .register_node_builder( NodeBuilderOptions::new("opaque_response_map").with_name("Test Name"), move |builder: &mut Builder, _config: ()| { builder.create_map_block(opaque_response_map) }, - ); + ) + .with_deserialize_request(); assert!(registry .get_node_registration("opaque_response_map") .is_ok()); @@ -1322,9 +1639,9 @@ mod tests { let opaque_req_resp_map = |_: NonSerializableRequest| NonSerializableRequest {}; registry .opt_out() - .no_request_deserializing() - .no_response_serializing() - .no_response_cloning() + .no_deserializing() + .no_serializing() + .no_cloning() .register_node_builder( NodeBuilderOptions::new("opaque_req_resp_map").with_name("Test Name"), move |builder: &mut Builder, _config: ()| { @@ -1367,7 +1684,7 @@ mod tests { assert!(!ops.unzippable()); assert!(!ops.can_fork_result()); assert!(!ops.splittable()); - assert!(ops.joinable()); + assert!(!ops.joinable()); } #[test] @@ -1387,7 +1704,9 @@ mod tests { struct Opaque; reg.opt_out() - .no_request_deserializing() + .no_serializing() + .no_deserializing() + .no_cloning() .register_node_builder(NodeBuilderOptions::new("test"), |builder, _config: ()| { builder.create_map_block(|_: Opaque| { ( diff --git a/src/diagram/serialization.rs b/src/diagram/serialization.rs index 4c3c2e96..746d634f 100644 --- a/src/diagram/serialization.rs +++ b/src/diagram/serialization.rs @@ -1,14 +1,30 @@ -use schemars::{gen::SchemaGenerator, schema::Schema, JsonSchema}; +/* + * Copyright (C) 2025 Open Source Robotics Foundation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * +*/ + +use std::collections::{hash_map::Entry, HashMap}; + +use schemars::{gen::SchemaGenerator, JsonSchema}; use serde::{de::DeserializeOwned, Serialize}; -#[derive(thiserror::Error, Debug)] -pub enum SerializationError { - #[error("not supported")] - NotSupported, - - #[error(transparent)] - JsonError(#[from] serde_json::Error), -} +use super::{ + supported::*, type_info::TypeInfo, DiagramContext, DiagramErrorCode, DynForkResult, + DynInputSlot, DynOutput, JsonMessage, MessageRegistration, MessageRegistry, +}; +use crate::{Builder, JsonBuffer}; pub trait DynType { /// Returns the type name of the request. Note that the type name must be unique. @@ -31,111 +47,346 @@ where } pub trait SerializeMessage { - fn type_name() -> String; - - fn json_schema(gen: &mut SchemaGenerator) -> Option; - - fn to_json(v: &T) -> Result; - - fn serializable() -> bool; + fn register_serialize( + messages: &mut HashMap, + schema_generator: &mut SchemaGenerator, + ); } -#[derive(Default)] -pub struct DefaultSerializer; - -impl SerializeMessage for DefaultSerializer +impl SerializeMessage for Supported where - T: Serialize + DynType, + T: Serialize + DynType + Send + Sync + 'static, { - fn type_name() -> String { - T::type_name() + fn register_serialize( + messages: &mut HashMap, + schema_generator: &mut SchemaGenerator, + ) { + let reg = &mut messages + .entry(TypeInfo::of::()) + .or_insert(MessageRegistration::new::()); + + reg.operations.serialize_impl = Some(|builder| { + let serialize = builder.create_map_block(|message: T| { + serde_json::to_value(message).map_err(|err| err.to_string()) + }); + + let (ok, err) = serialize + .output + .chain(builder) + .fork_result(|ok| ok.output(), |err| err.output()); + + Ok(DynForkResult { + input: serialize.input.into(), + ok: ok.into(), + err: err.into(), + }) + }); + + // Serialize and deserialize both generate the schema, so check before + // generating it. + if reg.schema.is_none() { + reg.schema = Some(T::json_schema(schema_generator)); + } } +} - fn json_schema(gen: &mut SchemaGenerator) -> Option { - Some(T::json_schema(gen)) - } +pub trait DeserializeMessage { + fn register_deserialize( + messages: &mut HashMap, + schema_generator: &mut SchemaGenerator, + ); +} - fn to_json(v: &T) -> Result { - serde_json::to_value(v).map_err(|err| SerializationError::from(err)) +impl DeserializeMessage for Supported +where + T: 'static + Send + Sync + DeserializeOwned + DynType, +{ + fn register_deserialize( + messages: &mut HashMap, + schema_generator: &mut SchemaGenerator, + ) { + let reg = &mut messages + .entry(TypeInfo::of::()) + .or_insert(MessageRegistration::new::()); + + reg.operations.deserialize_impl = Some(|builder| { + let deserialize = builder.create_map_block(|message: JsonMessage| { + serde_json::from_value::(message).map_err(|err| err.to_string()) + }); + + let (ok, err) = deserialize + .output + .chain(builder) + .fork_result(|ok| ok.output(), |err| err.output()); + + Ok(DynForkResult { + input: deserialize.input.into(), + ok: ok.into(), + err: err.into(), + }) + }); + + // Serialize and deserialize both generate the schema, so check before + // generating it. + if reg.schema.is_none() { + reg.schema = Some(T::json_schema(schema_generator)); + } } +} - fn serializable() -> bool { - true +impl SerializeMessage for NotSupported { + fn register_serialize(_: &mut HashMap, _: &mut SchemaGenerator) { + // Do nothing } } -pub trait DeserializeMessage { - fn type_name() -> String; - - fn json_schema(gen: &mut SchemaGenerator) -> Option; - - fn from_json(json: serde_json::Value) -> Result; +impl DeserializeMessage for NotSupported { + fn register_deserialize( + _: &mut HashMap, + _: &mut SchemaGenerator, + ) { + // Do nothing + } +} - fn deserializable() -> bool; +pub trait RegisterJson { + fn register_json(); } -#[derive(Default)] -pub struct DefaultDeserializer; +pub struct JsonRegistration { + _ignore: std::marker::PhantomData, +} -impl DeserializeMessage for DefaultDeserializer +impl RegisterJson for JsonRegistration where - T: DeserializeOwned + DynType, + T: 'static + Send + Sync + Serialize + DeserializeOwned, { - fn type_name() -> String { - T::type_name() + fn register_json() { + JsonBuffer::register_for::(); } +} - fn json_schema(gen: &mut SchemaGenerator) -> Option { - Some(T::json_schema(gen)) +impl RegisterJson for JsonRegistration { + fn register_json() { + // Do nothing } +} - fn from_json(json: serde_json::Value) -> Result { - serde_json::from_value::(json).map_err(|err| SerializationError::from(err)) +impl RegisterJson for JsonRegistration { + fn register_json() { + // Do nothing } +} - fn deserializable() -> bool { - true +impl RegisterJson for JsonRegistration { + fn register_json() { + // Do nothing } } -#[derive(Default)] -pub struct OpaqueMessageSerializer; +pub(super) fn register_json() +where + JsonRegistration: RegisterJson, +{ + JsonRegistration::::register_json(); +} -impl SerializeMessage for OpaqueMessageSerializer { - fn type_name() -> String { - std::any::type_name::().to_string() - } +pub struct ImplicitSerialization { + incoming_types: HashMap, + serialized_input: DynInputSlot, +} - fn json_schema(_gen: &mut SchemaGenerator) -> Option { - None +impl ImplicitSerialization { + pub fn new(serialized_input: DynInputSlot) -> Result { + if serialized_input.message_info() != &TypeInfo::of::() { + return Err(DiagramErrorCode::TypeMismatch { + source_type: TypeInfo::of::(), + target_type: *serialized_input.message_info(), + }); + } + + Ok(Self { + serialized_input, + incoming_types: Default::default(), + }) } - fn to_json(_v: &T) -> Result { - Err(SerializationError::NotSupported) + /// Attempt to implicitly serialize an output before passing it into the + /// input slot that this implicit serialization targets. + /// + /// If the incoming type cannot be serialized then it will be returned + /// unchanged as the inner [`Err`]. + pub fn try_implicit_serialize( + &mut self, + incoming: DynOutput, + builder: &mut Builder, + ctx: &mut DiagramContext, + ) -> Result, DiagramErrorCode> { + if incoming.message_info() == &TypeInfo::of::() { + incoming.connect_to(&self.serialized_input, builder)?; + return Ok(Ok(())); + } + + let input = match self.incoming_types.entry(*incoming.message_info()) { + Entry::Occupied(input_slot) => input_slot.get().clone(), + Entry::Vacant(vacant) => { + let Some(serialize) = ctx + .registry + .messages + .try_serialize(incoming.message_info(), builder)? + else { + // We are unable to serialize this type. + return Ok(Err(incoming)); + }; + + serialize.ok.connect_to(&self.serialized_input, builder)?; + + let error_target = ctx.get_implicit_error_target(); + ctx.add_output_into_target(error_target, serialize.err); + + vacant.insert(serialize.input).clone() + } + }; + + incoming.connect_to(&input, builder)?; + + Ok(Ok(())) } - fn serializable() -> bool { - false + /// Implicitly serialize an output. If the incoming message cannot be + /// serialized then treat it is a diagram error. + pub fn implicit_serialize( + &mut self, + incoming: DynOutput, + builder: &mut Builder, + ctx: &mut DiagramContext, + ) -> Result<(), DiagramErrorCode> { + self.try_implicit_serialize(incoming, builder, ctx)? + .map_err(|incoming| DiagramErrorCode::NotSerializable(*incoming.message_info())) } } -#[derive(Default)] -pub struct OpaqueMessageDeserializer; +pub struct ImplicitDeserialization { + deserialized_input: DynInputSlot, + // The serialized input will only be created if a JsonMessage output + // attempts to connect to this operation. Otherwise there is no need to + // create it. + serialized_input: Option, +} -impl DeserializeMessage for OpaqueMessageDeserializer { - fn type_name() -> String { - std::any::type_name::().to_string() +impl ImplicitDeserialization { + pub fn try_new( + deserialized_input: DynInputSlot, + registration: &MessageRegistry, + ) -> Result, DiagramErrorCode> { + if registration + .messages + .get(&deserialized_input.message_info()) + .and_then(|reg| reg.operations.deserialize_impl.as_ref()) + .is_some() + { + return Ok(Some(Self { + deserialized_input, + serialized_input: None, + })); + } + + return Ok(None); } - fn json_schema(_gen: &mut SchemaGenerator) -> Option { - None + pub fn implicit_deserialize( + &mut self, + incoming: DynOutput, + builder: &mut Builder, + ctx: &mut DiagramContext, + ) -> Result<(), DiagramErrorCode> { + if incoming.message_info() == self.deserialized_input.message_info() { + // Connect them directly because they match + return incoming.connect_to(&self.deserialized_input, builder); + } + + if incoming.message_info() == &TypeInfo::of::() { + // Connect to the input for serialized messages + let serialized_input = match self.serialized_input { + Some(serialized_input) => serialized_input, + None => { + let deserialize = ctx + .registry + .messages + .deserialize(self.deserialized_input.message_info(), builder)?; + + deserialize + .ok + .connect_to(&self.deserialized_input, builder)?; + + let error_target = ctx.get_implicit_error_target(); + ctx.add_output_into_target(error_target, deserialize.err); + + self.serialized_input = Some(deserialize.input); + deserialize.input + } + }; + + return incoming.connect_to(&serialized_input, builder); + } + + Err(DiagramErrorCode::TypeMismatch { + source_type: *incoming.message_info(), + target_type: *self.deserialized_input.message_info(), + }) } +} + +pub struct ImplicitStringify { + incoming_types: HashMap, + string_input: DynInputSlot, +} - fn from_json(_json: serde_json::Value) -> Result { - Err(SerializationError::NotSupported) +impl ImplicitStringify { + pub fn new(string_input: DynInputSlot) -> Result { + if string_input.message_info() != &TypeInfo::of::() { + return Err(DiagramErrorCode::TypeMismatch { + source_type: TypeInfo::of::(), + target_type: *string_input.message_info(), + }); + } + + Ok(Self { + string_input, + incoming_types: Default::default(), + }) } - fn deserializable() -> bool { - false + pub fn try_implicit_stringify( + &mut self, + incoming: DynOutput, + builder: &mut Builder, + ctx: &mut DiagramContext, + ) -> Result, DiagramErrorCode> { + if incoming.message_info() == &TypeInfo::of::() { + incoming.connect_to(&self.string_input, builder)?; + return Ok(Ok(())); + } + + let input = match self.incoming_types.entry(*incoming.message_info()) { + Entry::Occupied(input_slot) => input_slot.get().clone(), + Entry::Vacant(vacant) => { + let Some(stringify) = ctx + .registry + .messages + .try_to_string(incoming.message_info(), builder)? + else { + // We are unable to stringify this type. + return Ok(Err(incoming)); + }; + + stringify.output.connect_to(&self.string_input, builder)?; + vacant.insert(stringify.input).clone() + } + }; + + incoming.connect_to(&input, builder)?; + + Ok(Ok(())) } } diff --git a/src/diagram/split_serialized.rs b/src/diagram/split_schema.rs similarity index 84% rename from src/diagram/split_serialized.rs rename to src/diagram/split_schema.rs index 333c8be4..a841051d 100644 --- a/src/diagram/split_serialized.rs +++ b/src/diagram/split_schema.rs @@ -15,27 +15,26 @@ * */ -use std::{any::TypeId, collections::HashMap, usize}; +use std::{collections::HashMap, usize}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use serde_json::Value; -use tracing::debug; use crate::{ - Builder, Chain, ForRemaining, FromSequential, FromSpecific, ListSplitKey, MapSplitKey, + Builder, ForRemaining, FromSequential, FromSpecific, ListSplitKey, MapSplitKey, OperationResult, SplitDispatcher, Splittable, }; use super::{ - impls::{DefaultImpl, NotSupported}, - join::register_join_impl, - DiagramError, DynOutput, MessageRegistry, NextOperation, SerializeMessage, + supported::*, type_info::TypeInfo, BuildDiagramOperation, BuildStatus, DiagramContext, + DiagramErrorCode, DynInputSlot, DynOutput, MessageRegistration, MessageRegistry, NextOperation, + OperationId, PerformForkClone, SerializeMessage, }; -#[derive(Debug, Serialize, Deserialize, JsonSchema)] +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] #[serde(rename_all = "snake_case")] -pub struct SplitOp { +pub struct SplitSchema { #[serde(default)] pub(super) sequential: Vec, @@ -45,6 +44,28 @@ pub struct SplitOp { pub(super) remaining: Option, } +impl BuildDiagramOperation for SplitSchema { + fn build_diagram_operation( + &self, + id: &OperationId, + builder: &mut Builder, + ctx: &mut DiagramContext, + ) -> Result { + let Some(sample_input) = ctx.infer_input_type_into_target(id) else { + // There are no outputs ready for this target, so we can't do + // anything yet. The builder should try again later. + return Ok(BuildStatus::defer("waiting for an input")); + }; + + let split = ctx.registry.messages.split(sample_input, self, builder)?; + ctx.set_input_for_target(id, split.input)?; + for (target, output) in split.outputs { + ctx.add_output_into_target(target, output); + } + Ok(BuildStatus::Finished) + } +} + impl Splittable for Value { type Key = MapSplitKey; type Item = (JsonPosition, Value); @@ -158,106 +179,65 @@ impl FromSpecific for ListSplitKey { } #[derive(Debug)] -pub struct DynSplitOutputs<'a> { - pub(super) outputs: HashMap<&'a NextOperation, DynOutput>, - pub(super) remaining: DynOutput, +pub struct DynSplit { + pub(super) input: DynInputSlot, + pub(super) outputs: Vec<(NextOperation, DynOutput)>, } -pub(super) fn split_chain<'a, T>( - chain: Chain, - split_op: &'a SplitOp, -) -> Result, DiagramError> -where - T: Send + Sync + 'static + Splittable, - T::Key: FromSequential + FromSpecific + ForRemaining, -{ - debug!( - "split chain of type: {:?}, op: {:?}", - TypeId::of::(), - split_op - ); - - enum SeqOrKey<'inner> { - Seq(usize), - Key(&'inner String), - } - - chain.split(|mut sb| -> Result { - let outputs: HashMap<&NextOperation, DynOutput> = split_op - .sequential - .iter() - .enumerate() - .map(|(i, op_id)| (SeqOrKey::Seq(i), op_id)) - .chain( - split_op - .keyed - .iter() - .map(|(k, op_id)| (SeqOrKey::Key(k), op_id)), - ) - .map( - |(ki, op_id)| -> Result<(&NextOperation, DynOutput), DiagramError> { - match ki { - SeqOrKey::Seq(i) => Ok((op_id, sb.sequential_output(i)?.into())), - SeqOrKey::Key(k) => Ok((op_id, sb.specific_output(k.clone())?.into())), - } - }, - ) - .collect::, _>>()?; - let split_outputs = DynSplitOutputs { - outputs, - remaining: sb.remaining_output()?.into(), - }; - debug!("splitted outputs: {:?}", split_outputs); - Ok(split_outputs) - }) -} - -pub trait DynSplit { - const SUPPORTED: bool; - - fn dyn_split<'a>( +pub trait RegisterSplit { + fn perform_split( + split_op: &SplitSchema, builder: &mut Builder, - output: DynOutput, - split_op: &'a SplitOp, - ) -> Result, DiagramError>; + ) -> Result; fn on_register(registry: &mut MessageRegistry); } -impl DynSplit for NotSupported { - const SUPPORTED: bool = false; - - fn dyn_split<'a>( - _builder: &mut Builder, - _output: DynOutput, - _split_op: &'a SplitOp, - ) -> Result, DiagramError> { - Err(DiagramError::NotSplittable) - } - - fn on_register(_registry: &mut MessageRegistry) {} -} - -impl DynSplit for DefaultImpl +impl RegisterSplit for Supported<(T, Serializer, Cloneable)> where T: Send + Sync + 'static + Splittable, T::Key: FromSequential + FromSpecific + ForRemaining, Serializer: SerializeMessage + SerializeMessage>, + Cloneable: PerformForkClone + PerformForkClone>, { - const SUPPORTED: bool = true; - - fn dyn_split<'a>( + fn perform_split( + split_op: &SplitSchema, builder: &mut Builder, - output: DynOutput, - split_op: &'a SplitOp, - ) -> Result, DiagramError> { - let chain = output.into_output::()?.chain(builder); - split_chain(chain, split_op) + ) -> Result { + let (input, split) = builder.create_split::(); + let mut outputs = Vec::new(); + let mut split = split.build(builder); + for (key, target) in &split_op.keyed { + outputs.push((target.clone(), split.specific_output(key.clone())?.into())); + } + + for (i, target) in split_op.sequential.iter().enumerate() { + outputs.push((target.clone(), split.sequential_output(i)?.into())) + } + + if let Some(remaining_target) = &split_op.remaining { + outputs.push((remaining_target.clone(), split.remaining_output()?.into())); + } + + Ok(DynSplit { + input: input.into(), + outputs, + }) } fn on_register(registry: &mut MessageRegistry) { + let ops = &mut registry + .messages + .entry(TypeInfo::of::()) + .or_insert(MessageRegistration::new::()) + .operations; + + ops.split_impl = Some(Self::perform_split); + registry.register_serialize::(); - register_join_impl::(registry); + registry.register_fork_clone::(); + registry.register_serialize::, Serializer>(); + registry.register_fork_clone::, Cloneable>(); } } @@ -415,6 +395,7 @@ mod tests { let result = fixture .spawn_and_run(&diagram, serde_json::Value::from(4)) .unwrap(); + assert!(fixture.context.no_unhandled_errors()); assert_eq!(result[1], 1); } @@ -454,6 +435,7 @@ mod tests { let result = fixture .spawn_and_run(&diagram, serde_json::Value::from(4)) .unwrap(); + assert!(fixture.context.no_unhandled_errors()); assert_eq!(result[1], 2); } @@ -497,6 +479,7 @@ mod tests { let result = fixture .spawn_and_run(&diagram, serde_json::Value::from(4)) .unwrap(); + assert!(fixture.context.no_unhandled_errors()); assert_eq!(result[1], 2); } @@ -537,6 +520,7 @@ mod tests { let result = fixture .spawn_and_run(&diagram, serde_json::Value::from(4)) .unwrap(); + assert!(fixture.context.no_unhandled_errors()); // "a" is "eaten" up by the keyed path, so we should be the result of "b". assert_eq!(result[1], 2); } @@ -578,6 +562,7 @@ mod tests { let result = fixture .spawn_and_run(&diagram, serde_json::Value::from(4)) .unwrap(); + assert!(fixture.context.no_unhandled_errors()); assert_eq!(result[1], 2); } @@ -617,6 +602,7 @@ mod tests { serde_json::to_value(HashMap::from([("test".to_string(), 1)])).unwrap(), ) .unwrap(); + assert!(fixture.context.no_unhandled_errors()); assert_eq!(result, 1); } } diff --git a/src/diagram/supported.rs b/src/diagram/supported.rs new file mode 100644 index 00000000..344d27d1 --- /dev/null +++ b/src/diagram/supported.rs @@ -0,0 +1,32 @@ +/* + * Copyright (C) 2025 Open Source Robotics Foundation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * +*/ + +/// A struct to provide the default implementation for various operations. +pub struct Supported { + _ignore: std::marker::PhantomData, +} + +impl Supported { + pub fn new() -> Self { + Self { + _ignore: Default::default(), + } + } +} + +/// A struct to provide "not supported" implementations for various operations. +pub struct NotSupported; diff --git a/src/diagram/testing.rs b/src/diagram/testing.rs index 0e170d36..975f4f4b 100644 --- a/src/diagram/testing.rs +++ b/src/diagram/testing.rs @@ -4,13 +4,10 @@ use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use crate::{ - testing::TestingContext, Builder, RequestExt, RunCommandsOnWorldExt, Service, StreamPack, + testing::TestingContext, Builder, JsonMessage, RequestExt, RunCommandsOnWorldExt, Service, }; -use super::{ - Diagram, DiagramElementRegistry, DiagramError, DiagramStart, DiagramTerminate, - NodeBuilderOptions, -}; +use super::{Diagram, DiagramElementRegistry, DiagramError, NodeBuilderOptions}; pub(super) struct DiagramTestFixture { pub(super) context: TestingContext, @@ -25,24 +22,17 @@ impl DiagramTestFixture { } } - pub(super) fn spawn_workflow( + /// Equivalent to `self.spawn_workflow::(diagram)` + pub(super) fn spawn_json_io_workflow( &mut self, diagram: &Diagram, - ) -> Result, DiagramError> { + ) -> Result, DiagramError> { self.context .app .world .command(|cmds| diagram.spawn_workflow(cmds, &self.registry)) } - /// Equivalent to `self.spawn_workflow::<()>(diagram)` - pub(super) fn spawn_io_workflow( - &mut self, - diagram: &Diagram, - ) -> Result, DiagramError> { - self.spawn_workflow::<()>(diagram) - } - /// Spawns a workflow from a diagram then run the workflow until completion. /// Returns the result of the workflow. pub(super) fn spawn_and_run( @@ -50,7 +40,7 @@ impl DiagramTestFixture { diagram: &Diagram, request: serde_json::Value, ) -> Result> { - let workflow = self.spawn_workflow::<()>(diagram)?; + let workflow = self.spawn_json_io_workflow(diagram)?; let mut promise = self .context .command(|cmds| cmds.request(request, workflow).take_response()); @@ -97,13 +87,10 @@ fn opaque_response(_: i64) -> Unserializable { /// create a new node registry with some basic nodes registered fn new_registry_with_basic_nodes() -> DiagramElementRegistry { let mut registry = DiagramElementRegistry::new(); - registry - .opt_out() - .no_response_cloning() - .register_node_builder( - NodeBuilderOptions::new("multiply3_uncloneable"), - |builder: &mut Builder, _config: ()| builder.create_map_block(multiply3_uncloneable), - ); + registry.opt_out().no_cloning().register_node_builder( + NodeBuilderOptions::new("multiply3_uncloneable"), + |builder: &mut Builder, _config: ()| builder.create_map_block(multiply3_uncloneable), + ); registry.register_node_builder( NodeBuilderOptions::new("multiply3"), |builder: &mut Builder, _config: ()| builder.create_map_block(multiply3), @@ -122,24 +109,27 @@ fn new_registry_with_basic_nodes() -> DiagramElementRegistry { registry .opt_out() - .no_request_deserializing() - .no_response_serializing() - .no_response_cloning() + .no_deserializing() + .no_serializing() + .no_cloning() .register_node_builder( NodeBuilderOptions::new("opaque"), |builder: &mut Builder, _config: ()| builder.create_map_block(opaque), ); registry .opt_out() - .no_request_deserializing() + .no_deserializing() + .no_serializing() + .no_cloning() .register_node_builder( NodeBuilderOptions::new("opaque_request"), |builder: &mut Builder, _config: ()| builder.create_map_block(opaque_request), ); registry .opt_out() - .no_response_serializing() - .no_response_cloning() + .no_deserializing() + .no_serializing() + .no_cloning() .register_node_builder( NodeBuilderOptions::new("opaque_response"), |builder: &mut Builder, _config: ()| builder.create_map_block(opaque_response), diff --git a/src/diagram/transform.rs b/src/diagram/transform_schema.rs similarity index 55% rename from src/diagram/transform.rs rename to src/diagram/transform_schema.rs index c98d5971..32669d3b 100644 --- a/src/diagram/transform.rs +++ b/src/diagram/transform_schema.rs @@ -1,14 +1,33 @@ -use std::{any::TypeId, error::Error}; +/* + * Copyright (C) 2025 Open Source Robotics Foundation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * +*/ + +use std::error::Error; use cel_interpreter::{Context, ExecutionError, ParseError, Program}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use thiserror::Error; -use tracing::debug; -use crate::{Builder, Output}; +use crate::{Builder, JsonMessage}; -use super::{DiagramElementRegistry, DiagramError, DynOutput, NextOperation}; +use super::{ + BuildDiagramOperation, BuildStatus, DiagramContext, DiagramErrorCode, NextOperation, + OperationId, +}; #[derive(Error, Debug)] pub enum TransformError { @@ -22,50 +41,62 @@ pub enum TransformError { Other(#[from] Box), } -#[derive(Debug, Serialize, Deserialize, JsonSchema)] +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] #[serde(rename_all = "snake_case")] -pub struct TransformOp { +pub struct TransformSchema { pub(super) cel: String, pub(super) next: NextOperation, + /// Specify what happens if an error occurs during the transformation. If + /// you specify a target for on_error, then an error message will be sent to + /// that target. You can set this to `{ "builtin": "dispose" }` to simply + /// ignore errors. + /// + /// If left unspecified, a failure will be treated like an implicit operation + /// failure and behave according to `on_implicit_error`. + #[serde(default)] + pub(super) on_error: Option, } -pub(super) fn transform_output( - builder: &mut Builder, - registry: &DiagramElementRegistry, - output: DynOutput, - transform_op: &TransformOp, -) -> Result, DiagramError> { - debug!("transform output: {:?}, op: {:?}", output, transform_op); - - let json_output = if output.type_id == TypeId::of::() { - output.into_output() - } else { - registry.messages.serialize(builder, output) - }?; - - let program = Program::compile(&transform_op.cel).map_err(|err| TransformError::Parse(err))?; - let transform_node = builder.create_map_block( - move |req: serde_json::Value| -> Result { - let mut context = Context::default(); - context - .add_variable("request", req) - // cannot keep the original error because it is not Send + Sync - .map_err(|err| TransformError::Other(err.to_string().into()))?; - program - .execute(&context)? - .json() - // cel_interpreter::json is private so we have to type erase ConvertToJsonError - .map_err(|err| TransformError::Other(err.to_string().into())) - }, - ); - builder.connect(json_output, transform_node.input); - let transformed_output = transform_node - .output - .chain(builder) - .cancel_on_err() - .output(); - debug!("transformed output: {:?}", transformed_output); - Ok(transformed_output) +impl BuildDiagramOperation for TransformSchema { + fn build_diagram_operation( + &self, + id: &OperationId, + builder: &mut Builder, + ctx: &mut DiagramContext, + ) -> Result { + let program = Program::compile(&self.cel).map_err(TransformError::Parse)?; + let node = builder.create_map_block( + move |req: JsonMessage| -> Result { + let mut context = Context::default(); + context + .add_variable("request", req) + // cannot keep the original error because it is not Send + Sync + .map_err(|err| TransformError::Other(err.to_string().into()))?; + program + .execute(&context)? + .json() + // cel_interpreter::json is private so we have to type erase ConvertToJsonError + .map_err(|err| TransformError::Other(err.to_string().into())) + }, + ); + + let error_target = self.on_error.clone().unwrap_or( + // If no error target was explicitly given then treat this as an + // implicit error. + ctx.get_implicit_error_target(), + ); + + let (ok, _) = node.output.chain(builder).fork_result( + |ok| ok.output(), + |err| { + ctx.add_output_into_target(error_target.clone(), err.output().into()); + }, + ); + + ctx.set_input_for_target(id, node.input.into())?; + ctx.add_output_into_target(self.next.clone(), ok.into()); + Ok(BuildStatus::Finished) + } } #[cfg(test)] @@ -100,6 +131,7 @@ mod tests { let result = fixture .spawn_and_run(&diagram, serde_json::Value::from(4)) .unwrap(); + assert!(fixture.context.no_unhandled_errors()); assert_eq!(result, 777); } @@ -123,6 +155,7 @@ mod tests { let result = fixture .spawn_and_run(&diagram, serde_json::Value::from(4)) .unwrap(); + assert!(fixture.context.no_unhandled_errors()); assert_eq!(result, 777); } @@ -146,6 +179,7 @@ mod tests { let result = fixture .spawn_and_run(&diagram, serde_json::Value::from(4)) .unwrap(); + assert!(fixture.context.no_unhandled_errors()); assert_eq!(result, 12); } @@ -169,6 +203,7 @@ mod tests { let result = fixture .spawn_and_run(&diagram, serde_json::Value::from(4)) .unwrap(); + assert!(fixture.context.no_unhandled_errors()); assert_eq!(result["request"], 4); assert_eq!(result["seven"], 7); } @@ -196,6 +231,7 @@ mod tests { }); let result = fixture.spawn_and_run(&diagram, request).unwrap(); + assert!(fixture.context.no_unhandled_errors()); assert_eq!(result, 40); } } diff --git a/src/diagram/type_info.rs b/src/diagram/type_info.rs new file mode 100644 index 00000000..611afb42 --- /dev/null +++ b/src/diagram/type_info.rs @@ -0,0 +1,52 @@ +use std::{ + any::{type_name, Any, TypeId}, + fmt::Display, + hash::Hash, +}; + +use serde::Serialize; + +#[derive(Copy, Clone, Debug, Eq)] +pub struct TypeInfo { + pub type_id: TypeId, + pub type_name: &'static str, +} + +impl TypeInfo { + pub(super) fn of() -> Self + where + T: Any, + { + Self { + type_id: TypeId::of::(), + type_name: type_name::(), + } + } +} + +impl Hash for TypeInfo { + fn hash(&self, state: &mut H) { + self.type_id.hash(state) + } +} + +impl PartialEq for TypeInfo { + fn eq(&self, other: &Self) -> bool { + self.type_id == other.type_id + } +} + +impl Display for TypeInfo { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.type_name.fmt(f) + } +} + +impl Serialize for TypeInfo { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_str(self.type_name) + } +} diff --git a/src/diagram/unzip.rs b/src/diagram/unzip_schema.rs similarity index 60% rename from src/diagram/unzip.rs rename to src/diagram/unzip_schema.rs index 1afd11e7..b84828c0 100644 --- a/src/diagram/unzip.rs +++ b/src/diagram/unzip_schema.rs @@ -1,81 +1,115 @@ +/* + * Copyright (C) 2025 Open Source Robotics Foundation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * +*/ + use bevy_utils::all_tuples_with_size; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -use tracing::debug; use crate::Builder; use super::{ - impls::{DefaultImplMarker, NotSupportedMarker}, - join::register_join_impl, - DiagramError, DynOutput, MessageRegistry, NextOperation, SerializeMessage, + supported::*, BuildDiagramOperation, BuildStatus, DiagramContext, DiagramErrorCode, + DynInputSlot, DynOutput, MessageRegistry, NextOperation, OperationId, PerformForkClone, + SerializeMessage, TypeInfo, }; #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] #[serde(rename_all = "snake_case")] -pub struct UnzipOp { +pub struct UnzipSchema { pub(super) next: Vec, } -pub trait DynUnzip { - /// Returns a list of type names that this message unzips to. - fn output_types(&self) -> Vec<&'static str>; - - fn dyn_unzip( +impl BuildDiagramOperation for UnzipSchema { + fn build_diagram_operation( &self, + id: &OperationId, builder: &mut Builder, - output: DynOutput, - ) -> Result, DiagramError>; + ctx: &mut DiagramContext, + ) -> Result { + let Some(inferred_type) = ctx.infer_input_type_into_target(id) else { + // There are no outputs ready for this target, so we can't do + // anything yet. The builder should try again later. + return Ok(BuildStatus::defer("waiting for an input")); + }; + + let unzip = ctx.registry.messages.unzip(inferred_type)?; + let actual_output = unzip.output_types(); + if actual_output.len() != self.next.len() { + return Err(DiagramErrorCode::UnzipMismatch { + expected: self.next.len(), + actual: unzip.output_types().len(), + elements: actual_output, + }); + } - /// Called when a node is registered. - fn on_register(&self, registry: &mut MessageRegistry); -} + let unzip = unzip.perform_unzip(builder)?; -impl DynUnzip for NotSupportedMarker { - fn output_types(&self) -> Vec<&'static str> { - Vec::new() + ctx.set_input_for_target(id, unzip.input)?; + for (target, output) in self.next.iter().zip(unzip.outputs) { + ctx.add_output_into_target(target.clone(), output); + } + Ok(BuildStatus::Finished) } +} - fn dyn_unzip( - &self, - _builder: &mut Builder, - _output: DynOutput, - ) -> Result, DiagramError> { - Err(DiagramError::NotUnzippable) - } +pub struct DynUnzip { + input: DynInputSlot, + outputs: Vec, +} - fn on_register(&self, _registry: &mut MessageRegistry) {} +pub trait PerformUnzip { + /// Returns a list of type names that this message unzips to. + fn output_types(&self) -> Vec; + + fn perform_unzip(&self, builder: &mut Builder) -> Result; + + /// Called when a node is registered. + fn on_register(&self, registry: &mut MessageRegistry); } macro_rules! dyn_unzip_impl { ($len:literal, $(($P:ident, $o:ident)),*) => { - impl<$($P),*, Serializer> DynUnzip for DefaultImplMarker<(($($P,)*), Serializer)> + impl<$($P),*, Serializer, Cloneable> PerformUnzip for Supported<(($($P,)*), Serializer, Cloneable)> where $($P: Send + Sync + 'static),*, Serializer: $(SerializeMessage<$P> +)* $(SerializeMessage> +)*, + Cloneable: $(PerformForkClone<$P> +)* $(PerformForkClone> +)*, { - fn output_types(&self) -> Vec<&'static str> { + fn output_types(&self) -> Vec { vec![$( - std::any::type_name::<$P>(), + TypeInfo::of::<$P>(), )*] } - fn dyn_unzip( + fn perform_unzip( &self, builder: &mut Builder, - output: DynOutput - ) -> Result, DiagramError> { - debug!("unzip output: {:?}", output); - let mut outputs: Vec = Vec::with_capacity($len); - let chain = output.into_output::<($($P,)*)>()?.chain(builder); - let ($($o,)*) = chain.unzip(); + ) -> Result { + let (input, ($($o,)*)) = builder.create_unzip::<($($P,)*)>(); + let mut outputs: Vec = Vec::with_capacity($len); $({ outputs.push($o.into()); })* - debug!("unzipped outputs: {:?}", outputs); - Ok(outputs) + Ok(DynUnzip { + input: input.into(), + outputs, + }) } fn on_register(&self, registry: &mut MessageRegistry) @@ -84,11 +118,7 @@ macro_rules! dyn_unzip_impl { // For a tuple of (T1, T2, T3), registers serialize for T1, T2 and T3. $( registry.register_serialize::<$P, Serializer>(); - )* - - // Register join impls for T1, T2, T3... - $( - register_join_impl::<$P, Serializer>(registry); + registry.register_fork_clone::<$P, Cloneable>(); )* } } @@ -102,7 +132,7 @@ mod tests { use serde_json::json; use test_log::test; - use crate::{diagram::testing::DiagramTestFixture, Diagram, DiagramError}; + use crate::{diagram::testing::DiagramTestFixture, Diagram, DiagramErrorCode}; #[test] fn test_unzip_not_unzippable() { @@ -125,8 +155,12 @@ mod tests { })) .unwrap(); - let err = fixture.spawn_io_workflow(&diagram).unwrap_err(); - assert!(matches!(err, DiagramError::NotUnzippable), "{}", err); + let err = fixture.spawn_json_io_workflow(&diagram).unwrap_err(); + assert!( + matches!(err.code, DiagramErrorCode::NotUnzippable), + "{}", + err + ); } #[test] @@ -165,8 +199,15 @@ mod tests { })) .unwrap(); - let err = fixture.spawn_io_workflow(&diagram).unwrap_err(); - assert!(matches!(err, DiagramError::NotUnzippable)); + let err = fixture.spawn_json_io_workflow(&diagram).unwrap_err(); + assert!(matches!( + err.code, + DiagramErrorCode::UnzipMismatch { + expected: 3, + actual: 2, + .. + } + )); } #[test] @@ -193,6 +234,7 @@ mod tests { let result = fixture .spawn_and_run(&diagram, serde_json::Value::from(4)) .unwrap(); + assert!(fixture.context.no_unhandled_errors()); assert_eq!(result, 20); } @@ -211,7 +253,10 @@ mod tests { }, "unzip": { "type": "unzip", - "next": ["op2"], + "next": [ + "op2", + { "builtin": "dispose" }, + ], }, "op2": { "type": "node", @@ -225,6 +270,7 @@ mod tests { let result = fixture .spawn_and_run(&diagram, serde_json::Value::from(4)) .unwrap(); + assert!(fixture.context.no_unhandled_errors()); assert_eq!(result, 36); } @@ -243,10 +289,7 @@ mod tests { }, "unzip": { "type": "unzip", - "next": ["dispose", "op2"], - }, - "dispose": { - "type": "dispose", + "next": [{ "builtin": "dispose" }, "op2"], }, "op2": { "type": "node", @@ -260,6 +303,7 @@ mod tests { let result = fixture .spawn_and_run(&diagram, serde_json::Value::from(4)) .unwrap(); + assert!(fixture.context.no_unhandled_errors()); assert_eq!(result, 60); } } diff --git a/src/diagram/workflow_builder.rs b/src/diagram/workflow_builder.rs index c1156be9..49c16b60 100644 --- a/src/diagram/workflow_builder.rs +++ b/src/diagram/workflow_builder.rs @@ -1,488 +1,740 @@ -use std::{any::TypeId, collections::HashMap}; +/* + * Copyright (C) 2025 Open Source Robotics Foundation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * +*/ -use tracing::{debug, warn}; - -use crate::{ - diagram::join::serialize_and_join, unknown_diagram_error, Builder, InputSlot, Output, - StreamPack, +use std::{ + borrow::Cow, + collections::{hash_map::Entry, HashMap}, }; +use crate::{AnyBuffer, BufferIdentifier, BufferMap, Builder, JsonMessage, Scope, StreamPack}; + use super::{ - fork_clone::DynForkClone, impls::DefaultImpl, split_chain, transform::transform_output, - BuiltinTarget, Diagram, DiagramElementRegistry, DiagramError, DiagramOperation, DiagramScope, - DynInputSlot, DynOutput, NextOperation, NodeOp, OperationId, SourceOperation, + BufferInputs, BuiltinTarget, Diagram, DiagramElementRegistry, DiagramError, DiagramErrorCode, + DiagramOperation, DynInputSlot, DynOutput, ImplicitDeserialization, ImplicitSerialization, + ImplicitStringify, NextOperation, OperationId, TypeInfo, }; -struct Vertex<'a> { - op_id: &'a OperationId, - op: &'a DiagramOperation, - in_edges: Vec, - out_edges: Vec, -} - -struct Edge<'a> { - source: SourceOperation, - target: &'a NextOperation, - state: EdgeState<'a>, -} - -enum EdgeState<'a> { - Ready { - output: DynOutput, - /// The node that initially produces the output, may be `None` if there is no origin. - /// e.g. The entry point, or if the output passes through a `join` operation which - /// results in multiple origins. - origin: Option<&'a NodeOp>, - }, - Pending, +#[derive(Default)] +struct DiagramConstruction { + connect_into_target: HashMap>, + // We use a separate hashmap for OperationId vs BuiltinTarget so we can + // efficiently fetch with an &OperationId + outputs_to_operation_target: HashMap>, + outputs_to_builtin_target: HashMap>, + buffers: HashMap, } -pub(super) fn create_workflow<'a, Streams: StreamPack>( - scope: DiagramScope, - builder: &mut Builder, - registry: &DiagramElementRegistry, - diagram: &'a Diagram, -) -> Result<(), DiagramError> { - // first create all the vertices - let mut vertices: HashMap<&OperationId, Vertex> = diagram - .ops - .iter() - .map(|(op_id, op)| { - ( - op_id, - Vertex { - op_id, - op, - in_edges: Vec::new(), - out_edges: Vec::new(), - }, - ) - }) - .collect(); - - // init with some capacity to reduce resizing. HashMap for faster removal. - // NOTE: There are many `unknown_diagram_errors!()` used when accessing this. - // In theory these accesses should never fail because the keys come from - // `vertices` which are built using the same data as `edges`. But we do modify - // `edges` while we are building the workflow so if an unknown error occurs, it is - // likely due to some logic issue in the algorithm. - let mut edges: HashMap = HashMap::with_capacity(diagram.ops.len() * 2); - - // process start separately because we need to consume the scope input - match &diagram.start { - NextOperation::Builtin { builtin } => match builtin { - BuiltinTarget::Terminate => { - // such a workflow is equivalent to an no-op. - builder.connect(scope.input, scope.terminate); - return Ok(()); +impl DiagramConstruction { + fn is_finished(&self) -> bool { + for outputs in self.outputs_to_builtin_target.values() { + if !outputs.is_empty() { + return false; } - BuiltinTarget::Dispose => { - // bevy_impulse will immediate stop with an `CancellationCause::Unreachable` error - // if trying to run such a workflow. - return Ok(()); + } + + for outputs in self.outputs_to_operation_target.values() { + if !outputs.is_empty() { + return false; } - }, - NextOperation::Target(op_id) => { - edges.insert( - edges.len(), - Edge { - source: SourceOperation::Builtin { - builtin: super::BuiltinSource::Start, - }, - target: &diagram.start, - state: EdgeState::Ready { - output: scope.input.into(), - origin: None, - }, - }, - ); - vertices - .get_mut(&op_id) - .ok_or(DiagramError::OperationNotFound(op_id.clone()))? - .in_edges - .push(0); } - }; - let mut inputs: HashMap<&OperationId, DynInputSlot> = HashMap::with_capacity(diagram.ops.len()); + return true; + } +} - let mut terminate_edges: Vec = Vec::new(); +pub struct DiagramContext<'a> { + construction: &'a mut DiagramConstruction, + pub diagram: &'a Diagram, + pub registry: &'a DiagramElementRegistry, +} - let mut add_edge = |source: SourceOperation, - target: &'a NextOperation, - state: EdgeState<'a>| - -> Result<(), DiagramError> { - let source_id = if let SourceOperation::Source(source) = &source { - Some(source.clone()) - } else { - None - }; +impl<'a> DiagramContext<'a> { + /// Get all the currently known outputs that are aimed at this target operation. + /// + /// During the [`BuildDiagramOperation`] phase this will eventually contain + /// all outputs targeting this operation that are explicitly listed in the + /// diagram. It will never contain outputs that implicitly target this + /// operation. + /// + /// During the [`ConnectIntoTarget`] phase this will not contain any outputs + /// except new outputs added during the current call to [`ConnectIntoTarget`], + /// so this function is generally not useful during that phase. + pub fn get_outputs_into_operation_target(&self, id: &OperationId) -> Option<&Vec> { + self.construction.outputs_to_operation_target.get(id) + } - edges.insert( - edges.len(), - Edge { - source, - target, - state, - }, - ); - let new_edge_id = edges.len() - 1; - - if let Some(source_id) = source_id { - let source_vertex = vertices - .get_mut(&source_id) - .ok_or_else(|| DiagramError::OperationNotFound(source_id.clone()))?; - source_vertex.out_edges.push(new_edge_id); - } + /// Infer the [`TypeInfo`] for the input messages into the specified operation. + /// + /// If this returns [`None`] then not enough of the diagram has been built + /// yet to infer the input type. In that case you can return something like + /// `Ok(BuildStatus::defer("waiting for an input"))`. + /// + /// The workflow builder will ensure that all outputs targeting this + /// operation are compatible with this message type. + /// + /// During the [`ConnectIntoTarget`] phase all information about outputs + /// going into this target will be drained, so this function is generally + /// not useful during that phase. If you need to retain this information + /// during the [`ConnectIntoTarget`] phase then you should capture the + /// [`TypeInfo`] that you receive from this function during the + /// [`BuildDiagramOperation`] phase. + pub fn infer_input_type_into_target(&self, id: &OperationId) -> Option<&TypeInfo> { + self.get_outputs_into_operation_target(id) + .and_then(|outputs| outputs.first()) + .map(|o| o.message_info()) + } + /// Add an output to connect into a target. + /// + /// This can be used during both the [`BuildDiagramOperation`] phase and the + /// [`ConnectIntoTarget`] phase. + /// + /// # Arguments + /// + /// * target - The operation that needs to receive the output + /// * output - The output channel that needs to be connected into the target. + pub fn add_output_into_target(&mut self, target: NextOperation, output: DynOutput) { match target { - NextOperation::Target(target) => { - let target_vertex = vertices - .get_mut(target) - .ok_or_else(|| DiagramError::OperationNotFound(target.clone()))?; - target_vertex.in_edges.push(new_edge_id); + NextOperation::Target(id) => { + self.construction + .outputs_to_operation_target + .entry(id) + .or_default() + .push(output); + } + NextOperation::Builtin { builtin } => { + self.construction + .outputs_to_builtin_target + .entry(builtin) + .or_default() + .push(output); } - NextOperation::Builtin { builtin } => match builtin { - BuiltinTarget::Terminate => { - terminate_edges.push(new_edge_id); - } - BuiltinTarget::Dispose => {} - }, } + } + + /// Set the input slot of an operation. This should not be called more than + /// once per operation, because only one input slot can be used for any + /// operation. + /// + /// This should be used during the [`BuildDiagramOperation`] phase to set + /// the input slot of each operation with standard connection behavior. + /// Standard connection behavior means that implicit serialization or + /// deserialization will be applied to its inputs as needed to match the + /// [`DynInputSlot`] that you provide. + /// + /// If you need non-standard connection behavior then you can use + /// [`Self::set_connect_into_target`] to set the exact connection behavior. + /// No implicit behaviors will be provided for [`Self::set_connect_into_target`], + /// but you can enable those behaviors using [`ImplicitSerialization`] and + /// [`ImplicitDeserialization`]. + /// + /// This should never be used during the [`ConnectIntoTarget`] phase because + /// all connection behaviors must already be set by then. + pub fn set_input_for_target( + &mut self, + operation: &OperationId, + input: DynInputSlot, + ) -> Result<(), DiagramErrorCode> { + match self + .construction + .connect_into_target + .entry(NextOperation::Target(operation.clone())) + { + Entry::Occupied(_) => { + return Err(DiagramErrorCode::MultipleInputsCreated(operation.clone())); + } + Entry::Vacant(vacant) => { + vacant.insert(standard_input_connection(input, &self.registry)?); + } + } + Ok(()) - }; - - // create all the edges - for (op_id, op) in &diagram.ops { - match op { - DiagramOperation::Node(node_op) => { - let reg = registry.get_node_registration(&node_op.builder)?; - let n = reg.create_node(builder, node_op.config.clone())?; - inputs.insert(op_id, n.input); - add_edge( - op_id.clone().into(), - &node_op.next, - EdgeState::Ready { - output: n.output.into(), - origin: Some(node_op), - }, - )?; + } + + /// Set the implementation for how outputs connect into this target. This is + /// a more general method than [`Self::set_input_for_target`]. + /// + /// There will be no additional connection behavior beyond what you specify + /// with the object passed in for `connect`. That means we will not + /// automatically add implicit serialization or deserialization when you use + /// this method. If you want your connection behavior to have those implicit + /// operations you can use [`ImplicitSerialization`] and + /// [`ImplicitDeserialization`] inside your `connect` implementation. + /// + /// This should never be necessary to use this during the [`ConnectIntoTarget`] + /// phase because all connection behaviors should already be set by then. + pub fn set_connect_into_target( + &mut self, + operation: &OperationId, + connect: C, + ) -> Result<(), DiagramErrorCode> { + let connect = Box::new(connect); + match self + .construction + .connect_into_target + .entry(NextOperation::Target(operation.clone())) + { + Entry::Occupied(_) => { + return Err(DiagramErrorCode::MultipleInputsCreated(operation.clone())); } - DiagramOperation::ForkClone(fork_clone_op) => { - for next_op_id in fork_clone_op.next.iter() { - add_edge(op_id.clone().into(), next_op_id, EdgeState::Pending)?; - } + Entry::Vacant(vacant) => { + vacant.insert(connect); } - DiagramOperation::Unzip(unzip_op) => { - for next_op_id in unzip_op.next.iter() { - add_edge(op_id.clone().into(), next_op_id, EdgeState::Pending)?; - } + } + Ok(()) + } + + /// Same as [`Self::set_connect_into_target`] but you can pass in a closure. + /// + /// This is equivalent to doing + /// `set_connect_into_target(operation, ConnectionCallback(connect))`. + pub fn set_connect_into_target_callback( + &mut self, + operation: &OperationId, + connect: F, + ) -> Result<(), DiagramErrorCode> + where + F: FnMut(DynOutput, &mut Builder, &mut DiagramContext) -> Result<(), DiagramErrorCode> + + 'static, + { + self.set_connect_into_target(operation, ConnectionCallback(connect)) + } + + /// Set the buffer that should be used for a certain operation. This will + /// also set its connection callback. + pub fn set_buffer_for_operation( + &mut self, + operation: &OperationId, + buffer: AnyBuffer, + ) -> Result<(), DiagramErrorCode> { + match self.construction.buffers.entry(operation.clone()) { + Entry::Occupied(_) => { + return Err(DiagramErrorCode::MultipleBuffersCreated(operation.clone())); } - DiagramOperation::ForkResult(fork_result_op) => { - add_edge(op_id.clone().into(), &fork_result_op.ok, EdgeState::Pending)?; - add_edge( - op_id.clone().into(), - &fork_result_op.err, - EdgeState::Pending, - )?; + Entry::Vacant(vacant) => { + vacant.insert(buffer); } - DiagramOperation::Split(split_op) => { - let next_op_ids: Vec<&NextOperation> = split_op - .sequential - .iter() - .chain(split_op.keyed.values()) - .collect(); - for next_op_id in next_op_ids { - add_edge(op_id.clone().into(), next_op_id, EdgeState::Pending)?; - } - if let Some(remaining) = &split_op.remaining { - add_edge(op_id.clone().into(), &remaining, EdgeState::Pending)?; - } + } + + let input: DynInputSlot = buffer.into(); + self.set_input_for_target(operation, input) + } + + /// Create a buffer map based on the buffer inputs provided. If one or more + /// of the buffers in BufferInputs is not available, get an error including + /// the name of the missing buffer. + pub fn create_buffer_map(&self, inputs: &BufferInputs) -> Result { + let attempt_get_buffer = |name: &String| -> Result { + self.construction + .buffers + .get(name) + .copied() + .ok_or_else(|| format!("cannot find buffer named [{name}]")) + }; + + match inputs { + BufferInputs::Single(op_id) => { + let mut buffer_map = BufferMap::with_capacity(1); + buffer_map.insert(BufferIdentifier::Index(0), attempt_get_buffer(op_id)?); + Ok(buffer_map) } - DiagramOperation::Join(join_op) => { - add_edge(op_id.clone().into(), &join_op.next, EdgeState::Pending)?; + BufferInputs::Dict(mapping) => { + let mut buffer_map = BufferMap::with_capacity(mapping.len()); + for (k, op_id) in mapping { + buffer_map.insert( + BufferIdentifier::Name(k.clone().into()), + attempt_get_buffer(op_id)?, + ); + } + Ok(buffer_map) } - DiagramOperation::Transform(transform_op) => { - add_edge(op_id.clone().into(), &transform_op.next, EdgeState::Pending)?; + BufferInputs::Array(arr) => { + let mut buffer_map = BufferMap::with_capacity(arr.len()); + for (i, op_id) in arr.into_iter().enumerate() { + buffer_map.insert(BufferIdentifier::Index(i), attempt_get_buffer(op_id)?); + } + Ok(buffer_map) } - DiagramOperation::Dispose => {} } } - let mut unconnected_vertices: Vec<&Vertex> = vertices.values().collect(); - while unconnected_vertices.len() > 0 { - let ws = unconnected_vertices.clone(); - let ws_length = ws.len(); - unconnected_vertices.clear(); - - for v in ws { - let in_edges: Vec<&Edge> = v.in_edges.iter().map(|idx| &edges[idx]).collect(); - if in_edges - .iter() - .any(|e| matches!(e.state, EdgeState::Pending)) - { - // not all inputs are ready - debug!( - "defer connecting [{}] until all incoming edges are ready", - v.op_id - ); - unconnected_vertices.push(v); - continue; + /// Get the type information for the request message that goes into a node. + /// + /// # Arguments + /// + /// * `target` - Optionally indicate a specific node in the diagram to treat + /// as the target node, even if it is not the actual target. Using [`Some`] + /// for this will override whatever is used for `next`. + /// * `next` - Indicate the next operation, i.e. the true target. + pub fn get_node_request_type( + &self, + target: Option<&OperationId>, + next: &NextOperation, + ) -> Result { + let target_node = if let Some(target) = target { + self.diagram.get_op(target)? + } else { + match next { + NextOperation::Target(op_id) => self.diagram.get_op(op_id)?, + NextOperation::Builtin { builtin } => match builtin { + BuiltinTarget::Terminate => return Ok(TypeInfo::of::()), + BuiltinTarget::Dispose => return Err(DiagramErrorCode::UnknownTarget), + BuiltinTarget::Cancel => return Ok(TypeInfo::of::()), + }, } + }; + let node_op = match target_node { + DiagramOperation::Node(op) => op, + _ => return Err(DiagramErrorCode::UnknownTarget), + }; + let target_type = self + .registry + .get_node_registration(&node_op.builder)? + .request; + Ok(target_type) + } - connect_vertex(builder, registry, &mut edges, &inputs, v)?; - } + pub fn get_implicit_error_target(&self) -> NextOperation { + self.diagram + .on_implicit_error + .clone() + .unwrap_or(NextOperation::Builtin { + builtin: BuiltinTarget::Cancel, + }) + } +} + +/// Indicate whether the operation has finished building. +#[derive(Debug, Clone)] +pub enum BuildStatus { + /// The operation has finished building. + Finished, + /// The operation needs to make another attempt at building after more + /// information becomes available. + Defer { + /// Progress was made during this run. This can mean the operation has + /// added some information into the diagram but might need to provide + /// more later. If no operations make any progress within an entire + /// iteration of the workflow build then we assume it is impossible to + /// build the diagram. + progress: bool, + reason: Cow<'static, str>, + }, +} - // can't connect anything and there are still remaining vertices - if unconnected_vertices.len() > 0 && ws_length == unconnected_vertices.len() { - warn!( - "the following operations are not connected {:?}", - unconnected_vertices - .iter() - .map(|v| v.op_id) - .collect::>() - ); - return Err(DiagramError::BadInterconnectChain); +impl BuildStatus { + /// Indicate that the build of the operation needs to be deferred. + pub fn defer(reason: impl Into>) -> Self { + Self::Defer { + progress: false, + reason: reason.into(), } } - // connect terminate - for edge_id in terminate_edges { - let edge = edges.remove(&edge_id).ok_or(unknown_diagram_error!())?; - match edge.state { - EdgeState::Ready { output, origin: _ } => { - let serialized_output = registry.messages.serialize(builder, output)?; - builder.connect(serialized_output, scope.terminate); + /// Indicate that the operation made progress, even if it's deferred. + #[allow(unused)] + pub fn with_progress(mut self) -> Self { + match &mut self { + Self::Defer { progress, .. } => { + *progress = true; + } + Self::Finished => { + // Do nothing } - EdgeState::Pending => return Err(DiagramError::BadInterconnectChain), } + + self } - Ok(()) + /// Did this operation make progress on this round? + pub fn made_progress(&self) -> bool { + match self { + Self::Defer { progress, .. } => *progress, + Self::Finished => true, + } + } + + /// Change this build status into its reason for deferral, if it is a + /// deferral. + pub fn into_deferral_reason(self) -> Option> { + match self { + Self::Defer { reason, .. } => Some(reason), + Self::Finished => None, + } + } + + /// Check if the build finished. + pub fn is_finished(&self) -> bool { + matches!(self, BuildStatus::Finished) + } } -fn connect_vertex<'a>( - builder: &mut Builder, - registry: &DiagramElementRegistry, - edges: &mut HashMap>, - inputs: &HashMap<&OperationId, DynInputSlot>, - target: &'a Vertex, -) -> Result<(), DiagramError> { - debug!("connecting [{}]", target.op_id); - match target.op { - // join needs all incoming edges to be connected at once so it is done at the vertex level - // instead of per edge level. - DiagramOperation::Join(join_op) => { - if target.in_edges.is_empty() { - return Err(DiagramError::EmptyJoin); - } - let mut outputs: HashMap = target - .in_edges - .iter() - .map(|e| { - let edge = edges.remove(e).ok_or(unknown_diagram_error!())?; - match edge.state { - EdgeState::Ready { output, origin: _ } => Ok((edge.source, output)), - // "expected all incoming edges to be ready" - _ => Err(unknown_diagram_error!()), - } - }) - .collect::, _>>()?; - - let mut ordered_outputs: Vec = Vec::with_capacity(target.in_edges.len()); - for source_id in join_op.inputs.iter() { - let o = outputs - .remove(source_id) - .ok_or(DiagramError::OperationNotFound(source_id.to_string()))?; - ordered_outputs.push(o); - } +/// This trait is used to instantiate operations in the workflow. This trait +/// will be called on each operation in the diagram until it finishes building. +/// Each operation should use this to provide a [`ConnectOutput`] handle for +/// itself (if relevant) and deposit [`DynOutput`]s into [`DiagramContext`]. +/// +/// After all operations are fully built, [`ConnectIntoTarget`] will be used to +/// connect outputs into their target operations. +pub trait BuildDiagramOperation { + fn build_diagram_operation( + &self, + id: &OperationId, + builder: &mut Builder, + ctx: &mut DiagramContext, + ) -> Result; +} - let joined_output = if join_op.no_serialize.unwrap_or(false) { - registry.messages.join(builder, ordered_outputs)? - } else { - serialize_and_join(builder, ®istry.messages, ordered_outputs)?.into() - }; +/// This trait is used to connect outputs to their target operations. This trait +/// will be called for each output produced by [`BuildDiagramOperation`]. +/// +/// You are allowed to generate new outputs during the [`ConnectIntoTarget`] +/// phase by calling [`DiagramContext::add_outputs_into_target`]. +/// +/// However you cannot add new [`ConnectIntoTarget`] instances for operations. +/// Any use of [`DiagramContext::set_input_for_target`], +/// [`DiagramContext::set_connect_into_target`], or +/// [`DiagramContext::set_connect_into_target_callback`] will be discarded. +pub trait ConnectIntoTarget { + fn connect_into_target( + &mut self, + output: DynOutput, + builder: &mut Builder, + ctx: &mut DiagramContext, + ) -> Result<(), DiagramErrorCode>; +} - let out_edge = edges - .get_mut(&target.out_edges[0]) - .ok_or(unknown_diagram_error!())?; - out_edge.state = EdgeState::Ready { - output: joined_output, - origin: None, - }; - Ok(()) - } - // for other operations, each edge is independent, so we can connect at the edge level. - _ => { - for edge_id in target.in_edges.iter() { - connect_edge(builder, registry, edges, inputs, *edge_id, target)?; - } - Ok(()) - } +pub struct ConnectionCallback(pub F) +where + F: FnMut(DynOutput, &mut Builder, &mut DiagramContext) -> Result<(), DiagramErrorCode>; + +impl ConnectIntoTarget for ConnectionCallback +where + F: FnMut(DynOutput, &mut Builder, &mut DiagramContext) -> Result<(), DiagramErrorCode>, +{ + fn connect_into_target( + &mut self, + output: DynOutput, + builder: &mut Builder, + ctx: &mut DiagramContext, + ) -> Result<(), DiagramErrorCode> { + (self.0)(output, builder, ctx) } } -fn connect_edge<'a>( +pub(super) fn create_workflow( + scope: Scope, builder: &mut Builder, registry: &DiagramElementRegistry, - edges: &mut HashMap>, - inputs: &HashMap<&OperationId, DynInputSlot>, - edge_id: usize, - target: &Vertex, -) -> Result<(), DiagramError> { - let edge = edges.remove(&edge_id).ok_or(unknown_diagram_error!())?; - debug!( - "connect edge {:?}, source: {:?}, target: {:?}", - edge_id, edge.source, edge.target - ); - let (output, origin) = match edge.state { - EdgeState::Ready { - output, - origin: origin_node, - } => { - if let Some(origin_node) = origin_node { - (output, Some(origin_node)) - } else { - (output, None) + diagram: &Diagram, +) -> Result<(), DiagramError> +where + Request: 'static + Send + Sync, + Response: 'static + Send + Sync, + Streams: StreamPack, +{ + let mut construction = DiagramConstruction::default(); + + initialize_builtin_operations( + scope, + builder, + &mut DiagramContext { + construction: &mut construction, + diagram, + registry, + }, + )?; + + let mut unfinished_operations: Vec<&OperationId> = diagram.ops.keys().collect(); + let mut deferred_operations: Vec<(&OperationId, BuildStatus)> = Vec::new(); + + let mut iterations = 0; + const MAX_ITERATIONS: usize = 10_000; + + // Iteratively build all the operations in the diagram + while !unfinished_operations.is_empty() { + let mut made_progress = false; + for op in unfinished_operations.drain(..) { + let mut ctx = DiagramContext { + construction: &mut construction, + diagram, + registry, + }; + + // Attempt to build this operation + let status = diagram + .ops + .get(op) + .ok_or_else(|| { + DiagramErrorCode::UnknownOperation(NextOperation::Target(op.clone())) + })? + .build_diagram_operation(op, builder, &mut ctx) + .map_err(|code| DiagramError::in_operation(op.clone(), code))?; + + made_progress |= status.made_progress(); + if !status.is_finished() { + // The operation did not finish, so pass it into the deferred + // operations list. + deferred_operations.push((op, status)); } } - EdgeState::Pending => panic!("can only connect ready edges"), - }; - - match target.op { - DiagramOperation::Node(_) => { - let input = inputs[target.op_id]; - let deserialized_output = - registry - .messages - .deserialize(&input.type_id, builder, output)?; - dyn_connect(builder, deserialized_output, input)?; - } - DiagramOperation::ForkClone(fork_clone_op) => { - let amount = fork_clone_op.next.len(); - let outputs = if output.type_id == TypeId::of::() { - >::dyn_fork_clone( - builder, output, amount, - ) - } else { - registry.messages.fork_clone(builder, output, amount) - }?; - for (o, e) in outputs.into_iter().zip(target.out_edges.iter()) { - let out_edge = edges.get_mut(e).ok_or(unknown_diagram_error!())?; - out_edge.state = EdgeState::Ready { output: o, origin }; + + if made_progress { + // Try another iteration if needed since we made progress last time + unfinished_operations = deferred_operations.drain(..).map(|(op, _)| op).collect(); + } else { + // No progress can be made any longer so return an error + return Err(DiagramErrorCode::BuildHalted { + reasons: deferred_operations + .drain(..) + .filter_map(|(op, status)| { + status + .into_deferral_reason() + .map(|reason| (op.clone(), reason)) + }) + .collect(), } + .into()); } - DiagramOperation::Unzip(unzip_op) => { - let outputs = if output.type_id == TypeId::of::() { - Err(DiagramError::NotUnzippable) - } else { - registry.messages.unzip(builder, output) - }?; - if outputs.len() < unzip_op.next.len() { - return Err(DiagramError::NotUnzippable); - } - for (o, e) in outputs.into_iter().zip(target.out_edges.iter()) { - let out_edge = edges.get_mut(e).ok_or(unknown_diagram_error!())?; - out_edge.state = EdgeState::Ready { output: o, origin }; - } + + iterations += 1; + if iterations > MAX_ITERATIONS { + return Err(DiagramErrorCode::ExcessiveIterations.into()); } - DiagramOperation::ForkResult(_) => { - let (ok, err) = if output.type_id == TypeId::of::() { - Err(DiagramError::CannotForkResult) - } else { - registry.messages.fork_result(builder, output) - }?; - { - let out_edge = edges - .get_mut(&target.out_edges[0]) - .ok_or(unknown_diagram_error!())?; - out_edge.state = EdgeState::Ready { output: ok, origin }; - } - { - let out_edge = edges - .get_mut(&target.out_edges[1]) - .ok_or(unknown_diagram_error!())?; - out_edge.state = EdgeState::Ready { - output: err, - origin, - }; + } + + let mut new_construction = DiagramConstruction::default(); + new_construction.buffers = construction.buffers.clone(); + + iterations = 0; + while !construction.is_finished() { + let mut ctx = DiagramContext { + construction: &mut new_construction, + diagram, + registry, + }; + + // Attempt to connect to all regular operations + for (op, outputs) in construction.outputs_to_operation_target.drain() { + let op = NextOperation::Target(op); + let connect = construction + .connect_into_target + .get_mut(&op) + .ok_or_else(|| DiagramErrorCode::UnknownOperation(op.clone()))?; + + for output in outputs { + connect.connect_into_target(output, builder, &mut ctx)?; } } - DiagramOperation::Split(split_op) => { - let mut outputs = if output.type_id == TypeId::of::() { - let chain = output.into_output::()?.chain(builder); - split_chain(chain, split_op) - } else { - registry.messages.split(builder, output, split_op) - }?; - - // Because of how we build `out_edges`, if the split op uses the `remaining` slot, - // then the last item will always be the remaining edge. - let remaining_edge_id = if split_op.remaining.is_some() { - Some(target.out_edges.last().ok_or(unknown_diagram_error!())?) - } else { - None - }; - let other_edge_ids = if split_op.remaining.is_some() { - &target.out_edges[..(target.out_edges.len() - 1)] - } else { - &target.out_edges[..] - }; - for e in other_edge_ids { - let out_edge = edges.get_mut(e).ok_or(unknown_diagram_error!())?; - let output = outputs - .outputs - .remove(out_edge.target) - .ok_or(unknown_diagram_error!())?; - out_edge.state = EdgeState::Ready { output, origin }; - } - if let Some(remaining_edge_id) = remaining_edge_id { - let out_edge = edges - .get_mut(remaining_edge_id) - .ok_or(unknown_diagram_error!())?; - out_edge.state = EdgeState::Ready { - output: outputs.remaining, - origin, - }; + // Attempt to connect to all builtin operations + for (builtin, outputs) in construction.outputs_to_builtin_target.drain() { + let op = NextOperation::Builtin { builtin }; + let connect = construction + .connect_into_target + .get_mut(&op) + .ok_or_else(|| DiagramErrorCode::UnknownOperation(op.clone()))?; + + for output in outputs { + connect.connect_into_target(output, builder, &mut ctx)?; } } - DiagramOperation::Join(_) => { - // join is connected at the vertex level - } - DiagramOperation::Transform(transform_op) => { - let transformed_output = transform_output(builder, registry, output, transform_op)?; - let out_edge = edges - .get_mut(&target.out_edges[0]) - .ok_or(unknown_diagram_error!())?; - out_edge.state = EdgeState::Ready { - output: transformed_output.into(), - origin, - } + + construction + .outputs_to_builtin_target + .extend(new_construction.outputs_to_builtin_target.drain()); + + construction + .outputs_to_operation_target + .extend(new_construction.outputs_to_operation_target.drain()); + + iterations += 1; + if iterations > MAX_ITERATIONS { + return Err(DiagramErrorCode::ExcessiveIterations.into()); } - DiagramOperation::Dispose => {} } + Ok(()) } -/// Connect a [`DynOutput`] to a [`DynInputSlot`]. Use this only when both the output and input -/// are type erased. To connect an [`Output`] to a [`DynInputSlot`] or vice versa, prefer converting -/// the type erased output/input slot to the typed equivalent. -/// -/// ```text -/// builder.connect(output.into_output::()?, dyn_input)?; -/// ``` -fn dyn_connect( +fn initialize_builtin_operations( + scope: Scope, builder: &mut Builder, - output: DynOutput, - input: DynInputSlot, -) -> Result<(), DiagramError> { - if output.type_id != input.type_id { - return Err(DiagramError::TypeMismatch); - } - struct TypeErased {} - let typed_output = Output::::new(output.scope(), output.id()); - let typed_input = InputSlot::::new(input.scope(), input.id()); - builder.connect(typed_output, typed_input); + ctx: &mut DiagramContext, +) -> Result<(), DiagramError> +where + Request: 'static + Send + Sync, + Response: 'static + Send + Sync, + Streams: StreamPack, +{ + // Put the input message into the diagram + ctx.add_output_into_target(ctx.diagram.start.clone(), scope.input.into()); + + // Add the terminate operation + ctx.construction.connect_into_target.insert( + NextOperation::Builtin { + builtin: BuiltinTarget::Terminate, + }, + standard_input_connection(scope.terminate.into(), &ctx.registry)?, + ); + + // Add the dispose operation + ctx.construction.connect_into_target.insert( + NextOperation::Builtin { + builtin: BuiltinTarget::Dispose, + }, + Box::new(ConnectionCallback(move |_, _, _| { + // Do nothing since the output is being disposed + Ok(()) + })), + ); + + // Add the cancel operation + ctx.construction.connect_into_target.insert( + NextOperation::Builtin { + builtin: BuiltinTarget::Cancel, + }, + Box::new(ConnectToCancel::new(builder)?), + ); + Ok(()) } + +/// This returns an opaque [`ConnectIntoTarget`] implementation that provides +/// the standard behavior of an input slot that other operations are connecting +/// into. +pub fn standard_input_connection( + input_slot: DynInputSlot, + registry: &DiagramElementRegistry, +) -> Result, DiagramErrorCode> { + if input_slot.message_info() == &TypeInfo::of::() { + return Ok(Box::new(ImplicitSerialization::new(input_slot)?)); + } + + if let Some(deserialization) = ImplicitDeserialization::try_new(input_slot, ®istry.messages)? + { + // The target type is deserializable, so let's apply implicit deserialization + // to it. + return Ok(Box::new(deserialization)); + } + + Ok(Box::new(BasicConnect { input_slot })) +} + +impl ConnectIntoTarget for ImplicitSerialization { + fn connect_into_target( + &mut self, + output: DynOutput, + builder: &mut Builder, + ctx: &mut DiagramContext, + ) -> Result<(), DiagramErrorCode> { + self.implicit_serialize(output, builder, ctx) + } +} + +impl ConnectIntoTarget for ImplicitDeserialization { + fn connect_into_target( + &mut self, + output: DynOutput, + builder: &mut Builder, + ctx: &mut DiagramContext, + ) -> Result<(), DiagramErrorCode> { + self.implicit_deserialize(output, builder, ctx) + } +} + +struct BasicConnect { + input_slot: DynInputSlot, +} + +impl ConnectIntoTarget for BasicConnect { + fn connect_into_target( + &mut self, + output: DynOutput, + builder: &mut Builder, + _: &mut DiagramContext, + ) -> Result<(), DiagramErrorCode> { + output.connect_to(&self.input_slot, builder) + } +} + +struct ConnectToCancel { + quiet_cancel: DynInputSlot, + implicit_serialization: ImplicitSerialization, + implicit_stringify: ImplicitStringify, + triggers: HashMap, +} + +impl ConnectToCancel { + fn new(builder: &mut Builder) -> Result { + Ok(Self { + quiet_cancel: builder.create_quiet_cancel().into(), + implicit_serialization: ImplicitSerialization::new( + builder.create_cancel::().into(), + )?, + implicit_stringify: ImplicitStringify::new(builder.create_cancel::().into())?, + triggers: Default::default(), + }) + } +} + +impl ConnectIntoTarget for ConnectToCancel { + fn connect_into_target( + &mut self, + output: DynOutput, + builder: &mut Builder, + ctx: &mut DiagramContext, + ) -> Result<(), DiagramErrorCode> { + let Err(output) = self + .implicit_stringify + .try_implicit_stringify(output, builder, ctx)? + else { + // We successfully converted the output into a string, so we are done. + return Ok(()); + }; + + // Try to implicitly serialize the incoming message if the message + // type supports it. That way we can connect it to the regular + // cancel operation. + let Err(output) = self + .implicit_serialization + .try_implicit_serialize(output, builder, ctx)? + else { + // We successfully converted the output into a json, so we are done. + return Ok(()); + }; + + // In this case, the message type cannot be stringified or serialized so + // we'll change it into a trigger and then connect it to the quiet + // cancel instead. + let input_slot = match self.triggers.entry(*output.message_info()) { + Entry::Occupied(occupied) => occupied.get().clone(), + Entry::Vacant(vacant) => { + let trigger = ctx + .registry + .messages + .trigger(output.message_info(), builder)?; + trigger.output.connect_to(&self.quiet_cancel, builder)?; + vacant.insert(trigger.input).clone() + } + }; + + output.connect_to(&input_slot, builder)?; + + Ok(()) + } +} diff --git a/src/disposal.rs b/src/disposal.rs index 1bc453a5..c27e7ebd 100644 --- a/src/disposal.rs +++ b/src/disposal.rs @@ -30,7 +30,7 @@ use thiserror::Error as ThisError; use crate::{ operation::ScopeStorage, Cancel, Cancellation, DisposalFailure, ImpulseMarker, OperationResult, - OperationRoster, OrBroken, UnhandledErrors, + OperationRoster, OrBroken, UnhandledErrors, UnusedTarget, }; #[derive(ThisError, Debug, Clone)] @@ -452,9 +452,13 @@ impl<'w> ManageDisposal for EntityWorldMut<'w> { // TODO(@mxgrey): Consider whether there is a more sound way to // decide whether a disposal should be converted into a // cancellation for impulses. - } else { - // If the emitting node does not have a scope as not part of an - // impulse chain, then something is broken. + } else if !self.contains::() { + // If the emitting node does not have a scope, is not part of + // an impulse chain, and is not an unused target, then something + // is broken. + // + // We can safely ignore disposals for unused targets because + // unused targets cannot affect the reachability of a workflow. let broken_node = self.id(); self.world_scope(|world| { world diff --git a/src/node.rs b/src/node.rs index b3c054f5..764aec23 100644 --- a/src/node.rs +++ b/src/node.rs @@ -201,3 +201,21 @@ impl ForkCloneOutput { } } } + +/// The output of a fork result operation. Each output can be connected to one +/// input slot. +pub struct ForkResultOutput { + /// This output will be sent if an [`Ok`] is sent into the fork. + pub ok: Output, + /// This output will be sent if an [`Err`] is sent into the fork. + pub err: Output, +} + +/// The output of a fork option operation. Each output can be connected to one +/// input slot. +pub struct ForkOptionOutput { + /// This output will be sent if a [`Some`] is sent into the fork. + pub some: Output, + /// This output will be sent if a [`None`] is sent into the fork. + pub none: Output<()>, +} diff --git a/src/operation.rs b/src/operation.rs index 373d30ec..58953238 100644 --- a/src/operation.rs +++ b/src/operation.rs @@ -20,6 +20,7 @@ use crate::{ UnhandledErrors, }; +use bevy_derive::Deref; use bevy_ecs::{ prelude::{Component, Entity, World}, system::Command, @@ -71,6 +72,9 @@ pub use operate_buffer_access::*; mod operate_callback; pub(crate) use operate_callback::*; +mod operate_cancel; +pub(crate) use operate_cancel::*; + mod operate_gate; pub(crate) use operate_gate::*; @@ -610,6 +614,9 @@ impl Command for AddOperation { OperationExecuteStorage(perform_operation::), OperationCleanupStorage(Op::cleanup), OperationReachabilityStorage(Op::is_reachable), + OperationType { + name: std::any::type_name::(), + }, )); if let Some(scope) = self.scope { source_mut @@ -639,6 +646,11 @@ pub(crate) struct OperationExecuteStorage(pub(crate) fn(OperationRequest)); #[derive(Component)] pub(crate) struct OperationReachabilityStorage(fn(OperationReachability) -> ReachabilityResult); +#[derive(Component, Deref, Debug)] +pub(crate) struct OperationType { + name: &'static str, +} + pub fn execute_operation(request: OperationRequest) { let Some(operator) = request.world.get::(request.source) else { if request.world.get::(request.source).is_none() { diff --git a/src/operation/operate_cancel.rs b/src/operation/operate_cancel.rs new file mode 100644 index 00000000..4789c8d3 --- /dev/null +++ b/src/operation/operate_cancel.rs @@ -0,0 +1,121 @@ +/* + * Copyright (C) 2025 Open Source Robotics Foundation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * +*/ + +use crate::{ + Cancellation, Input, InputBundle, ManageCancellation, ManageInput, Operation, OperationCleanup, + OperationReachability, OperationRequest, OperationResult, OperationSetup, OrBroken, + ReachabilityResult, SingleInputStorage, +}; + +/// Create an operation that will cancel a scope. The incoming message will be +/// included in the cancellation data as a [`String`]. The incoming message type +/// must support the [`ToString`] trait. +/// +/// To trigger a cancellation for types that do not support [`ToString`], convert +/// the message to a trigger and send it to [`OperateQuietCancel`]. +pub struct OperateCancel { + _ignore: std::marker::PhantomData, +} + +impl OperateCancel +where + T: 'static + Send + Sync + ToString, +{ + pub fn new() -> Self { + Self { + _ignore: Default::default(), + } + } +} + +impl Operation for OperateCancel +where + T: 'static + Send + Sync + ToString, +{ + fn setup(self, OperationSetup { source, world }: OperationSetup) -> OperationResult { + world.entity_mut(source).insert(InputBundle::::new()); + Ok(()) + } + + fn execute( + OperationRequest { + source, + world, + roster, + }: OperationRequest, + ) -> OperationResult { + let mut source_mut = world.get_entity_mut(source).or_broken()?; + let Input { session, data } = source_mut.take_input::().or_broken()?; + + let cancellation = Cancellation::triggered(source, Some(data.to_string())); + source_mut.emit_cancel(session, cancellation, roster); + Ok(()) + } + + fn cleanup(mut clean: OperationCleanup) -> OperationResult { + clean.cleanup_inputs::()?; + clean.notify_cleaned() + } + + fn is_reachable(mut reachability: OperationReachability) -> ReachabilityResult { + if reachability.has_input::()? { + return Ok(true); + } + + SingleInputStorage::is_reachable(&mut reachability) + } +} + +/// Create an operation that will cancel a scope. This operation only accepts +/// trigger `()` inputs. There will be no information included in the +/// cancellation message except that the cancellation was triggered at this node. +pub struct OperateQuietCancel; + +impl Operation for OperateQuietCancel { + fn setup(self, OperationSetup { source, world }: OperationSetup) -> OperationResult { + world.entity_mut(source).insert(InputBundle::<()>::new()); + Ok(()) + } + + fn execute( + OperationRequest { + source, + world, + roster, + }: OperationRequest, + ) -> OperationResult { + let mut source_mut = world.get_entity_mut(source).or_broken()?; + let Input { session, .. } = source_mut.take_input::<()>().or_broken()?; + + let cancellation = Cancellation::triggered(source, None); + source_mut.emit_cancel(session, cancellation, roster); + Ok(()) + } + + fn cleanup(mut clean: OperationCleanup) -> OperationResult { + clean.cleanup_inputs::<()>()?; + clean.notify_cleaned() + } + + fn is_reachable(mut reachability: OperationReachability) -> ReachabilityResult { + if reachability.has_input::<()>()? { + return Ok(true); + } + + SingleInputStorage::is_reachable(&mut reachability) + } +} From 030510f9145047ff575dabaf9ba53ddbc83db9d2 Mon Sep 17 00:00:00 2001 From: Teo Koon Peng Date: Tue, 22 Apr 2025 20:45:18 +0800 Subject: [PATCH 15/20] Section builders and templates (#58) Signed-off-by: Michael X. Grey Signed-off-by: Teo Koon Peng Co-authored-by: Michael X. Grey --- Cargo.toml | 2 +- diagram.schema.json | 189 ++++- macros/src/lib.rs | 15 + macros/src/section.rs | 294 +++++++ src/diagram.rs | 569 +++++++++++-- src/diagram/buffer_schema.rs | 69 +- src/diagram/fork_clone_schema.rs | 39 +- src/diagram/fork_result_schema.rs | 28 +- src/diagram/join_schema.rs | 73 +- src/diagram/node.rs | 43 + src/diagram/node_schema.rs | 6 +- src/diagram/registration.rs | 165 +++- src/diagram/section_schema.rs | 1279 +++++++++++++++++++++++++++++ src/diagram/serialization.rs | 21 +- src/diagram/split_schema.rs | 32 +- src/diagram/testing.rs | 36 +- src/diagram/transform_schema.rs | 40 +- src/diagram/unzip_schema.rs | 26 +- src/diagram/workflow_builder.rs | 1052 ++++++++++++++++++------ 19 files changed, 3422 insertions(+), 556 deletions(-) create mode 100644 macros/src/section.rs create mode 100644 src/diagram/node.rs create mode 100644 src/diagram/section_schema.rs diff --git a/Cargo.toml b/Cargo.toml index 347cbdef..8a05bd05 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -47,7 +47,7 @@ bevy_core = "0.12" bevy_time = "0.12" schemars = { version = "0.8.21", optional = true } -serde = { version = "1.0.210", features = ["derive"], optional = true } +serde = { version = "1.0.210", features = ["derive", "rc"], optional = true } serde_json = { version = "1.0.128", optional = true } cel-interpreter = { version = "0.9.0", features = ["json"], optional = true } tracing = "0.1.41" diff --git a/diagram.schema.json b/diagram.schema.json index ae9b7796..40e04d27 100644 --- a/diagram.schema.json +++ b/diagram.schema.json @@ -35,27 +35,34 @@ } ] }, + "templates": { + "default": {}, + "type": "object", + "additionalProperties": { + "$ref": "#/definitions/SectionTemplate" + } + }, "version": { "description": "Version of the diagram, should always be `0.1.0`.", "type": "string" } }, "definitions": { - "BufferInputs": { + "BufferSelection": { "anyOf": [ { - "type": "string" + "$ref": "#/definitions/NextOperation" }, { "type": "object", "additionalProperties": { - "type": "string" + "$ref": "#/definitions/NextOperation" } }, { "type": "array", "items": { - "type": "string" + "$ref": "#/definitions/NextOperation" } } ] @@ -125,6 +132,57 @@ } } }, + { + "description": "Connect the request to a registered section.\n\n``` # bevy_impulse::Diagram::from_json_str(r#\" { \"version\": \"0.1.0\", \"start\": \"section_op\", \"ops\": { \"section_op\": { \"type\": \"section\", \"builder\": \"my_section_builder\", \"connect\": { \"my_section_output\": { \"builtin\": \"terminate\" } } } } } # \"#)?; # Ok::<_, serde_json::Error>(()) ```\n\nCustom sections can also be created via templates ``` # bevy_impulse::Diagram::from_json_str(r#\" { \"version\": \"0.1.0\", \"templates\": { \"my_template\": { \"inputs\": [\"section_input\"], \"outputs\": [\"section_output\"], \"buffers\": [], \"ops\": { \"section_input\": { \"type\": \"node\", \"builder\": \"my_node\", \"next\": \"section_output\" } } } }, \"start\": \"section_op\", \"ops\": { \"section_op\": { \"type\": \"section\", \"template\": \"my_template\", \"connect\": { \"section_output\": { \"builtin\": \"terminate\" } } } } } # \"#)?; # Ok::<_, serde_json::Error>(()) ```", + "type": "object", + "oneOf": [ + { + "type": "object", + "required": [ + "builder" + ], + "properties": { + "builder": { + "type": "string" + } + }, + "additionalProperties": false + }, + { + "type": "object", + "required": [ + "template" + ], + "properties": { + "template": { + "type": "string" + } + }, + "additionalProperties": false + } + ], + "required": [ + "type" + ], + "properties": { + "config": { + "default": null + }, + "connect": { + "default": {}, + "type": "object", + "additionalProperties": { + "$ref": "#/definitions/NextOperation" + } + }, + "type": { + "type": "string", + "enum": [ + "section" + ] + } + } + }, { "description": "If the request is cloneable, clone it into multiple responses that can each be sent to a different operation. The `next` property is an array.\n\nThis creates multiple simultaneous branches of execution within the workflow. Usually when you have multiple branches you will either * race - connect all branches to `terminate` and the first branch to finish \"wins\" the race and gets to the be output * join - connect each branch into a buffer and then use the `join` operation to reunite them * collect - TODO(@mxgrey): [add the collect operation](https://github.com/open-rmf/bevy_impulse/issues/59)\n\n# Examples ``` # bevy_impulse::Diagram::from_json_str(r#\" { \"version\": \"0.1.0\", \"start\": \"begin_race\", \"ops\": { \"begin_race\": { \"type\": \"fork_clone\", \"next\": [ \"ferrari\", \"mustang\" ] }, \"ferrari\": { \"type\": \"node\", \"builder\": \"drive\", \"config\": \"ferrari\", \"next\": { \"builtin\": \"terminate\" } }, \"mustang\": { \"type\": \"node\", \"builder\": \"drive\", \"config\": \"mustang\", \"next\": { \"builtin\": \"terminate\" } } } } # \"#)?; # Ok::<_, serde_json::Error>(())", "type": "object", @@ -148,7 +206,7 @@ } }, { - "description": "If the input message is a tuple of (T1, T2, T3, ...), unzip it into multiple output messages of T1, T2, T3, ...\n\nEach output message may have a different type and can be sent to a different operation. This creates multiple simultaneous branches of execution within the workflow. See [`DiagramOperation::ForkClone`] for more information on parallel branches.\n\n# Examples ``` # bevy_impulse::Diagram::from_json_str(r#\" { \"version\": \"0.1.0\", \"start\": \"name_phone_address\", \"ops\": { \"name_phone_address\": { \"type\": \"unzip\", \"next\": [ \"process_name\", \"process_phone_number\", \"process_address\" ] }, \"process_name\": { \"type\": \"node\", \"builder\": \"process_name\", \"next\": \"name_processed\" }, \"process_phone_number\": { \"type\": \"node\", \"builder\": \"process_phone_number\", \"next\": \"phone_number_processed\" }, \"process_address\": { \"type\": \"node\", \"builder\": \"process_address\", \"next\": \"address_processed\" }, \"name_processed\": { \"type\": \"buffer\" }, \"phone_number_processed\": { \"type\": \"buffer\" }, \"address_processed\": { \"type\": \"buffer\" }, \"finished\": { \"type\": \"join\", \"buffers\": [ \"name_processed\", \"phone_number_processed\", \"address_processed\" ], \"next\": { \"builtin\": \"terminate\" } } } } # \"#)?; # Ok::<_, serde_json::Error>(())", + "description": "If the input message is a tuple of (T1, T2, T3, ...), unzip it into multiple output messages of T1, T2, T3, ...\n\nEach output message may have a different type and can be sent to a different operation. This creates multiple simultaneous branches of execution within the workflow. See [`DiagramOperation::ForkClone`] for more information on parallel branches.\n\n# Examples ``` # bevy_impulse::Diagram::from_json_str(r#\" { \"version\": \"0.1.0\", \"start\": \"name_phone_address\", \"ops\": { \"name_phone_address\": { \"type\": \"unzip\", \"next\": [ \"process_name\", \"process_phone_number\", \"process_address\" ] }, \"process_name\": { \"type\": \"node\", \"builder\": \"process_name\", \"next\": \"name_processed\" }, \"process_phone_number\": { \"type\": \"node\", \"builder\": \"process_phone_number\", \"next\": \"phone_number_processed\" }, \"process_address\": { \"type\": \"node\", \"builder\": \"process_address\", \"next\": \"address_processed\" }, \"name_processed\": { \"type\": \"buffer\" }, \"phone_number_processed\": { \"type\": \"buffer\" }, \"address_processed\": { \"type\": \"buffer\" }, \"finished\": { \"type\": \"join\", \"buffers\": [ \"name_processed\", \"phone_number_processed\", \"address_processed\" ], \"next\": { \"builtin\": \"terminate\" } } } } # \"#)?; # Ok::<_, serde_json::Error>(()) ```", "type": "object", "required": [ "next", @@ -170,7 +228,7 @@ } }, { - "description": "If the request is a [`Result`], send the output message down an `ok` branch or down an `err` branch depending on whether the result has an [`Ok`] or [`Err`] value. The `ok` branch will receive a `T` while the `err` branch will receive an `E`.\n\nOnly one branch will be activated by each input message that enters the operation.\n\n# Examples ``` # bevy_impulse::Diagram::from_json_str(r#\" { \"version\": \"0.1.0\", \"start\": \"fork_result\", \"ops\": { \"fork_result\": { \"type\": \"fork_result\", \"ok\": { \"builtin\": \"terminate\" }, \"err\": { \"builtin\": \"dispose\" } } } } # \"#)?; # Ok::<_, serde_json::Error>(())", + "description": "If the request is a [`Result`], send the output message down an `ok` branch or down an `err` branch depending on whether the result has an [`Ok`] or [`Err`] value. The `ok` branch will receive a `T` while the `err` branch will receive an `E`.\n\nOnly one branch will be activated by each input message that enters the operation.\n\n# Examples ``` # bevy_impulse::Diagram::from_json_str(r#\" { \"version\": \"0.1.0\", \"start\": \"fork_result\", \"ops\": { \"fork_result\": { \"type\": \"fork_result\", \"ok\": { \"builtin\": \"terminate\" }, \"err\": { \"builtin\": \"dispose\" } } } } # \"#)?; # Ok::<_, serde_json::Error>(()) ```", "type": "object", "required": [ "err", @@ -232,7 +290,7 @@ } }, { - "description": "Wait for exactly one item to be available in each buffer listed in `buffers`, then join each of those items into a single output message that gets sent to `next`.\n\nIf the `next` operation is not a `node` type (e.g. `fork_clone`) then you must specify a `target_node` so that the diagram knows what data structure to join the values into.\n\nThe output message type must be registered as joinable at compile time. If you want to join into a dynamic data structure then you should use [`DiagramOperation::SerializedJoin`] instead.\n\n# Examples ``` # bevy_impulse::Diagram::from_json_str(r#\" { \"version\": \"0.1.0\", \"start\": \"fork_measuring\", \"ops\": { \"fork_measuring\": { \"type\": \"fork_clone\", \"next\": [\"localize\", \"imu\"] }, \"localize\": { \"type\": \"node\", \"builder\": \"localize\", \"next\": \"estimated_position\" }, \"imu\": { \"type\": \"node\", \"builder\": \"imu\", \"config\": \"velocity\", \"next\": \"estimated_velocity\" }, \"estimated_position\": { \"type\": \"buffer\" }, \"estimated_velocity\": { \"type\": \"buffer\" }, \"gather_state\": { \"type\": \"join\", \"buffers\": { \"position\": \"estimate_position\", \"velocity\": \"estimate_velocity\" }, \"next\": \"report_state\" }, \"report_state\": { \"type\": \"node\", \"builder\": \"publish_state\", \"next\": { \"builtin\": \"terminate\" } } } } # \"#)?; # Ok::<_, serde_json::Error>(()) ```", + "description": "Wait for exactly one item to be available in each buffer listed in `buffers`, then join each of those items into a single output message that gets sent to `next`.\n\nIf the `next` operation is not a `node` type (e.g. `fork_clone`) then you must specify a `target_node` so that the diagram knows what data structure to join the values into.\n\nThe output message type must be registered as joinable at compile time. If you want to join into a dynamic data structure then you should use [`DiagramOperation::SerializedJoin`] instead.\n\n# Examples ``` # bevy_impulse::Diagram::from_json_str(r#\" { \"version\": \"0.1.0\", \"start\": \"begin_measuring\", \"ops\": { \"begin_measuring\": { \"type\": \"fork_clone\", \"next\": [\"localize\", \"imu\"] }, \"localize\": { \"type\": \"node\", \"builder\": \"localize\", \"next\": \"estimated_position\" }, \"imu\": { \"type\": \"node\", \"builder\": \"imu\", \"config\": \"velocity\", \"next\": \"estimated_velocity\" }, \"estimated_position\": { \"type\": \"buffer\" }, \"estimated_velocity\": { \"type\": \"buffer\" }, \"gather_state\": { \"type\": \"join\", \"buffers\": { \"position\": \"estimate_position\", \"velocity\": \"estimate_velocity\" }, \"next\": \"report_state\" }, \"report_state\": { \"type\": \"node\", \"builder\": \"publish_state\", \"next\": { \"builtin\": \"terminate\" } } } } # \"#)?; # Ok::<_, serde_json::Error>(()) ```", "type": "object", "required": [ "buffers", @@ -244,20 +302,13 @@ "description": "Map of buffer keys and buffers.", "allOf": [ { - "$ref": "#/definitions/BufferInputs" + "$ref": "#/definitions/BufferSelection" } ] }, "next": { "$ref": "#/definitions/NextOperation" }, - "target_node": { - "description": "The id of an operation that this operation is for. The id must be a `node` operation. Optional if `next` is a node operation.", - "type": [ - "string", - "null" - ] - }, "type": { "type": "string", "enum": [ @@ -279,7 +330,7 @@ "description": "Map of buffer keys and buffers.", "allOf": [ { - "$ref": "#/definitions/BufferInputs" + "$ref": "#/definitions/BufferSelection" } ] }, @@ -364,7 +415,7 @@ } }, { - "description": "Zip a message together with access to one or more buffers.\n\nThe receiving node must have an input type of `(Message, Keys)` where `Keys` implements the [`Accessor`][1] trait.\n\n[1]: crate::Accessor\n\n# Examples ``` # bevy_impulse::Diagram::from_json_str(r#\" { \"version\": \"0.1.0\", \"start\": \"fork_clone\", \"ops\": { \"fork_clone\": { \"type\": \"fork_clone\", \"next\": [\"num_output\", \"string_output\"] }, \"num_output\": { \"type\": \"node\", \"builder\": \"num_output\", \"next\": \"buffer_access\" }, \"string_output\": { \"type\": \"node\", \"builder\": \"string_output\", \"next\": \"string_buffer\" }, \"string_buffer\": { \"type\": \"buffer\" }, \"buffer_access\": { \"type\": \"buffer_access\", \"buffers\": [\"string_buffer\"], \"target_node\": \"with_buffer_access\", \"next\": \"with_buffer_access\" }, \"with_buffer_access\": { \"type\": \"node\", \"builder\": \"with_buffer_access\", \"next\": { \"builtin\": \"terminate\" } } } } # \"#)?; # Ok::<_, serde_json::Error>(())", + "description": "Zip a message together with access to one or more buffers.\n\nThe receiving node must have an input type of `(Message, Keys)` where `Keys` implements the [`Accessor`][1] trait.\n\n[1]: crate::Accessor\n\n# Examples ``` # bevy_impulse::Diagram::from_json_str(r#\" { \"version\": \"0.1.0\", \"start\": \"fork_clone\", \"ops\": { \"fork_clone\": { \"type\": \"fork_clone\", \"next\": [\"num_output\", \"string_output\"] }, \"num_output\": { \"type\": \"node\", \"builder\": \"num_output\", \"next\": \"buffer_access\" }, \"string_output\": { \"type\": \"node\", \"builder\": \"string_output\", \"next\": \"string_buffer\" }, \"string_buffer\": { \"type\": \"buffer\" }, \"buffer_access\": { \"type\": \"buffer_access\", \"buffers\": [\"string_buffer\"], \"target_node\": \"with_buffer_access\", \"next\": \"with_buffer_access\" }, \"with_buffer_access\": { \"type\": \"node\", \"builder\": \"with_buffer_access\", \"next\": { \"builtin\": \"terminate\" } } } } # \"#)?; # Ok::<_, serde_json::Error>(()) ```", "type": "object", "required": [ "buffers", @@ -376,20 +427,13 @@ "description": "Map of buffer keys and buffers.", "allOf": [ { - "$ref": "#/definitions/BufferInputs" + "$ref": "#/definitions/BufferSelection" } ] }, "next": { "$ref": "#/definitions/NextOperation" }, - "target_node": { - "description": "The id of an operation that this operation is for. The id must be a `node` operation. Optional if `next` is a node operation.", - "type": [ - "string", - "null" - ] - }, "type": { "type": "string", "enum": [ @@ -411,7 +455,7 @@ "description": "Map of buffer keys and buffers.", "allOf": [ { - "$ref": "#/definitions/BufferInputs" + "$ref": "#/definitions/BufferSelection" } ] }, @@ -420,9 +464,13 @@ }, "target_node": { "description": "The id of an operation that this operation is for. The id must be a `node` operation. Optional if `next` is a node operation.", - "type": [ - "string", - "null" + "anyOf": [ + { + "$ref": "#/definitions/NextOperation" + }, + { + "type": "null" + } ] }, "type": { @@ -435,6 +483,40 @@ } ] }, + "InputRemapping": { + "description": "This defines how sections remap their inner operations (inputs and buffers) to expose them to operations that are siblings to the section.", + "anyOf": [ + { + "description": "Do a simple 1:1 forwarding of the names listed in the array", + "type": "array", + "items": { + "type": "string" + } + }, + { + "description": "Rename an operation inside the section to expose it externally. The key of the map is what siblings of the section can connect to, and the value of the entry is the identifier of the input inside the section that is being exposed.\n\nThis allows a section to expose inputs and buffers that are provided by inner sections.", + "type": "object", + "additionalProperties": { + "$ref": "#/definitions/NextOperation" + } + } + ] + }, + "NamespacedOperation": { + "title": "NamespacedOperation", + "description": "Refer to an operation inside of a namespace, e.g. { \"\": \"\"", + "type": "object", + "allOf": [ + { + "$ref": "NamespacedOperation" + } + ], + "maxProperties": 1, + "minProperties": 1, + "additionalProperties": { + "type": "string" + } + }, "NextOperation": { "anyOf": [ { @@ -450,6 +532,14 @@ "$ref": "#/definitions/BuiltinTarget" } } + }, + { + "description": "Refer to an \"inner\" operation of one of the sibling operations in a diagram. This can be used to target section inputs.", + "allOf": [ + { + "$ref": "#/definitions/NamespacedOperation" + } + ] } ] }, @@ -494,6 +584,47 @@ ] } ] + }, + "SectionTemplate": { + "type": "object", + "required": [ + "ops" + ], + "properties": { + "buffers": { + "description": "These are the buffers that the section is exposing for you to read, write, join, or listen to.", + "default": [], + "allOf": [ + { + "$ref": "#/definitions/InputRemapping" + } + ] + }, + "inputs": { + "description": "These are the inputs that the section is exposing for its sibling operations to send outputs into.", + "default": [], + "allOf": [ + { + "$ref": "#/definitions/InputRemapping" + } + ] + }, + "ops": { + "description": "Operations that define the behavior of the section.", + "type": "object", + "additionalProperties": { + "$ref": "#/definitions/DiagramOperation" + } + }, + "outputs": { + "description": "These are the outputs that the section is exposing so you can connect them into siblings of the section.", + "default": [], + "type": "array", + "items": { + "type": "string" + } + } + } } } } \ No newline at end of file diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 58873049..3eb4f168 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -18,6 +18,9 @@ mod buffer; use buffer::{impl_buffer_key_map, impl_joined_value}; +mod section; +use section::impl_section; + use proc_macro::TokenStream; use quote::quote; use syn::{parse_macro_input, DeriveInput, ItemStruct}; @@ -88,3 +91,15 @@ pub fn derive_buffer_key_map(input: TokenStream) -> TokenStream { .into(), } } + +#[proc_macro_derive(Section, attributes(message))] +pub fn derive_section(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as ItemStruct); + match impl_section(&input) { + Ok(tokens) => tokens.into(), + Err(msg) => quote! { + compile_error!(#msg); + } + .into(), + } +} diff --git a/macros/src/section.rs b/macros/src/section.rs new file mode 100644 index 00000000..f395d3a0 --- /dev/null +++ b/macros/src/section.rs @@ -0,0 +1,294 @@ +use std::iter::zip; + +use proc_macro2::{Span, TokenStream}; +use quote::{quote, quote_spanned}; +use syn::{parse_quote_spanned, spanned::Spanned, Field, Ident, ItemStruct, Member, Type}; + +use crate::Result; + +pub(crate) fn impl_section(input_struct: &ItemStruct) -> Result { + let struct_ident = &input_struct.ident; + let (impl_generics, ty_generics, where_clause) = input_struct.generics.split_for_impl(); + let field_ident: Vec = input_struct + .fields + .members() + .filter_map(|m| match m { + Member::Named(m) => Some(m), + _ => None, + }) + .collect(); + let field_name_str: Vec = field_ident.iter().map(|ident| ident.to_string()).collect(); + let field_type: Vec<&Type> = input_struct.fields.iter().map(|f| &f.ty).collect(); + let field_configs: Vec<(FieldConfig, Span)> = input_struct + .fields + .iter() + .map(|f| (FieldConfig::from_field(f), f.ty.span())) + .collect(); + + let register_deserialize = gen_register_deserialize(&field_configs); + let register_serialize = gen_register_serialize(&field_configs); + let register_fork_clone = gen_register_fork_clone(&field_configs); + let register_unzip = gen_register_unzip(&field_configs); + let register_fork_result = gen_register_fork_result(&field_configs); + let register_split = gen_register_split(&field_configs); + let register_join = gen_register_join(&field_configs); + let register_buffer_access = gen_register_buffer_access(&field_configs); + let register_listen = gen_register_listen(&field_configs); + + let register_message: Vec = zip(&field_type, &field_configs) + .map(|(field_type, (_config, span))| { + quote_spanned! {*span=> + let mut _message = _opt_out.register_message::<<#field_type as ::bevy_impulse::SectionItem>::MessageType>(); + } + }) + .collect(); + + let gen = quote! { + impl #impl_generics ::bevy_impulse::Section for #struct_ident #ty_generics #where_clause { + fn into_slots( + self: Box, + ) -> SectionSlots { + let mut slots = SectionSlots::new(); + #( + self.#field_ident.insert_into_slots(&#field_name_str, &mut slots); + )* + slots + } + + fn on_register(registry: &mut DiagramElementRegistry) + where + Self: Sized, + { + #({ + let _opt_out = registry.opt_out(); + #register_deserialize + #register_serialize + #register_fork_clone + + #register_message + + #register_unzip + #register_fork_result + #register_split + #register_join + #register_buffer_access + #register_listen + })* + } + } + + impl #impl_generics ::bevy_impulse::SectionMetadataProvider for #struct_ident #ty_generics #where_clause { + fn metadata() -> &'static ::bevy_impulse::SectionMetadata { + static METADATA: ::std::sync::OnceLock<::bevy_impulse::SectionMetadata> = ::std::sync::OnceLock::new(); + METADATA.get_or_init(|| { + let mut metadata = ::bevy_impulse::SectionMetadata::new(); + #( + <#field_type as ::bevy_impulse::SectionItem>::build_metadata( + &mut metadata, + &#field_name_str, + ); + )* + metadata + }) + } + } + }; + + Ok(gen) +} + +struct FieldConfig { + no_deserialize: bool, + no_serialize: bool, + no_clone: bool, + unzip: bool, + fork_result: bool, + split: bool, + join: bool, + buffer_access: bool, + listen: bool, +} + +impl FieldConfig { + fn from_field(field: &Field) -> Self { + let mut config = Self { + no_deserialize: false, + no_serialize: false, + no_clone: false, + unzip: false, + fork_result: false, + split: false, + join: false, + buffer_access: false, + listen: false, + }; + + for attr in field + .attrs + .iter() + .filter(|attr| attr.path().is_ident("message")) + { + attr.parse_nested_meta(|meta| { + if meta.path.is_ident("no_deserialize") { + config.no_deserialize = true; + } else if meta.path.is_ident("no_serialize") { + config.no_serialize = true; + } else if meta.path.is_ident("no_clone") { + config.no_clone = true; + } else if meta.path.is_ident("unzip") { + config.unzip = true; + } else if meta.path.is_ident("result") { + config.fork_result = true; + } else if meta.path.is_ident("split") { + config.split = true; + } else if meta.path.is_ident("join") { + config.join = true; + } else if meta.path.is_ident("buffer_access") { + config.buffer_access = true; + } else if meta.path.is_ident("listen") { + config.listen = true; + } + Ok(()) + }) + // panic if attribute is malformed, this will result in a compile error which is intended. + .unwrap(); + } + + config + } +} + +fn gen_register_deserialize(fields: &Vec<(FieldConfig, Span)>) -> Vec { + fields + .into_iter() + .map(|(config, span)| { + if config.no_deserialize { + quote_spanned! {*span=> + let _opt_out = _opt_out.no_deserializing(); + } + } else { + TokenStream::new() + } + }) + .collect() +} + +fn gen_register_serialize(fields: &Vec<(FieldConfig, Span)>) -> Vec { + fields + .into_iter() + .map(|(config, span)| { + if config.no_serialize { + parse_quote_spanned! {*span=> + let _opt_out = _opt_out.no_serializing(); + } + } else { + TokenStream::new() + } + }) + .collect() +} + +fn gen_register_fork_clone(fields: &Vec<(FieldConfig, Span)>) -> Vec { + fields + .into_iter() + .map(|(config, span)| { + if config.no_clone { + quote_spanned! {*span=> + let _opt_out = _opt_out.no_cloning(); + } + } else { + TokenStream::new() + } + }) + .collect() +} + +fn gen_register_unzip(fields: &Vec<(FieldConfig, Span)>) -> Vec { + fields + .into_iter() + .map(|(config, span)| { + if config.unzip { + quote_spanned! {*span=> + _message.with_unzip(); + } + } else { + TokenStream::new() + } + }) + .collect() +} + +fn gen_register_fork_result(fields: &Vec<(FieldConfig, Span)>) -> Vec { + fields + .into_iter() + .map(|(config, span)| { + if config.fork_result { + quote_spanned! {*span=> + _message.with_fork_result(); + } + } else { + TokenStream::new() + } + }) + .collect() +} + +fn gen_register_split(fields: &Vec<(FieldConfig, Span)>) -> Vec { + fields + .into_iter() + .map(|(config, span)| { + if config.split { + quote_spanned! {*span=> + _message.with_split(); + } + } else { + TokenStream::new() + } + }) + .collect() +} + +fn gen_register_join(fields: &Vec<(FieldConfig, Span)>) -> Vec { + fields + .into_iter() + .map(|(config, span)| { + if config.join { + quote_spanned! {*span=> + _message.with_join(); + } + } else { + TokenStream::new() + } + }) + .collect() +} + +fn gen_register_buffer_access(fields: &Vec<(FieldConfig, Span)>) -> Vec { + fields + .into_iter() + .map(|(config, span)| { + if config.buffer_access { + quote_spanned! {*span=> + _message.with_buffer_access(); + } + } else { + TokenStream::new() + } + }) + .collect() +} + +fn gen_register_listen(fields: &Vec<(FieldConfig, Span)>) -> Vec { + fields + .into_iter() + .map(|(config, span)| { + if config.listen { + quote_spanned! {*span=> + _message.with_listen(); + } + } else { + TokenStream::new() + } + }) + .collect() +} diff --git a/src/diagram.rs b/src/diagram.rs index 1c85aa21..493fe735 100644 --- a/src/diagram.rs +++ b/src/diagram.rs @@ -21,6 +21,7 @@ mod fork_result_schema; mod join_schema; mod node_schema; mod registration; +mod section_schema; mod serialization; mod split_schema; mod supported; @@ -29,6 +30,7 @@ mod type_info; mod unzip_schema; mod workflow_builder; +use bevy_derive::{Deref, DerefMut}; use bevy_ecs::system::Commands; use buffer_schema::{BufferAccessSchema, BufferSchema, ListenSchema}; use fork_clone_schema::{DynForkClone, ForkCloneSchema, PerformForkClone}; @@ -37,57 +39,182 @@ pub use join_schema::JoinOutput; use join_schema::{JoinSchema, SerializedJoinSchema}; pub use node_schema::NodeSchema; pub use registration::*; +pub use section_schema::*; pub use serialization::*; pub use split_schema::*; use tracing::debug; use transform_schema::{TransformError, TransformSchema}; -use type_info::TypeInfo; +pub use type_info::TypeInfo; use unzip_schema::UnzipSchema; -use workflow_builder::{create_workflow, BuildDiagramOperation, BuildStatus, DiagramContext}; +pub use workflow_builder::*; // ---------- -use std::{borrow::Cow, collections::HashMap, fmt::Display, io::Read}; +use std::{ + borrow::Cow, + collections::{HashMap, HashSet}, + fmt::Display, + io::Read, + sync::Arc, +}; use crate::{ Builder, IncompatibleLayout, JsonMessage, Scope, Service, SpawnWorkflowExt, SplitConnectionError, StreamPack, }; -use schemars::JsonSchema; -use serde::{Deserialize, Serialize}; +use schemars::{ + r#gen::SchemaGenerator, + schema::{InstanceType, Metadata, ObjectValidation, Schema, SchemaObject, SingleOrVec}, + JsonSchema, +}; +use serde::{ + de::{Error, Visitor}, + ser::SerializeMap, + Deserialize, Deserializer, Serialize, Serializer, +}; +const CURRENT_DIAGRAM_VERSION: &str = "0.1.0"; const SUPPORTED_DIAGRAM_VERSION: &str = ">=0.1.0, <0.2.0"; +const RESERVED_OPERATION_NAMES: [&'static str; 2] = ["", "builtin"]; -pub type BuilderId = String; -pub type OperationId = String; +pub type BuilderId = Arc; +pub type OperationName = Arc; #[derive( Debug, Clone, Serialize, Deserialize, JsonSchema, Hash, PartialEq, Eq, PartialOrd, Ord, )] #[serde(untagged, rename_all = "snake_case")] pub enum NextOperation { - Target(OperationId), - Builtin { builtin: BuiltinTarget }, + Name(OperationName), + Builtin { + builtin: BuiltinTarget, + }, + /// Refer to an "inner" operation of one of the sibling operations in a + /// diagram. This can be used to target section inputs. + Namespace(NamespacedOperation), +} + +impl NextOperation { + pub fn dispose() -> Self { + NextOperation::Builtin { + builtin: BuiltinTarget::Dispose, + } + } } impl Display for NextOperation { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Self::Target(operation_id) => f.write_str(operation_id), - Self::Builtin { builtin } => write!(f, "builtin:{}", builtin), + Self::Name(operation_id) => f.write_str(operation_id), + Self::Namespace(NamespacedOperation { + namespace, + operation, + }) => write!(f, "{namespace}:{operation}"), + Self::Builtin { builtin } => write!(f, "builtin:{builtin}"), + } + } +} + +#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)] +/// This describes an operation that exists inside of some namespace, such as a +/// [`Section`]. This will serialize as a map with a single entry of +/// `{ "": "" }`. +pub struct NamespacedOperation { + pub namespace: OperationName, + pub operation: OperationName, +} + +impl Serialize for NamespacedOperation { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut map = serializer.serialize_map(Some(1))?; + map.serialize_entry(&self.namespace, &self.operation)?; + map.end() + } +} + +struct NamespacedOperationVisitor; + +impl<'de> Visitor<'de> for NamespacedOperationVisitor { + type Value = NamespacedOperation; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str( + "a map with exactly one entry of { \"\" : \"\" } \ + whose key is the namespace string and whose value is the operation string", + ) + } + + fn visit_map(self, mut map: A) -> Result + where + A: serde::de::MapAccess<'de>, + { + let (key, value) = map.next_entry::()?.ok_or_else(|| { + A::Error::custom( + "namespaced operation must be a map from the namespace to the operation name", + ) + })?; + + if !map.next_key::()?.is_none() { + return Err(A::Error::custom( + "namespaced operation must contain exactly one entry", + )); } + + Ok(NamespacedOperation { + namespace: key.into(), + operation: value.into(), + }) + } +} + +impl<'de> Deserialize<'de> for NamespacedOperation { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_map(NamespacedOperationVisitor) + } +} + +impl JsonSchema for NamespacedOperation { + fn json_schema(generator: &mut SchemaGenerator) -> Schema { + let mut schema = SchemaObject::new_ref(Self::schema_name()); + schema.instance_type = Some(SingleOrVec::Single(Box::new(InstanceType::Object))); + schema.object = Some(Box::new(ObjectValidation { + max_properties: Some(1), + min_properties: Some(1), + required: Default::default(), + properties: Default::default(), + pattern_properties: Default::default(), + property_names: Default::default(), + additional_properties: Some(Box::new(generator.subschema_for::())), + })); + schema.metadata = Some(Box::new(Metadata { + title: Some("NamespacedOperation".to_string()), + description: Some("Refer to an operation inside of a namespace, e.g. { \"\": \"\"".to_string()), + ..Default::default() + })); + + Schema::Object(schema) + } + + fn schema_name() -> String { + "NamespacedOperation".to_string() } } #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] #[serde(rename_all = "snake_case", untagged)] -pub enum BufferInputs { - Single(OperationId), - Dict(HashMap), - Array(Vec), +pub enum BufferSelection { + Single(NextOperation), + Dict(HashMap), + Array(Vec), } -impl BufferInputs { +impl BufferSelection { pub fn is_empty(&self) -> bool { match self { Self::Single(_) => false, @@ -126,30 +253,6 @@ pub enum BuiltinTarget { Cancel, } -#[derive( - Debug, Clone, Serialize, Deserialize, JsonSchema, Hash, PartialEq, Eq, PartialOrd, Ord, -)] -#[serde(untagged, rename_all = "snake_case")] -pub enum SourceOperation { - Source(OperationId), - Builtin { builtin: BuiltinSource }, -} - -impl From for SourceOperation { - fn from(value: OperationId) -> Self { - SourceOperation::Source(value) - } -} - -impl Display for SourceOperation { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::Source(operation_id) => f.write_str(operation_id), - Self::Builtin { builtin } => write!(f, "builtin:{}", builtin), - } - } -} - #[derive( Debug, Clone, @@ -221,6 +324,62 @@ pub enum DiagramOperation { /// # Ok::<_, serde_json::Error>(()) Node(NodeSchema), + /// Connect the request to a registered section. + /// + /// ``` + /// # bevy_impulse::Diagram::from_json_str(r#" + /// { + /// "version": "0.1.0", + /// "start": "section_op", + /// "ops": { + /// "section_op": { + /// "type": "section", + /// "builder": "my_section_builder", + /// "connect": { + /// "my_section_output": { "builtin": "terminate" } + /// } + /// } + /// } + /// } + /// # "#)?; + /// # Ok::<_, serde_json::Error>(()) + /// ``` + /// + /// Custom sections can also be created via templates + /// ``` + /// # bevy_impulse::Diagram::from_json_str(r#" + /// { + /// "version": "0.1.0", + /// "templates": { + /// "my_template": { + /// "inputs": ["section_input"], + /// "outputs": ["section_output"], + /// "buffers": [], + /// "ops": { + /// "section_input": { + /// "type": "node", + /// "builder": "my_node", + /// "next": "section_output" + /// } + /// } + /// } + /// }, + /// "start": "section_op", + /// "ops": { + /// "section_op": { + /// "type": "section", + /// "template": "my_template", + /// "connect": { + /// "section_output": { "builtin": "terminate" } + /// } + /// } + /// } + /// } + /// # "#)?; + /// # Ok::<_, serde_json::Error>(()) + /// ``` + Section(SectionSchema), + /// If the request is cloneable, clone it into multiple responses that can /// each be sent to a different operation. The `next` property is an array. /// @@ -318,6 +477,7 @@ pub enum DiagramOperation { /// } /// # "#)?; /// # Ok::<_, serde_json::Error>(()) + /// ``` Unzip(UnzipSchema), /// If the request is a [`Result`], send the output message down an @@ -344,6 +504,7 @@ pub enum DiagramOperation { /// } /// # "#)?; /// # Ok::<_, serde_json::Error>(()) + /// ``` ForkResult(ForkResultSchema), /// If the input message is a list-like or map-like object, split it into @@ -430,9 +591,9 @@ pub enum DiagramOperation { /// # bevy_impulse::Diagram::from_json_str(r#" /// { /// "version": "0.1.0", - /// "start": "fork_measuring", + /// "start": "begin_measuring", /// "ops": { - /// "fork_measuring": { + /// "begin_measuring": { /// "type": "fork_clone", /// "next": ["localize", "imu"] /// }, @@ -642,6 +803,7 @@ pub enum DiagramOperation { /// } /// # "#)?; /// # Ok::<_, serde_json::Error>(()) + /// ``` BufferAccess(BufferAccessSchema), /// Listen on a buffer. @@ -682,7 +844,7 @@ pub enum DiagramOperation { impl BuildDiagramOperation for DiagramOperation { fn build_diagram_operation( &self, - id: &OperationId, + id: &OperationName, builder: &mut Builder, ctx: &mut DiagramContext, ) -> Result { @@ -694,6 +856,7 @@ impl BuildDiagramOperation for DiagramOperation { Self::Join(op) => op.build_diagram_operation(id, builder, ctx), Self::Listen(op) => op.build_diagram_operation(id, builder, ctx), Self::Node(op) => op.build_diagram_operation(id, builder, ctx), + Self::Section(op) => op.build_diagram_operation(id, builder, ctx), Self::SerializedJoin(op) => op.build_diagram_operation(id, builder, ctx), Self::Split(op) => op.build_diagram_operation(id, builder, ctx), Self::Transform(op) => op.build_diagram_operation(id, builder, ctx), @@ -735,7 +898,7 @@ where o.to_string().serialize(ser) } -#[derive(JsonSchema, Serialize, Deserialize)] +#[derive(Debug, Clone, JsonSchema, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub struct Diagram { /// Version of the diagram, should always be `0.1.0`. @@ -746,8 +909,11 @@ pub struct Diagram { #[schemars(schema_with = "schema_with_string")] version: semver::Version, + #[serde(default)] + pub templates: Templates, + /// Indicates where the workflow should start running. - start: NextOperation, + pub start: NextOperation, /// To simplify diagram definitions, the diagram workflow builder will /// sometimes insert implicit operations into the workflow, such as implicit @@ -757,13 +923,24 @@ pub struct Diagram { /// If left unspecified, an implicit error will cause the entire workflow to /// be cancelled. #[serde(default)] - on_implicit_error: Option, + pub on_implicit_error: Option, /// Operations that define the workflow - ops: HashMap, + pub ops: Operations, } impl Diagram { + /// Begin creating a new diagram + pub fn new(start: NextOperation) -> Self { + Self { + version: semver::Version::parse(CURRENT_DIAGRAM_VERSION).unwrap(), + start, + templates: Default::default(), + on_implicit_error: Default::default(), + ops: Default::default(), + } + } + /// Spawns a workflow from this diagram. /// /// # Examples @@ -898,10 +1075,157 @@ impl Diagram { serde_json::from_reader(r) } - fn get_op(&self, op_id: &OperationId) -> Result<&DiagramOperation, DiagramErrorCode> { - self.ops - .get(op_id) - .ok_or_else(|| DiagramErrorCode::OperationNotFound(op_id.clone())) + /// Make sure all operation names are valid, e.g. no reserved words such as + /// `builtin` are being used. + pub fn validate_operation_names(&self) -> Result<(), DiagramErrorCode> { + self.ops.validate_operation_names()?; + self.templates.validate_operation_names()?; + Ok(()) + } + + /// Validate the templates that are being used within the `ops` section, or + /// recursively within any templates used by the `ops` section. Any unused + /// templates will not be validated. + pub fn validate_template_usage(&self) -> Result<(), DiagramErrorCode> { + for op in self.ops.values() { + match op { + DiagramOperation::Section(section) => match §ion.provider { + SectionProvider::Template(template) => { + self.templates.validate_template(template)?; + } + _ => continue, + }, + _ => continue, + } + } + + Ok(()) + } +} + +#[derive(Debug, Clone, Default, JsonSchema, Serialize, Deserialize, Deref, DerefMut)] +#[serde(transparent, rename_all = "snake_case")] +pub struct Operations(HashMap); + +impl Operations { + /// Get an operation from this map, or a diagram error code if the operation + /// is not available. + pub fn get_op(&self, op_id: &Arc) -> Result<&DiagramOperation, DiagramErrorCode> { + self.get(op_id) + .ok_or_else(|| DiagramErrorCode::operation_name_not_found(op_id.clone())) + } + + pub fn validate_operation_names(&self) -> Result<(), DiagramErrorCode> { + validate_operation_names(&self.0) + } +} + +#[derive(Debug, Clone, Default, JsonSchema, Serialize, Deserialize, Deref, DerefMut)] +#[serde(transparent, rename_all = "snake_case")] +pub struct Templates(HashMap); + +impl Templates { + /// Get a template from this map, or a diagram error code if the template is + /// not available. + pub fn get_template( + &self, + template_id: &OperationName, + ) -> Result<&SectionTemplate, DiagramErrorCode> { + self.get(template_id) + .ok_or_else(|| DiagramErrorCode::TemplateNotFound(template_id.clone())) + } + + pub fn validate_operation_names(&self) -> Result<(), DiagramErrorCode> { + for (name, template) in &self.0 { + validate_operation_name(name)?; + validate_operation_names(&template.ops)?; + // TODO(@mxgrey): Validate correctness of input, output, and buffer mapping + } + + Ok(()) + } + + /// Check for potential issues in one of the templates, e.g. a circular + /// dependency with other templates. + pub fn validate_template(&self, template_id: &OperationName) -> Result<(), DiagramErrorCode> { + check_circular_template_dependency(template_id, &self.0)?; + Ok(()) + } +} + +fn validate_operation_names( + ops: &HashMap, +) -> Result<(), DiagramErrorCode> { + for name in ops.keys() { + validate_operation_name(name)?; + } + + Ok(()) +} + +fn validate_operation_name(name: &str) -> Result<(), DiagramErrorCode> { + for reserved in &RESERVED_OPERATION_NAMES { + if name == *reserved { + return Err(DiagramErrorCode::InvalidUseOfReservedName(*reserved)); + } + } + + Ok(()) +} + +fn check_circular_template_dependency( + start_from: &OperationName, + templates: &HashMap, +) -> Result<(), DiagramErrorCode> { + let mut queue = Vec::new(); + queue.push(TemplateStack::new(start_from)); + + while let Some(top) = queue.pop() { + let Some(template) = templates.get(&top.next) else { + return Err(DiagramErrorCode::UnknownTemplate(top.next)); + }; + + for op in template.ops.0.values() { + match op { + DiagramOperation::Section(section) => match §ion.provider { + SectionProvider::Template(template) => { + queue.push(top.child(template)?); + } + _ => continue, + }, + _ => continue, + } + } + } + + Ok(()) +} + +struct TemplateStack { + used: HashSet, + next: OperationName, +} + +impl TemplateStack { + fn new(op: &OperationName) -> Self { + TemplateStack { + used: HashSet::from_iter([Arc::clone(op)]), + next: Arc::clone(op), + } + } + + fn child(&self, next: &OperationName) -> Result { + let mut used = self.used.clone(); + if !used.insert(Arc::clone(next)) { + return Err(DiagramErrorCode::CircularTemplateDependency( + used.into_iter().collect(), + )); + } + + Ok(Self { + used, + next: Arc::clone(next), + }) } } @@ -915,9 +1239,11 @@ pub struct DiagramError { } impl DiagramError { - pub fn in_operation(op_id: OperationId, code: DiagramErrorCode) -> Self { + pub fn in_operation(op_id: impl Into, code: DiagramErrorCode) -> Self { Self { - context: DiagramErrorContext { op_id: Some(op_id) }, + context: DiagramErrorContext { + op_id: Some(op_id.into()), + }, code, } } @@ -925,7 +1251,7 @@ impl DiagramError { #[derive(Debug)] pub struct DiagramErrorContext { - op_id: Option, + op_id: Option, } impl Display for DiagramErrorContext { @@ -943,7 +1269,10 @@ pub enum DiagramErrorCode { BuilderNotFound(BuilderId), #[error("operation [{0}] not found")] - OperationNotFound(OperationId), + OperationNotFound(NextOperation), + + #[error("section template [{0}] does not exist")] + TemplateNotFound(OperationName), #[error("type mismatch, source {source_type}, target {target_type}")] TypeMismatch { @@ -951,11 +1280,11 @@ pub enum DiagramErrorCode { target_type: TypeInfo, }, - #[error("Operation [{0}] attempted to instantiate multiple inputs.")] - MultipleInputsCreated(OperationId), + #[error("Operation [{0}] attempted to instantiate a duplicate of itself.")] + DuplicateInputsCreated(OperationRef), - #[error("Operation [{0}] attempted to instantiate multiple buffers.")] - MultipleBuffersCreated(OperationId), + #[error("Operation [{0}] attempted to instantiate a duplicate buffer.")] + DuplicateBuffersCreated(OperationRef), #[error("Missing a connection to start or terminate. A workflow cannot run with a valid connection to each.")] MissingStartOrTerminate, @@ -966,11 +1295,11 @@ pub enum DiagramErrorCode { #[error("Deserialization was disabled for the target message type.")] NotDeserializable(TypeInfo), - #[error("Cloning was disabled for the target message type.")] - NotCloneable, + #[error("Cloning was disabled for the target message type. Type: {0}")] + NotCloneable(TypeInfo), - #[error("The number of unzip slots in response does not match the number of inputs.")] - NotUnzippable, + #[error("The target message type does not support unzipping. Type: {0}")] + NotUnzippable(TypeInfo), #[error("The number of elements in the unzip expected by the diagram [{expected}] is different from the real number [{actual}]")] UnzipMismatch { @@ -979,25 +1308,31 @@ pub enum DiagramErrorCode { elements: Vec, }, - #[error("Call .with_fork_result() on your node to be able to fork its Result-type output.")] - CannotForkResult, + #[error("Call .with_fork_result() on your node to be able to fork its Result-type output. Type: {0}")] + CannotForkResult(TypeInfo), - #[error("Response cannot be split. Make sure to use .with_split() when building the node.")] - NotSplittable, + #[error("Response cannot be split. Make sure to use .with_split() when building the node. Type: {0}")] + NotSplittable(TypeInfo), #[error( - "Message cannot be joined. Make sure to use .with_join() when building the target node." + "Message cannot be joined. Make sure to use .with_join() when building the target node. Type: {0}" )] - NotJoinable, + NotJoinable(TypeInfo), #[error("Empty join is not allowed.")] EmptyJoin, - #[error("Target type cannot be determined from [next] and [target_node] is not provided.")] + #[error("Target type cannot be determined from [next] and [target_node] is not provided or cannot be inferred from.")] UnknownTarget, - #[error("There was an attempt to access an unknown operation: [{0}]")] - UnknownOperation(NextOperation), + #[error("There was an attempt to connect to an unknown operation: [{0}]")] + UnknownOperation(OperationRef), + + #[error("There was an attempt to use an unknown section template: [{0}]")] + UnknownTemplate(OperationName), + + #[error("There was an attempt to use an operation in an invalid way: [{0}]")] + InvalidOperation(OperationRef), #[error(transparent)] CannotTransform(#[from] TransformError), @@ -1014,12 +1349,12 @@ pub enum DiagramErrorCode { #[error(transparent)] IncompatibleBuffers(#[from] IncompatibleLayout), + #[error(transparent)] + SectionError(#[from] SectionError), + #[error("one or more operation is missing inputs")] IncompleteDiagram, - #[error("operation type only accept single input")] - OnlySingleInput, - #[error(transparent)] JsonError(#[from] serde_json::Error), @@ -1032,11 +1367,32 @@ pub enum DiagramErrorCode { #[error("The build of the workflow came to a halt, reasons:\n{reasons:?}")] BuildHalted { /// Reasons that operations were unable to make progress building - reasons: HashMap>, + reasons: HashMap>, }, #[error("The workflow building process has had an excessive number of iterations. This may indicate an implementation bug or an extraordinarily complex diagram.")] ExcessiveIterations, + + #[error("An operation was given a reserved name [{0}]")] + InvalidUseOfReservedName(&'static str), + + #[error("an error happened while building a nested diagram: {0}")] + NestedError(Box), + + #[error("A circular redirection exists between operations: {}", format_list(&.0))] + CircularRedirect(Vec), + + #[error("A circular dependency exists between templates: {}", format_list(&.0))] + CircularTemplateDependency(Vec), +} + +fn format_list(list: &[T]) -> String { + let mut output = String::new(); + for op in list { + output += &format!("[{op}]"); + } + + output } impl From for DiagramError { @@ -1048,6 +1404,16 @@ impl From for DiagramError { } } +impl DiagramErrorCode { + pub fn operation_name_not_found(name: OperationName) -> Self { + DiagramErrorCode::OperationNotFound(NextOperation::Name(name)) + } + + pub fn in_operation(self, op_id: OperationRef) -> DiagramError { + DiagramError::in_operation(op_id, self) + } +} + #[cfg(test)] mod testing; @@ -1078,7 +1444,7 @@ mod tests { .unwrap(); let err = fixture - .spawn_and_run(&diagram, serde_json::Value::from(4)) + .spawn_and_run::<_, JsonMessage>(&diagram, JsonMessage::from(4)) .unwrap_err(); assert!(fixture.context.no_unhandled_errors()); assert!(matches!( @@ -1196,12 +1562,12 @@ mod tests { .unwrap(); let err = fixture - .spawn_and_run(&diagram, serde_json::Value::from(4)) + .spawn_and_run::<_, JsonMessage>(&diagram, JsonMessage::from(4)) .unwrap_err(); assert!(fixture.context.no_unhandled_errors()); assert!(matches!( *err.downcast_ref::().unwrap().cause, - CancellationCause::Unreachable(_) + CancellationCause::Unreachable(_), )); } @@ -1231,8 +1597,8 @@ mod tests { })) .unwrap(); - let result = fixture - .spawn_and_run(&diagram, serde_json::Value::from(4)) + let result: JsonMessage = fixture + .spawn_and_run(&diagram, JsonMessage::from(4)) .unwrap(); assert!(fixture.context.no_unhandled_errors()); assert_eq!(result, 36); @@ -1249,8 +1615,8 @@ mod tests { })) .unwrap(); - let result = fixture - .spawn_and_run(&diagram, serde_json::Value::from(4)) + let result: JsonMessage = fixture + .spawn_and_run(&diagram, JsonMessage::from(4)) .unwrap(); assert!(fixture.context.no_unhandled_errors()); assert_eq!(result, 4); @@ -1275,10 +1641,10 @@ mod tests { } "#; - let result = fixture + let result: JsonMessage = fixture .spawn_and_run( &Diagram::from_json_str(json_str).unwrap(), - serde_json::Value::from(4), + JsonMessage::from(4), ) .unwrap(); assert!(fixture.context.no_unhandled_errors()); @@ -1316,10 +1682,39 @@ mod tests { })) .unwrap(); - let result = fixture - .spawn_and_run(&diagram, serde_json::Value::from(4)) + let result: JsonMessage = fixture + .spawn_and_run(&diagram, JsonMessage::from(4)) .unwrap(); assert!(fixture.context.no_unhandled_errors()); assert_eq!(result, 777); } + + #[test] + fn test_unknown_operation_detection() { + let mut fixture = DiagramTestFixture::new(); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "start": "op1", + "ops": { + "op1": { + "type": "node", + "builder": "multiply3_5", + "next": "clone", + }, + "clone": { + "type": "fork_clone", + "next": [ + "unknown", + { "builtin": "terminate" }, + ], + }, + }, + })) + .unwrap(); + + let result = fixture.spawn_json_io_workflow(&diagram).unwrap_err(); + + assert!(matches!(result.code, DiagramErrorCode::UnknownOperation(_),)); + } } diff --git a/src/diagram/buffer_schema.rs b/src/diagram/buffer_schema.rs index 79e7ec74..048c6998 100644 --- a/src/diagram/buffer_schema.rs +++ b/src/diagram/buffer_schema.rs @@ -21,8 +21,8 @@ use serde::{Deserialize, Serialize}; use crate::{Accessor, BufferSettings, Builder, JsonMessage}; use super::{ - type_info::TypeInfo, BufferInputs, BuildDiagramOperation, BuildStatus, DiagramContext, - DiagramErrorCode, NextOperation, OperationId, + type_info::TypeInfo, BufferSelection, BuildDiagramOperation, BuildStatus, DiagramContext, + DiagramErrorCode, NextOperation, OperationName, }; #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] @@ -37,14 +37,14 @@ pub struct BufferSchema { impl BuildDiagramOperation for BufferSchema { fn build_diagram_operation( &self, - id: &OperationId, + id: &OperationName, builder: &mut Builder, ctx: &mut DiagramContext, ) -> Result { let message_info = if self.serialize.is_some_and(|v| v) { TypeInfo::of::() } else { - let Some(inferred_type) = ctx.infer_input_type_into_target(id) else { + let Some(inferred_type) = ctx.infer_input_type_into_target(id)? else { // There are no outputs ready for this target, so we can't do // anything yet. The builder should try again later. @@ -54,7 +54,7 @@ impl BuildDiagramOperation for BufferSchema { return Ok(BuildStatus::defer("waiting for an input")); }; - *inferred_type + inferred_type }; let buffer = @@ -72,31 +72,33 @@ pub struct BufferAccessSchema { pub(super) next: NextOperation, /// Map of buffer keys and buffers. - pub(super) buffers: BufferInputs, - - /// The id of an operation that this operation is for. The id must be a `node` operation. Optional if `next` is a node operation. - pub(super) target_node: Option, + pub(super) buffers: BufferSelection, } impl BuildDiagramOperation for BufferAccessSchema { fn build_diagram_operation( &self, - id: &OperationId, + id: &OperationName, builder: &mut Builder, ctx: &mut DiagramContext, ) -> Result { + let Some(target_type) = ctx.infer_input_type_into_target(&self.next)? else { + return Ok(BuildStatus::defer( + "waiting to find out target message type", + )); + }; + let buffer_map = match ctx.create_buffer_map(&self.buffers) { Ok(buffer_map) => buffer_map, Err(reason) => return Ok(BuildStatus::defer(reason)), }; - let target_type = ctx.get_node_request_type(self.target_node.as_ref(), &self.next)?; let node = ctx .registry .messages .with_buffer_access(&target_type, &buffer_map, builder)?; ctx.set_input_for_target(id, node.input)?; - ctx.add_output_into_target(self.next.clone(), node.output); + ctx.add_output_into_target(&self.next, node.output); Ok(BuildStatus::Finished) } } @@ -121,30 +123,35 @@ pub struct ListenSchema { pub(super) next: NextOperation, /// Map of buffer keys and buffers. - pub(super) buffers: BufferInputs, + pub(super) buffers: BufferSelection, /// The id of an operation that this operation is for. The id must be a `node` operation. Optional if `next` is a node operation. - pub(super) target_node: Option, + pub(super) target_node: Option, } impl BuildDiagramOperation for ListenSchema { fn build_diagram_operation( &self, - _: &OperationId, + _: &OperationName, builder: &mut Builder, ctx: &mut DiagramContext, ) -> Result { + let Some(target_type) = ctx.infer_input_type_into_target(&self.next)? else { + return Ok(BuildStatus::defer( + "waiting to find out target message type", + )); + }; + let buffer_map = match ctx.create_buffer_map(&self.buffers) { Ok(buffer_map) => buffer_map, Err(reason) => return Ok(BuildStatus::defer(reason)), }; - let target_type = ctx.get_node_request_type(self.target_node.as_ref(), &self.next)?; let output = ctx .registry .messages .listen(&target_type, &buffer_map, builder)?; - ctx.add_output_into_target(self.next.clone(), output); + ctx.add_output_into_target(&self.next, output); Ok(BuildStatus::Finished) } } @@ -430,9 +437,7 @@ mod tests { })) .unwrap(); - let result = fixture - .spawn_and_run(&diagram, serde_json::Value::Null) - .unwrap(); + let result: JsonMessage = fixture.spawn_and_run(&diagram, JsonMessage::Null).unwrap(); assert!(fixture.context.no_unhandled_errors()); assert_eq!(result, "hello world"); } @@ -490,9 +495,7 @@ mod tests { })) .unwrap(); - let result = fixture - .spawn_and_run(&diagram, serde_json::Value::Null) - .unwrap(); + let result: JsonMessage = fixture.spawn_and_run(&diagram, JsonMessage::Null).unwrap(); assert!(fixture.context.no_unhandled_errors()); assert_eq!(result, 1); } @@ -541,8 +544,8 @@ mod tests { })) .unwrap(); - let result = fixture - .spawn_and_run(&diagram, serde_json::Value::String("hello".to_owned())) + let result: JsonMessage = fixture + .spawn_and_run(&diagram, JsonMessage::String("hello".to_owned())) .unwrap(); assert!(fixture.context.no_unhandled_errors()); assert_eq!(result, 11); @@ -592,8 +595,8 @@ mod tests { })) .unwrap(); - let result = fixture - .spawn_and_run(&diagram, serde_json::Value::String("hello".to_owned())) + let result: JsonMessage = fixture + .spawn_and_run(&diagram, JsonMessage::String("hello".to_owned())) .unwrap(); assert!(fixture.context.no_unhandled_errors()); assert_eq!(result, 11); @@ -648,7 +651,7 @@ mod tests { })) .unwrap(); - let result = fixture + let result: JsonMessage = fixture .spawn_and_run(&diagram, JsonMessage::Number(5_i64.into())) .unwrap(); assert!(fixture.context.no_unhandled_errors()); @@ -704,7 +707,7 @@ mod tests { })) .unwrap(); - let result = fixture.spawn_and_run(&diagram, JsonMessage::Null).unwrap(); + let result: JsonMessage = fixture.spawn_and_run(&diagram, JsonMessage::Null).unwrap(); assert!(fixture.context.no_unhandled_errors()); assert_eq!(result, 1); } @@ -732,7 +735,7 @@ mod tests { })) .unwrap(); - let result = fixture.spawn_and_run(&diagram, JsonMessage::Null).unwrap(); + let result: JsonMessage = fixture.spawn_and_run(&diagram, JsonMessage::Null).unwrap(); assert!(fixture.context.no_unhandled_errors()); assert_eq!(result, 1); } @@ -760,7 +763,7 @@ mod tests { })) .unwrap(); - let result = fixture.spawn_and_run(&diagram, JsonMessage::Null).unwrap(); + let result: JsonMessage = fixture.spawn_and_run(&diagram, JsonMessage::Null).unwrap(); assert!(fixture.context.no_unhandled_errors()); assert_eq!(result, 1); } @@ -875,7 +878,7 @@ mod tests { })) .unwrap(); - let result = fixture.spawn_and_run(&diagram, input).unwrap(); + let result: JsonMessage = fixture.spawn_and_run(&diagram, input).unwrap(); assert!(fixture.context.no_unhandled_errors()); assert_eq!(result, JsonMessage::Null); } @@ -969,7 +972,7 @@ mod tests { })) .unwrap(); - let result = fixture.spawn_and_run(&diagram, input).unwrap(); + let result: JsonMessage = fixture.spawn_and_run(&diagram, input).unwrap(); assert!(fixture.context.no_unhandled_errors()); assert_eq!(result, JsonMessage::Null); } diff --git a/src/diagram/fork_clone_schema.rs b/src/diagram/fork_clone_schema.rs index 3dda14d2..15e7eba0 100644 --- a/src/diagram/fork_clone_schema.rs +++ b/src/diagram/fork_clone_schema.rs @@ -22,7 +22,7 @@ use crate::{Builder, ForkCloneOutput}; use super::{ supported::*, BuildDiagramOperation, BuildStatus, DiagramContext, DiagramErrorCode, - DynInputSlot, DynOutput, NextOperation, OperationId, + DynInputSlot, DynOutput, NextOperation, OperationName, TypeInfo, }; #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] @@ -34,20 +34,31 @@ pub struct ForkCloneSchema { impl BuildDiagramOperation for ForkCloneSchema { fn build_diagram_operation( &self, - id: &OperationId, + id: &OperationName, builder: &mut Builder, ctx: &mut DiagramContext, ) -> Result { - let Some(inferred_type) = ctx.infer_input_type_into_target(id) else { - // There are no outputs ready for this target, so we can't do - // anything yet. The builder should try again later. - return Ok(BuildStatus::defer("waiting for an input")); + let inferred_type = 'inferred: { + match ctx.infer_input_type_into_target(id)? { + Some(inferred_type) => break 'inferred inferred_type, + None => { + for target in &self.next { + if let Some(inferred_type) = ctx.infer_input_type_into_target(target)? { + break 'inferred inferred_type; + } + } + + // There are no outputs or input slots ready for this target, + // so we can't do anything yet. The builder should try again later. + return Ok(BuildStatus::defer("waiting for an input")); + } + } }; - let fork = ctx.registry.messages.fork_clone(inferred_type, builder)?; + let fork = ctx.registry.messages.fork_clone(&inferred_type, builder)?; ctx.set_input_for_target(id, fork.input)?; for target in &self.next { - ctx.add_output_into_target(target.clone(), fork.outputs.clone_output(builder)); + ctx.add_output_into_target(target, fork.outputs.clone_output(builder)); } Ok(BuildStatus::Finished) @@ -60,11 +71,11 @@ pub trait PerformForkClone { fn perform_fork_clone(builder: &mut Builder) -> Result; } -impl PerformForkClone for NotSupported { +impl PerformForkClone for NotSupported { const CLONEABLE: bool = false; fn perform_fork_clone(_builder: &mut Builder) -> Result { - Err(DiagramErrorCode::NotCloneable) + Err(DiagramErrorCode::NotCloneable(TypeInfo::of::())) } } @@ -120,7 +131,7 @@ mod tests { use serde_json::json; use test_log::test; - use crate::{diagram::testing::DiagramTestFixture, Diagram}; + use crate::{diagram::testing::DiagramTestFixture, Diagram, JsonMessage}; use super::*; @@ -151,7 +162,7 @@ mod tests { .unwrap(); let err = fixture.spawn_json_io_workflow(&diagram).unwrap_err(); assert!( - matches!(err.code, DiagramErrorCode::NotCloneable), + matches!(err.code, DiagramErrorCode::NotCloneable(_)), "{:?}", err ); @@ -183,8 +194,8 @@ mod tests { })) .unwrap(); - let result = fixture - .spawn_and_run(&diagram, serde_json::Value::from(4)) + let result: JsonMessage = fixture + .spawn_and_run(&diagram, JsonMessage::from(4)) .unwrap(); assert!(fixture.context.no_unhandled_errors()); assert_eq!(result, 36); diff --git a/src/diagram/fork_result_schema.rs b/src/diagram/fork_result_schema.rs index 3d6c975d..54c3cdd7 100644 --- a/src/diagram/fork_result_schema.rs +++ b/src/diagram/fork_result_schema.rs @@ -23,7 +23,7 @@ use crate::Builder; use super::{ supported::*, type_info::TypeInfo, BuildDiagramOperation, BuildStatus, DiagramContext, DiagramErrorCode, DynInputSlot, DynOutput, MessageRegistration, MessageRegistry, NextOperation, - OperationId, PerformForkClone, SerializeMessage, + OperationName, PerformForkClone, SerializeMessage, }; pub struct DynForkResult { @@ -42,20 +42,24 @@ pub struct ForkResultSchema { impl BuildDiagramOperation for ForkResultSchema { fn build_diagram_operation( &self, - id: &OperationId, + id: &OperationName, builder: &mut Builder, ctx: &mut DiagramContext, ) -> Result { - let Some(inferred_type) = ctx.infer_input_type_into_target(id) else { + let Some(inferred_type) = ctx.infer_input_type_into_target(id)? else { + // TODO(@mxgrey): For each result type we can register a tuple of + // (T, E) for the Ok and Err types as a key so we could infer the + // operation type using the expected types for ok and err. + // There are no outputs ready for this target, so we can't do // anything yet. The builder should try again later. return Ok(BuildStatus::defer("waiting for an input")); }; - let fork = ctx.registry.messages.fork_result(inferred_type, builder)?; + let fork = ctx.registry.messages.fork_result(&inferred_type, builder)?; ctx.set_input_for_target(id, fork.input)?; - ctx.add_output_into_target(self.ok.clone(), fork.ok); - ctx.add_output_into_target(self.err.clone(), fork.err); + ctx.add_output_into_target(&self.ok, fork.ok); + ctx.add_output_into_target(&self.err, fork.err); Ok(BuildStatus::Finished) } } @@ -105,7 +109,9 @@ mod tests { use serde_json::json; use test_log::test; - use crate::{diagram::testing::DiagramTestFixture, Builder, Diagram, NodeBuilderOptions}; + use crate::{ + diagram::testing::DiagramTestFixture, Builder, Diagram, JsonMessage, NodeBuilderOptions, + }; #[test] fn test_fork_result() { @@ -164,14 +170,14 @@ mod tests { })) .unwrap(); - let result = fixture - .spawn_and_run(&diagram, serde_json::Value::from(4)) + let result: JsonMessage = fixture + .spawn_and_run(&diagram, JsonMessage::from(4)) .unwrap(); assert!(fixture.context.no_unhandled_errors()); assert_eq!(result, "even"); - let result = fixture - .spawn_and_run(&diagram, serde_json::Value::from(3)) + let result: JsonMessage = fixture + .spawn_and_run(&diagram, JsonMessage::from(3)) .unwrap(); assert!(fixture.context.no_unhandled_errors()); assert_eq!(result, "odd"); diff --git a/src/diagram/join_schema.rs b/src/diagram/join_schema.rs index d312cd70..d0d86109 100644 --- a/src/diagram/join_schema.rs +++ b/src/diagram/join_schema.rs @@ -22,8 +22,8 @@ use smallvec::SmallVec; use crate::{Builder, JsonMessage}; use super::{ - BufferInputs, BuildDiagramOperation, BuildStatus, DiagramContext, DiagramErrorCode, - NextOperation, OperationId, + BufferSelection, BuildDiagramOperation, BuildStatus, DiagramContext, DiagramErrorCode, + NextOperation, OperationName, }; #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] @@ -32,16 +32,13 @@ pub struct JoinSchema { pub(super) next: NextOperation, /// Map of buffer keys and buffers. - pub(super) buffers: BufferInputs, - - /// The id of an operation that this operation is for. The id must be a `node` operation. Optional if `next` is a node operation. - pub(super) target_node: Option, + pub(super) buffers: BufferSelection, } impl BuildDiagramOperation for JoinSchema { fn build_diagram_operation( &self, - _: &OperationId, + _: &OperationName, builder: &mut Builder, ctx: &mut DiagramContext, ) -> Result { @@ -49,18 +46,22 @@ impl BuildDiagramOperation for JoinSchema { return Err(DiagramErrorCode::EmptyJoin); } + let Some(target_type) = ctx.infer_input_type_into_target(&self.next)? else { + return Ok(BuildStatus::defer( + "waiting to find out target message type", + )); + }; + let buffer_map = match ctx.create_buffer_map(&self.buffers) { Ok(buffer_map) => buffer_map, Err(reason) => return Ok(BuildStatus::defer(reason)), }; - let target_type = ctx.get_node_request_type(self.target_node.as_ref(), &self.next)?; - let output = ctx .registry .messages .join(&target_type, &buffer_map, builder)?; - ctx.add_output_into_target(self.next.clone(), output); + ctx.add_output_into_target(&self.next, output); Ok(BuildStatus::Finished) } } @@ -71,13 +72,13 @@ pub struct SerializedJoinSchema { pub(super) next: NextOperation, /// Map of buffer keys and buffers. - pub(super) buffers: BufferInputs, + pub(super) buffers: BufferSelection, } impl BuildDiagramOperation for SerializedJoinSchema { fn build_diagram_operation( &self, - _: &OperationId, + _: &OperationName, builder: &mut Builder, ctx: &mut DiagramContext, ) -> Result { @@ -91,7 +92,7 @@ impl BuildDiagramOperation for SerializedJoinSchema { }; let output = builder.try_join::(&buffer_map)?.output(); - ctx.add_output_into_target(self.next.clone(), output.into()); + ctx.add_output_into_target(&self.next, output.into()); Ok(BuildStatus::Finished) } @@ -110,8 +111,8 @@ mod tests { use super::*; use crate::{ - diagram::testing::DiagramTestFixture, Diagram, DiagramElementRegistry, DiagramError, - DiagramErrorCode, NodeBuilderOptions, + diagram::testing::DiagramTestFixture, Diagram, DiagramElementRegistry, DiagramErrorCode, + NodeBuilderOptions, }; fn foo(_: serde_json::Value) -> String { @@ -223,9 +224,7 @@ mod tests { })) .unwrap(); - let result = fixture - .spawn_and_run(&diagram, serde_json::Value::Null) - .unwrap(); + let result: JsonMessage = fixture.spawn_and_run(&diagram, JsonMessage::Null).unwrap(); assert!(fixture.context.no_unhandled_errors()); assert_eq!(result, "foobar"); } @@ -277,16 +276,14 @@ mod tests { })) .unwrap(); - let result = fixture - .spawn_and_run(&diagram, serde_json::Value::Null) - .unwrap(); + let result: JsonMessage = fixture.spawn_and_run(&diagram, JsonMessage::Null).unwrap(); assert!(fixture.context.no_unhandled_errors()); assert_eq!(result, "foobar"); } #[test] - /// when `target_node` is not given and next is not a node - fn test_join_infer_type_fail() { + /// join should be able to infer its output type when connected to terminate + fn test_join_infer_from_terminate() { let mut fixture = DiagramTestFixture::new(); register_join_nodes(&mut fixture.registry); @@ -330,12 +327,18 @@ mod tests { })) .unwrap(); - let result = fixture - .spawn_and_run(&diagram, serde_json::Value::Null) - .unwrap_err(); - assert!(fixture.context.no_unhandled_errors()); - let err_code = &result.downcast_ref::().unwrap().code; - assert!(matches!(err_code, DiagramErrorCode::UnknownTarget,)); + let result: JsonMessage = fixture.spawn_and_run(&diagram, JsonMessage::Null).unwrap(); + let expectation = serde_json::Value::Object(serde_json::Map::from_iter([ + ( + "bar".to_string(), + serde_json::Value::String("bar".to_string()), + ), + ( + "foo".to_string(), + serde_json::Value::String("foo".to_string()), + ), + ])); + assert_eq!(result, expectation); } #[test] @@ -382,9 +385,7 @@ mod tests { })) .unwrap(); - let result = fixture - .spawn_and_run(&diagram, serde_json::Value::Null) - .unwrap(); + let result: JsonMessage = fixture.spawn_and_run(&diagram, JsonMessage::Null).unwrap(); assert!(fixture.context.no_unhandled_errors()); assert_eq!(result, "foobar"); } @@ -465,9 +466,7 @@ mod tests { })) .unwrap(); - let result = fixture - .spawn_and_run(&diagram, serde_json::Value::Null) - .unwrap(); + let result: JsonMessage = fixture.spawn_and_run(&diagram, JsonMessage::Null).unwrap(); assert!(fixture.context.no_unhandled_errors()); assert_eq!(result["foo"], "foo"); assert_eq!(result["bar"], "bar"); @@ -522,9 +521,7 @@ mod tests { })) .unwrap(); - let result = fixture - .spawn_and_run(&diagram, serde_json::Value::Null) - .unwrap(); + let result: JsonMessage = fixture.spawn_and_run(&diagram, JsonMessage::Null).unwrap(); assert!(fixture.context.no_unhandled_errors()); let object = result.as_object().unwrap(); assert_eq!(object["foobar_1"].as_object().unwrap()["foo"], "foo_1"); diff --git a/src/diagram/node.rs b/src/diagram/node.rs new file mode 100644 index 00000000..95b993e9 --- /dev/null +++ b/src/diagram/node.rs @@ -0,0 +1,43 @@ +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; + +use crate::Builder; + +use super::{ + workflow_builder::dyn_connect, BuilderId, DiagramElementRegistry, DiagramErrorCode, + NextOperation, WorkflowBuilder, +}; + +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +#[serde(rename_all = "snake_case")] +pub struct NodeOp { + pub(super) builder: BuilderId, + #[serde(default)] + pub(super) config: serde_json::Value, + pub(super) next: NextOperation, +} + +impl NodeOp { + pub(super) fn add_vertices<'a>( + &'a self, + builder: &mut Builder, + wf_builder: &mut WorkflowBuilder<'a>, + op_id: String, + registry: &DiagramElementRegistry, + ) -> Result<(), DiagramErrorCode> { + let reg = registry.get_node_registration(&self.builder)?; + let node = reg.create_node(builder, self.config.clone())?; + + let mut edge_builder = + wf_builder.add_vertex(op_id.clone(), move |vertex, builder, registry, _| { + for edge in &vertex.in_edges { + let output = edge.take_output(); + dyn_connect(builder, output, node.input.into(), ®istry.messages)?; + } + Ok(true) + }); + edge_builder.add_output_edge(self.next.clone(), Some(node.output)); + + Ok(()) + } +} diff --git a/src/diagram/node_schema.rs b/src/diagram/node_schema.rs index c45d6598..09214b45 100644 --- a/src/diagram/node_schema.rs +++ b/src/diagram/node_schema.rs @@ -22,7 +22,7 @@ use crate::Builder; use super::{ BuildDiagramOperation, BuildStatus, BuilderId, DiagramContext, DiagramErrorCode, NextOperation, - OperationId, + OperationName, }; #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] @@ -37,7 +37,7 @@ pub struct NodeSchema { impl BuildDiagramOperation for NodeSchema { fn build_diagram_operation( &self, - id: &OperationId, + id: &OperationName, builder: &mut Builder, ctx: &mut DiagramContext, ) -> Result { @@ -45,7 +45,7 @@ impl BuildDiagramOperation for NodeSchema { let node = node_registration.create_node(builder, self.config.clone())?; ctx.set_input_for_target(id, node.input.into())?; - ctx.add_output_into_target(self.next.clone(), node.output); + ctx.add_output_into_target(&self.next, node.output); Ok(BuildStatus::Finished) } } diff --git a/src/diagram/registration.rs b/src/diagram/registration.rs index 7859d92f..b52e1aa6 100644 --- a/src/diagram/registration.rs +++ b/src/diagram/registration.rs @@ -22,6 +22,7 @@ use std::{ collections::HashMap, fmt::Debug, marker::PhantomData, + sync::Arc, }; use crate::{ @@ -46,8 +47,8 @@ use super::{ buffer_schema::BufferAccessRequest, fork_clone_schema::PerformForkClone, fork_result_schema::RegisterForkResult, register_json, supported::*, type_info::TypeInfo, unzip_schema::PerformUnzip, BuilderId, DeserializeMessage, DiagramErrorCode, DynForkClone, - DynForkResult, DynSplit, DynType, JsonRegistration, RegisterJson, RegisterSplit, - SerializeMessage, SplitSchema, TransformError, + DynForkResult, DynSplit, DynType, JsonRegistration, RegisterJson, RegisterSplit, Section, + SectionMetadata, SectionMetadataProvider, SerializeMessage, SplitSchema, TransformError, }; /// A type erased [`crate::InputSlot`] @@ -97,6 +98,7 @@ impl From for DynInputSlot { } /// A type erased [`crate::Output`] +#[derive(Debug)] pub struct DynOutput { scope: Entity, target: Entity, @@ -160,16 +162,6 @@ impl DynOutput { } } -impl Debug for DynOutput { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("DynOutput") - .field("scope", &self.scope) - .field("target", &self.target) - .field("type_info", &self.message_info) - .finish() - } -} - impl From> for DynOutput where T: Send + Sync + 'static + Any, @@ -182,6 +174,7 @@ where } } } + /// A type erased [`bevy_impulse::Node`] pub struct DynNode { pub input: DynInputSlot, @@ -217,8 +210,9 @@ where #[derive(Serialize)] pub struct NodeRegistration { + #[serde(rename = "$key$")] pub(super) id: BuilderId, - pub(super) name: String, + pub(super) name: Arc, pub(super) request: TypeInfo, pub(super) response: TypeInfo, pub(super) config_schema: Schema, @@ -402,7 +396,10 @@ impl<'a, Message> MessageRegistrationBuilder<'a, Message> where Message: Send + Sync + 'static + Any, { - fn new(registry: &'a mut MessageRegistry) -> Self { + pub fn new(registry: &'a mut MessageRegistry) -> Self { + // Any message type can be joined into a Vec + registry.register_join::>(); + Self { data: registry, _ignore: Default::default(), @@ -707,9 +704,68 @@ pub trait IntoNodeRegistration { ) -> NodeRegistration; } +type CreateSectionFn = dyn FnMut(&mut Builder, serde_json::Value) -> Box; + +#[derive(Serialize)] +pub struct SectionRegistration { + pub(super) name: BuilderId, + pub(super) metadata: SectionMetadata, + pub(super) config_schema: Schema, + + #[serde(skip)] + create_section_impl: RefCell>, +} + +impl SectionRegistration { + pub(super) fn create_section( + &self, + builder: &mut Builder, + config: serde_json::Value, + ) -> Result, DiagramErrorCode> { + let section = (self.create_section_impl.borrow_mut())(builder, config); + Ok(section) + } +} + +pub trait IntoSectionRegistration +where + SectionT: Section, +{ + fn into_section_registration( + self, + name: BuilderId, + schema_generator: &mut SchemaGenerator, + ) -> SectionRegistration; +} + +impl IntoSectionRegistration for F +where + F: 'static + FnMut(&mut Builder, Config) -> SectionT, + SectionT: 'static + Section + SectionMetadataProvider, + Config: DeserializeOwned + JsonSchema, +{ + fn into_section_registration( + mut self, + name: BuilderId, + schema_generator: &mut SchemaGenerator, + ) -> SectionRegistration { + SectionRegistration { + name, + metadata: SectionT::metadata().clone(), + config_schema: schema_generator.subschema_for::<()>(), + create_section_impl: RefCell::new(Box::new(move |builder, config| { + let section = self(builder, serde_json::from_value::(config).unwrap()); + Box::new(section) + })), + } + } +} + #[derive(Serialize)] pub struct DiagramElementRegistry { pub(super) nodes: HashMap, + pub(super) sections: HashMap, + #[serde(flatten)] pub(super) messages: MessageRegistry, } @@ -871,7 +927,7 @@ impl MessageRegistry { /// Register a deserialize function if not already registered, returns true if the new /// function is registered. - pub fn register_deserialize(&mut self) + pub(super) fn register_deserialize(&mut self) where T: Send + Sync + 'static + Any, Deserializer: DeserializeMessage, @@ -916,7 +972,7 @@ impl MessageRegistry { /// Register a serialize function if not already registered, returns true if the new /// function is registered. - pub fn register_serialize(&mut self) + pub(super) fn register_serialize(&mut self) where T: Send + Sync + 'static + Any, Serializer: SerializeMessage, @@ -932,13 +988,13 @@ impl MessageRegistry { self.messages .get(message_info) .and_then(|reg| reg.operations.fork_clone_impl.as_ref()) - .ok_or(DiagramErrorCode::NotCloneable) + .ok_or(DiagramErrorCode::NotCloneable(*message_info)) .and_then(|f| f(builder)) } /// Register a fork_clone function if not already registered, returns true if the new /// function is registered. - pub fn register_fork_clone(&mut self) -> bool + pub(super) fn register_fork_clone(&mut self) -> bool where T: Send + Sync + 'static + Any, F: PerformForkClone, @@ -965,7 +1021,7 @@ impl MessageRegistry { .get(message_info) .and_then(|reg| reg.operations.unzip_impl.as_ref()) .map(|unzip| -> &'a (dyn PerformUnzip) { unzip.as_ref() }) - .ok_or(DiagramErrorCode::NotUnzippable) + .ok_or(DiagramErrorCode::NotUnzippable(*message_info)) } /// Register a unzip function if not already registered, returns true if the new @@ -1001,7 +1057,7 @@ impl MessageRegistry { self.messages .get(message_info) .and_then(|reg| reg.operations.fork_result_impl.as_ref()) - .ok_or(DiagramErrorCode::CannotForkResult) + .ok_or(DiagramErrorCode::CannotForkResult(*message_info)) .and_then(|f| f(builder)) } @@ -1023,7 +1079,7 @@ impl MessageRegistry { self.messages .get(message_info) .and_then(|reg| reg.operations.split_impl.as_ref()) - .ok_or(DiagramErrorCode::NotSplittable) + .ok_or(DiagramErrorCode::NotSplittable(*message_info)) .and_then(|f| f(split_op, builder)) } @@ -1072,7 +1128,7 @@ impl MessageRegistry { self.messages .get(joinable) .and_then(|reg| reg.operations.join_impl.as_ref()) - .ok_or_else(|| DiagramErrorCode::NotJoinable) + .ok_or_else(|| DiagramErrorCode::NotJoinable(*joinable)) .and_then(|f| f(buffers, builder)) } @@ -1222,6 +1278,7 @@ impl Default for DiagramElementRegistry { let mut registry = DiagramElementRegistry { nodes: Default::default(), + sections: Default::default(), messages: MessageRegistry::new(), }; @@ -1242,6 +1299,7 @@ impl DiagramElementRegistry { JsonBuffer::register_for::<()>(); DiagramElementRegistry { nodes: Default::default(), + sections: Default::default(), messages: MessageRegistry::new(), } } @@ -1302,6 +1360,29 @@ impl DiagramElementRegistry { self.opt_out().register_message() } + /// Register a section builder with the specified common operations. + /// + /// # Arguments + /// + /// * `id` - Id of the builder, this must be unique. + /// * `name` - Friendly name for the builder, this is only used for display purposes. + /// * `f` - The section builder to register. + pub fn register_section_builder( + &mut self, + options: SectionBuilderOptions, + section_builder: SectionBuilder, + ) where + SectionBuilder: IntoSectionRegistration, + SectionT: Section, + { + let reg = section_builder.into_section_registration( + options.name.unwrap_or_else(|| options.id.clone()), + &mut self.messages.schema_generator, + ); + self.sections.insert(options.id, reg); + SectionT::on_register(self); + } + /// In some cases the common operations of deserialization, serialization, /// and cloning cannot be performed for the input or output message of a node. /// When that happens you can still register your node builder by calling @@ -1362,7 +1443,19 @@ impl DiagramElementRegistry { let k = id.borrow(); self.nodes .get(k) - .ok_or(DiagramErrorCode::BuilderNotFound(k.to_string())) + .ok_or(DiagramErrorCode::BuilderNotFound(k.to_string().into())) + } + + pub fn get_section_registration( + &self, + id: &Q, + ) -> Result<&SectionRegistration, DiagramErrorCode> + where + Q: Borrow + ?Sized, + { + self.sections + .get(id.borrow()) + .ok_or_else(|| DiagramErrorCode::BuilderNotFound(id.borrow().to_string().into())) } pub fn get_message_registration(&self) -> Option<&MessageRegistration> @@ -1409,19 +1502,39 @@ impl DiagramElementRegistry { #[non_exhaustive] pub struct NodeBuilderOptions { pub id: BuilderId, - pub name: Option, + pub name: Option, } impl NodeBuilderOptions { pub fn new(id: impl ToString) -> Self { Self { - id: id.to_string(), + id: id.to_string().into(), + name: None, + } + } + + pub fn with_name(mut self, name: impl ToString) -> Self { + self.name = Some(name.to_string().into()); + self + } +} + +#[non_exhaustive] +pub struct SectionBuilderOptions { + pub id: BuilderId, + pub name: Option, +} + +impl SectionBuilderOptions { + pub fn new(id: impl ToString) -> Self { + Self { + id: id.to_string().into(), name: None, } } pub fn with_name(mut self, name: impl ToString) -> Self { - self.name = Some(name.to_string()); + self.name = Some(name.to_string().into()); self } } diff --git a/src/diagram/section_schema.rs b/src/diagram/section_schema.rs new file mode 100644 index 00000000..798bac5e --- /dev/null +++ b/src/diagram/section_schema.rs @@ -0,0 +1,1279 @@ +/* + * Copyright (C) 2025 Open Source Robotics Foundation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * +*/ + +use std::{collections::HashMap, sync::Arc}; + +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; + +use crate::{ + AnyBuffer, AnyMessageBox, Buffer, Builder, InputSlot, JsonBuffer, JsonMessage, Output, +}; + +use super::{ + type_info::TypeInfo, BuildDiagramOperation, BuildStatus, BuilderId, DiagramContext, + DiagramElementRegistry, DiagramErrorCode, DynInputSlot, DynOutput, NamespacedOperation, + NextOperation, OperationName, Operations, +}; + +pub use bevy_impulse_derive::Section; + +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +#[serde(rename_all = "snake_case")] +pub enum SectionProvider { + Builder(BuilderId), + Template(OperationName), +} + +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +#[serde(rename_all = "snake_case")] +pub struct SectionSchema { + #[serde(flatten)] + pub(super) provider: SectionProvider, + #[serde(default)] + pub(super) config: serde_json::Value, + #[serde(default)] + pub(super) connect: HashMap, NextOperation>, +} + +impl BuildDiagramOperation for SectionSchema { + fn build_diagram_operation( + &self, + id: &OperationName, + builder: &mut Builder, + ctx: &mut DiagramContext, + ) -> Result { + match &self.provider { + SectionProvider::Builder(section_builder) => { + let section = ctx + .registry + .get_section_registration(section_builder)? + .create_section(builder, self.config.clone())? + .into_slots(); + + for (op, input) in section.inputs { + ctx.set_input_for_target( + &NextOperation::Namespace(NamespacedOperation { + namespace: id.clone(), + operation: op.clone(), + }), + input, + )?; + } + + for expected_output in self.connect.keys() { + if !section.outputs.contains_key(expected_output) { + return Err(SectionError::UnknownOutput(Arc::clone(expected_output)).into()); + } + } + + for (target, output) in section.outputs { + if let Some(target) = self.connect.get(&target) { + ctx.add_output_into_target(target, output); + } + } + + for (op, buffer) in section.buffers { + ctx.set_buffer_for_operation( + &NextOperation::Namespace(NamespacedOperation { + namespace: id.clone(), + operation: op.clone(), + }), + buffer, + )?; + } + } + SectionProvider::Template(section_template) => { + let section = ctx.templates.get_template(section_template)?; + + for (child_id, op) in section.ops.iter() { + ctx.add_child_operation(id, child_id, op, §ion.ops); + } + + section + .inputs + .redirect(|op, next| ctx.redirect_to_child_input(id, op, next))?; + + section + .buffers + .redirect(|op, next| ctx.redirect_to_child_buffer(id, op, next))?; + + for expected_output in self.connect.keys() { + if !section.outputs.contains(expected_output) { + return Err(SectionError::UnknownOutput(Arc::clone(expected_output)).into()); + } + } + + for output in §ion.outputs { + if let Some(target) = self.connect.get(output) { + ctx.redirect_exposed_output_to_sibling(id, output, target)?; + } else { + ctx.redirect_exposed_output_to_sibling( + id, + output, + &NextOperation::dispose(), + )?; + } + } + } + } + + Ok(BuildStatus::Finished) + } +} + +#[derive(Serialize, Clone)] +pub struct SectionMetadata { + pub(super) inputs: HashMap, + pub(super) outputs: HashMap, + pub(super) buffers: HashMap, +} + +impl SectionMetadata { + pub fn new() -> Self { + Self { + inputs: HashMap::new(), + outputs: HashMap::new(), + buffers: HashMap::new(), + } + } +} + +pub trait SectionMetadataProvider { + fn metadata() -> &'static SectionMetadata; +} + +pub struct SectionSlots { + inputs: HashMap, + outputs: HashMap, + buffers: HashMap, +} + +impl SectionSlots { + pub fn new() -> Self { + Self { + inputs: HashMap::new(), + outputs: HashMap::new(), + buffers: HashMap::new(), + } + } +} + +pub trait Section { + fn into_slots(self: Box) -> SectionSlots; + + fn on_register(registry: &mut DiagramElementRegistry) + where + Self: Sized; +} + +pub trait SectionItem { + type MessageType; + + fn build_metadata(metadata: &mut SectionMetadata, key: &str); + + fn insert_into_slots(self, key: &str, slots: &mut SectionSlots); +} + +impl SectionItem for InputSlot +where + T: Send + Sync + 'static, +{ + type MessageType = T; + + fn build_metadata(metadata: &mut SectionMetadata, key: &str) { + metadata.inputs.insert( + key.into(), + SectionInput { + message_type: TypeInfo::of::(), + }, + ); + } + + fn insert_into_slots(self, key: &str, slots: &mut SectionSlots) { + slots.inputs.insert(key.into(), self.into()); + } +} + +impl SectionItem for Output +where + T: Send + Sync + 'static, +{ + type MessageType = T; + + fn build_metadata(metadata: &mut SectionMetadata, key: &str) { + metadata.outputs.insert( + key.into(), + SectionOutput { + message_type: TypeInfo::of::(), + }, + ); + } + + fn insert_into_slots(self, key: &str, slots: &mut SectionSlots) { + slots.outputs.insert(key.into(), self.into()); + } +} + +impl SectionItem for Buffer +where + T: Send + Sync + 'static, +{ + type MessageType = T; + + fn build_metadata(metadata: &mut SectionMetadata, key: &str) { + metadata.buffers.insert( + key.into(), + SectionBuffer { + item_type: Some(TypeInfo::of::()), + }, + ); + } + + fn insert_into_slots(self, key: &str, slots: &mut SectionSlots) { + slots.buffers.insert(key.into(), self.into()); + } +} + +impl SectionItem for AnyBuffer { + type MessageType = AnyMessageBox; + + fn build_metadata(metadata: &mut SectionMetadata, key: &str) { + metadata + .buffers + .insert(key.into(), SectionBuffer { item_type: None }); + } + + fn insert_into_slots(self, key: &str, slots: &mut SectionSlots) { + slots.buffers.insert(key.into(), self); + } +} + +impl SectionItem for JsonBuffer { + type MessageType = JsonMessage; + + fn build_metadata(metadata: &mut SectionMetadata, key: &str) { + metadata + .buffers + .insert(key.into(), SectionBuffer { item_type: None }); + } + + fn insert_into_slots(self, key: &str, slots: &mut SectionSlots) { + slots.buffers.insert(key.into(), self.into()); + } +} + +#[derive(Serialize, Clone)] +pub struct SectionInput { + pub(super) message_type: TypeInfo, +} + +#[derive(Serialize, Clone)] +pub struct SectionOutput { + pub(super) message_type: TypeInfo, +} + +#[derive(Serialize, Clone)] +pub struct SectionBuffer { + pub(super) item_type: Option, +} + +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +pub struct SectionTemplate { + /// These are the inputs that the section is exposing for its sibling + /// operations to send outputs into. + #[serde(default)] + pub inputs: InputRemapping, + /// These are the outputs that the section is exposing so you can connect + /// them into siblings of the section. + #[serde(default)] + pub outputs: Vec, + /// These are the buffers that the section is exposing for you to read, + /// write, join, or listen to. + #[serde(default)] + pub buffers: InputRemapping, + /// Operations that define the behavior of the section. + pub ops: Operations, +} + +/// This defines how sections remap their inner operations (inputs and buffers) +/// to expose them to operations that are siblings to the section. +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +#[serde(untagged, rename_all = "snake_case")] +pub enum InputRemapping { + /// Do a simple 1:1 forwarding of the names listed in the array + Forward(Vec), + /// Rename an operation inside the section to expose it externally. The key + /// of the map is what siblings of the section can connect to, and the value + /// of the entry is the identifier of the input inside the section that is + /// being exposed. + /// + /// This allows a section to expose inputs and buffers that are provided + /// by inner sections. + Remap(HashMap), +} + +impl Default for InputRemapping { + fn default() -> Self { + Self::Forward(Vec::new()) + } +} + +impl InputRemapping { + pub fn get_inner(&self, op: &OperationName) -> Option { + match self { + Self::Forward(forward) => { + if forward.contains(op) { + return Some(NextOperation::Name(Arc::clone(op))); + } + } + Self::Remap(remap) => { + if let Some(next) = remap.get(op) { + return Some(next.clone()); + } + } + } + + None + } + + pub fn redirect( + &self, + mut f: impl FnMut(&OperationName, &NextOperation) -> Result<(), DiagramErrorCode>, + ) -> Result<(), DiagramErrorCode> { + match self { + Self::Forward(operations) => { + for op in operations { + f(op, &NextOperation::Name(Arc::clone(op)))?; + } + } + Self::Remap(remap) => { + for (op, next) in remap { + f(op, next)?; + } + } + } + Ok(()) + } +} + +#[derive(thiserror::Error, Debug)] +pub enum SectionError { + #[error("operation has extra output [{0}] that is not in the section")] + UnknownOutput(OperationName), +} + +#[cfg(test)] +mod tests { + use bevy_ecs::system::In; + use serde_json::json; + + use crate::{ + diagram::testing::DiagramTestFixture, testing::TestingContext, BufferAccess, + BufferAccessMut, BufferKey, BufferSettings, Diagram, IntoBlockingCallback, JsonMessage, + Node, NodeBuilderOptions, RequestExt, RunCommandsOnWorldExt, SectionBuilderOptions, + }; + + use super::*; + + #[derive(Section)] + struct TestSection { + foo: InputSlot, + bar: Output, + baz: Buffer, + } + + #[test] + fn test_register_section() { + let mut registry = DiagramElementRegistry::new(); + registry.register_section_builder( + SectionBuilderOptions::new("test_section").with_name("TestSection"), + |builder: &mut Builder, _config: ()| { + let node = builder.create_map_block(|_: i64| 1_f64); + let buffer = builder.create_buffer(BufferSettings::default()); + TestSection { + foo: node.input, + bar: node.output, + baz: buffer, + } + }, + ); + + let reg = registry.get_section_registration("test_section").unwrap(); + assert_eq!(reg.name.as_ref(), "TestSection"); + let metadata = ®.metadata; + assert_eq!(metadata.inputs.len(), 1); + assert_eq!(metadata.inputs["foo"].message_type, TypeInfo::of::()); + assert_eq!(metadata.outputs.len(), 1); + assert_eq!(metadata.outputs["bar"].message_type, TypeInfo::of::()); + assert_eq!(metadata.buffers.len(), 1); + assert_eq!( + metadata.buffers["baz"].item_type, + Some(TypeInfo::of::()) + ); + } + + struct OpaqueMessage; + + /// A test compile that opaque messages can be used in sections. + #[derive(Section)] + struct TestSectionNoDeserialize { + #[message(no_deserialize, no_serialize, no_clone)] + msg: InputSlot, + } + + #[derive(Section)] + struct TestSectionUnzip { + input: InputSlot<()>, + #[message(unzip)] + output: Output<(i64, i64)>, + } + + #[test] + fn test_section_unzip() { + let mut registry = DiagramElementRegistry::new(); + registry.register_section_builder( + SectionBuilderOptions::new("test_section").with_name("TestSection"), + |builder: &mut Builder, _config: ()| { + let node = builder.create_map_block(|_: ()| (1, 2)); + TestSectionUnzip { + input: node.input, + output: node.output, + } + }, + ); + let reg = registry.get_message_registration::<(i64, i64)>().unwrap(); + assert!(reg.operations.unzip_impl.is_some()); + } + + #[derive(Section)] + struct TestSectionForkResult { + input: InputSlot<()>, + #[message(result)] + output: Output>, + } + + #[test] + fn test_section_fork_result() { + let mut registry = DiagramElementRegistry::new(); + registry.register_section_builder( + SectionBuilderOptions::new("test_section").with_name("TestSection"), + |builder: &mut Builder, _config: ()| { + let node = builder.create_map_block(|_: ()| Ok(1)); + TestSectionForkResult { + input: node.input, + output: node.output, + } + }, + ); + let reg = registry + .get_message_registration::>() + .unwrap(); + assert!(reg.operations.fork_result_impl.is_some()); + } + + #[derive(Section)] + struct TestSectionSplit { + input: InputSlot<()>, + #[message(split)] + output: Output>, + } + + #[test] + fn test_section_split() { + let mut registry = DiagramElementRegistry::new(); + registry.register_section_builder( + SectionBuilderOptions::new("test_section").with_name("TestSection"), + |builder: &mut Builder, _config: ()| { + let node = builder.create_map_block(|_: ()| vec![]); + TestSectionSplit { + input: node.input, + output: node.output, + } + }, + ); + let reg = registry.get_message_registration::>().unwrap(); + assert!(reg.operations.split_impl.is_some()); + } + + #[derive(Section)] + struct TestSectionJoin { + #[message(join)] + input: InputSlot>, + output: Output<()>, + } + + #[test] + fn test_section_join() { + let mut registry = DiagramElementRegistry::new(); + registry.register_section_builder( + SectionBuilderOptions::new("test_section").with_name("TestSection"), + |builder: &mut Builder, _config: ()| { + let node = builder.create_map_block(|_: Vec| {}); + TestSectionJoin { + input: node.input, + output: node.output, + } + }, + ); + let reg = registry.get_message_registration::>().unwrap(); + assert!(reg.operations.join_impl.is_some()); + } + + #[derive(Section)] + struct TestSectionBufferAccess { + #[message(buffer_access, no_deserialize, no_serialize)] + input: InputSlot<(i64, Vec>)>, + output: Output<()>, + } + + #[test] + fn test_section_buffer_access() { + let mut registry = DiagramElementRegistry::new(); + registry.register_section_builder( + SectionBuilderOptions::new("test_section").with_name("TestSection"), + |builder: &mut Builder, _config: ()| { + let node = builder.create_map_block(|_: (i64, Vec>)| {}); + TestSectionBufferAccess { + input: node.input, + output: node.output, + } + }, + ); + let reg = registry + .get_message_registration::<(i64, Vec>)>() + .unwrap(); + assert!(reg.operations.buffer_access_impl.is_some()); + } + + #[derive(Section)] + struct TestSectionListen { + #[message(listen, no_deserialize, no_serialize)] + input: InputSlot>>, + output: Output<()>, + } + + #[test] + fn test_section_listen() { + let mut registry = DiagramElementRegistry::new(); + registry.register_section_builder( + SectionBuilderOptions::new("test_section").with_name("TestSection"), + |builder: &mut Builder, _config: ()| { + let node = builder.create_map_block(|_: Vec>| {}); + TestSectionListen { + input: node.input, + output: node.output, + } + }, + ); + let reg = registry + .get_message_registration::>>() + .unwrap(); + assert!(reg.operations.listen_impl.is_some()); + } + + #[derive(Section)] + struct TestAddOne { + test_input: InputSlot, + test_output: Output, + } + + fn register_add_one(registry: &mut DiagramElementRegistry) { + registry.register_section_builder( + SectionBuilderOptions::new("add_one"), + |builder: &mut Builder, _config: ()| { + let node = builder.create_map_block(|i: i64| i + 1); + TestAddOne { + test_input: node.input, + test_output: node.output, + } + }, + ); + } + + #[test] + fn test_section_workflow() { + let mut registry = DiagramElementRegistry::new(); + register_add_one(&mut registry); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "start": { "add_one": "test_input" }, + "ops": { + "add_one": { + "type": "section", + "builder": "add_one", + "connect": { + "test_output": { "builtin": "terminate" }, + }, + }, + }, + })) + .unwrap(); + + let mut context = TestingContext::minimal_plugins(); + let mut promise = context.app.world.command(|cmds| { + let workflow = diagram + .spawn_io_workflow::(cmds, ®istry) + .unwrap(); + cmds.request(serde_json::to_value(1).unwrap(), workflow) + .take_response() + }); + context.run_while_pending(&mut promise); + let result = promise.take().available().unwrap(); + assert_eq!(result, 2); + } + + #[test] + fn test_section_workflow_extra_output() { + let mut registry = DiagramElementRegistry::new(); + register_add_one(&mut registry); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "start": { "add_one": "test_input" }, + "ops": { + "add_one": { + "type": "section", + "builder": "add_one", + "connect": { + "extra": { "builtin": "dispose" }, + "test_output": { "builtin": "terminate" }, + }, + }, + }, + })) + .unwrap(); + + let mut context = TestingContext::minimal_plugins(); + let err = context + .app + .world + .command(|cmds| diagram.spawn_io_workflow::(cmds, ®istry)) + .unwrap_err(); + let section_err = match err.code { + DiagramErrorCode::SectionError(section_err) => section_err, + _ => panic!("expected SectionError"), + }; + assert!(matches!(section_err, SectionError::UnknownOutput(_))); + } + + #[test] + fn test_section_workflow_extra_input() { + let mut fixture = DiagramTestFixture::new(); + register_add_one(&mut fixture.registry); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "start": "multiply3", + "ops": { + "multiply3": { + "type": "node", + "builder": "multiply3", + "next": "fork_clone", + }, + "fork_clone": { + "type": "fork_clone", + "next": [ + { "add_one": "test_input" }, + { "add_one": "extra_input" }, + ] + }, + "add_one": { + "type": "section", + "builder": "add_one", + "connect": { + "test_output": { "builtin": "terminate" }, + }, + }, + }, + })) + .unwrap(); + + let mut context = TestingContext::minimal_plugins(); + let err = context + .app + .world + .command(|cmds| { + diagram.spawn_io_workflow::(cmds, &fixture.registry) + }) + .unwrap_err(); + assert!(matches!(err.code, DiagramErrorCode::UnknownOperation(_))); + } + + #[derive(Section)] + struct TestSectionAddToBuffer { + test_input: InputSlot, + test_buffer: Buffer, + } + + #[test] + fn test_section_workflow_buffer() { + let mut fixture = DiagramTestFixture::new(); + + fixture.registry.register_section_builder( + SectionBuilderOptions::new("add_one_to_buffer"), + |builder: &mut Builder, _config: ()| { + let node = builder.create_map_block(|i: i64| i + 1); + let buffer = builder.create_buffer(BufferSettings::default()); + builder.connect(node.output, buffer.input_slot()); + TestSectionAddToBuffer { + test_input: node.input, + test_buffer: buffer, + } + }, + ); + fixture + .registry + .opt_out() + .no_serializing() + .no_deserializing() + .register_node_builder( + NodeBuilderOptions::new("buffer_length"), + |builder: &mut Builder, _config: ()| -> Node>, usize, ()> { + { + builder.create_node( + (|In(request): In>>, access: BufferAccess| { + access.get(&request[0]).unwrap().len() + }) + .into_blocking_callback(), + ) + } + }, + ) + .with_listen() + .with_common_response(); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "start": { "add_one_to_buffer": "test_input" }, + "ops": { + "add_one_to_buffer": { + "type": "section", + "builder": "add_one_to_buffer", + "connect": {}, + }, + "listen": { + "type": "listen", + "buffers": [{"add_one_to_buffer" : "test_buffer"}], + "next": "buffer_length", + }, + "buffer_length": { + "type": "node", + "builder": "buffer_length", + "next": { "builtin": "terminate" }, + }, + }, + })) + .unwrap(); + + let count: usize = fixture.spawn_and_run(&diagram, 1_i64).unwrap(); + assert_eq!(count, 1); + + fixture + .registry + .opt_out() + .no_serializing() + .no_deserializing() + .register_node_builder(NodeBuilderOptions::new("pull"), |builder, _: ()| { + builder.create_node(pull_from_buffer.into_blocking_callback()) + }) + .with_listen(); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "templates": { + "multiply_10_to_buffer": { + "inputs": ["input"], + "buffers": ["buffer"], + "ops": { + "input": { + "type": "node", + "builder": "multiply_by", + "config": 10, + "next": "buffer", + }, + "buffer": { "type": "buffer" }, + } + } + }, + "start": { "multiply": "input" }, + "ops": { + "multiply": { + "type": "section", + "template": "multiply_10_to_buffer", + "connect": { + + } + }, + "listen": { + "type": "listen", + "buffers": { "multiply": "buffer" }, + "next": "pull", + }, + "pull": { + "type": "node", + "builder": "pull", + "next": { "builtin": "terminate" }, + } + } + })) + .unwrap(); + + let result: i64 = fixture.spawn_and_run(&diagram, 2_i64).unwrap(); + assert_eq!(result, 20); + } + + fn pull_from_buffer(In(key): In>, mut access: BufferAccessMut) -> i64 { + access.get_mut(&key).unwrap().pull().unwrap() + } + + #[test] + fn test_section_template() { + let mut fixture = DiagramTestFixture::new(); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "templates": { + "test_template": { + "inputs": ["input1", "input2"], + "outputs": ["output1", "output2"], + "ops": { + "input1": { + "type": "node", + "builder": "multiply3", + "next": "output1", + }, + "input2": { + "type": "node", + "builder": "multiply3", + "next": "output2", + }, + }, + }, + }, + "start": "fork_clone", + "ops": { + "fork_clone": { + "type": "fork_clone", + "next": [ + { "test_tmpl": "input1" }, + { "test_tmpl": "input2" }, + ], + }, + "test_tmpl": { + "type": "section", + "template": "test_template", + "connect": { + "output1": { "builtin": "terminate" }, + "output2": { "builtin": "terminate" }, + }, + }, + }, + })) + .unwrap(); + + let result: JsonMessage = fixture + .spawn_and_run(&diagram, JsonMessage::from(4)) + .unwrap(); + assert_eq!(result, 12); + } + + #[test] + fn test_template_input_remap() { + let mut fixture = DiagramTestFixture::new(); + + // Testing that we can remap inputs to both internal operations within + // the template and also to builtin operations. + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "templates": { + "test_template": { + "inputs": { + "multiply": "multiply", + "terminate": { "builtin": "terminate" }, + }, + "outputs": ["output"], + "ops": { + "multiply": { + "type": "node", + "builder": "multiply3", + "next": "output", + } + } + } + }, + "start": { "section": "multiply" }, + "ops": { + "section": { + "type": "section", + "template": "test_template", + "connect": { + "output": { "section": "terminate" }, + }, + }, + }, + })) + .unwrap(); + + let result: i64 = fixture.spawn_and_run(&diagram, 4_i64).unwrap(); + + assert_eq!(result, 12); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "templates": { + "multiply_then_add": { + "inputs": { + "input": "multiply", + }, + "outputs": ["output"], + "ops": { + "multiply": { + "type": "node", + "builder": "multiply_by", + "config": 10, + "next": { "add": "input" }, + }, + "add": { + "type": "section", + "template": "adding_template", + "connect": { + "added": "output", + } + } + } + }, + "adding_template": { + "inputs": ["input"], + "outputs": ["added"], + "ops": { + "input": { + "type": "node", + "builder": "add_to", + "config": 1, + "next": "added", + } + } + } + }, + "start": { "multiply": "input" }, + "ops": { + "multiply": { + "type": "section", + "template": "multiply_then_add", + "connect": { + "output": { "builtin": "terminate" }, + } + } + } + })) + .unwrap(); + + let result: i64 = fixture.spawn_and_run(&diagram, 5_i64).unwrap(); + + assert_eq!(result, 51); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "templates": { + "calculate": { + "inputs": ["add", "multiply"], + "outputs": ["added", "multiplied"], + "ops": { + "multiply": { + "type": "node", + "builder": "multiply_by", + "config": 10, + "next": "multiplied", + }, + "add": { + "type": "node", + "builder": "add_to", + "config": 10, + "next": "added", + } + } + } + }, + "start": "start", + "ops": { + "start": { + "type": "fork_clone", + "next": [ + { "calc": "add" }, + { "calc": "multiply" }, + ], + }, + "calc": { + "type": "section", + "template": "calculate", + "connect": { + "added": "added", + "multiplied": "multiplied", + }, + }, + "added": { "type": "buffer" }, + "multiplied": { "type": "buffer" }, + "join": { + "type": "join", + "buffers": ["added", "multiplied"], + "next": { "builtin": "terminate" }, + }, + }, + })) + .unwrap(); + + let result: Vec = fixture.spawn_and_run(&diagram, 2_i64).unwrap(); + + assert_eq!(result.len(), 2); + assert_eq!(result[0], 12); + assert_eq!(result[1], 20); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "templates": { + "redirect_buffers": { + "buffers": { + "added": { "inner": "added" }, + "multiplied": { "inner": "multiplied" }, + }, + "ops": { + "inner": { + "type": "section", + "template": "calculate", + } + } + }, + "calculate": { + "inputs": ["add", "multiply"], + "buffers": ["added", "multiplied"], + "ops": { + "multiply": { + "type": "node", + "builder": "multiply_by", + "config": 10, + "next": "multiplied", + }, + "multiplied": { "type": "buffer" }, + "add": { + "type": "node", + "builder": "add_to", + "config": 10, + "next": "added", + }, + "added": { "type": "buffer" }, + }, + } + }, + "start": "start", + "ops": { + "start": { + "type": "fork_clone", + "next": [ + { "calc": "add" }, + { "calc": "multiply" }, + ], + }, + "calc": { + "type": "section", + "template": "calculate", + }, + "join": { + "type": "join", + "buffers": [ + { "calc": "added" }, + { "calc": "multiplied" }, + ], + "next": { "builtin": "terminate" }, + }, + }, + })) + .unwrap(); + + let result: Vec = fixture.spawn_and_run(&diagram, 3_i64).unwrap(); + + assert_eq!(result.len(), 2); + assert_eq!(result[0], 13); + assert_eq!(result[1], 30); + } + + #[test] + fn test_detect_circular_redirect() { + let mut fixture = DiagramTestFixture::new(); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "templates": { + "test_template": { + "inputs": { + "input": "output" + }, + "outputs": ["output"], + "ops": { + } + } + }, + "start": "fork", + "ops": { + "fork": { + "type": "fork_clone", + "next": [ + { "section": "input" }, + { "builtin": "terminate" }, + ] + }, + "section": { + "type": "section", + "template": "test_template", + "connect": { + "output": { "section": "input" }, + }, + } + } + })) + .unwrap(); + + let result = fixture.spawn_json_io_workflow(&diagram).unwrap_err(); + + assert!(matches!(result.code, DiagramErrorCode::CircularRedirect(_))); + } + + #[test] + fn test_circular_template_dependency() { + let mut fixture = DiagramTestFixture::new(); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "templates": { + "recursive_template": { + "inputs": ["input"], + "outputs": ["output"], + "ops": { + "input": { + "type": "fork_clone", + "next": [ + { "recursive_self": "input" }, + "output", + ] + }, + "recursive_self": { + "type": "section", + "template": "recursive_template", + "connect": { + "output": "output", + } + } + } + } + }, + "start": { "start": "input" }, + "ops": { + "start": { + "type": "section", + "template": "recursive_template", + "connect": { + "output": { "builtin": "terminate" }, + } + } + } + })) + .unwrap(); + + let result = fixture.spawn_json_io_workflow(&diagram).unwrap_err(); + + assert!(matches!( + result.code, + DiagramErrorCode::CircularTemplateDependency(_), + )); + + let diagram = Diagram::from_json(json!({ + "version": "0.1.0", + "templates": { + "parent_template": { + "inputs": ["input"], + "outputs": ["output"], + "ops": { + "input": { + "type": "fork_clone", + "next": [ + { "child": "input" }, + "output", + ] + }, + "child": { + "type": "section", + "template": "child_template", + "connect": { + "output": "output", + } + } + }, + }, + "child_template": { + "inputs": ["input"], + "outputs": ["output"], + "ops": { + "input": { + "type": "node", + "builder": "multiply3", + "next": "grandchild", + }, + "grandchild": { + "type": "section", + "template": "grandchild_template", + "connect": { + "output": "output", + } + } + } + }, + "grandchild_template": { + "inputs": ["input"], + "outputs": ["output"], + "ops": { + "input": { + "type": "section", + "template": "parent_template", + "connect": { + "output": "output", + }, + }, + }, + }, + }, + "start": { "start": "input" }, + "ops": { + "start": { + "type": "section", + "template": "parent_template", + "connect": { + "output": { "builtin": "terminate" }, + }, + }, + }, + })) + .unwrap(); + + let result = fixture.spawn_json_io_workflow(&diagram).unwrap_err(); + + assert!(matches!( + result.code, + DiagramErrorCode::CircularTemplateDependency(_), + )); + } +} diff --git a/src/diagram/serialization.rs b/src/diagram/serialization.rs index 746d634f..36ae6961 100644 --- a/src/diagram/serialization.rs +++ b/src/diagram/serialization.rs @@ -15,7 +15,10 @@ * */ -use std::collections::{hash_map::Entry, HashMap}; +use std::{ + collections::{hash_map::Entry, HashMap}, + sync::Arc, +}; use schemars::{gen::SchemaGenerator, JsonSchema}; use serde::{de::DeserializeOwned, Serialize}; @@ -193,7 +196,7 @@ where pub struct ImplicitSerialization { incoming_types: HashMap, - serialized_input: DynInputSlot, + serialized_input: Arc, } impl ImplicitSerialization { @@ -206,7 +209,7 @@ impl ImplicitSerialization { } Ok(Self { - serialized_input, + serialized_input: Arc::new(serialized_input), incoming_types: Default::default(), }) } @@ -264,10 +267,14 @@ impl ImplicitSerialization { self.try_implicit_serialize(incoming, builder, ctx)? .map_err(|incoming| DiagramErrorCode::NotSerializable(*incoming.message_info())) } + + pub fn serialized_input_slot(&self) -> &Arc { + &self.serialized_input + } } pub struct ImplicitDeserialization { - deserialized_input: DynInputSlot, + deserialized_input: Arc, // The serialized input will only be created if a JsonMessage output // attempts to connect to this operation. Otherwise there is no need to // create it. @@ -286,7 +293,7 @@ impl ImplicitDeserialization { .is_some() { return Ok(Some(Self { - deserialized_input, + deserialized_input: Arc::new(deserialized_input), serialized_input: None, })); } @@ -335,6 +342,10 @@ impl ImplicitDeserialization { target_type: *self.deserialized_input.message_info(), }) } + + pub fn deserialized_input_slot(&self) -> &Arc { + &self.deserialized_input + } } pub struct ImplicitStringify { diff --git a/src/diagram/split_schema.rs b/src/diagram/split_schema.rs index a841051d..85555f64 100644 --- a/src/diagram/split_schema.rs +++ b/src/diagram/split_schema.rs @@ -29,7 +29,7 @@ use crate::{ use super::{ supported::*, type_info::TypeInfo, BuildDiagramOperation, BuildStatus, DiagramContext, DiagramErrorCode, DynInputSlot, DynOutput, MessageRegistration, MessageRegistry, NextOperation, - OperationId, PerformForkClone, SerializeMessage, + OperationName, PerformForkClone, SerializeMessage, }; #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] @@ -47,20 +47,20 @@ pub struct SplitSchema { impl BuildDiagramOperation for SplitSchema { fn build_diagram_operation( &self, - id: &OperationId, + id: &OperationName, builder: &mut Builder, ctx: &mut DiagramContext, ) -> Result { - let Some(sample_input) = ctx.infer_input_type_into_target(id) else { + let Some(sample_input) = ctx.infer_input_type_into_target(id)? else { // There are no outputs ready for this target, so we can't do // anything yet. The builder should try again later. return Ok(BuildStatus::defer("waiting for an input")); }; - let split = ctx.registry.messages.split(sample_input, self, builder)?; + let split = ctx.registry.messages.split(&sample_input, self, builder)?; ctx.set_input_for_target(id, split.input)?; for (target, output) in split.outputs { - ctx.add_output_into_target(target, output); + ctx.add_output_into_target(&target, output); } Ok(BuildStatus::Finished) } @@ -392,8 +392,8 @@ mod tests { })) .unwrap(); - let result = fixture - .spawn_and_run(&diagram, serde_json::Value::from(4)) + let result: JsonMessage = fixture + .spawn_and_run(&diagram, JsonMessage::from(4)) .unwrap(); assert!(fixture.context.no_unhandled_errors()); assert_eq!(result[1], 1); @@ -432,8 +432,8 @@ mod tests { })) .unwrap(); - let result = fixture - .spawn_and_run(&diagram, serde_json::Value::from(4)) + let result: JsonMessage = fixture + .spawn_and_run(&diagram, JsonMessage::from(4)) .unwrap(); assert!(fixture.context.no_unhandled_errors()); assert_eq!(result[1], 2); @@ -476,8 +476,8 @@ mod tests { })) .unwrap(); - let result = fixture - .spawn_and_run(&diagram, serde_json::Value::from(4)) + let result: JsonMessage = fixture + .spawn_and_run(&diagram, JsonMessage::from(4)) .unwrap(); assert!(fixture.context.no_unhandled_errors()); assert_eq!(result[1], 2); @@ -517,8 +517,8 @@ mod tests { })) .unwrap(); - let result = fixture - .spawn_and_run(&diagram, serde_json::Value::from(4)) + let result: JsonMessage = fixture + .spawn_and_run(&diagram, JsonMessage::from(4)) .unwrap(); assert!(fixture.context.no_unhandled_errors()); // "a" is "eaten" up by the keyed path, so we should be the result of "b". @@ -559,8 +559,8 @@ mod tests { })) .unwrap(); - let result = fixture - .spawn_and_run(&diagram, serde_json::Value::from(4)) + let result: JsonMessage = fixture + .spawn_and_run(&diagram, JsonMessage::from(4)) .unwrap(); assert!(fixture.context.no_unhandled_errors()); assert_eq!(result[1], 2); @@ -596,7 +596,7 @@ mod tests { })) .unwrap(); - let result = fixture + let result: JsonMessage = fixture .spawn_and_run( &diagram, serde_json::to_value(HashMap::from([("test".to_string(), 1)])).unwrap(), diff --git a/src/diagram/testing.rs b/src/diagram/testing.rs index 975f4f4b..84f4b904 100644 --- a/src/diagram/testing.rs +++ b/src/diagram/testing.rs @@ -22,11 +22,22 @@ impl DiagramTestFixture { } } - /// Equivalent to `self.spawn_workflow::(diagram)` pub(super) fn spawn_json_io_workflow( &mut self, diagram: &Diagram, - ) -> Result, DiagramError> { + ) -> Result, DiagramError> { + self.spawn_io_workflow::(diagram) + } + + /// Equivalent to `self.spawn_workflow::(diagram)` + pub(super) fn spawn_io_workflow( + &mut self, + diagram: &Diagram, + ) -> Result, DiagramError> + where + Request: 'static + Send + Sync, + Response: 'static + Send + Sync, + { self.context .app .world @@ -35,12 +46,16 @@ impl DiagramTestFixture { /// Spawns a workflow from a diagram then run the workflow until completion. /// Returns the result of the workflow. - pub(super) fn spawn_and_run( + pub(super) fn spawn_and_run( &mut self, diagram: &Diagram, - request: serde_json::Value, - ) -> Result> { - let workflow = self.spawn_json_io_workflow(diagram)?; + request: Request, + ) -> Result> + where + Request: 'static + Send + Sync, + Response: 'static + Send + Sync, + { + let workflow = self.spawn_io_workflow(diagram)?; let mut promise = self .context .command(|cmds| cmds.request(request, workflow).take_response()); @@ -58,14 +73,15 @@ impl DiagramTestFixture { } #[derive(Serialize, Deserialize, JsonSchema)] +#[serde(transparent)] struct Uncloneable(T); fn multiply3(i: i64) -> i64 { i * 3 } -fn multiply3_uncloneable(i: i64) -> Uncloneable { - Uncloneable(i * 3) +fn multiply3_uncloneable(i: Uncloneable) -> Uncloneable { + Uncloneable(i.0 * 3) } fn multiply3_5(x: i64) -> (i64, i64) { @@ -107,6 +123,10 @@ fn new_registry_with_basic_nodes() -> DiagramElementRegistry { |builder: &mut Builder, config: i64| builder.create_map_block(move |a: i64| a * config), ); + registry.register_node_builder(NodeBuilderOptions::new("add_to"), |builder, config: i64| { + builder.create_map_block(move |a: i64| a + config) + }); + registry .opt_out() .no_deserializing() diff --git a/src/diagram/transform_schema.rs b/src/diagram/transform_schema.rs index 32669d3b..ff801cb2 100644 --- a/src/diagram/transform_schema.rs +++ b/src/diagram/transform_schema.rs @@ -26,7 +26,7 @@ use crate::{Builder, JsonMessage}; use super::{ BuildDiagramOperation, BuildStatus, DiagramContext, DiagramErrorCode, NextOperation, - OperationId, + OperationName, }; #[derive(Error, Debug)] @@ -60,7 +60,7 @@ pub struct TransformSchema { impl BuildDiagramOperation for TransformSchema { fn build_diagram_operation( &self, - id: &OperationId, + id: &OperationName, builder: &mut Builder, ctx: &mut DiagramContext, ) -> Result { @@ -80,11 +80,15 @@ impl BuildDiagramOperation for TransformSchema { }, ); - let error_target = self.on_error.clone().unwrap_or( - // If no error target was explicitly given then treat this as an - // implicit error. - ctx.get_implicit_error_target(), - ); + let error_target = self + .on_error + .as_ref() + .map(|on_error| ctx.into_operation_ref(on_error)) + .unwrap_or( + // If no error target was explicitly given then treat this as an + // implicit error. + ctx.get_implicit_error_target(), + ); let (ok, _) = node.output.chain(builder).fork_result( |ok| ok.output(), @@ -94,7 +98,7 @@ impl BuildDiagramOperation for TransformSchema { ); ctx.set_input_for_target(id, node.input.into())?; - ctx.add_output_into_target(self.next.clone(), ok.into()); + ctx.add_output_into_target(&self.next, ok.into()); Ok(BuildStatus::Finished) } } @@ -104,7 +108,7 @@ mod tests { use serde_json::json; use test_log::test; - use crate::{diagram::testing::DiagramTestFixture, Diagram}; + use crate::{diagram::testing::DiagramTestFixture, Diagram, JsonMessage}; #[test] fn test_transform_node_response() { @@ -128,8 +132,8 @@ mod tests { })) .unwrap(); - let result = fixture - .spawn_and_run(&diagram, serde_json::Value::from(4)) + let result: JsonMessage = fixture + .spawn_and_run(&diagram, JsonMessage::from(4)) .unwrap(); assert!(fixture.context.no_unhandled_errors()); assert_eq!(result, 777); @@ -152,8 +156,8 @@ mod tests { })) .unwrap(); - let result = fixture - .spawn_and_run(&diagram, serde_json::Value::from(4)) + let result: JsonMessage = fixture + .spawn_and_run(&diagram, JsonMessage::from(4)) .unwrap(); assert!(fixture.context.no_unhandled_errors()); assert_eq!(result, 777); @@ -176,8 +180,8 @@ mod tests { })) .unwrap(); - let result = fixture - .spawn_and_run(&diagram, serde_json::Value::from(4)) + let result: JsonMessage = fixture + .spawn_and_run(&diagram, JsonMessage::from(4)) .unwrap(); assert!(fixture.context.no_unhandled_errors()); assert_eq!(result, 12); @@ -200,8 +204,8 @@ mod tests { })) .unwrap(); - let result = fixture - .spawn_and_run(&diagram, serde_json::Value::from(4)) + let result: JsonMessage = fixture + .spawn_and_run(&diagram, JsonMessage::from(4)) .unwrap(); assert!(fixture.context.no_unhandled_errors()); assert_eq!(result["request"], 4); @@ -230,7 +234,7 @@ mod tests { "age": 40, }); - let result = fixture.spawn_and_run(&diagram, request).unwrap(); + let result: JsonMessage = fixture.spawn_and_run(&diagram, request).unwrap(); assert!(fixture.context.no_unhandled_errors()); assert_eq!(result, 40); } diff --git a/src/diagram/unzip_schema.rs b/src/diagram/unzip_schema.rs index b84828c0..89fad930 100644 --- a/src/diagram/unzip_schema.rs +++ b/src/diagram/unzip_schema.rs @@ -23,7 +23,7 @@ use crate::Builder; use super::{ supported::*, BuildDiagramOperation, BuildStatus, DiagramContext, DiagramErrorCode, - DynInputSlot, DynOutput, MessageRegistry, NextOperation, OperationId, PerformForkClone, + DynInputSlot, DynOutput, MessageRegistry, NextOperation, OperationName, PerformForkClone, SerializeMessage, TypeInfo, }; @@ -36,17 +36,17 @@ pub struct UnzipSchema { impl BuildDiagramOperation for UnzipSchema { fn build_diagram_operation( &self, - id: &OperationId, + id: &OperationName, builder: &mut Builder, ctx: &mut DiagramContext, ) -> Result { - let Some(inferred_type) = ctx.infer_input_type_into_target(id) else { + let Some(inferred_type) = ctx.infer_input_type_into_target(id)? else { // There are no outputs ready for this target, so we can't do // anything yet. The builder should try again later. return Ok(BuildStatus::defer("waiting for an input")); }; - let unzip = ctx.registry.messages.unzip(inferred_type)?; + let unzip = ctx.registry.messages.unzip(&inferred_type)?; let actual_output = unzip.output_types(); if actual_output.len() != self.next.len() { return Err(DiagramErrorCode::UnzipMismatch { @@ -60,7 +60,7 @@ impl BuildDiagramOperation for UnzipSchema { ctx.set_input_for_target(id, unzip.input)?; for (target, output) in self.next.iter().zip(unzip.outputs) { - ctx.add_output_into_target(target.clone(), output); + ctx.add_output_into_target(target, output); } Ok(BuildStatus::Finished) } @@ -132,7 +132,7 @@ mod tests { use serde_json::json; use test_log::test; - use crate::{diagram::testing::DiagramTestFixture, Diagram, DiagramErrorCode}; + use crate::{diagram::testing::DiagramTestFixture, Diagram, DiagramErrorCode, JsonMessage}; #[test] fn test_unzip_not_unzippable() { @@ -157,7 +157,7 @@ mod tests { let err = fixture.spawn_json_io_workflow(&diagram).unwrap_err(); assert!( - matches!(err.code, DiagramErrorCode::NotUnzippable), + matches!(err.code, DiagramErrorCode::NotUnzippable(_)), "{}", err ); @@ -231,8 +231,8 @@ mod tests { })) .unwrap(); - let result = fixture - .spawn_and_run(&diagram, serde_json::Value::from(4)) + let result: JsonMessage = fixture + .spawn_and_run(&diagram, JsonMessage::from(4)) .unwrap(); assert!(fixture.context.no_unhandled_errors()); assert_eq!(result, 20); @@ -267,8 +267,8 @@ mod tests { })) .unwrap(); - let result = fixture - .spawn_and_run(&diagram, serde_json::Value::from(4)) + let result: JsonMessage = fixture + .spawn_and_run(&diagram, JsonMessage::from(4)) .unwrap(); assert!(fixture.context.no_unhandled_errors()); assert_eq!(result, 36); @@ -300,8 +300,8 @@ mod tests { })) .unwrap(); - let result = fixture - .spawn_and_run(&diagram, serde_json::Value::from(4)) + let result: JsonMessage = fixture + .spawn_and_run(&diagram, JsonMessage::from(4)) .unwrap(); assert!(fixture.context.no_unhandled_errors()); assert_eq!(result, 60); diff --git a/src/diagram/workflow_builder.rs b/src/diagram/workflow_builder.rs index 49c16b60..b14d7bf6 100644 --- a/src/diagram/workflow_builder.rs +++ b/src/diagram/workflow_builder.rs @@ -17,66 +17,241 @@ use std::{ borrow::Cow, - collections::{hash_map::Entry, HashMap}, + collections::{hash_map::Entry, HashMap, HashSet}, + sync::Arc, }; use crate::{AnyBuffer, BufferIdentifier, BufferMap, Builder, JsonMessage, Scope, StreamPack}; use super::{ - BufferInputs, BuiltinTarget, Diagram, DiagramElementRegistry, DiagramError, DiagramErrorCode, - DiagramOperation, DynInputSlot, DynOutput, ImplicitDeserialization, ImplicitSerialization, - ImplicitStringify, NextOperation, OperationId, TypeInfo, + BufferSelection, BuiltinTarget, Diagram, DiagramElementRegistry, DiagramError, + DiagramErrorCode, DiagramOperation, DynInputSlot, DynOutput, ImplicitDeserialization, + ImplicitSerialization, ImplicitStringify, NamespacedOperation, NextOperation, OperationName, + Operations, Templates, TypeInfo, }; -#[derive(Default)] -struct DiagramConstruction { - connect_into_target: HashMap>, - // We use a separate hashmap for OperationId vs BuiltinTarget so we can - // efficiently fetch with an &OperationId - outputs_to_operation_target: HashMap>, - outputs_to_builtin_target: HashMap>, - buffers: HashMap, -} - -impl DiagramConstruction { - fn is_finished(&self) -> bool { - for outputs in self.outputs_to_builtin_target.values() { - if !outputs.is_empty() { - return false; - } +use bevy_ecs::prelude::Entity; + +use smallvec::SmallVec; + +type NamespaceList = SmallVec<[OperationName; 4]>; + +/// This key is used so we can do a clone-free .get(&NextOperation) of a hashmap +/// that uses this as a key. +// +// TODO(@mxgrey): With this struct we could apply a lifetime to +// DiagramConstruction and then borrow all the names used in this struct instead +// of using Cow. +#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)] +pub enum OperationRef { + Named(NamedOperationRef), + Builtin { builtin: BuiltinTarget }, +} + +impl OperationRef { + fn in_namespaces(self, parent_namespaces: &[Arc]) -> Self { + match self { + Self::Builtin { builtin } => Self::Builtin { builtin }, + Self::Named(named) => Self::Named(named.in_namespaces(parent_namespaces)), } + } +} - for outputs in self.outputs_to_operation_target.values() { - if !outputs.is_empty() { - return false; - } +impl<'a> From<&'a NextOperation> for OperationRef { + fn from(value: &'a NextOperation) -> Self { + match value { + NextOperation::Name(name) => OperationRef::Named(name.into()), + NextOperation::Namespace(id) => OperationRef::Named(id.into()), + NextOperation::Builtin { builtin } => OperationRef::Builtin { + builtin: builtin.clone(), + }, } + } +} - return true; +impl<'a> From<&'a OperationName> for OperationRef { + fn from(value: &'a OperationName) -> Self { + OperationRef::Named(value.into()) } } -pub struct DiagramContext<'a> { - construction: &'a mut DiagramConstruction, - pub diagram: &'a Diagram, - pub registry: &'a DiagramElementRegistry, +impl From for OperationRef { + fn from(value: NamedOperationRef) -> Self { + OperationRef::Named(value) + } } -impl<'a> DiagramContext<'a> { - /// Get all the currently known outputs that are aimed at this target operation. +impl std::fmt::Display for OperationRef { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Named(named) => write!(f, "{named}"), + Self::Builtin { builtin } => write!(f, "builtin:{builtin}"), + } + } +} + +#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)] +pub struct NamedOperationRef { + pub namespaces: SmallVec<[Arc; 4]>, + /// If this references an exposed operation, such as an exposed input or + /// output of a session, this will contain the session name. Suppose we have + /// a section named `sec` and it has an exposed output named `out`. Then + /// there are two values of `NamedOperationRef` to consider: /// - /// During the [`BuildDiagramOperation`] phase this will eventually contain - /// all outputs targeting this operation that are explicitly listed in the - /// diagram. It will never contain outputs that implicitly target this - /// operation. + /// ``` + /// # use bevy_impulse::diagram::NamedOperationRef; + /// # use smallvec::smallvec; + /// + /// let op_id = NamedOperationRef { + /// namespaces: smallvec!["sec".into()], + /// exposed_namespace: None, + /// name: "out".into(), + /// }; + /// ``` + /// + /// is the internal reference to `sec:out` that will be used by other + /// operations inside of `sec`. On the other hand, operations that are + /// siblings of `sec` would instead connect to + /// + /// ``` + /// # use bevy_impulse::diagram::NamedOperationRef; + /// # use smallvec::smallvec; + /// + /// let op_id = NamedOperationRef { + /// namespaces: smallvec![], + /// exposed_namespace: Some("sec".into()), + /// name: "out".into(), + /// }; + /// ``` /// - /// During the [`ConnectIntoTarget`] phase this will not contain any outputs - /// except new outputs added during the current call to [`ConnectIntoTarget`], - /// so this function is generally not useful during that phase. - pub fn get_outputs_into_operation_target(&self, id: &OperationId) -> Option<&Vec> { - self.construction.outputs_to_operation_target.get(id) + /// We need to make this distinction because operations inside `sec` do not + /// know which of their siblings are exposed, and we don't want operations + /// outside of `sec` to accidentally connect to operations that are supposed + /// to be internal to `sec`. + pub exposed_namespace: Option>, + pub name: Arc, +} + +impl NamedOperationRef { + fn in_namespaces(mut self, parent_namespaces: &[Arc]) -> Self { + // Put the parent namespaces at the front and append the operation's + // existing namespaces at the back. + let new_namespaces = parent_namespaces + .iter() + .cloned() + .chain(self.namespaces.drain(..)) + .collect(); + + self.namespaces = new_namespaces; + self + } +} + +impl<'a> From<&'a OperationName> for NamedOperationRef { + fn from(name: &'a OperationName) -> Self { + NamedOperationRef { + namespaces: SmallVec::new(), + exposed_namespace: None, + name: Arc::clone(name), + } + } +} + +impl From for NamedOperationRef { + fn from(name: OperationName) -> Self { + NamedOperationRef { + namespaces: SmallVec::new(), + exposed_namespace: None, + name, + } + } +} + +impl<'a> From<&'a NamespacedOperation> for NamedOperationRef { + fn from(id: &'a NamespacedOperation) -> Self { + NamedOperationRef { + namespaces: SmallVec::new(), + // This is referring to an exposed operation, so the namespace + // mentioned in the operation goes into the exposed_namespace field + exposed_namespace: Some(Arc::clone(&id.namespace)), + name: Arc::clone(&id.operation), + } + } +} + +impl<'a> From<&'a NamespacedOperation> for OperationRef { + fn from(value: &'a NamespacedOperation) -> Self { + OperationRef::Named(value.into()) } +} + +impl std::fmt::Display for NamedOperationRef { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + for namespace in &self.namespaces { + write!(f, "{namespace}:")?; + } + + if let Some(exposed) = &self.exposed_namespace { + write!(f, "{exposed}:(exposed):")?; + } + + f.write_str(&self.name) + } +} + +#[derive(Default)] +struct DiagramConstruction<'a> { + /// Implementations that define how outputs can connect to their target operations + connect_into_target: HashMap>, + /// A map of what outputs are going into each target operation + outputs_into_target: HashMap>, + /// A map of what buffers exist in the diagram + buffers: HashMap, + /// Operations that were spawned by another operation. + generated_operations: Vec>, +} + +impl<'a> DiagramConstruction<'a> { + fn transfer_generated_operations( + &mut self, + deferred_operations: &mut Vec>, + made_progress: &mut bool, + ) { + deferred_operations.extend(self.generated_operations.drain(..).map(|unfinished| { + *made_progress = true; + unfinished + })); + } +} + +#[derive(Clone, Debug)] +enum BufferRef { + Ref(OperationRef), + Value(AnyBuffer), +} + +impl<'a> DiagramConstruction<'a> { + fn has_outputs(&self) -> bool { + for outputs in self.outputs_into_target.values() { + if !outputs.is_empty() { + return true; + } + } + + return false; + } +} + +pub struct DiagramContext<'a, 'c> { + construction: &'c mut DiagramConstruction<'a>, + pub registry: &'a DiagramElementRegistry, + pub operations: &'a Operations, + pub templates: &'a Templates, + pub on_implicit_error: &'a OperationRef, + namespaces: NamespaceList, +} +impl<'a, 'c> DiagramContext<'a, 'c> { /// Infer the [`TypeInfo`] for the input messages into the specified operation. /// /// If this returns [`None`] then not enough of the diagram has been built @@ -92,10 +267,30 @@ impl<'a> DiagramContext<'a> { /// during the [`ConnectIntoTarget`] phase then you should capture the /// [`TypeInfo`] that you receive from this function during the /// [`BuildDiagramOperation`] phase. - pub fn infer_input_type_into_target(&self, id: &OperationId) -> Option<&TypeInfo> { - self.get_outputs_into_operation_target(id) + pub fn infer_input_type_into_target( + &self, + id: impl Into, + ) -> Result, DiagramErrorCode> { + let id = self.into_operation_ref(id); + if let Some(connect) = self.construction.connect_into_target.get(&id) { + let mut visited = HashSet::new(); + visited.insert(id.clone()); + let preferred = connect + .infer_input_type(self, &mut visited)? + .map(|infer| infer.preferred_input_type()) + .flatten(); + if let Some(preferred) = preferred { + return Ok(Some(preferred)); + } + } + + let infer = self + .construction + .outputs_into_target + .get(&id) .and_then(|outputs| outputs.first()) - .map(|o| o.message_info()) + .map(|o| *o.message_info()); + Ok(infer) } /// Add an output to connect into a target. @@ -105,25 +300,15 @@ impl<'a> DiagramContext<'a> { /// /// # Arguments /// - /// * target - The operation that needs to receive the output - /// * output - The output channel that needs to be connected into the target. - pub fn add_output_into_target(&mut self, target: NextOperation, output: DynOutput) { - match target { - NextOperation::Target(id) => { - self.construction - .outputs_to_operation_target - .entry(id) - .or_default() - .push(output); - } - NextOperation::Builtin { builtin } => { - self.construction - .outputs_to_builtin_target - .entry(builtin) - .or_default() - .push(output); - } - } + /// * `target` - The operation that needs to receive the output + /// * `output` - The output channel that needs to be connected into the target. + pub fn add_output_into_target(&mut self, target: impl Into, output: DynOutput) { + let target = self.into_operation_ref(target); + self.construction + .outputs_into_target + .entry(target) + .or_default() + .push(output); } /// Set the input slot of an operation. This should not be called more than @@ -146,23 +331,12 @@ impl<'a> DiagramContext<'a> { /// all connection behaviors must already be set by then. pub fn set_input_for_target( &mut self, - operation: &OperationId, + operation: impl Into, input: DynInputSlot, ) -> Result<(), DiagramErrorCode> { - match self - .construction - .connect_into_target - .entry(NextOperation::Target(operation.clone())) - { - Entry::Occupied(_) => { - return Err(DiagramErrorCode::MultipleInputsCreated(operation.clone())); - } - Entry::Vacant(vacant) => { - vacant.insert(standard_input_connection(input, &self.registry)?); - } - } - - Ok(()) + let operation = self.into_operation_ref(operation); + let connect = standard_input_connection(input, &self.registry)?; + self.impl_connect_into_target(operation, connect) } /// Set the implementation for how outputs connect into this target. This is @@ -179,17 +353,27 @@ impl<'a> DiagramContext<'a> { /// phase because all connection behaviors should already be set by then. pub fn set_connect_into_target( &mut self, - operation: &OperationId, + operation: impl Into, connect: C, ) -> Result<(), DiagramErrorCode> { + let operation = self.into_operation_ref(operation); let connect = Box::new(connect); + self.impl_connect_into_target(operation, connect) + } + + /// Internal implementation of adding a connection into a target + fn impl_connect_into_target( + &mut self, + operation: OperationRef, + connect: Box, + ) -> Result<(), DiagramErrorCode> { match self .construction .connect_into_target - .entry(NextOperation::Target(operation.clone())) + .entry(operation.clone()) { Entry::Occupied(_) => { - return Err(DiagramErrorCode::MultipleInputsCreated(operation.clone())); + return Err(DiagramErrorCode::DuplicateInputsCreated(operation)); } Entry::Vacant(vacant) => { vacant.insert(connect); @@ -198,61 +382,61 @@ impl<'a> DiagramContext<'a> { Ok(()) } - /// Same as [`Self::set_connect_into_target`] but you can pass in a closure. - /// - /// This is equivalent to doing - /// `set_connect_into_target(operation, ConnectionCallback(connect))`. - pub fn set_connect_into_target_callback( - &mut self, - operation: &OperationId, - connect: F, - ) -> Result<(), DiagramErrorCode> - where - F: FnMut(DynOutput, &mut Builder, &mut DiagramContext) -> Result<(), DiagramErrorCode> - + 'static, - { - self.set_connect_into_target(operation, ConnectionCallback(connect)) - } - /// Set the buffer that should be used for a certain operation. This will /// also set its connection callback. pub fn set_buffer_for_operation( &mut self, - operation: &OperationId, + operation: impl Into + Clone, buffer: AnyBuffer, ) -> Result<(), DiagramErrorCode> { + let input: DynInputSlot = buffer.into(); + self.set_input_for_target(operation.clone(), input)?; + + let operation = self.into_operation_ref(operation); match self.construction.buffers.entry(operation.clone()) { Entry::Occupied(_) => { - return Err(DiagramErrorCode::MultipleBuffersCreated(operation.clone())); + return Err(DiagramErrorCode::DuplicateBuffersCreated(operation)); } Entry::Vacant(vacant) => { - vacant.insert(buffer); + vacant.insert(BufferRef::Value(buffer)); } } - let input: DynInputSlot = buffer.into(); - self.set_input_for_target(operation, input) + Ok(()) } /// Create a buffer map based on the buffer inputs provided. If one or more /// of the buffers in BufferInputs is not available, get an error including /// the name of the missing buffer. - pub fn create_buffer_map(&self, inputs: &BufferInputs) -> Result { - let attempt_get_buffer = |name: &String| -> Result { - self.construction - .buffers - .get(name) - .copied() - .ok_or_else(|| format!("cannot find buffer named [{name}]")) + pub fn create_buffer_map(&self, inputs: &BufferSelection) -> Result { + let attempt_get_buffer = |buffer: &NextOperation| -> Result { + let mut buffer_ref: OperationRef = self.into_operation_ref(buffer); + let mut visited = HashSet::new(); + loop { + if !visited.insert(buffer_ref.clone()) { + return Err(format!("circular reference for buffer [{buffer}]")); + } + + let next = self + .construction + .buffers + .get(&buffer_ref) + .ok_or_else(|| format!("cannot find buffer named [{buffer}]"))?; + + buffer_ref = match next { + BufferRef::Value(value) => return Ok(*value), + BufferRef::Ref(reference) => reference.clone(), + }; + } }; match inputs { - BufferInputs::Single(op_id) => { + BufferSelection::Single(op_id) => { let mut buffer_map = BufferMap::with_capacity(1); buffer_map.insert(BufferIdentifier::Index(0), attempt_get_buffer(op_id)?); Ok(buffer_map) } - BufferInputs::Dict(mapping) => { + BufferSelection::Dict(mapping) => { let mut buffer_map = BufferMap::with_capacity(mapping.len()); for (k, op_id) in mapping { buffer_map.insert( @@ -262,7 +446,7 @@ impl<'a> DiagramContext<'a> { } Ok(buffer_map) } - BufferInputs::Array(arr) => { + BufferSelection::Array(arr) => { let mut buffer_map = BufferMap::with_capacity(arr.len()); for (i, op_id) in arr.into_iter().enumerate() { buffer_map.insert(BufferIdentifier::Index(i), attempt_get_buffer(op_id)?); @@ -272,49 +456,147 @@ impl<'a> DiagramContext<'a> { } } - /// Get the type information for the request message that goes into a node. + pub fn get_implicit_error_target(&self) -> OperationRef { + self.on_implicit_error.clone() + } + + pub fn into_operation_ref(&self, id: impl Into) -> OperationRef { + let id: OperationRef = id.into(); + id.in_namespaces(&self.namespaces) + } + + /// Add an operation which exists as a child inside another operation. + /// + /// For example this is used by section templates to add their inner + /// operations into the workflow builder. + /// + /// Operations added this way will be internal to the parent operation, + /// which means they are not "exposed" to be connected to other operations + /// that are not inside the parent. + pub fn add_child_operation( + &mut self, + id: &OperationName, + child_id: &OperationName, + op: &'a DiagramOperation, + sibling_ops: &'a Operations, + ) { + let mut namespaces = self.namespaces.clone(); + namespaces.push(Arc::clone(id)); + + self.construction + .generated_operations + .push(UnfinishedOperation { + id: Arc::clone(child_id), + namespaces, + op, + sibling_ops, + }); + } + + /// Create a connection for an exposed input that allows it to redirect any + /// connections to an internal (child) input. /// /// # Arguments /// - /// * `target` - Optionally indicate a specific node in the diagram to treat - /// as the target node, even if it is not the actual target. Using [`Some`] - /// for this will override whatever is used for `next`. - /// * `next` - Indicate the next operation, i.e. the true target. - pub fn get_node_request_type( - &self, - target: Option<&OperationId>, - next: &NextOperation, - ) -> Result { - let target_node = if let Some(target) = target { - self.diagram.get_op(target)? - } else { - match next { - NextOperation::Target(op_id) => self.diagram.get_op(op_id)?, - NextOperation::Builtin { builtin } => match builtin { - BuiltinTarget::Terminate => return Ok(TypeInfo::of::()), - BuiltinTarget::Dispose => return Err(DiagramErrorCode::UnknownTarget), - BuiltinTarget::Cancel => return Ok(TypeInfo::of::()), - }, + /// * `section_id` - The ID of the parent operation (e.g. ID of the section) + /// that is exposing an operation to its siblings. + /// * `exposed_id` - The ID of the exposed operation that is being provided + /// by the section. E.g. siblings of the section will refer to the redirected + /// input by `{ section_id: exposed_id }` + /// * `child_id` - The ID of the operation inside the section which the exposed + /// operation ID is being redirected to. + /// + /// This should not be used to expose a buffer. For that, use + /// [`Self::redirect_to_child_buffer`]. + pub fn redirect_to_child_input( + &mut self, + section_id: &OperationName, + exposed_id: &OperationName, + child_id: &NextOperation, + ) -> Result<(), DiagramErrorCode> { + let (exposed, redirect_to) = + self.get_exposed_and_inner_ids(section_id, exposed_id, child_id); + + self.impl_connect_into_target(exposed, Box::new(RedirectConnection::new(redirect_to))) + } + + /// Same as [`Self::redirect_to_child_input`], but meant for buffers. + pub fn redirect_to_child_buffer( + &mut self, + section_id: &OperationName, + exposed_id: &OperationName, + child_id: &NextOperation, + ) -> Result<(), DiagramErrorCode> { + let (exposed, redirect_to) = + self.get_exposed_and_inner_ids(section_id, exposed_id, child_id); + + self.impl_connect_into_target( + exposed.clone(), + Box::new(RedirectConnection::new(redirect_to.clone())), + )?; + + match self.construction.buffers.entry(exposed.clone()) { + Entry::Occupied(_) => { + return Err(DiagramErrorCode::DuplicateBuffersCreated(exposed)); } - }; - let node_op = match target_node { - DiagramOperation::Node(op) => op, - _ => return Err(DiagramErrorCode::UnknownTarget), - }; - let target_type = self - .registry - .get_node_registration(&node_op.builder)? - .request; - Ok(target_type) + Entry::Vacant(vacant) => { + vacant.insert(BufferRef::Ref(redirect_to)); + } + } + + Ok(()) + } + + /// Create a connection for an exposed output that allows it to redirect + /// any connections to an external (sibling) input. This is used to implement + /// the `"connect":` schema for section templates. + /// + /// # Arguments + /// + /// * `section_id` - The ID of the parent operation (e.g. ID of the section) + /// that is exposing an output to its siblings. + /// * `output_id` - The ID of the exposed output that is being provided by + /// the section. E.g. siblings of the section will refer to the redirected + /// output by `{ section_id: output_id }`. + /// * `sibling_id` - The sibling of the section that should receive the output. + pub fn redirect_exposed_output_to_sibling( + &mut self, + section_id: &OperationName, + output_id: &OperationName, + sibling_id: &NextOperation, + ) -> Result<(), DiagramErrorCode> { + // This is the slot that operations inside of the section will direct + // their outputs to. It will receive the outputs and then redirect them. + let internal: OperationRef = NamedOperationRef { + namespaces: SmallVec::from_iter([Arc::clone(section_id)]), + exposed_namespace: None, + name: Arc::clone(output_id), + } + .in_namespaces(&self.namespaces) + .into(); + + let redirect_to = self.into_operation_ref(sibling_id); + self.impl_connect_into_target(internal, Box::new(RedirectConnection::new(redirect_to))) } - pub fn get_implicit_error_target(&self) -> NextOperation { - self.diagram - .on_implicit_error - .clone() - .unwrap_or(NextOperation::Builtin { - builtin: BuiltinTarget::Cancel, - }) + fn get_exposed_and_inner_ids( + &self, + section_id: &OperationName, + exposed_id: &OperationName, + child_id: &NextOperation, + ) -> (OperationRef, OperationRef) { + let mut child_namespaces = self.namespaces.clone(); + child_namespaces.push(Arc::clone(section_id)); + let inner = Into::::into(child_id).in_namespaces(&child_namespaces); + + let exposed: OperationRef = NamedOperationRef { + namespaces: self.namespaces.clone(), + exposed_namespace: Some(section_id.clone()), + name: exposed_id.clone(), + } + .into(); + + (exposed, inner) } } @@ -385,15 +667,12 @@ impl BuildStatus { /// This trait is used to instantiate operations in the workflow. This trait /// will be called on each operation in the diagram until it finishes building. -/// Each operation should use this to provide a [`ConnectOutput`] handle for +/// Each operation should use this to provide a [`ConnectIntoTarget`] handle for /// itself (if relevant) and deposit [`DynOutput`]s into [`DiagramContext`]. -/// -/// After all operations are fully built, [`ConnectIntoTarget`] will be used to -/// connect outputs into their target operations. pub trait BuildDiagramOperation { fn build_diagram_operation( &self, - id: &OperationId, + id: &OperationName, builder: &mut Builder, ctx: &mut DiagramContext, ) -> Result; @@ -403,12 +682,7 @@ pub trait BuildDiagramOperation { /// will be called for each output produced by [`BuildDiagramOperation`]. /// /// You are allowed to generate new outputs during the [`ConnectIntoTarget`] -/// phase by calling [`DiagramContext::add_outputs_into_target`]. -/// -/// However you cannot add new [`ConnectIntoTarget`] instances for operations. -/// Any use of [`DiagramContext::set_input_for_target`], -/// [`DiagramContext::set_connect_into_target`], or -/// [`DiagramContext::set_connect_into_target_callback`] will be discarded. +/// phase by calling [`DiagramContext::add_output_into_target`]. pub trait ConnectIntoTarget { fn connect_into_target( &mut self, @@ -416,24 +690,26 @@ pub trait ConnectIntoTarget { builder: &mut Builder, ctx: &mut DiagramContext, ) -> Result<(), DiagramErrorCode>; -} -pub struct ConnectionCallback(pub F) -where - F: FnMut(DynOutput, &mut Builder, &mut DiagramContext) -> Result<(), DiagramErrorCode>; + fn infer_input_type( + &self, + ctx: &DiagramContext, + visited: &mut HashSet, + ) -> Result>, DiagramErrorCode>; +} -impl ConnectIntoTarget for ConnectionCallback -where - F: FnMut(DynOutput, &mut Builder, &mut DiagramContext) -> Result<(), DiagramErrorCode>, -{ - fn connect_into_target( - &mut self, - output: DynOutput, - builder: &mut Builder, - ctx: &mut DiagramContext, - ) -> Result<(), DiagramErrorCode> { - (self.0)(output, builder, ctx) - } +/// This trait helps to determine what types of messages can go into an input +/// slot. +/// +/// For the first implementation of this, we are only considering the most +/// preferred message type, but in the future we should add support for +/// building a constraint graph to support inference in cases where an input +/// can accept multiple different message types. +/// +/// NOTE(@mxgrey): We may expand on this trait in the future to enable message +/// type negotiation for operations that can accept a range of message types. +pub trait InferMessageType { + fn preferred_input_type(&self) -> Option; } pub(super) fn create_workflow( @@ -447,64 +723,148 @@ where Response: 'static + Send + Sync, Streams: StreamPack, { + diagram.validate_operation_names()?; + diagram.validate_template_usage()?; + let mut construction = DiagramConstruction::default(); + let default_on_implicit_error = OperationRef::Builtin { + builtin: BuiltinTarget::Cancel, + }; + let opt_on_implicit_error: Option = + diagram.on_implicit_error.as_ref().map(Into::into); + + let on_implicit_error = opt_on_implicit_error + .as_ref() + .unwrap_or(&default_on_implicit_error); + initialize_builtin_operations( + diagram.start.clone(), scope, builder, &mut DiagramContext { construction: &mut construction, - diagram, registry, + operations: &diagram.ops, + templates: &diagram.templates, + on_implicit_error, + namespaces: NamespaceList::new(), }, )?; - let mut unfinished_operations: Vec<&OperationId> = diagram.ops.keys().collect(); - let mut deferred_operations: Vec<(&OperationId, BuildStatus)> = Vec::new(); + let mut unfinished_operations: Vec = diagram + .ops + .iter() + .map(|(id, op)| UnfinishedOperation::new(Arc::clone(id), op, &diagram.ops)) + .collect(); + let mut deferred_operations: Vec = Vec::new(); + let mut deferred_statuses: Vec<(OperationRef, BuildStatus)> = Vec::new(); + + let mut deferred_connections = HashMap::new(); + let mut connector_construction = DiagramConstruction::default(); let mut iterations = 0; const MAX_ITERATIONS: usize = 10_000; // Iteratively build all the operations in the diagram - while !unfinished_operations.is_empty() { + while !unfinished_operations.is_empty() || construction.has_outputs() { let mut made_progress = false; - for op in unfinished_operations.drain(..) { + for unfinished in unfinished_operations.drain(..) { let mut ctx = DiagramContext { construction: &mut construction, - diagram, registry, + operations: &unfinished.sibling_ops, + templates: &diagram.templates, + on_implicit_error, + namespaces: unfinished.namespaces.clone(), }; // Attempt to build this operation - let status = diagram - .ops - .get(op) - .ok_or_else(|| { - DiagramErrorCode::UnknownOperation(NextOperation::Target(op.clone())) - })? - .build_diagram_operation(op, builder, &mut ctx) - .map_err(|code| DiagramError::in_operation(op.clone(), code))?; + let status = unfinished + .op + .build_diagram_operation(&unfinished.id, builder, &mut ctx) + .map_err(|code| code.in_operation(unfinished.as_operation_ref()))?; + + ctx.construction + .transfer_generated_operations(&mut deferred_operations, &mut made_progress); made_progress |= status.made_progress(); if !status.is_finished() { // The operation did not finish, so pass it into the deferred // operations list. - deferred_operations.push((op, status)); + deferred_statuses.push((unfinished.as_operation_ref(), status)); + deferred_operations.push(unfinished); } } - if made_progress { - // Try another iteration if needed since we made progress last time - unfinished_operations = deferred_operations.drain(..).map(|(op, _)| op).collect(); - } else { + unfinished_operations.extend(deferred_operations.drain(..)); + + // Transfer outputs into their connections. Sometimes this needs to be + // done before other operations can be built, e.g. a connection may need + // to be redirected before its target operation knows how to infer its + // message type. + connector_construction.buffers = construction.buffers.clone(); + loop { + for (id, outputs) in construction.outputs_into_target.drain() { + let mut ctx = DiagramContext { + construction: &mut connector_construction, + registry, + operations: &diagram.ops, + templates: &diagram.templates, + on_implicit_error, + namespaces: Default::default(), + }; + + let Some(connect) = construction.connect_into_target.get_mut(&id) else { + if unfinished_operations.is_empty() { + return Err(DiagramErrorCode::UnknownOperation(id.into()).into()); + } else { + deferred_connections.insert(id, outputs); + continue; + } + }; + + for output in outputs { + made_progress = true; + connect + .connect_into_target(output, builder, &mut ctx) + .map_err(|code| code.in_operation(id.clone()))?; + } + } + + let new_connections = !connector_construction.outputs_into_target.is_empty(); + + construction + .outputs_into_target + .extend(connector_construction.outputs_into_target.drain()); + + construction + .outputs_into_target + .extend(deferred_connections.drain()); + + connector_construction + .transfer_generated_operations(&mut unfinished_operations, &mut made_progress); + + // TODO(@mxgrey): Consider draining new connect_into_target entries + // out of connector_construction. + + iterations += 1; + if iterations > MAX_ITERATIONS { + return Err(DiagramErrorCode::ExcessiveIterations.into()); + } + + if !new_connections { + break; + } + } + + if !made_progress { // No progress can be made any longer so return an error return Err(DiagramErrorCode::BuildHalted { - reasons: deferred_operations + reasons: deferred_statuses .drain(..) - .filter_map(|(op, status)| { - status - .into_deferral_reason() - .map(|reason| (op.clone(), reason)) + .filter_map(|(id, status)| { + status.into_deferral_reason().map(|reason| (id, reason)) }) .collect(), } @@ -517,61 +877,51 @@ where } } - let mut new_construction = DiagramConstruction::default(); - new_construction.buffers = construction.buffers.clone(); - - iterations = 0; - while !construction.is_finished() { - let mut ctx = DiagramContext { - construction: &mut new_construction, - diagram, - registry, - }; - - // Attempt to connect to all regular operations - for (op, outputs) in construction.outputs_to_operation_target.drain() { - let op = NextOperation::Target(op); - let connect = construction - .connect_into_target - .get_mut(&op) - .ok_or_else(|| DiagramErrorCode::UnknownOperation(op.clone()))?; + Ok(()) +} - for output in outputs { - connect.connect_into_target(output, builder, &mut ctx)?; - } - } +pub struct UnfinishedOperation<'a> { + /// Name of the operation within its scope + pub id: OperationName, + /// The namespaces that this operation takes place inside + pub namespaces: NamespaceList, + /// Description of the operation + pub op: &'a DiagramOperation, + /// The sibling operations of the one that is being built + pub sibling_ops: &'a Operations, +} - // Attempt to connect to all builtin operations - for (builtin, outputs) in construction.outputs_to_builtin_target.drain() { - let op = NextOperation::Builtin { builtin }; - let connect = construction - .connect_into_target - .get_mut(&op) - .ok_or_else(|| DiagramErrorCode::UnknownOperation(op.clone()))?; +impl<'a> std::fmt::Debug for UnfinishedOperation<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("UnfinishedOperation") + .field("id", &self.id) + .field("namespaces", &self.namespaces) + .finish() + } +} - for output in outputs { - connect.connect_into_target(output, builder, &mut ctx)?; - } +impl<'a> UnfinishedOperation<'a> { + pub fn new(id: OperationName, op: &'a DiagramOperation, sibling_ops: &'a Operations) -> Self { + Self { + id, + op, + sibling_ops, + namespaces: Default::default(), } + } - construction - .outputs_to_builtin_target - .extend(new_construction.outputs_to_builtin_target.drain()); - - construction - .outputs_to_operation_target - .extend(new_construction.outputs_to_operation_target.drain()); - - iterations += 1; - if iterations > MAX_ITERATIONS { - return Err(DiagramErrorCode::ExcessiveIterations.into()); + pub fn as_operation_ref(&self) -> OperationRef { + NamedOperationRef { + namespaces: self.namespaces.clone(), + exposed_namespace: None, + name: self.id.clone(), } + .into() } - - Ok(()) } fn initialize_builtin_operations( + start: NextOperation, scope: Scope, builder: &mut Builder, ctx: &mut DiagramContext, @@ -582,11 +932,11 @@ where Streams: StreamPack, { // Put the input message into the diagram - ctx.add_output_into_target(ctx.diagram.start.clone(), scope.input.into()); + ctx.add_output_into_target(&start, scope.input.into()); // Add the terminate operation ctx.construction.connect_into_target.insert( - NextOperation::Builtin { + OperationRef::Builtin { builtin: BuiltinTarget::Terminate, }, standard_input_connection(scope.terminate.into(), &ctx.registry)?, @@ -594,18 +944,15 @@ where // Add the dispose operation ctx.construction.connect_into_target.insert( - NextOperation::Builtin { + OperationRef::Builtin { builtin: BuiltinTarget::Dispose, }, - Box::new(ConnectionCallback(move |_, _, _| { - // Do nothing since the output is being disposed - Ok(()) - })), + Box::new(ConnectToDispose), ); // Add the cancel operation ctx.construction.connect_into_target.insert( - NextOperation::Builtin { + OperationRef::Builtin { builtin: BuiltinTarget::Cancel, }, Box::new(ConnectToCancel::new(builder)?), @@ -614,6 +961,33 @@ where Ok(()) } +struct ConnectToDispose; + +impl ConnectIntoTarget for ConnectToDispose { + fn connect_into_target( + &mut self, + _output: DynOutput, + _builder: &mut Builder, + _ctx: &mut DiagramContext, + ) -> Result<(), DiagramErrorCode> { + Ok(()) + } + + fn infer_input_type( + &self, + _ctx: &DiagramContext, + _visited: &mut HashSet, + ) -> Result>, DiagramErrorCode> { + Ok(Some(Arc::new(ConnectToDispose))) + } +} + +impl InferMessageType for ConnectToDispose { + fn preferred_input_type(&self) -> Option { + None + } +} + /// This returns an opaque [`ConnectIntoTarget`] implementation that provides /// the standard behavior of an input slot that other operations are connecting /// into. @@ -632,7 +1006,9 @@ pub fn standard_input_connection( return Ok(Box::new(deserialization)); } - Ok(Box::new(BasicConnect { input_slot })) + Ok(Box::new(BasicConnect { + input_slot: Arc::new(input_slot), + })) } impl ConnectIntoTarget for ImplicitSerialization { @@ -644,6 +1020,15 @@ impl ConnectIntoTarget for ImplicitSerialization { ) -> Result<(), DiagramErrorCode> { self.implicit_serialize(output, builder, ctx) } + + fn infer_input_type( + &self, + _ctx: &DiagramContext, + _visited: &mut HashSet, + ) -> Result>, DiagramErrorCode> { + let infer = Arc::clone(self.serialized_input_slot()); + Ok(Some(infer)) + } } impl ConnectIntoTarget for ImplicitDeserialization { @@ -655,10 +1040,19 @@ impl ConnectIntoTarget for ImplicitDeserialization { ) -> Result<(), DiagramErrorCode> { self.implicit_deserialize(output, builder, ctx) } + + fn infer_input_type( + &self, + _ctx: &DiagramContext, + _visited: &mut HashSet, + ) -> Result>, DiagramErrorCode> { + let infer = Arc::clone(self.deserialized_input_slot()); + Ok(Some(infer)) + } } struct BasicConnect { - input_slot: DynInputSlot, + input_slot: Arc, } impl ConnectIntoTarget for BasicConnect { @@ -670,6 +1064,15 @@ impl ConnectIntoTarget for BasicConnect { ) -> Result<(), DiagramErrorCode> { output.connect_to(&self.input_slot, builder) } + + fn infer_input_type( + &self, + _ctx: &DiagramContext, + _visited: &mut HashSet, + ) -> Result>, DiagramErrorCode> { + let infer = Arc::clone(&self.input_slot); + Ok(Some(infer)) + } } struct ConnectToCancel { @@ -737,4 +1140,145 @@ impl ConnectIntoTarget for ConnectToCancel { Ok(()) } + + fn infer_input_type( + &self, + ctx: &DiagramContext, + visited: &mut HashSet, + ) -> Result>, DiagramErrorCode> { + self.implicit_serialization.infer_input_type(ctx, visited) + } +} + +impl InferMessageType for DynInputSlot { + fn preferred_input_type(&self) -> Option { + Some(*self.message_info()) + } +} + +#[derive(Debug)] +struct RedirectConnection { + redirect_to: OperationRef, + /// Keep track of which DynOutputs have been redirected in the past so we + /// can identify when a circular redirection is happening. + redirected: HashSet, +} + +impl RedirectConnection { + fn new(redirect_to: OperationRef) -> Self { + Self { + redirect_to, + redirected: Default::default(), + } + } +} + +impl ConnectIntoTarget for RedirectConnection { + fn connect_into_target( + &mut self, + output: DynOutput, + _builder: &mut Builder, + ctx: &mut DiagramContext, + ) -> Result<(), DiagramErrorCode> { + if self.redirected.insert(output.id()) { + // This DynOutput has not been redirected by this connector yet, so + // we should go ahead and redirect it. + ctx.add_output_into_target(self.redirect_to.clone(), output); + } else { + // This DynOutput has been redirected by this connector before, so + // we have a circular connection, making it impossible for the + // output to ever really be connected to anything. + return Err(DiagramErrorCode::CircularRedirect(vec![self + .redirect_to + .clone()])); + } + Ok(()) + } + + fn infer_input_type( + &self, + ctx: &DiagramContext, + visited: &mut HashSet, + ) -> Result>, DiagramErrorCode> { + if visited.insert(self.redirect_to.clone()) { + if let Some(connect) = ctx.construction.connect_into_target.get(&self.redirect_to) { + return connect.infer_input_type(ctx, visited); + } else { + return Ok(None); + } + } else { + return Err(DiagramErrorCode::CircularRedirect( + visited.drain().collect(), + )); + } + } +} + +impl<'a, 'c> std::fmt::Debug for DiagramContext<'a, 'c> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("DiagramContext") + .field("construction", &DebugDiagramConstruction(self)) + .finish() + } +} + +struct DebugDiagramConstruction<'a, 'c, 'd>(&'d DiagramContext<'a, 'c>); + +impl<'a, 'c, 'd> std::fmt::Debug for DebugDiagramConstruction<'a, 'c, 'd> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("DiagramConstruction") + .field("connect_into_target", &DebugConnections(self.0)) + .field( + "outputs_into_target", + &self.0.construction.outputs_into_target, + ) + .field("buffers", &self.0.construction.buffers) + .field( + "generated_operations", + &self.0.construction.generated_operations, + ) + .finish() + } +} + +struct DebugConnections<'a, 'c, 'd>(&'d DiagramContext<'a, 'c>); + +impl<'a, 'c, 'd> std::fmt::Debug for DebugConnections<'a, 'c, 'd> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut debug = f.debug_map(); + for (op, connect) in &self.0.construction.connect_into_target { + debug.entry( + op, + &DebugConnection { + connect, + context: self.0, + }, + ); + } + + debug.finish() + } +} + +struct DebugConnection<'a, 'c, 'd> { + connect: &'d Box, + context: &'d DiagramContext<'a, 'c>, +} + +impl<'a, 'c, 'd> std::fmt::Debug for DebugConnection<'a, 'c, 'd> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut visited = HashSet::new(); + let mut debug = f.debug_struct("ConnectIntoTarget"); + match self.connect.infer_input_type(self.context, &mut visited) { + Ok(ok) => { + let inferred = ok.map(|infer| infer.preferred_input_type()).flatten(); + debug.field("inferred type", &inferred); + } + Err(err) => { + debug.field("infered type error", &err); + } + } + + debug.finish() + } } From 2e5e7162dad955dd0308e8aabb914650496e9725 Mon Sep 17 00:00:00 2001 From: Xiyu Oh Date: Wed, 23 Apr 2025 09:42:10 +0000 Subject: [PATCH 16/20] Cleanup App::world getters Signed-off-by: Xiyu Oh --- examples/diagram/calculator/src/main.rs | 2 +- src/diagram.rs | 4 ++-- src/diagram/section_schema.rs | 6 +++--- src/diagram/testing.rs | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/diagram/calculator/src/main.rs b/examples/diagram/calculator/src/main.rs index 2321f0cf..dc743d13 100644 --- a/examples/diagram/calculator/src/main.rs +++ b/examples/diagram/calculator/src/main.rs @@ -46,7 +46,7 @@ fn main() -> Result<(), Box> { let request = serde_json::Value::from_str(&args.request)?; let mut promise = - app.world + app.world_mut() .command(|cmds| -> Result, DiagramError> { let workflow = diagram.spawn_io_workflow(cmds, ®istry)?; Ok(cmds.request(request, workflow).take_response()) diff --git a/src/diagram.rs b/src/diagram.rs index 493fe735..7806a1ac 100644 --- a/src/diagram.rs +++ b/src/diagram.rs @@ -969,7 +969,7 @@ impl Diagram { /// "#; /// /// let diagram = Diagram::from_json_str(json_str)?; - /// let workflow = app.world.command(|cmds| diagram.spawn_io_workflow::(cmds, ®istry))?; + /// let workflow = app.world_mut().command(|cmds| diagram.spawn_io_workflow::(cmds, ®istry))?; /// # Ok::<_, Box>(()) /// ``` // TODO(koonpeng): Support streams other than `()` #43. @@ -1045,7 +1045,7 @@ impl Diagram { /// "#; /// /// let diagram = Diagram::from_json_str(json_str)?; - /// let workflow = app.world.command(|cmds| diagram.spawn_io_workflow::(cmds, ®istry))?; + /// let workflow = app.world_mut().command(|cmds| diagram.spawn_io_workflow::(cmds, ®istry))?; /// # Ok::<_, Box>(()) /// ``` pub fn spawn_io_workflow( diff --git a/src/diagram/section_schema.rs b/src/diagram/section_schema.rs index 798bac5e..2a2aa5c6 100644 --- a/src/diagram/section_schema.rs +++ b/src/diagram/section_schema.rs @@ -626,7 +626,7 @@ mod tests { .unwrap(); let mut context = TestingContext::minimal_plugins(); - let mut promise = context.app.world.command(|cmds| { + let mut promise = context.app.world_mut().command(|cmds| { let workflow = diagram .spawn_io_workflow::(cmds, ®istry) .unwrap(); @@ -662,7 +662,7 @@ mod tests { let mut context = TestingContext::minimal_plugins(); let err = context .app - .world + .world_mut() .command(|cmds| diagram.spawn_io_workflow::(cmds, ®istry)) .unwrap_err(); let section_err = match err.code { @@ -707,7 +707,7 @@ mod tests { let mut context = TestingContext::minimal_plugins(); let err = context .app - .world + .world_mut() .command(|cmds| { diagram.spawn_io_workflow::(cmds, &fixture.registry) }) diff --git a/src/diagram/testing.rs b/src/diagram/testing.rs index 84f4b904..996777a2 100644 --- a/src/diagram/testing.rs +++ b/src/diagram/testing.rs @@ -40,7 +40,7 @@ impl DiagramTestFixture { { self.context .app - .world + .world_mut() .command(|cmds| diagram.spawn_workflow(cmds, &self.registry)) } From f3c961a22ed0571e34bffcace8f140cab5df8c0a Mon Sep 17 00:00:00 2001 From: Xiyu Oh Date: Fri, 2 May 2025 03:57:14 +0000 Subject: [PATCH 17/20] Style Signed-off-by: Xiyu Oh --- src/service/discovery.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/service/discovery.rs b/src/service/discovery.rs index e0f69f89..4d804868 100644 --- a/src/service/discovery.rs +++ b/src/service/discovery.rs @@ -17,7 +17,7 @@ use bevy_ecs::{ prelude::{Entity, Query, With}, - query::{QueryEntityError, QueryIter, QueryFilter}, + query::{QueryEntityError, QueryFilter, QueryIter}, system::SystemParam, }; From 26032b26b3e77f02ea07ca05631761c9b300f843 Mon Sep 17 00:00:00 2001 From: Xiyu Oh Date: Fri, 2 May 2025 04:02:46 +0000 Subject: [PATCH 18/20] Style Signed-off-by: Xiyu Oh --- src/channel.rs | 2 +- src/flush.rs | 2 +- src/input.rs | 2 +- src/service/continuous.rs | 2 +- src/service/discovery.rs | 2 +- src/stream.rs | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/channel.rs b/src/channel.rs index 878d737c..1abacf12 100644 --- a/src/channel.rs +++ b/src/channel.rs @@ -17,8 +17,8 @@ use bevy_ecs::{ prelude::{Entity, Resource, World}, - world::CommandQueue, system::Commands, + world::CommandQueue, }; use tokio::sync::mpsc::{ diff --git a/src/flush.rs b/src/flush.rs index 15948624..eca26676 100644 --- a/src/flush.rs +++ b/src/flush.rs @@ -19,8 +19,8 @@ use bevy_derive::{Deref, DerefMut}; use bevy_ecs::{ prelude::{Added, Entity, Query, QueryState, Resource, With, World}, schedule::{IntoSystemConfigs, SystemConfigs}, - world::Command, system::SystemState, + world::Command, }; use bevy_hierarchy::{BuildWorldChildren, Children, DespawnRecursiveExt}; diff --git a/src/input.rs b/src/input.rs index bd73a664..78224936 100644 --- a/src/input.rs +++ b/src/input.rs @@ -17,7 +17,7 @@ use bevy_ecs::{ prelude::{Bundle, Component, Entity}, - world::{EntityRef, EntityWorldMut, World, Command}, + world::{Command, EntityRef, EntityWorldMut, World}, }; use smallvec::SmallVec; diff --git a/src/service/continuous.rs b/src/service/continuous.rs index f28b3620..77f1d49b 100644 --- a/src/service/continuous.rs +++ b/src/service/continuous.rs @@ -19,7 +19,7 @@ use bevy_ecs::{ prelude::{Commands, Component, Entity, Event, EventReader, In, Local, Query, World}, schedule::IntoSystemConfigs, system::{IntoSystem, SystemParam}, - world::{EntityWorldMut, Command}, + world::{Command, EntityWorldMut}, }; use bevy_hierarchy::prelude::{BuildWorldChildren, DespawnRecursiveExt}; diff --git a/src/service/discovery.rs b/src/service/discovery.rs index e0f69f89..4d804868 100644 --- a/src/service/discovery.rs +++ b/src/service/discovery.rs @@ -17,7 +17,7 @@ use bevy_ecs::{ prelude::{Entity, Query, With}, - query::{QueryEntityError, QueryIter, QueryFilter}, + query::{QueryEntityError, QueryFilter, QueryIter}, system::SystemParam, }; diff --git a/src/stream.rs b/src/stream.rs index c57f734b..d7e7a768 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -18,7 +18,7 @@ use bevy_derive::{Deref, DerefMut}; use bevy_ecs::{ prelude::{Bundle, Commands, Component, Entity, With, World}, - query::{QueryFilter, WorldQuery, ReadOnlyQueryData}, + query::{QueryFilter, ReadOnlyQueryData, WorldQuery}, world::Command, }; use bevy_hierarchy::BuildChildren; From 4e8b2241dcdeca5a08aef24bbabfe0459eb41cae Mon Sep 17 00:00:00 2001 From: Xiyu Oh Date: Fri, 2 May 2025 04:07:47 +0000 Subject: [PATCH 19/20] Style again Signed-off-by: Xiyu Oh --- src/stream.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/stream.rs b/src/stream.rs index e5cc875b..42bd773e 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -18,7 +18,7 @@ use bevy_derive::{Deref, DerefMut}; use bevy_ecs::{ prelude::{Bundle, Commands, Component, Entity, With, World}, - query::{QueryFilter, WorldQuery, ReadOnlyQueryData}, + query::{QueryFilter, ReadOnlyQueryData, WorldQuery}, system::Command, }; use bevy_hierarchy::BuildChildren; From 2598131ae1aa74ce6b9d38e281720c98d6a0e4f5 Mon Sep 17 00:00:00 2001 From: Xiyu Oh Date: Fri, 2 May 2025 06:06:08 +0000 Subject: [PATCH 20/20] Update CI to 1.76 Signed-off-by: Xiyu Oh --- .github/workflows/ci_linux.yaml | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/.github/workflows/ci_linux.yaml b/.github/workflows/ci_linux.yaml index 660d6933..e3faa5b4 100644 --- a/.github/workflows/ci_linux.yaml +++ b/.github/workflows/ci_linux.yaml @@ -18,7 +18,7 @@ jobs: build: strategy: matrix: - rust-version: [stable, 1.75] + rust-version: [stable, 1.81] runs-on: ubuntu-latest @@ -28,14 +28,6 @@ jobs: - name: Setup rust run: rustup default ${{ matrix.rust-version }} - # As new versions of our dependencies come out, they might depend on newer - # versions of the Rust compiler. When that happens, we'll use this step to - # lock down the dependency to a version that is known to be compatible with - # compiler version 1.75. - - name: Patch dependencies - if: ${{ matrix.rust-version == 1.75 }} - run: ./scripts/patch-versions-msrv-1_75.sh - - name: Build default features run: cargo build --workspace - name: Test default features