Skip to content

Commit 880f76a

Browse files
committed
support enums
1 parent 5ae8d9f commit 880f76a

File tree

1 file changed

+167
-31
lines changed

1 file changed

+167
-31
lines changed

pyo3-macros-backend/src/intopyobject.rs

Lines changed: 167 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
use crate::attributes::{self, get_pyo3_options, CrateAttribute};
22
use crate::utils::Ctx;
33
use proc_macro2::{Span, TokenStream};
4-
use quote::quote;
4+
use quote::{format_ident, quote};
55
use syn::ext::IdentExt;
66
use syn::parse::{Parse, ParseStream};
77
use syn::spanned::Spanned as _;
8-
use syn::{parse_quote, Attribute, DeriveInput, Fields, Index, Result, Token};
8+
use syn::{parse_quote, Attribute, DataEnum, DeriveInput, Fields, Ident, Index, Result, Token};
99

1010
/// Attributes for deriving FromPyObject scoped on containers.
1111
enum ContainerPyO3Attribute {
@@ -112,14 +112,21 @@ enum ContainerType<'a> {
112112
///
113113
/// Either describes a struct or an enum variant.
114114
struct Container<'a> {
115+
path: syn::Path,
116+
receiver: Option<Ident>,
115117
ty: ContainerType<'a>,
116118
}
117119

