Skip to content

Commit 7c62f7d

Browse files
committed
rework complex enum field conversion
1 parent 3a6296e commit 7c62f7d

File tree

3 files changed

+52
-5
lines changed

3 files changed

+52
-5
lines changed

pyo3-macros-backend/src/pyclass.rs

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1233,9 +1233,16 @@ fn impl_complex_enum_struct_variant_cls(
12331233
complex_enum_variant_field_getter(&variant_cls_type, field_name, field.span, ctx)?;
12341234

12351235
let field_getter_impl = quote! {
1236-
fn #field_name(slf: #pyo3_path::PyRef<Self>) -> #pyo3_path::PyResult<#field_type> {
1236+
fn #field_name(slf: #pyo3_path::PyRef<Self>) -> #pyo3_path::PyResult<#pyo3_path::PyObject> {
1237+
#[allow(unused_imports)]
1238+
use #pyo3_path::impl_::pyclass::Probe;
1239+
let py = slf.py();
12371240
match &*slf.into_super() {
1238-
#enum_name::#variant_ident { #field_name, .. } => ::std::result::Result::Ok(::std::clone::Clone::clone(&#field_name)),
1241+
#enum_name::#variant_ident { #field_name, .. } =>
1242+
#pyo3_path::impl_::pyclass::ConvertField::<
1243+
{ #pyo3_path::impl_::pyclass::IsIntoPyObjectRef::<#field_type>::VALUE },
1244+
{ #pyo3_path::impl_::pyclass::IsIntoPyObject::<#field_type>::VALUE },
1245+
>::convert_field::<#field_type>(#field_name, py),
12391246
_ => ::core::unreachable!("Wrong complex enum variant found in variant wrapper PyClass"),
12401247
}
12411248
}
@@ -1302,9 +1309,16 @@ fn impl_complex_enum_tuple_variant_field_getters(
13021309
})
13031310
.collect();
13041311
let field_getter_impl: syn::ImplItemFn = parse_quote! {
1305-
fn #field_name(slf: #pyo3_path::PyRef<Self>) -> #pyo3_path::PyResult<#field_type> {
1312+
fn #field_name(slf: #pyo3_path::PyRef<Self>) -> #pyo3_path::PyResult<#pyo3_path::PyObject> {
1313+
#[allow(unused_imports)]
1314+
use #pyo3_path::impl_::pyclass::Probe;
1315+
let py = slf.py();
13061316
match &*slf.into_super() {
1307-
#enum_name::#variant_ident ( #(#field_access_tokens), *) => ::std::result::Result::Ok(::std::clone::Clone::clone(&val)),
1317+
#enum_name::#variant_ident ( #(#field_access_tokens), *) =>
1318+
#pyo3_path::impl_::pyclass::ConvertField::<
1319+
{ #pyo3_path::impl_::pyclass::IsIntoPyObjectRef::<#field_type>::VALUE },
1320+
{ #pyo3_path::impl_::pyclass::IsIntoPyObject::<#field_type>::VALUE },
1321+
>::convert_field::<#field_type>(val, py),
13081322
_ => ::core::unreachable!("Wrong complex enum variant found in variant wrapper PyClass"),
13091323
}
13101324
}

src/impl_/pyclass.rs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1521,6 +1521,38 @@ fn pyo3_get_value<
15211521
Ok((unsafe { &*value }).clone().into_py(py).into_ptr())
15221522
}
15231523

1524+
pub struct ConvertField<
1525+
const IMPLEMENTS_INTOPYOBJECT_REF: bool,
1526+
const IMPLEMENTS_INTOPYOBJECT: bool,
1527+
>;
1528+
1529+
impl<const IMPLEMENTS_INTOPYOBJECT: bool> ConvertField<true, IMPLEMENTS_INTOPYOBJECT> {
1530+
#[inline]
1531+
pub fn convert_field<'a, 'py, T>(obj: &'a T, py: Python<'py>) -> PyResult<Py<PyAny>>
1532+
where
1533+
&'a T: IntoPyObject<'py>,
1534+
{
1535+
obj.into_pyobject(py)
1536+
.map(BoundObject::into_any)
1537+
.map(BoundObject::unbind)
1538+
.map_err(Into::into)
1539+
}
1540+
}
1541+
1542+
impl<const IMPLEMENTS_INTOPYOBJECT: bool> ConvertField<false, IMPLEMENTS_INTOPYOBJECT> {
1543+
#[inline]
1544+
pub fn convert_field<'py, T>(obj: &T, py: Python<'py>) -> PyResult<Py<PyAny>>
1545+
where
1546+
T: PyO3GetField<'py>,
1547+
{
1548+
obj.clone()
1549+
.into_pyobject(py)
1550+
.map(BoundObject::into_any)
1551+
.map(BoundObject::unbind)
1552+
.map_err(Into::into)
1553+
}
1554+
}
1555+
15241556
/// Marker trait whether a class implemented a custom comparison. Used to
15251557
/// silence deprecation of autogenerated `__richcmp__` for enums.
15261558
pub trait HasCustomRichCmp {}

tests/test_enum.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,9 +202,10 @@ fn test_renaming_all_enum_variants() {
202202
}
203203

204204
#[pyclass(module = "custom_module")]
205-
#[derive(Debug, Clone)]
205+
#[derive(Debug)]
206206
enum CustomModuleComplexEnum {
207207
Variant(),
208+
Py(Py<PyAny>),
208209
}
209210

210211
#[test]

0 commit comments

Comments
 (0)