Skip to content

Commit f64871e

Browse files
koonpengmxgrey
andauthored
add derive JoinedValue to implement trait automatically (#53)
Signed-off-by: Teo Koon Peng <[email protected]> Signed-off-by: Michael X. Grey <[email protected]> Co-authored-by: Michael X. Grey <[email protected]>
1 parent 5dd9a30 commit f64871e

File tree

8 files changed

+580
-150
lines changed

8 files changed

+580
-150
lines changed

macros/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@ proc-macro = true
1616
[dependencies]
1717
syn = "2.0"
1818
quote = "1.0"
19+
proc-macro2 = "1.0.93"

macros/src/buffer.rs

+266
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
use proc_macro2::TokenStream;
2+
use quote::{format_ident, quote};
3+
use syn::{parse_quote, Field, Generics, Ident, ItemStruct, Type, TypePath};
4+
5+
use crate::Result;
6+
7+
pub(crate) fn impl_joined_value(input_struct: &ItemStruct) -> Result<TokenStream> {
8+
let struct_ident = &input_struct.ident;
9+
let (impl_generics, ty_generics, where_clause) = input_struct.generics.split_for_impl();
10+
let StructConfig {
11+
buffer_struct_name: buffer_struct_ident,
12+
} = StructConfig::from_data_struct(&input_struct);
13+
let buffer_struct_vis = &input_struct.vis;
14+
15+
let (field_ident, _, field_config) = get_fields_map(&input_struct.fields)?;
16+
let buffer: Vec<&Type> = field_config.iter().map(|config| &config.buffer).collect();
17+
let noncopy = field_config.iter().any(|config| config.noncopy);
18+
19+
let buffer_struct: ItemStruct = parse_quote! {
20+
#[allow(non_camel_case_types, unused)]
21+
#buffer_struct_vis struct #buffer_struct_ident #impl_generics #where_clause {
22+
#(
23+
#buffer_struct_vis #field_ident: #buffer,
24+
)*
25+
}
26+
};
27+
28+
let buffer_clone_impl = if noncopy {
29+
// Clone impl for structs with a buffer that is not copyable
30+
quote! {
31+
impl #impl_generics ::std::clone::Clone for #buffer_struct_ident #ty_generics #where_clause {
32+
fn clone(&self) -> Self {
33+
Self {
34+
#(
35+
#field_ident: self.#field_ident.clone(),
36+
)*
37+
}
38+
}
39+
}
40+
}
41+
} else {
42+
// Clone and copy impl for structs with buffers that are all copyable
43+
quote! {
44+
impl #impl_generics ::std::clone::Clone for #buffer_struct_ident #ty_generics #where_clause {
45+
fn clone(&self) -> Self {
46+
*self
47+
}
48+
}
49+
50+
impl #impl_generics ::std::marker::Copy for #buffer_struct_ident #ty_generics #where_clause {}
51+
}
52+
};
53+
54+
let impl_buffer_map_layout = impl_buffer_map_layout(&buffer_struct, &input_struct)?;
55+
let impl_joined = impl_joined(&buffer_struct, &input_struct)?;
56+
57+
let gen = quote! {
58+
impl #impl_generics ::bevy_impulse::JoinedValue for #struct_ident #ty_generics #where_clause {
59+
type Buffers = #buffer_struct_ident #ty_generics;
60+
}
61+
62+
#buffer_struct
63+
64+
#buffer_clone_impl
65+
66+
impl #impl_generics #struct_ident #ty_generics #where_clause {
67+
fn select_buffers(
68+
#(
69+
#field_ident: #buffer,
70+
)*
71+
) -> #buffer_struct_ident #ty_generics {
72+
#buffer_struct_ident {
73+
#(
74+
#field_ident,
75+
)*
76+
}
77+
}
78+
}
79+
80+
#impl_buffer_map_layout
81+
82+
#impl_joined
83+
};
84+
85+
Ok(gen.into())
86+
}
87+
88+
/// Code that are currently unused but could be used in the future, move them out of this mod if
89+
/// they are ever used.
90+
#[allow(unused)]
91+
mod _unused {
92+
use super::*;
93+
94+
/// Converts a list of generics to a [`PhantomData`] TypePath.
95+
/// e.g. `::std::marker::PhantomData<fn(T,)>`
96+
fn to_phantom_data(generics: &Generics) -> TypePath {
97+
let lifetimes: Vec<Type> = generics
98+
.lifetimes()
99+
.map(|lt| {
100+
let lt = &lt.lifetime;
101+
let ty: Type = parse_quote! { & #lt () };
102+
ty
103+
})
104+
.collect();
105+
let ty_params: Vec<&Ident> = generics.type_params().map(|ty| &ty.ident).collect();
106+
parse_quote! { ::std::marker::PhantomData<fn(#(#lifetimes,)* #(#ty_params,)*)> }
107+
}
108+
}
109+
110+
struct StructConfig {
111+
buffer_struct_name: Ident,
112+
}
113+
114+
impl StructConfig {
115+
fn from_data_struct(data_struct: &ItemStruct) -> Self {
116+
let mut config = Self {
117+
buffer_struct_name: format_ident!("__bevy_impulse_{}_Buffers", data_struct.ident),
118+
};
119+
120+
let attr = data_struct
121+
.attrs
122+
.iter()
123+
.find(|attr| attr.path().is_ident("joined"));
124+
125+
if let Some(attr) = attr {
126+
attr.parse_nested_meta(|meta| {
127+
if meta.path.is_ident("buffers_struct_name") {
128+
config.buffer_struct_name = meta.value()?.parse()?;
129+
}
130+
Ok(())
131+
})
132+
// panic if attribute is malformed, this will result in a compile error which is intended.
133+
.unwrap();
134+
}
135+
136+
config
137+
}
138+
}
139+
140+
struct FieldConfig {
141+
buffer: Type,
142+
noncopy: bool,
143+
}
144+
145+
impl FieldConfig {
146+
fn from_field(field: &Field) -> Self {
147+
let ty = &field.ty;
148+
let mut config = Self {
149+
buffer: parse_quote! { ::bevy_impulse::Buffer<#ty> },
150+
noncopy: false,
151+
};
152+
153+
for attr in field
154+
.attrs
155+
.iter()
156+
.filter(|attr| attr.path().is_ident("joined"))
157+
{
158+
attr.parse_nested_meta(|meta| {
159+
if meta.path.is_ident("buffer") {
160+
config.buffer = meta.value()?.parse()?;
161+
}
162+
if meta.path.is_ident("noncopy_buffer") {
163+
config.noncopy = true;
164+
}
165+
Ok(())
166+
})
167+
// panic if attribute is malformed, this will result in a compile error which is intended.
168+
.unwrap();
169+
}
170+
171+
config
172+
}
173+
}
174+
175+
fn get_fields_map(fields: &syn::Fields) -> Result<(Vec<&Ident>, Vec<&Type>, Vec<FieldConfig>)> {
176+
match fields {
177+
syn::Fields::Named(data) => {
178+
let mut idents = Vec::new();
179+
let mut types = Vec::new();
180+
let mut configs = Vec::new();
181+
for field in &data.named {
182+
let ident = field
183+
.ident
184+
.as_ref()
185+
.ok_or("expected named fields".to_string())?;
186+
idents.push(ident);
187+
types.push(&field.ty);
188+
configs.push(FieldConfig::from_field(field));
189+
}
190+
Ok((idents, types, configs))
191+
}
192+
_ => return Err("expected named fields".to_string()),
193+
}
194+
}
195+
196+
/// Params:
197+
/// buffer_struct: The struct to implement `BufferMapLayout`.
198+
/// item_struct: The struct which `buffer_struct` is derived from.
199+
fn impl_buffer_map_layout(
200+
buffer_struct: &ItemStruct,
201+
item_struct: &ItemStruct,
202+
) -> Result<proc_macro2::TokenStream> {
203+
let struct_ident = &buffer_struct.ident;
204+
let (impl_generics, ty_generics, where_clause) = buffer_struct.generics.split_for_impl();
205+
let (field_ident, _, field_config) = get_fields_map(&item_struct.fields)?;
206+
let buffer: Vec<&Type> = field_config.iter().map(|config| &config.buffer).collect();
207+
let map_key: Vec<String> = field_ident.iter().map(|v| v.to_string()).collect();
208+
209+
Ok(quote! {
210+
impl #impl_generics ::bevy_impulse::BufferMapLayout for #struct_ident #ty_generics #where_clause {
211+
fn buffer_list(&self) -> ::smallvec::SmallVec<[AnyBuffer; 8]> {
212+
use smallvec::smallvec;
213+
smallvec![#(
214+
self.#field_ident.as_any_buffer(),
215+
)*]
216+
}
217+
218+
fn try_from_buffer_map(buffers: &::bevy_impulse::BufferMap) -> Result<Self, ::bevy_impulse::IncompatibleLayout> {
219+
let mut compatibility = ::bevy_impulse::IncompatibleLayout::default();
220+
#(
221+
let #field_ident = if let Ok(buffer) = compatibility.require_buffer_type::<#buffer>(#map_key, buffers) {
222+
buffer
223+
} else {
224+
return Err(compatibility);
225+
};
226+
)*
227+
228+
Ok(Self {
229+
#(
230+
#field_ident,
231+
)*
232+
})
233+
}
234+
}
235+
}
236+
.into())
237+
}
238+
239+
/// Params:
240+
/// joined_struct: The struct to implement `Joined`.
241+
/// item_struct: The associated `Item` type to use for the `Joined` implementation.
242+
fn impl_joined(
243+
joined_struct: &ItemStruct,
244+
item_struct: &ItemStruct,
245+
) -> Result<proc_macro2::TokenStream> {
246+
let struct_ident = &joined_struct.ident;
247+
let item_struct_ident = &item_struct.ident;
248+
let (impl_generics, ty_generics, where_clause) = item_struct.generics.split_for_impl();
249+
let (field_ident, _, _) = get_fields_map(&item_struct.fields)?;
250+
251+
Ok(quote! {
252+
impl #impl_generics ::bevy_impulse::Joined for #struct_ident #ty_generics #where_clause {
253+
type Item = #item_struct_ident #ty_generics;
254+
255+
fn pull(&self, session: ::bevy_ecs::prelude::Entity, world: &mut ::bevy_ecs::prelude::World) -> Result<Self::Item, ::bevy_impulse::OperationError> {
256+
#(
257+
let #field_ident = self.#field_ident.pull(session, world)?;
258+
)*
259+
260+
Ok(Self::Item {#(
261+
#field_ident,
262+
)*})
263+
}
264+
}
265+
}.into())
266+
}

