Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(factories): added an ability to use the Coroutine as a factory field #641

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions docs/examples/fields/test_example_sqla_pre_fetched_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from __future__ import annotations

from sqlalchemy import ForeignKey, select
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column

from polyfactory.factories.sqlalchemy_factory import SQLAlchemyFactory

async_engine = create_async_engine("sqlite+aiosqlite:///:memory:")


class Base(DeclarativeBase): ...


class User(Base):
__tablename__ = "users"

id: Mapped[int] = mapped_column(primary_key=True)


class Department(Base):
__tablename__ = "departments"

id: Mapped[int] = mapped_column(primary_key=True)
director_id: Mapped[str] = mapped_column(ForeignKey("users.id"))


class UserFactory(SQLAlchemyFactory[User]): ...


class DepartmentFactory(SQLAlchemyFactory[Department]): ...


async def get_director_ids() -> int:
async with AsyncSession(async_engine) as session:
result = (await session.scalars(select(User.id))).all()
return UserFactory.__random__.choice(result)


async def test_factory_with_pre_fetched_async_data() -> None:
async with async_engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
await conn.run_sync(Base.metadata.create_all)

async with AsyncSession(async_engine) as session:
UserFactory.__async_session__ = session
await UserFactory.create_batch_async(3)

async with AsyncSession(async_engine) as session:
DepartmentFactory.__async_session__ = session
department = await DepartmentFactory.create_async(director_id=await get_director_ids())
user = await session.scalar(select(User).where(User.id == department.director_id))
assert isinstance(user, User)
12 changes: 11 additions & 1 deletion docs/usage/fields.rst
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,20 @@ callable should be: ``name: str, values: dict[str, Any], *args, **defaults``. Th
name in the values dictionary.

Factories as Fields
---------------------------
-------------------

Factories themselves can be used as fields. In this usage, build parameters will be passed to the declared factory.

.. literalinclude:: /examples/fields/test_example_8.py
:caption: Using a factory as a field
:language: python

Handling Asynchronous Data in Factory Fields
--------------------------------------------

If you need to populate a factory field with data pre-fetched asynchronously (e.g., from a database using an ORM like SQLAlchemy or Beanie), the recommended approach is to handle the asynchronous call outside the factory and pass the resolved value as a regular argument.

.. literalinclude:: /examples/fields/test_example_sqla_pre_fetched_data.py
:caption: SQLAlchemy example
:language: python
:emphasize-lines: 34-37, 51
86 changes: 62 additions & 24 deletions polyfactory/factories/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,60 @@ def _get_or_create_factory(cls, model: type) -> type[BaseFactory[Any]]:
msg = f"unsupported model type {model.__name__}"
raise ParameterException(msg) # pragma: no cover

@classmethod
def _get_initial_variables(cls, kwargs: Any) -> tuple[dict[str, Any], dict[str, PostGenerated], BuildContext]:
"""Prepare the given kwargs and generate initial variables for further usage.

:param kwargs: Any build kwargs.

:returns: A tuple of build results.

"""
_build_context = cls._get_build_context(kwargs.pop("_build_context", None))
_build_context["seen_models"].add(cls.__model__)

result: dict[str, Any] = {**kwargs}
generate_post: dict[str, PostGenerated] = {}

return result, generate_post, _build_context

@classmethod
def _check_special_field(
cls,
field_meta: FieldMeta,
result: dict[str, Any],
generate_post: dict[str, PostGenerated],
field_build_parameters: Any,
build_context: BuildContext,
) -> Any:
"""Check if a field value is a special type field or get a value defined on the factory class itself.

:param field_meta: FieldMeta instance.
:param result: A dict with result field values.
:param generate_post: A dict with post generating values.
:param field_build_parameters: Any build parameters passed to the factory as kwarg values.
:param build_context: BuildContext data for current build.

:returns: None or a value defined on the factory class itself.
"""
field_value = getattr(cls, field_meta.name)
if isinstance(field_value, Ignore):
return None

if isinstance(field_value, Require) and field_meta.name not in result:
msg = f"Require kwarg {field_meta.name} is missing"
raise MissingBuildKwargException(msg)

if isinstance(field_value, PostGenerated):
generate_post[field_meta.name] = field_value
return None

return cls._handle_factory_field(
field_value=field_value,
field_build_parameters=field_build_parameters,
build_context=build_context,
)

# Public Methods

@classmethod
Expand Down Expand Up @@ -974,33 +1028,21 @@ def process_kwargs(cls, **kwargs: Any) -> dict[str, Any]:
:returns: A dictionary of build results.

