diff --git a/src/dataclass_binder/_impl.py b/src/dataclass_binder/_impl.py index e89b101..8c94883 100644 --- a/src/dataclass_binder/_impl.py +++ b/src/dataclass_binder/_impl.py @@ -19,6 +19,7 @@ ) from dataclasses import MISSING, Field, asdict, dataclass, fields, is_dataclass, replace from datetime import date, datetime, time, timedelta +from enum import Enum from functools import reduce from importlib import import_module from inspect import cleandoc, get_annotations, getmodule, getsource, isabstract @@ -28,10 +29,21 @@ from typing import TYPE_CHECKING, Any, BinaryIO, ClassVar, Generic, TypeVar, Union, cast, get_args, get_origin, overload from weakref import WeakKeyDictionary -if sys.version_info < (3, 11): - import tomli as tomllib # pragma: no cover -else: - import tomllib # pragma: no cover +if sys.version_info < (3, 11): # pragma: no cover + import tomli as tomllib + + if TYPE_CHECKING: + + class ReprEnum(Enum): + ... + + else: + from enum import IntEnum, IntFlag + + ReprEnum = IntEnum | IntFlag +else: # pragma: no cover + import tomllib # noqa: I001 + from enum import ReprEnum def _collect_type(field_type: type, context: str) -> type | Binder[Any]: @@ -49,7 +61,7 @@ def _collect_type(field_type: type, context: str) -> type | Binder[Any]: return object elif not isinstance(field_type, type): raise TypeError(f"Annotation for field '{context}' is not a type") - elif issubclass(field_type, str | int | float | date | time | timedelta | ModuleType | Path): + elif issubclass(field_type, str | int | float | date | time | timedelta | ModuleType | Path | Enum): return field_type elif field_type is type: # https://github.com/python/mypy/issues/13026 @@ -209,7 +221,6 @@ def _check_field(field: Field, field_type: type, context: str) -> None: @dataclass(slots=True) class _ClassInfo(Generic[T]): - _cache: ClassVar[MutableMapping[type[Any], _ClassInfo[Any]]] = WeakKeyDictionary() dataclass: type[T] @@ -314,6 +325,24 @@ def _bind_to_single_type(self, value: object, field_type: type, context: str) -> if not isinstance(value, str): raise TypeError(f"Expected TOML string for path '{context}', got '{type(value).__name__}'") return field_type(value) + elif issubclass(field_type, ReprEnum): + if issubclass(field_type, int) and not isinstance(value, int): + raise TypeError(f"Value for '{context}': '{value}' is not of type int") + if issubclass(field_type, str) and not isinstance(value, str): + raise TypeError(f"Value for '{context}': '{value}' is not of type str") + return field_type(value) + elif issubclass(field_type, Enum): + if not isinstance(value, str): + raise TypeError( + f"Value for '{context}': '{value}' is not a valid key for enum '{field_type}', " + f"must be of type str" + ) + for enum_value in field_type: + if enum_value.name.lower() == value.lower(): + return enum_value + raise TypeError( + f"Value for '{context}': '{value}' is not a valid key for enum '{field_type}', could not be found" + ) elif isinstance(value, field_type) and ( type(value) is not bool or field_type is bool or field_type is object ): @@ -668,6 +697,13 @@ def format_toml_pair(key: str, value: object) -> str: def _to_toml_pair(value: object) -> tuple[str | None, Any]: """Return a TOML-compatible suffix and value pair with the data from the given rich value object.""" match value: + # enums have to be checked before basic types because for instance + # IntEnum is also of type int + case Enum(): + if isinstance(value, ReprEnum): + return None, value.value + else: + return None, value.name.lower() case str() | int() | float() | date() | time() | Path(): # note: 'bool' is a subclass of 'int' return None, value case timedelta(): diff --git a/tests/test_formatting.py b/tests/test_formatting.py index 48d39c5..2057276 100644 --- a/tests/test_formatting.py +++ b/tests/test_formatting.py @@ -3,6 +3,7 @@ from collections.abc import Sequence from dataclasses import dataclass, field from datetime import date, datetime, time, timedelta +from enum import Enum, IntEnum, auto from io import BytesIO from pathlib import Path from types import ModuleType, NoneType, UnionType @@ -859,3 +860,35 @@ def test_format_template_no_module(sourceless_class: type[Any]) -> None: value = 0 """.strip() ) + + +class Verbosity(Enum): + QUIET = auto() + NORMAL = auto() + DETAILED = auto() + + +class IntVerbosity(IntEnum): + QUIET = 0 + NORMAL = 1 + DETAILED = 2 + + +def test_format_with_enums() -> None: + @dataclass + class Log: + message: str + verbosity: Verbosity + verbosity_level: IntVerbosity + + log = Log("Hello, World", Verbosity.DETAILED, IntVerbosity.DETAILED) + + template = "\n".join(Binder(log).format_toml()) + + assert template == ( + """ +message = 'Hello, World' +verbosity = 'detailed' +verbosity-level = 2 +""".strip() + ) diff --git a/tests/test_parsing.py b/tests/test_parsing.py index 79a4a93..28e4495 100644 --- a/tests/test_parsing.py +++ b/tests/test_parsing.py @@ -5,6 +5,7 @@ from contextlib import contextmanager from dataclasses import FrozenInstanceError, dataclass, field from datetime import date, datetime, time, timedelta +from enum import Enum, IntEnum from io import BytesIO from pathlib import Path from types import ModuleType @@ -1084,3 +1085,144 @@ def test_bind_merge() -> None: assert merged_config.flag is True assert merged_config.nested1.value == "sun" assert merged_config.nested2.value == "cheese" + + +class Color(Enum): + RED = "#FF0000" + GREEN = "#00FF00" + BLUE = "#0000FF" + + +class Number(IntEnum): + ONE = 1 + TWO = 2 + THREE = 3 + + +class Weekday(Enum): + MONDAY = 0 + TUESDAY = 1 + WEDNESDAY = 2 + THURSDAY = 3 + FRIDAY = 4 + SATURDAY = 5 + SUNDAY = 6 + + +@dataclass +class EnumEntry: + name: str + color: Color + number: Number + + +def test_enums() -> None: + @dataclass + class Config: + best_colors: list[Color] + best_numbers: list[Number] + entries: list[EnumEntry] + + with stream_text( + """ + best-colors = ["red", "green", "blue"] + best-numbers = [1, 2, 3] + + [[entries]] + name = "Entry 1" + color = "blue" + number = 2 + + [[entries]] + name = "Entry 2" + color = "red" + number = 1 + """ + ) as stream: + config = Binder(Config).parse_toml(stream) + + assert len(config.best_colors) == 3 + assert len(config.best_numbers) == 3 + assert config.best_colors.index(Color.RED) == 0 + assert config.best_colors.index(Color.GREEN) == 1 + assert config.best_colors.index(Color.BLUE) == 2 + assert all(num in config.best_numbers for num in Number) + assert len(config.entries) == 2 + assert config.entries[0].color is Color.BLUE + assert config.entries[0].number is Number.TWO + assert config.entries[1].color is Color.RED + assert config.entries[1].number is Number.ONE + + +def test_enum_with_invalid_value() -> None: + @dataclass + class UserFavorites: + favorite_number: Number + favorite_color: Color + + with stream_text( + """ + favorite-number = "one" + favorite-color = "red" + """ + ) as stream, pytest.raises(TypeError): + Binder(UserFavorites).parse_toml(stream) + + +def test_enum_keys_being_case_insensitive() -> None: + @dataclass + class Theme: + primary: Color + secondary: Color + accent: Color + + with stream_text( + """ + primary = "RED" + secondary = "green" + accent = "blUE" + """ + ) as stream: + theme = Binder(Theme).parse_toml(stream) + + assert theme.primary is Color.RED + assert theme.secondary is Color.GREEN + assert theme.accent is Color.BLUE + + +def test_key_based_enum_while_using_value_ident() -> None: + @dataclass + class UserColorPreference: + primary: Color + secondary: Color + + with stream_text( + """ + primary = "#FF0000" + seconadry = "blue" + """ + ) as stream, pytest.raises(TypeError): + Binder(UserColorPreference).parse_toml(stream) + + +def test_enum_parsing_with_invalid_key_type() -> None: + @dataclass + class UserPrefs: + name: str + start_of_the_week: Weekday + + with stream_text( + """ + name = "Peter Testuser" + start-of-the-week = "sunday" + """ + ) as stream: + Binder(UserPrefs).parse_toml(stream) + + with stream_text( + """ + name = "Peter Testuser" + start-of-the-week = 1 + """ + ) as stream, pytest.raises(TypeError): + Binder(UserPrefs).parse_toml(stream)