macros/src/lib.rs

+19-1
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,12 @@
1515
*
1616
*/
1717

18+
mod buffer;
19+
use buffer::impl_joined_value;
20+
1821
use proc_macro::TokenStream;
1922
use quote::quote;
20-
use syn::DeriveInput;
23+
use syn::{parse_macro_input, DeriveInput, ItemStruct};
2124

2225
#[proc_macro_derive(Stream)]
2326
pub fn simple_stream_macro(item: TokenStream) -> TokenStream {
@@ -58,3 +61,18 @@ pub fn delivery_label_macro(item: TokenStream) -> TokenStream {
5861
}
5962
.into()
6063
}
64+
65+
/// The result error is the compiler error message to be displayed.
66+
type Result<T> = std::result::Result<T, String>;
67+
68+
#[proc_macro_derive(JoinedValue, attributes(joined))]
69+
pub fn derive_joined_value(input: TokenStream) -> TokenStream {
70+
let input = parse_macro_input!(input as ItemStruct);
71+
match impl_joined_value(&input) {
72+
Ok(tokens) => tokens.into(),
73+
Err(msg) => quote! {
74+
compile_error!(#msg);
75+
}
76+
.into(),
77+
}
78+
}

src/buffer/any_buffer.rs

+15
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,10 @@ impl AnyBuffer {
121121
.ok()
122122
.map(|x| *x)
123123
}
124+
125+
pub fn as_any_buffer(&self) -> Self {
126+
self.clone().into()
127+
}
124128
}
125129

126130
impl<T: 'static + Send + Sync + Any> From<Buffer<T>> for AnyBuffer {
@@ -857,6 +861,17 @@ impl<T: 'static + Send + Sync> AnyBufferAccessImpl<T> {
857861
})),
858862
);
859863

864+
// Allow downcasting back to the original Buffer<T>
865+
buffer_downcasts.insert(
866+
TypeId::of::<Buffer<T>>(),
867+
Box::leak(Box::new(|location| -> Box<dyn Any> {
868+
Box::new(Buffer::<T> {
869+
location,
870+
_ignore: Default::default(),
871+
})
872+
})),
873+
);
874+
860875
let mut key_downcasts: HashMap<_, KeyDowncastRef> = HashMap::new();
861876

862877
// Automatically register a downcast to AnyBufferKey

0 commit comments

Comments
 (0)