Skip to content

Commit a96fd4a

Browse files
191220029genedna
authored andcommitted
Parser: auto_node macro
1 parent f4d2466 commit a96fd4a

File tree

8 files changed

+289
-4
lines changed

8 files changed

+289
-4
lines changed

Cargo.toml

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,25 @@ repository = "https://github.com/open-rust-initiative/dagrs"
1010
keywords = ["DAG", "task", "async", "parallel", "concurrent"]
1111

1212
[workspace]
13-
members = ["."]
13+
members = [".", "derive"]
1414

1515
[dependencies]
1616
tokio = { version = "1.28", features = ["rt", "sync", "rt-multi-thread"] }
1717
log = "0.4"
1818
env_logger = "0.10.1"
1919
async-trait = "0.1.83"
20+
derive = { path = "derive", optional = true }
2021

2122
[dev-dependencies]
2223
simplelog = "0.12"
2324
criterion = { version = "0.5.1", features = ["html_reports"] }
2425

2526
[target.'cfg(unix)'.dev-dependencies]
2627

27-
2828
[features]
29+
default = ["derive"]
30+
derive = ["derive/derive"]
31+
32+
[[example]]
33+
name = "auto_node"
34+
required-features = ["derive"]

derive/Cargo.toml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
[package]
2+
name = "derive"
3+
version = "0.1.0"
4+
edition = "2021"
5+
license = "MIT OR Apache-2.0"
6+
7+
[dependencies]
8+
syn = { version = "2.0", features = ["full"] }
9+
quote = "1.0"
10+
proc-macro2= "1.0"
11+
12+
[lib]
13+
proc-macro = true
14+
15+
[features]
16+
default = ["derive"]
17+
derive = []

