Skip to content

Commit

Permalink
rework complex enum field conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
Icxolu committed Nov 10, 2024
1 parent 3a6296e commit 7c62f7d
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 5 deletions.
22 changes: 18 additions & 4 deletions pyo3-macros-backend/src/pyclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1233,9 +1233,16 @@ fn impl_complex_enum_struct_variant_cls(
complex_enum_variant_field_getter(&variant_cls_type, field_name, field.span, ctx)?;

let field_getter_impl = quote! {
fn #field_name(slf: #pyo3_path::PyRef<Self>) -> #pyo3_path::PyResult<#field_type> {
fn #field_name(slf: #pyo3_path::PyRef<Self>) -> #pyo3_path::PyResult<#pyo3_path::PyObject> {
#[allow(unused_imports)]
use #pyo3_path::impl_::pyclass::Probe;
let py = slf.py();
match &*slf.into_super() {
#enum_name::#variant_ident { #field_name, .. } => ::std::result::Result::Ok(::std::clone::Clone::clone(&#field_name)),
#enum_name::#variant_ident { #field_name, .. } =>
#pyo3_path::impl_::pyclass::ConvertField::<
{ #pyo3_path::impl_::pyclass::IsIntoPyObjectRef::<#field_type>::VALUE },
{ #pyo3_path::impl_::pyclass::IsIntoPyObject::<#field_type>::VALUE },
>::convert_field::<#field_type>(#field_name, py),
_ => ::core::unreachable!("Wrong complex enum variant found in variant wrapper PyClass"),
}
}
Expand Down Expand Up @@ -1302,9 +1309,16 @@ fn impl_complex_enum_tuple_variant_field_getters(
})
.collect();
let field_getter_impl: syn::ImplItemFn = parse_quote! {
fn #field_name(slf: #pyo3_path::PyRef<Self>) -> #pyo3_path::PyResult<#field_type> {
fn #field_name(slf: #pyo3_path::PyRef<Self>) -> #pyo3_path::PyResult<#pyo3_path::PyObject> {
#[allow(unused_imports)]
use #pyo3_path::impl_::pyclass::Probe;
let py = slf.py();
match &*slf.into_super() {
#enum_name::#variant_ident ( #(#field_access_tokens), *) => ::std::result::Result::Ok(::std::clone::Clone::clone(&val)),
#enum_name::#variant_ident ( #(#field_access_tokens), *) =>
#pyo3_path::impl_::pyclass::ConvertField::<
{ #pyo3_path::impl_::pyclass::IsIntoPyObjectRef::<#field_type>::VALUE },
{ #pyo3_path::impl_::pyclass::IsIntoPyObject::<#field_type>::VALUE },
>::convert_field::<#field_type>(val, py),
_ => ::core::unreachable!("Wrong complex enum variant found in variant wrapper PyClass"),
}
}
Expand Down
32 changes: 32 additions & 0 deletions src/impl_/pyclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1521,6 +1521,38 @@ fn pyo3_get_value<
Ok((unsafe { &*value }).clone().into_py(py).into_ptr())
}

pub struct ConvertField<
const IMPLEMENTS_INTOPYOBJECT_REF: bool,
const IMPLEMENTS_INTOPYOBJECT: bool,
>;

impl<const IMPLEMENTS_INTOPYOBJECT: bool> ConvertField<true, IMPLEMENTS_INTOPYOBJECT> {
#[inline]
pub fn convert_field<'a, 'py, T>(obj: &'a T, py: Python<'py>) -> PyResult<Py<PyAny>>
where
&'a T: IntoPyObject<'py>,
{
obj.into_pyobject(py)
.map(BoundObject::into_any)
.map(BoundObject::unbind)
.map_err(Into::into)
}
}

impl<const IMPLEMENTS_INTOPYOBJECT: bool> ConvertField<false, IMPLEMENTS_INTOPYOBJECT> {
#[inline]
pub fn convert_field<'py, T>(obj: &T, py: Python<'py>) -> PyResult<Py<PyAny>>
where
T: PyO3GetField<'py>,
{
obj.clone()
.into_pyobject(py)
.map(BoundObject::into_any)
.map(BoundObject::unbind)
.map_err(Into::into)
}
}

/// Marker trait whether a class implemented a custom comparison. Used to
/// silence deprecation of autogenerated `__richcmp__` for enums.
pub trait HasCustomRichCmp {}
Expand Down
3 changes: 2 additions & 1 deletion tests/test_enum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,10 @@ fn test_renaming_all_enum_variants() {
}

#[pyclass(module = "custom_module")]
#[derive(Debug, Clone)]
#[derive(Debug)]
enum CustomModuleComplexEnum {
Variant(),
Py(Py<PyAny>),
}

#[test]
Expand Down

0 comments on commit 7c62f7d

Please sign in to comment.