120+
/// Construct a container based on fields, identifier and attributes.
118121
impl<'a> Container<'a> {
119-
/// Construct a container based on fields, identifier and attributes.
120122
///
121123
/// Fails if the variant has no fields or incompatible attributes.
122-
fn new(fields: &'a Fields, options: ContainerOptions) -> Result<Self> {
124+
fn new(
125+
receiver: Option<Ident>,
126+
fields: &'a Fields,
127+
path: syn::Path,
128+
options: ContainerOptions,
129+
) -> Result<Self> {
123130
let style = match fields {
124131
Fields::Unnamed(unnamed) if !unnamed.unnamed.is_empty() => {
125132
if unnamed.unnamed.iter().count() == 1 {
@@ -171,10 +178,41 @@ impl<'a> Container<'a> {
171178
),
172179
};
173180

174-
let v = Container { ty: style };
181+
let v = Container {
182+
path,
183+
receiver,
184+
ty: style,
185+
};
175186
Ok(v)
176187
}
177188

189+
fn match_pattern(&self) -> TokenStream {
190+
let path = &self.path;
191+
let pattern = match &self.ty {
192+
ContainerType::Struct(fields) => fields
193+
.iter()
194+
.enumerate()
195+
.map(|(i, f)| {
196+
let ident = f.ident;
197+
let new_ident = format_ident!("arg{i}");
198+
quote! {#ident: #new_ident,}
199+
})
200+
.collect::<TokenStream>(),
201+
ContainerType::StructNewtype(field) => {
202+
let ident = field.ident.as_ref().unwrap();
203+
quote!(#ident: arg0)
204+
}
205+
ContainerType::Tuple(fields) => {
206+
let i = (0..fields.len()).map(Index::from);
207+
let idents = (0..fields.len()).map(|i| format_ident!("arg{i}"));
208+
quote! { #(#i: #idents,)* }
209+
}
210+
ContainerType::TupleNewtype(_) => quote!(0: arg0),
211+
};
212+
213+
quote! { #path{ #pattern } }
214+
}
215+
178216
/// Build derivation body for a struct.
179217
fn build(&self, ctx: &Ctx) -> IntoPyObjectImpl {
180218
match &self.ty {
@@ -189,30 +227,47 @@ impl<'a> Container<'a> {
189227
fn build_newtype_struct(&self, field: &syn::Field, ctx: &Ctx) -> IntoPyObjectImpl {
190228
let Ctx { pyo3_path, .. } = ctx;
191229
let ty = &field.ty;
192-
let ident = if let Some(ident) = &field.ident {
193-
quote! {self.#ident}
194-
} else {
195-
quote! {self.0}
196-
};
230+
231+
let unpack = self
232+
.receiver
233+
.as_ref()
234+
.map(|i| {
235+
let pattern = self.match_pattern();
236+
quote! { let #pattern = #i;}
237+
})
238+
.unwrap_or_default();
197239

198240
IntoPyObjectImpl {
199241
target: quote! {<#ty as #pyo3_path::conversion::IntoPyObject<'py>>::Target},
200242
output: quote! {<#ty as #pyo3_path::conversion::IntoPyObject<'py>>::Output},
201243
error: quote! {<#ty as #pyo3_path::conversion::IntoPyObject<'py>>::Error},
202-
body: quote! { <#ty as #pyo3_path::conversion::IntoPyObject<'py>>::into_pyobject(#ident, py) },
244+
body: quote! {
245+
#unpack
246+
<#ty as #pyo3_path::conversion::IntoPyObject<'py>>::into_pyobject(arg0, py)
247+
},
203248
}
204249
}
205250

206251
fn build_struct(&self, fields: &[NamedStructField<'_>], ctx: &Ctx) -> IntoPyObjectImpl {
207252
let Ctx { pyo3_path, .. } = ctx;
208253

254+
let unpack = self
255+
.receiver
256+
.as_ref()
257+
.map(|i| {
258+
let pattern = self.match_pattern();
259+
quote! { let #pattern = #i;}
260+
})
261+
.unwrap_or_default();
262+
209263
let setter = fields
210264
.iter()
211-
.map(|f| {
265+
.enumerate()
266+
.map(|(i, f)| {
212267
let key = f.ident.unraw().to_string();
213-
let ident = f.ident;
268+
let value = format_ident!("arg{i}");
214269
quote! {
215-
dict.set_item(#key, self.#ident)?;
270+
#pyo3_path::types::PyDictMethods::set_item(&dict, #key, #value)?;
216271
}
217272
})
218273
.collect::<TokenStream>();
@@ -222,26 +277,35 @@ impl<'a> Container<'a> {
222277
output: quote!(#pyo3_path::Bound<'py, Self::Target>),
223278
error: quote!(#pyo3_path::PyErr),
224279
body: quote! {
280+
#unpack
225281
let dict = #pyo3_path::types::PyDict::new(py);
226282
#setter
227-
Ok(dict)
283+
::std::result::Result::Ok::<_, Self::Error>(dict)
228284
},
229285
}
230286
}
231287

232288
fn build_tuple_struct(&self, fields: &[TupleStructField], ctx: &Ctx) -> IntoPyObjectImpl {
233289
let Ctx { pyo3_path, .. } = ctx;
234290

291+
let unpack = self
292+
.receiver
293+
.as_ref()
294+
.map(|i| {
295+
let pattern = self.match_pattern();
296+
quote! { let #pattern = #i;}
297+
})
298+
.unwrap_or_default();
299+
235300
let setter = fields
236301
.iter()
237302
.enumerate()
238-
.map(|(index, _)| {
239-
let i = Index {
240-
index: index as u32,
241-
span: Span::call_site(),
242-
};
303+
.map(|(i, _)| {
304+
let value = format_ident!("arg{i}");
243305
quote! {
244-
#pyo3_path::conversion::IntoPyObject::into_pyobject(self.#i, py)?,
306+
#pyo3_path::conversion::IntoPyObject::into_pyobject(#value, py)
307+
.map(#pyo3_path::BoundObject::into_any)
308+
.map(#pyo3_path::BoundObject::into_bound)?,
245309
}
246310
})
247311
.collect::<TokenStream>();
@@ -251,7 +315,75 @@ impl<'a> Container<'a> {
251315
output: quote!(#pyo3_path::Bound<'py, Self::Target>),
252316
error: quote!(#pyo3_path::PyErr),
253317
body: quote! {
254-
Ok(#pyo3_path::types::PyTuple::new(py, [#setter]))
318+
#unpack
319+
::std::result::Result::Ok::<_, Self::Error>(#pyo3_path::types::PyTuple::new(py, [#setter]))
320+
},
321+
}
322+
}
323+
}
324+
325+
/// Describes derivation input of an enum.
326+
struct Enum<'a> {
327+
variants: Vec<Container<'a>>,
328+
}
329+
330+
impl<'a> Enum<'a> {
331+
/// Construct a new enum representation.
332+
///
333+
/// `data_enum` is the `syn` representation of the input enum, `ident` is the
334+
/// `Identifier` of the enum.
335+
fn new(data_enum: &'a DataEnum, ident: &'a Ident) -> Result<Self> {
336+
ensure_spanned!(
337+
!data_enum.variants.is_empty(),
338+
ident.span() => "cannot derive `IntoPyObject` for empty enum"
339+
);
340+
let variants = data_enum
341+
.variants
342+
.iter()
343+
.map(|variant| {
344+
let attrs = ContainerOptions::from_attrs(&variant.attrs)?;
345+
let var_ident = &variant.ident;
346+
Container::new(
347+
None,
348+
&variant.fields,
349+
parse_quote!(#ident::#var_ident),
350+
attrs,
351+
)
352+
})
353+
.collect::<Result<Vec<_>>>()?;
354+
355+
Ok(Enum { variants })
356+
}
357+
358+
/// Build derivation body for enums.
359+
fn build(&self, ctx: &Ctx) -> IntoPyObjectImpl {
360+
let Ctx { pyo3_path, .. } = ctx;
361+
362+
let variants = self
363+
.variants
364+
.iter()
365+
.map(|v| {
366+
let IntoPyObjectImpl { body, .. } = v.build(ctx);
367+
let pattern = v.match_pattern();
368+
quote! {
369+
#pattern => {
370+
{#body}
371+
.map(#pyo3_path::BoundObject::into_any)
372+
.map(#pyo3_path::BoundObject::into_bound)
373+
.map_err(::std::convert::Into::<PyErr>::into)
374+
}
375+
}
376+
})
377+
.collect::<TokenStream>();
378+
379+
IntoPyObjectImpl {
380+
target: quote!(#pyo3_path::types::PyAny),
381+
output: quote!(#pyo3_path::Bound<'py, Self::Target>),
382+
error: quote!(#pyo3_path::PyErr),
383+
body: quote! {
384+
match self {
385+
#variants
386+
}
255387
},
256388
}
257389
}
@@ -291,20 +423,24 @@ pub fn build_derive_into_pyobject(tokens: &DeriveInput) -> Result<TokenStream> {
291423
body,
292424
} = match &tokens.data {
293425
syn::Data::Enum(en) => {
294-
// if options.transparent || options.annotation.is_some() {
295-
// bail_spanned!(tokens.span() => "`transparent` or `annotation` is not supported \
296-
// at top level for enums");
297-
// }
298-
// let en = Enum::new(en, &tokens.ident)?;
299-
// en.build(ctx)
300-
todo!()
426+
if options.transparent.is_some() {
427+
bail_spanned!(tokens.span() => "`transparent` is not supported at top level for enums");
428+
}
429+
let en = Enum::new(en, &tokens.ident)?;
430+
en.build(ctx)
301431
}
302432
syn::Data::Struct(st) => {
303-
let st = Container::new(&st.fields, options)?;
433+
let ident = &tokens.ident;
434+
let st = Container::new(
435+
Some(Ident::new("self", Span::call_site())),
436+
&st.fields,
437+
parse_quote!(#ident),
438+
options,
439+
)?;
304440
st.build(ctx)
305441
}
306442
syn::Data::Union(_) => bail_spanned!(
307-
tokens.span() => "#[derive(FromPyObject)] is not supported for unions"
443+
tokens.span() => "#[derive(`IntoPyObject`)] is not supported for unions"
308444
),
309445
};
310446

0 commit comments

Comments
 (0)