diff --git a/pytest_factoryboy/fixture.py b/pytest_factoryboy/fixture.py index 782806e..6c7263a 100644 --- a/pytest_factoryboy/fixture.py +++ b/pytest_factoryboy/fixture.py @@ -210,6 +210,51 @@ def make_declaration_fixturedef( related: list[str], ) -> Callable[..., Any]: """Create the FixtureDef for a factory declaration.""" + if isinstance(value, factory.Maybe): + if value.FACTORY_BUILDER_PHASE != factory.enums.BuilderPhase.ATTRIBUTE_RESOLUTION: + raise NotImplementedError("Maybe declarations are not supported with post-generation declarations.") + + if not isinstance(value.decider, factory.SelfAttribute): + raise NotImplementedError("Maybe declarations are only supported with SelfAttribute deciders.") + + if value.decider.depth != 0: + raise NotImplementedError("Maybe declarations are only supported with SelfAttributes of depth 0.") + + yes_declaration = value.yes + if yes_declaration != factory.declarations.SKIP: + yes_declaration = make_declaration_fixturedef( + attr_name + "__yes_declaration", + value=value.yes_declaration, + factory_class=factory_class, + ) + no_declaration = value.no + if no_declaration != factory.declarations.SKIP: + no_declaration = make_declaration_fixturedef( + attr_name + "__no_declaration", + value=value.no_declaration, + factory_class=factory_class, + ) + + new_maybe = factory.Maybe( + decider=value.decider, + yes_declaration=yes_declaration, + no_declaration=no_declaration, + ) + + # we want to generate a fixture like + # @pytest.fixture + # def user__company__decider(user__is_staff): + # return user__is_staff + # @pytest.fixture + # def user__company__yes_declaration(user__company__decider): + # return ... + # @pytest.fixture + # def user__company__no_declaration(user__company__decider): + # return None + # @pytest.fixture + # def user__company(request, user__company__decider, user__company__yes_declaration, user__company__no_declaration): + # declaration = yes_declaration if user__company__decider else no_declaration + # return evaluate(request, request.getfixturevalue(declaration)) if isinstance(value, (factory.SubFactory, factory.RelatedFactory)): subfactory_class = value.get_factory() subfactory_deps = get_deps(subfactory_class, factory_class) diff --git a/tests/maybe/__init__.py b/tests/maybe/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/maybe/test_subfactory.py b/tests/maybe/test_subfactory.py new file mode 100644 index 0000000..bfd2cff --- /dev/null +++ b/tests/maybe/test_subfactory.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from dataclasses import * + +import pytest +from factory import * + +from pytest_factoryboy import register + + +@dataclass +class Company: + name: str + + +@dataclass +class User: + is_staff: bool + company: Company | None + + +@register +class CompanyFactory(Factory): + class Meta: + model = Company + + name = "foo" + + +@register +class UserFactory(Factory): + class Meta: + model = User + + is_staff = False + company = Maybe("is_staff", yes_declaration=None, no_declaration=SubFactory(CompanyFactory)) + + +@pytest.mark.parametrize("user__is_staff", [False]) +def test_staff_user_has_no_company_by_default(user): + assert user.company is None + + +@pytest.mark.parametrize("user__is_staff", [False]) +def test_normal_user_has_company_by_default(user, company): + assert user.company is company