Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Enums #29 #44

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
7 changes: 6 additions & 1 deletion src/dataclass_binder/_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)
from dataclasses import MISSING, Field, asdict, dataclass, fields, is_dataclass, replace
from datetime import date, datetime, time, timedelta
from enum import Enum
from functools import reduce
from importlib import import_module
from inspect import cleandoc, get_annotations, getmodule, getsource, isabstract
Expand Down Expand Up @@ -49,7 +50,7 @@ def _collect_type(field_type: type, context: str) -> type | Binder[Any]:
return object
elif not isinstance(field_type, type):
raise TypeError(f"Annotation for field '{context}' is not a type")
elif issubclass(field_type, str | int | float | date | time | timedelta | ModuleType | Path):
elif issubclass(field_type, str | int | float | date | time | timedelta | ModuleType | Path | Enum):
return field_type
elif field_type is type:
# https://github.com/python/mypy/issues/13026
Expand Down Expand Up @@ -314,6 +315,8 @@ def _bind_to_single_type(self, value: object, field_type: type, context: str) ->
if not isinstance(value, str):
raise TypeError(f"Expected TOML string for path '{context}', got '{type(value).__name__}'")
return field_type(value)
elif issubclass(field_type, Enum):
return field_type(value)
elif isinstance(value, field_type) and (
type(value) is not bool or field_type is bool or field_type is object
):
Expand Down Expand Up @@ -697,6 +700,8 @@ def _to_toml_pair(value: object) -> tuple[str | None, Any]:
return "-weeks", days // 7
else:
return "-days", days
case Enum():
return None, value.value
case ModuleType():
return None, value.__name__
case Mapping():
Expand Down
27 changes: 27 additions & 0 deletions tests/test_formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections.abc import Sequence
from dataclasses import dataclass, field
from datetime import date, datetime, time, timedelta
from enum import Enum
from io import BytesIO
from pathlib import Path
from types import ModuleType, NoneType, UnionType
Expand Down Expand Up @@ -859,3 +860,29 @@ def test_format_template_no_module(sourceless_class: type[Any]) -> None:
value = 0
""".strip()
)


class IssueStatus(Enum):
OPEN = "open"
REJECTED = "rejected"
COMPLETED = "completed"


def test_format_with_enums() -> None:
@dataclass
class Issue:
issue_id: int
title: str
status: IssueStatus

issue = Issue(1, "Test", IssueStatus.OPEN)

template = "\n".join(Binder(issue).format_toml())

assert template == (
"""
issue-id = 1
title = 'Test'
status = 'open'
""".strip()
)
58 changes: 58 additions & 0 deletions tests/test_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from contextlib import contextmanager
from dataclasses import FrozenInstanceError, dataclass, field
from datetime import date, datetime, time, timedelta
from enum import Enum
from io import BytesIO
from pathlib import Path
from types import ModuleType
Expand Down Expand Up @@ -1084,3 +1085,60 @@ def test_bind_merge() -> None:
assert merged_config.flag is True
assert merged_config.nested1.value == "sun"
assert merged_config.nested2.value == "cheese"


class Color(Enum):
RED = "red"
BLUE = "blue"
GREEN = "green"


class Number(Enum):
ONE = 1
TWO = 2
THREE = 3


@dataclass
class EnumEntry:
name: str
color: Color
number: Number


def test_enums() -> None:
@dataclass
class Config:
best_colors: list[Color]
best_numbers: list[Number]
entries: list[EnumEntry]

with stream_text(
"""
best-colors = ["red", "green", "blue"]
best-numbers = [1, 2, 3]

[[entries]]
name = "Entry 1"
color = "blue"
number = 2

[[entries]]
name = "Entry 2"
color = "red"
number = 1
"""
) as stream:
config = Binder(Config).parse_toml(stream)

assert len(config.best_colors) == 3
assert len(config.best_numbers) == 3
assert config.best_colors.index(Color.RED) == 0
assert config.best_colors.index(Color.GREEN) == 1
assert config.best_colors.index(Color.BLUE) == 2
assert all(num in config.best_numbers for num in Number)
assert len(config.entries) == 2
assert config.entries[0].color == Color.BLUE
assert config.entries[0].number == Number.TWO
assert config.entries[1].color == Color.RED
assert config.entries[1].number == Number.ONE