Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cfg features for enum variants #4509

Merged
merged 8 commits into from
Sep 16, 2024
Merged
1 change: 1 addition & 0 deletions newsfragments/4509.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix compile failure when using `#[cfg]` attributes for simple enum variants.
48 changes: 41 additions & 7 deletions pyo3-macros-backend/src/pyclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use crate::attributes::{
use crate::konst::{ConstAttributes, ConstSpec};
use crate::method::{FnArg, FnSpec, PyArg, RegularArg};
use crate::pyfunction::ConstructorAttribute;
use crate::pyimpl::{gen_py_const, PyClassMethodsType};
use crate::pyimpl::{gen_py_const, get_cfg_attributes, PyClassMethodsType};
use crate::pymethod::{
impl_py_getter_def, impl_py_setter_def, MethodAndMethodDef, MethodAndSlotDef, PropertyType,
SlotDef, __GETITEM__, __HASH__, __INT__, __LEN__, __REPR__, __RICHCMP__, __STR__,
Expand Down Expand Up @@ -533,7 +533,12 @@ impl<'a> PyClassSimpleEnum<'a> {
_ => bail_spanned!(variant.span() => "Must be a unit variant."),
};
let options = EnumVariantPyO3Options::take_pyo3_options(&mut variant.attrs)?;
Ok(PyClassEnumUnitVariant { ident, options })
let attrs = get_cfg_attributes(&variant.attrs);
Ok(PyClassEnumUnitVariant {
ident,
options,
attrs,
})
}

let ident = &enum_.ident;
Expand Down Expand Up @@ -693,6 +698,7 @@ impl<'a> EnumVariant for PyClassEnumVariant<'a> {
struct PyClassEnumUnitVariant<'a> {
ident: &'a syn::Ident,
options: EnumVariantPyO3Options,
attrs: Vec<&'a syn::Attribute>,
}

impl<'a> EnumVariant for PyClassEnumUnitVariant<'a> {
Expand Down Expand Up @@ -880,13 +886,14 @@ fn impl_simple_enum(
let (default_repr, default_repr_slot) = {
let variants_repr = variants.iter().map(|variant| {
let variant_name = variant.ident;
let attrs = &variant.attrs;
// Assuming all variants are unit variants because they are the only type we support.
let repr = format!(
"{}.{}",
get_class_python_name(cls, args),
variant.get_python_name(args),
);
quote! { #cls::#variant_name => #repr, }
quote! { #(#attrs)* #cls::#variant_name => #repr, }
});
let mut repr_impl: syn::ImplItemFn = syn::parse_quote! {
fn __pyo3__repr__(&self) -> &'static str {
Expand All @@ -908,7 +915,8 @@ fn impl_simple_enum(
// This implementation allows us to convert &T to #repr_type without implementing `Copy`
let variants_to_int = variants.iter().map(|variant| {
let variant_name = variant.ident;
quote! { #cls::#variant_name => #cls::#variant_name as #repr_type, }
let attrs = &variant.attrs;
quote! { #(#attrs)* #cls::#variant_name => #cls::#variant_name as #repr_type, }
});
let mut int_impl: syn::ImplItemFn = syn::parse_quote! {
fn __pyo3__int__(&self) -> #repr_type {
Expand Down Expand Up @@ -936,7 +944,9 @@ fn impl_simple_enum(
methods_type,
simple_enum_default_methods(
cls,
variants.iter().map(|v| (v.ident, v.get_python_name(args))),
variants
.iter()
.map(|v| (v.ident, v.get_python_name(args), &v.attrs)),
ctx,
),
default_slots,
Expand Down Expand Up @@ -1474,7 +1484,13 @@ fn generate_default_protocol_slot(

fn simple_enum_default_methods<'a>(
cls: &'a syn::Ident,
unit_variant_names: impl IntoIterator<Item = (&'a syn::Ident, Cow<'a, syn::Ident>)>,
unit_variant_names: impl IntoIterator<
Item = (
&'a syn::Ident,
Cow<'a, syn::Ident>,
&'a Vec<&'a syn::Attribute>,
),
>,
ctx: &Ctx,
) -> Vec<MethodAndMethodDef> {
let cls_type = syn::parse_quote!(#cls);
Expand All @@ -1490,7 +1506,25 @@ fn simple_enum_default_methods<'a>(
};
unit_variant_names
.into_iter()
.map(|(var, py_name)| gen_py_const(&cls_type, &variant_to_attribute(var, &py_name), ctx))
.map(|(var, py_name, attrs)| {
let method = gen_py_const(&cls_type, &variant_to_attribute(var, &py_name), ctx);
let associated_method_tokens = method.associated_method;
let method_def_tokens = method.method_def;

let associated_method = quote! {
#(#attrs)*
#associated_method_tokens
};
let method_def = quote! {
#(#attrs)*
#method_def_tokens
};

MethodAndMethodDef {
associated_method,
method_def,
}
})
.collect()
}

Expand Down
2 changes: 1 addition & 1 deletion pyo3-macros-backend/src/pyimpl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ fn submit_methods_inventory(
}
}

fn get_cfg_attributes(attrs: &[syn::Attribute]) -> Vec<&syn::Attribute> {
pub(crate) fn get_cfg_attributes(attrs: &[syn::Attribute]) -> Vec<&syn::Attribute> {
attrs
.iter()
.filter(|attr| attr.path().is_ident("cfg"))
Expand Down
24 changes: 24 additions & 0 deletions tests/test_field_cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,15 @@ struct CfgClass {
pub b: u32,
}

#[pyclass(eq, eq_int)]
#[derive(PartialEq)]
enum CfgSimpleEnum {
#[cfg(any())]
DisabledVariant,
#[cfg(not(any()))]
EnabledVariant,
}

#[test]
fn test_cfg() {
Python::with_gil(|py| {
Expand All @@ -27,3 +36,18 @@ fn test_cfg() {
assert_eq!(b, 3);
});
}

#[test]
fn test_cfg_simple_enum() {
Python::with_gil(|py| {
let simple = py.get_type::<CfgSimpleEnum>();
pyo3::py_run!(
py,
simple,
r#"
assert hasattr(simple, "EnabledVariant")
assert not hasattr(simple, "DisabledVariant")
"#
);
})
}
Loading