diff --git a/pyo3-macros-backend/src/pyclass.rs b/pyo3-macros-backend/src/pyclass.rs index 9bb94e00d6f..2dd4cbfab2e 100644 --- a/pyo3-macros-backend/src/pyclass.rs +++ b/pyo3-macros-backend/src/pyclass.rs @@ -1233,9 +1233,16 @@ fn impl_complex_enum_struct_variant_cls( complex_enum_variant_field_getter(&variant_cls_type, field_name, field.span, ctx)?; let field_getter_impl = quote! { - fn #field_name(slf: #pyo3_path::PyRef) -> #pyo3_path::PyResult<#field_type> { + fn #field_name(slf: #pyo3_path::PyRef) -> #pyo3_path::PyResult<#pyo3_path::PyObject> { + #[allow(unused_imports)] + use #pyo3_path::impl_::pyclass::Probe; + let py = slf.py(); match &*slf.into_super() { - #enum_name::#variant_ident { #field_name, .. } => ::std::result::Result::Ok(::std::clone::Clone::clone(&#field_name)), + #enum_name::#variant_ident { #field_name, .. } => + #pyo3_path::impl_::pyclass::ConvertField::< + { #pyo3_path::impl_::pyclass::IsIntoPyObjectRef::<#field_type>::VALUE }, + { #pyo3_path::impl_::pyclass::IsIntoPyObject::<#field_type>::VALUE }, + >::convert_field::<#field_type>(#field_name, py), _ => ::core::unreachable!("Wrong complex enum variant found in variant wrapper PyClass"), } } @@ -1302,9 +1309,16 @@ fn impl_complex_enum_tuple_variant_field_getters( }) .collect(); let field_getter_impl: syn::ImplItemFn = parse_quote! { - fn #field_name(slf: #pyo3_path::PyRef) -> #pyo3_path::PyResult<#field_type> { + fn #field_name(slf: #pyo3_path::PyRef) -> #pyo3_path::PyResult<#pyo3_path::PyObject> { + #[allow(unused_imports)] + use #pyo3_path::impl_::pyclass::Probe; + let py = slf.py(); match &*slf.into_super() { - #enum_name::#variant_ident ( #(#field_access_tokens), *) => ::std::result::Result::Ok(::std::clone::Clone::clone(&val)), + #enum_name::#variant_ident ( #(#field_access_tokens), *) => + #pyo3_path::impl_::pyclass::ConvertField::< + { #pyo3_path::impl_::pyclass::IsIntoPyObjectRef::<#field_type>::VALUE }, + { #pyo3_path::impl_::pyclass::IsIntoPyObject::<#field_type>::VALUE }, + >::convert_field::<#field_type>(val, py), _ => ::core::unreachable!("Wrong complex enum variant found in variant wrapper PyClass"), } } diff --git a/src/impl_/pyclass.rs b/src/impl_/pyclass.rs index c947df6e432..8e7e8cf844f 100644 --- a/src/impl_/pyclass.rs +++ b/src/impl_/pyclass.rs @@ -1521,6 +1521,38 @@ fn pyo3_get_value< Ok((unsafe { &*value }).clone().into_py(py).into_ptr()) } +pub struct ConvertField< + const IMPLEMENTS_INTOPYOBJECT_REF: bool, + const IMPLEMENTS_INTOPYOBJECT: bool, +>; + +impl ConvertField { + #[inline] + pub fn convert_field<'a, 'py, T>(obj: &'a T, py: Python<'py>) -> PyResult> + where + &'a T: IntoPyObject<'py>, + { + obj.into_pyobject(py) + .map(BoundObject::into_any) + .map(BoundObject::unbind) + .map_err(Into::into) + } +} + +impl ConvertField { + #[inline] + pub fn convert_field<'py, T>(obj: &T, py: Python<'py>) -> PyResult> + where + T: PyO3GetField<'py>, + { + obj.clone() + .into_pyobject(py) + .map(BoundObject::into_any) + .map(BoundObject::unbind) + .map_err(Into::into) + } +} + /// Marker trait whether a class implemented a custom comparison. Used to /// silence deprecation of autogenerated `__richcmp__` for enums. pub trait HasCustomRichCmp {} diff --git a/tests/test_enum.rs b/tests/test_enum.rs index c0a8f8b1e35..40c5f4681a8 100644 --- a/tests/test_enum.rs +++ b/tests/test_enum.rs @@ -202,9 +202,10 @@ fn test_renaming_all_enum_variants() { } #[pyclass(module = "custom_module")] -#[derive(Debug, Clone)] +#[derive(Debug)] enum CustomModuleComplexEnum { Variant(), + Py(Py), } #[test]