"""
_build_context = cls._get_build_context(kwargs.pop("_build_context", None))
_build_context["seen_models"].add(cls.__model__)

result: dict[str, Any] = {**kwargs}
generate_post: dict[str, PostGenerated] = {}
result, generate_post, _build_context = cls._get_initial_variables(kwargs)

for field_meta in cls.get_model_fields():
field_build_parameters = cls.extract_field_build_parameters(field_meta=field_meta, build_args=kwargs)
if cls.should_set_field_value(field_meta, **kwargs) and not cls.should_use_default_value(field_meta):
if hasattr(cls, field_meta.name) and not hasattr(BaseFactory, field_meta.name):
field_value = getattr(cls, field_meta.name)
if isinstance(field_value, Ignore):
continue

if isinstance(field_value, Require) and field_meta.name not in kwargs:
msg = f"Require kwarg {field_meta.name} is missing"
raise MissingBuildKwargException(msg)

if isinstance(field_value, PostGenerated):
generate_post[field_meta.name] = field_value
continue

result[field_meta.name] = cls._handle_factory_field(
field_value=field_value,
field_value = cls._check_special_field(
field_meta=field_meta,
result=result,
generate_post=generate_post,
field_build_parameters=field_build_parameters,
build_context=_build_context,
)
if field_value is not None:
result[field_meta.name] = field_value
continue

field_result = cls.get_field_value(
Expand Down Expand Up @@ -1028,11 +1070,7 @@ def process_kwargs_coverage(cls, **kwargs: Any) -> abc.Iterable[dict[str, Any]]:
:returns: A dictionary of build results.

"""
_build_context = cls._get_build_context(kwargs.pop("_build_context", None))
_build_context["seen_models"].add(cls.__model__)

result: dict[str, Any] = {**kwargs}
generate_post: dict[str, PostGenerated] = {}
result, generate_post, _build_context = cls._get_initial_variables(kwargs)

for field_meta in cls.get_model_fields():
field_build_parameters = cls.extract_field_build_parameters(field_meta=field_meta, build_args=kwargs)
Expand Down
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,11 @@ pydantic = [
]
msgspec = ["msgspec",]
odmantic = ["odmantic<1.0.0", "pydantic[email]",]
beanie = ["beanie", "pydantic[email]",]
beanie = [
"beanie",
"pydantic[email]",
"pymongo<4.9",
]
attrs = ["attrs>=22.2.0",]
full = ["pydantic", "odmantic", "msgspec", "beanie", "attrs", "sqlalchemy"]

Expand Down
Empty file.
26 changes: 26 additions & 0 deletions tests/sqlalchemy_factory/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import AsyncIterator

import pytest
from sqlalchemy import create_engine
from sqlalchemy.engine import Engine
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine

from tests.sqlalchemy_factory.models import Base


@pytest.fixture()
def engine() -> Engine:
return create_engine("sqlite:///:memory:")


@pytest.fixture()
def async_engine() -> AsyncEngine:
return create_async_engine("sqlite+aiosqlite:///:memory:")


@pytest.fixture(autouse=True)
async def fx_drop_create_meta(async_engine: AsyncEngine) -> AsyncIterator[None]:
async with async_engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
await conn.run_sync(Base.metadata.create_all)
yield
120 changes: 120 additions & 0 deletions tests/sqlalchemy_factory/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from dataclasses import dataclass
from typing import Any, Optional

from sqlalchemy import (
Boolean,
Column,
DateTime,
ForeignKey,
Integer,
String,
func,
orm,
text,
)
from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.orm import relationship
from sqlalchemy.orm.decl_api import DeclarativeMeta, registry

_registry = registry()


@dataclass
class NonSQLAchemyClass:
id: int


class Base(metaclass=DeclarativeMeta):
__abstract__ = True
__allow_unmapped__ = True

registry = _registry
metadata = _registry.metadata


class Author(Base):
__tablename__ = "authors"

id: Any = Column(Integer(), primary_key=True)
books: Any = orm.relationship(
"Book",
collection_class=list,
uselist=True,
back_populates="author",
)


class Book(Base):
__tablename__ = "books"

id: Any = Column(Integer(), primary_key=True)
author_id: Any = Column(
Integer(),
ForeignKey(Author.id),
nullable=False,
)
author: Any = orm.relationship(
Author,
uselist=False,
back_populates="books",
)


class AsyncModel(Base):
__tablename__ = "async_model"

id: Any = Column(Integer(), primary_key=True)


class AsyncRefreshModel(Base):
__tablename__ = "server_default_test"

id: Any = Column(Integer(), primary_key=True)
test_datetime: Any = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
test_str: Any = Column(String, nullable=False, server_default=text("test_str"))
test_int: Any = Column(Integer, nullable=False, server_default=text("123"))
test_bool: Any = Column(Boolean, nullable=False, server_default=text("False"))


class User(Base):
__tablename__ = "users"

id = Column(Integer, primary_key=True)
name = Column(String)

user_keyword_associations = relationship(
"UserKeywordAssociation",
back_populates="user",
lazy="selectin", # codespell:ignore selectin
)
keywords = association_proxy(
"user_keyword_associations", "keyword", creator=lambda keyword_obj: UserKeywordAssociation(keyword=keyword_obj)
)


class UserKeywordAssociation(Base):
__tablename__ = "user_keyword"

user_id = Column(Integer, ForeignKey("users.id"), primary_key=True)
keyword_id = Column(Integer, ForeignKey("keywords.id"), primary_key=True)

user = relationship(User, back_populates="user_keyword_associations")
keyword = relationship("Keyword", lazy="selectin") # codespell:ignore selectin

# for prevent mypy error: Unexpected keyword argument "keyword" for "UserKeywordAssociation" [call-arg]
def __init__(self, keyword: Optional["Keyword"] = None):
self.keyword = keyword


class Keyword(Base):
__tablename__ = "keywords"

id = Column(Integer, primary_key=True)
keyword = Column(String)


class Department(Base):
__tablename__ = "departments"

id = Column(Integer, primary_key=True)
director_id = Column(Integer, ForeignKey("users.id"))
Loading
Loading