derive/src/auto_node.rs

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
use proc_macro::TokenStream;
2+
use quote::quote;
3+
use syn::parse::Parser;
4+
use syn::{parse, parse_macro_input, Field, Generics, Ident, ItemStruct};
5+
6+
/// Generate fields & implements of `Node` trait.
7+
///
8+
/// Step 1: generate fields (`id`, `name`, `input_channel`, `output_channel`, `action`)
9+
///
10+
/// Step 2: generates methods for `Node` implementation.
11+
///
12+
/// Step 3: append the generated fields to the input struct.
13+
///
14+
/// Step 4: return tokens of the input struct & the generated methods.
15+
pub(crate) fn auto_node(args: TokenStream, input: TokenStream) -> TokenStream {
16+
let mut item_struct = parse_macro_input!(input as ItemStruct);
17+
let _ = parse_macro_input!(args as parse::Nothing);
18+
19+
let generics = &item_struct.generics;
20+
21+
let field_id = syn::Field::parse_named
22+
.parse2(quote! {
23+
id: dagrs::NodeId
24+
})
25+
.unwrap();
26+
27+
let field_name = syn::Field::parse_named
28+
.parse2(quote! {
29+
name: String
30+
})
31+
.unwrap();
32+
33+
let field_in_channels = syn::Field::parse_named
34+
.parse2(quote! {
35+
input_channels: dagrs::InChannels
36+
})
37+
.unwrap();
38+
39+
let field_out_channels = syn::Field::parse_named
40+
.parse2(quote! {
41+
output_channels: dagrs::OutChannels
42+
})
43+
.unwrap();
44+
45+
let field_action = syn::Field::parse_named
46+
.parse2(quote! {
47+
action: Box<dyn dagrs::Action>
48+
})
49+
.unwrap();
50+
51+
let auto_impl = auto_impl_node(
52+
&item_struct.ident,
53+
generics,
54+
&field_id,
55+
&field_name,
56+
&field_in_channels,
57+
&field_out_channels,
58+
&field_action,
59+
);
60+
61+
match item_struct.fields {
62+
syn::Fields::Named(ref mut fields) => {
63+
fields.named.push(field_id);
64+
fields.named.push(field_name);
65+
fields.named.push(field_in_channels);
66+
fields.named.push(field_out_channels);
67+
fields.named.push(field_action);
68+
}
69+
syn::Fields::Unit => {
70+
item_struct.fields = syn::Fields::Named(syn::FieldsNamed {
71+
named: [
72+
field_id,
73+
field_name,
74+
field_in_channels,
75+
field_out_channels,
76+
field_action,
77+
]
78+
.into_iter()
79+
.collect(),
80+
brace_token: Default::default(),
81+
});
82+
}
83+
_ => {
84+
return syn::Error::new_spanned(
85+
item_struct.ident,
86+
"`auto_node` macro can only be annotated on named struct or unit struct.",
87+
)
88+
.into_compile_error()
89+
.into()
90+
}
91+
};
92+
93+
return quote! {
94+
#item_struct
95+
#auto_impl
96+
}
97+
.into();
98+
}
99+
100+
fn auto_impl_node(
101+
struct_ident: &Ident,
102+
generics: &Generics,
103+
field_id: &Field,
104+
field_name: &Field,
105+
field_in_channels: &Field,
106+
field_out_channels: &Field,
107+
field_action: &Field,
108+
) -> proc_macro2::TokenStream {
109+
let mut impl_tokens = proc_macro2::TokenStream::new();
110+
impl_tokens.extend([
111+
impl_id(field_id),
112+
impl_name(field_name),
113+
impl_in_channels(field_in_channels),
114+
impl_out_channels(field_out_channels),
115+
impl_run(field_action, field_in_channels, field_out_channels),
116+
]);
117+
118+
quote::quote!(
119+
impl #generics dagrs::Node for #struct_ident #generics {
120+
#impl_tokens
121+
}
122+
unsafe impl #generics Send for #struct_ident #generics{}
123+
unsafe impl #generics Sync for #struct_ident #generics{}
124+
)
125+
}
126+
127+
fn impl_id(field: &Field) -> proc_macro2::TokenStream {
128+
let ident = &field.ident;
129+
quote::quote!(
130+
fn id(&self) -> dagrs::NodeId {
131+
self.#ident
132+
}
133+
)
134+
}
135+
136+
fn impl_name(field: &Field) -> proc_macro2::TokenStream {
137+
let ident = &field.ident;
138+
quote::quote!(
139+
fn name(&self) -> dagrs::NodeName {
140+
self.#ident.clone()
141+
}
142+
)
143+
}
144+
145+
fn impl_in_channels(field: &Field) -> proc_macro2::TokenStream {
146+
let ident = &field.ident;
147+
quote::quote!(
148+
fn input_channels(&mut self) -> &mut dagrs::InChannels {
149+
&mut self.#ident
150+
}
151+
)
152+
}
153+
154+
fn impl_out_channels(field: &Field) -> proc_macro2::TokenStream {
155+
let ident = &field.ident;
156+
quote::quote!(
157+
fn output_channels(&mut self) -> &mut dagrs::OutChannels {
158+
&mut self.#ident
159+
}
160+
)
161+
}
162+
163+
fn impl_run(
164+
field: &Field,
165+
field_in_channels: &Field,
166+
field_out_channels: &Field,
167+
) -> proc_macro2::TokenStream {
168+
let ident = &field.ident;
169+
let in_channels_ident = &field_in_channels.ident;
170+
let out_channels_ident = &field_out_channels.ident;
171+
quote::quote!(
172+
fn run(&mut self, env: std::sync::Arc<dagrs::EnvVar>) -> dagrs::Output {
173+
tokio::runtime::Runtime::new().unwrap().block_on(async {
174+
self.#ident
175+
.run(&mut self.#in_channels_ident, &self.#out_channels_ident, env)
176+
.await
177+
})
178+
}
179+
)
180+
}

derive/src/lib.rs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
use proc_macro::TokenStream;
2+
3+
#[cfg(feature = "derive")]
4+
mod auto_node;
5+
6+
/// [`auto_node`] is a macro that may be used when customizing nodes. It can only be
7+
/// marked on named struct or unit struct.
8+
///
9+
/// The macro [`auto_node`] generates essential fields and implementation of traits for
10+
/// structs intended to represent `Node` in **Dagrs**.
11+
/// By applying this macro to a struct, it appends fields including `id: dagrs::NodeId`,
12+
/// `name: dagrs::NodeName`, `input_channels: dagrs::InChannels`, `output_channels: dagrs::OutChannels`,
13+
/// and `action: dagrs::Action`, and implements the required `dagrs::Node` trait.
14+
///
15+
/// ## Example
16+
/// - Mark `auto_node` on a struct with customized fields.
17+
/// ```ignore
18+
/// use dagrs::auto_node;
19+
/// #[auto_node]
20+
/// struct MyNode {/*Put your customized fields here.*/}
21+
/// ```
22+
///
23+
/// - Mark `auto_node` on a struct with generic & lifetime params.
24+
/// ```ignore
25+
/// use dagrs::auto_node;
26+
/// #[auto_node]
27+
/// struct MyNode<T, 'a> {/*Put your customized fields here.*/}
28+
/// ```
29+
/// - Mark `auto_node` on a unit struct.
30+
/// ```ignore
31+
/// use dagrs::auto_node;
32+
/// #[auto_node]
33+
/// struct MyNode()
34+
/// ```
35+
#[cfg(feature = "derive")]
36+
#[proc_macro_attribute]
37+
pub fn auto_node(args: TokenStream, input: TokenStream) -> TokenStream {
38+
use crate::auto_node::auto_node;
39+
auto_node(args, input).into()
40+
}

examples/auto_node.rs

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
use std::sync::Arc;
2+
3+
use dagrs::{auto_node, EmptyAction, EnvVar, InChannels, Node, NodeTable, OutChannels};
4+
5+
#[auto_node]
6+
struct MyNode {/*Put customized fields here.*/}
7+
8+
#[auto_node]
9+
struct _MyNodeGeneric<T, 'a> {
10+
my_field: Vec<T>,
11+
my_name: &'a str,
12+
}
13+
14+
#[auto_node]
15+
struct _MyUnitNode;
16+
17+
fn main() {
18+
let mut node_table = NodeTable::default();
19+
20+
let node_name = "auto_node".to_string();
21+
22+
let mut s = MyNode {
23+
id: node_table.alloc_id_for(&node_name),
24+
name: node_name.clone(),
25+
input_channels: InChannels::default(),
26+
output_channels: OutChannels::default(),
27+
action: Box::new(EmptyAction),
28+
};
29+
30+
assert_eq!(&s.id(), node_table.get(&node_name).unwrap());
31+
assert_eq!(&s.name(), &node_name);
32+
33+
let output = s.run(Arc::new(EnvVar::new(NodeTable::default())));
34+
match output {
35+
dagrs::Output::Out(content) => assert!(content.is_none()),
36+
_ => panic!(),
37+
}
38+
}

src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,8 @@ pub use node::{
1212
default_node::DefaultNode,
1313
node::*,
1414
};
15+
pub use tokio;
1516
pub use utils::{env::EnvVar, output::Output};
17+
18+
#[cfg(feature = "derive")]
19+
pub use derive::*;

src/node/default_node.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ pub struct DefaultNode {
5353

5454
impl Node for DefaultNode {
5555
fn id(&self) -> NodeId {
56-
self.id.clone()
56+
self.id
5757
}
5858

5959
fn name(&self) -> NodeName {

src/node/node.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ pub trait Node: Send + Sync {
2929
fn run(&mut self, env: Arc<EnvVar>) -> Output;
3030
}
3131

32-
#[derive(Debug, Hash, PartialEq, Eq, Clone)]
32+
#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)]
3333
pub struct NodeId(pub(crate) usize);
3434

3535
pub type NodeName = String;

0 commit comments

Comments
 (0)