diff --git a/include/pybind11/pybind11.h b/include/pybind11/pybind11.h index 104f32206d..90e442acca 100644 --- a/include/pybind11/pybind11.h +++ b/include/pybind11/pybind11.h @@ -1569,6 +1569,59 @@ inline str enum_name(handle arg) { return "???"; } +class enum_meta_info { +public: + static pybind11::object enum_meta_cls() { + return get().enum_meta_cls_; + } + + static pybind11::object enum_base_cls() { + return get().enum_base_cls_; + } + +private: + template + friend T& pybind11::get_or_create_shared_data(const std::string&); + + static const enum_meta_info& get() { + return pybind11::get_or_create_shared_data( + "_pybind11_enum_meta_info"); + } + + enum_meta_info() { + handle copy = pybind11::module::import("copy").attr("copy"); + locals_ = copy(pybind11::globals()); + locals_["pybind11_meta_cls"] = reinterpret_borrow( + reinterpret_cast(get_internals().default_metaclass)); + locals_["pybind11_base_cls"] = reinterpret_borrow( + get_internals().instance_base); + // TODO: Make the base class work. + const char code[] = R"""( +pybind11_enum_base_cls = None + +class pybind11_enum_meta_cls(pybind11_meta_cls): + is_pybind11_enum = True + + def __iter__(cls): + return iter(cls.__members__.values()) + + def __len__(cls): + return len(cls.__members__) +)"""; + PyObject *result = PyRun_String( + code, Py_file_input, locals_.ptr(), locals_.ptr()); + if (result == nullptr) { + throw error_already_set(); + } + enum_meta_cls_ = locals_["pybind11_enum_meta_cls"]; + enum_base_cls_ = locals_["pybind11_enum_base_cls"]; + } + + pybind11::object enum_meta_cls_; + pybind11::object enum_base_cls_; + pybind11::dict locals_; +}; + struct enum_base { enum_base(handle base, handle parent) : m_base(base), m_parent(parent) { } @@ -1725,12 +1778,19 @@ template class enum_ : public class_ { template enum_(const handle &scope, const char *name, const Extra&... extra) - : class_(scope, name, extra...), m_base(*this, scope) { + : class_( + scope, name, + // Can't re-declare base type??? + // detail::enum_meta_info::enum_base_cls(), + pybind11::metaclass(detail::enum_meta_info::enum_meta_cls()), + extra...), + m_base(*this, scope) { constexpr bool is_arithmetic = detail::any_of...>::value; constexpr bool is_convertible = std::is_convertible::value; m_base.init(is_arithmetic, is_convertible); def(init([](Scalar i) { return static_cast(i); }), arg("value")); + def_property_readonly("value", [](Type value) { return (Scalar) value; }); def("__int__", [](Type value) { return (Scalar) value; }); #if PY_MAJOR_VERSION < 3 def("__long__", [](Type value) { return (Scalar) value; }); diff --git a/tests/test_enum.py b/tests/test_enum.py index f3cce8bce5..080440b5bf 100644 --- a/tests/test_enum.py +++ b/tests/test_enum.py @@ -1,7 +1,44 @@ # -*- coding: utf-8 -*- import pytest + +import env # noqa: F401 + from pybind11_tests import enums as m +if env.PY2: + enum = None +else: + import enum + + +def is_enum(cls): + """Example showing how to recognize a class as pybind11 enum or a + PEP 345 enum.""" + if enum is not None: + if issubclass(cls, enum.Enum): + return True + return getattr(cls, "is_pybind11_enum", False) + + +def test_pep435(): + # See #2332. + cls = m.UnscopedEnum + names = ("EOne", "ETwo", "EThree") + values = (cls.EOne, cls.ETwo, cls.EThree) + raw_values = (1, 2, 3) + + assert len(cls) == len(names) + assert list(cls) == list(values) + assert is_enum(cls) + if enum: + assert not issubclass(cls, enum.Enum) + for name, value, raw_value in zip(names, values, raw_values): + assert isinstance(value, cls) + if enum: + assert not isinstance(value, enum.Enum) + assert value.name == name + assert value.value == raw_value + def test_unscoped_enum(): assert str(m.UnscopedEnum.EOne) == "UnscopedEnum.EOne" @@ -13,15 +50,24 @@ def test_unscoped_enum(): # name property assert m.UnscopedEnum.EOne.name == "EOne" + assert m.UnscopedEnum.EOne.value == 1 assert m.UnscopedEnum.ETwo.name == "ETwo" - assert m.EOne.name == "EOne" - # name readonly + assert m.UnscopedEnum.ETwo.value == 2 + assert m.EOne is m.UnscopedEnum.EOne + # name, value readonly with pytest.raises(AttributeError): m.UnscopedEnum.EOne.name = "" - # name returns a copy - foo = m.UnscopedEnum.EOne.name - foo = "bar" + with pytest.raises(AttributeError): + m.UnscopedEnum.EOne.value = 10 + # name, value returns a copy + # TODO: Neither the name nor value tests actually check against aliasing. + # Use a mutable type that has reference semantics. + nonaliased_name = m.UnscopedEnum.EOne.name + nonaliased_name = "bar" # noqa: F841 assert m.UnscopedEnum.EOne.name == "EOne" + nonaliased_value = m.UnscopedEnum.EOne.value + nonaliased_value = 10 # noqa: F841 + assert m.UnscopedEnum.EOne.value == 1 # __members__ property assert m.UnscopedEnum.__members__ == { @@ -33,8 +79,8 @@ def test_unscoped_enum(): with pytest.raises(AttributeError): m.UnscopedEnum.__members__ = {} # __members__ returns a copy - foo = m.UnscopedEnum.__members__ - foo["bar"] = "baz" + nonaliased_members = m.UnscopedEnum.__members__ + nonaliased_members["bar"] = "baz" assert m.UnscopedEnum.__members__ == { "EOne": m.UnscopedEnum.EOne, "ETwo": m.UnscopedEnum.ETwo,