Skip to content

Commit

Permalink
fix: temporary fix to ensure compat with litestar filters (#32)
Browse files Browse the repository at this point in the history
  • Loading branch information
cofin authored Sep 18, 2023
1 parent a520293 commit 1753709
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 56 deletions.
72 changes: 48 additions & 24 deletions advanced_alchemy/repository/_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,29 @@
from advanced_alchemy.repository._util import get_instrumented_attr, wrap_sqlalchemy_exception
from advanced_alchemy.repository.typing import ModelT

# pyright: reportMissingImports=false
try:
from litestar.repository.filters import BeforeAfter as BeforeAfterLitestar
from litestar.repository.filters import CollectionFilter as CollectionFilterLitestar
from litestar.repository.filters import FilterTypes as FilterTypesLitestar
from litestar.repository.filters import LimitOffset as LimitOffsetLitestar
from litestar.repository.filters import NotInCollectionFilter as NotInCollectionFilterLitestar
from litestar.repository.filters import NotInSearchFilter as NotInSearchFilterLitestar
from litestar.repository.filters import OnBeforeAfter as OnBeforeAfterLitestar
from litestar.repository.filters import OrderBy as OrderByLitestar
from litestar.repository.filters import SearchFilter as SearchFilterLitestar
except ImportError:
from advanced_alchemy.filters import BeforeAfter as BeforeAfterLitestar
from advanced_alchemy.filters import CollectionFilter as CollectionFilterLitestar
from advanced_alchemy.filters import FilterTypes as FilterTypesLitestar
from advanced_alchemy.filters import LimitOffset as LimitOffsetLitestar
from advanced_alchemy.filters import NotInCollectionFilter as NotInCollectionFilterLitestar
from advanced_alchemy.filters import NotInSearchFilter as NotInSearchFilterLitestar
from advanced_alchemy.filters import OnBeforeAfter as OnBeforeAfterLitestar
from advanced_alchemy.filters import OrderBy as OrderByLitestar
from advanced_alchemy.filters import SearchFilter as SearchFilterLitestar


if TYPE_CHECKING:
from collections import abc
from datetime import datetime
Expand Down Expand Up @@ -299,7 +322,11 @@ async def delete_many(
def _get_insertmanyvalues_max_parameters(self, chunk_size: int | None = None) -> int:
return chunk_size if chunk_size is not None else DEFAULT_INSERTMANYVALUES_MAX_PARAMETERS

async def exists(self, *filters: FilterTypes | ColumnElement[bool], **kwargs: Any) -> bool:
async def exists(
self,
*filters: FilterTypes | FilterTypesLitestar | ColumnElement[bool],
**kwargs: Any,
) -> bool:
"""Return true if the object specified by ``kwargs`` exists.
Args:
Expand Down Expand Up @@ -541,7 +568,7 @@ async def get_or_upsert(

async def count(
self,
*filters: FilterTypes | ColumnElement[bool],
*filters: FilterTypes | FilterTypesLitestar | ColumnElement[bool],
statement: Select[tuple[ModelT]] | StatementLambdaElement | None = None,
**kwargs: Any,
) -> int:
Expand Down Expand Up @@ -675,7 +702,7 @@ def _get_update_many_statement(model_type: type[ModelT], supports_returning: boo

async def list_and_count(
self,
*filters: FilterTypes | ColumnElement[bool],
*filters: FilterTypes | FilterTypesLitestar | ColumnElement[bool],
auto_expunge: bool | None = None,
statement: Select[tuple[ModelT]] | StatementLambdaElement | None = None,
force_basic_query_mode: bool | None = None,
Expand Down Expand Up @@ -729,7 +756,7 @@ async def _refresh(

async def _list_and_count_window(
self,
*filters: FilterTypes | ColumnElement[bool],
*filters: FilterTypes | FilterTypesLitestar | ColumnElement[bool],
auto_expunge: bool | None = None,
statement: Select[tuple[ModelT]] | StatementLambdaElement | None = None,
**kwargs: Any,
Expand Down Expand Up @@ -765,7 +792,7 @@ async def _list_and_count_window(

async def _list_and_count_basic(
self,
*filters: FilterTypes | ColumnElement[bool],
*filters: FilterTypes | FilterTypesLitestar | ColumnElement[bool],
auto_expunge: bool | None = None,
statement: Select[tuple[ModelT]] | StatementLambdaElement | None = None,
**kwargs: Any,
Expand Down Expand Up @@ -851,17 +878,14 @@ async def upsert(
return instance

def _supports_merge_operations(self, force_disable_merge: bool = False) -> bool:
return bool(
return (
(
(
self._dialect.server_version_info is not None
and self._dialect.server_version_info[0] >= POSTGRES_VERSION_SUPPORTING_MERGE
and self._dialect.name == "postgresql"
)
or self._dialect.name == "oracle"
self._dialect.server_version_info is not None
and self._dialect.server_version_info[0] >= POSTGRES_VERSION_SUPPORTING_MERGE
and self._dialect.name == "postgresql"
)
and not force_disable_merge,
)
or self._dialect.name == "oracle"
) and not force_disable_merge

def _get_merge_stmt(
self,
Expand Down Expand Up @@ -931,7 +955,7 @@ async def upsert_many(

async def list(
self,
*filters: FilterTypes | ColumnElement[bool],
*filters: FilterTypes | FilterTypesLitestar | ColumnElement[bool],
auto_expunge: bool | None = None,
statement: Select[tuple[ModelT]] | StatementLambdaElement | None = None,
**kwargs: Any,
Expand Down Expand Up @@ -1034,7 +1058,7 @@ def _apply_limit_offset_pagination(

def _apply_filters(
self,
*filters: FilterTypes | ColumnElement[bool],
*filters: FilterTypes | FilterTypesLitestar | ColumnElement[bool],
apply_pagination: bool = True,
statement: StatementLambdaElement,
) -> StatementLambdaElement:
Expand All @@ -1052,38 +1076,38 @@ def _apply_filters(
The select with filters applied.
"""
for filter_ in filters:
if isinstance(filter_, LimitOffset):
if isinstance(filter_, (LimitOffset, LimitOffsetLitestar)):
if apply_pagination:
statement = self._apply_limit_offset_pagination(filter_.limit, filter_.offset, statement=statement)
elif isinstance(filter_, BeforeAfter):
elif isinstance(filter_, (BeforeAfter, BeforeAfterLitestar)):
statement = self._filter_on_datetime_field(
field_name=filter_.field_name,
before=filter_.before,
after=filter_.after,
statement=statement,
)
elif isinstance(filter_, OnBeforeAfter):
elif isinstance(filter_, (OnBeforeAfter, OnBeforeAfterLitestar)):
statement = self._filter_on_datetime_field(
field_name=filter_.field_name,
on_or_before=filter_.on_or_before,
on_or_after=filter_.on_or_after,
statement=statement,
)

elif isinstance(filter_, NotInCollectionFilter):
elif isinstance(filter_, (NotInCollectionFilter, NotInCollectionFilterLitestar)):
statement = self._filter_not_in_collection(filter_.field_name, filter_.values, statement=statement)
elif isinstance(filter_, CollectionFilter):
elif isinstance(filter_, (CollectionFilter, CollectionFilterLitestar)):
statement = self._filter_in_collection(filter_.field_name, filter_.values, statement=statement)
elif isinstance(filter_, OrderBy):
elif isinstance(filter_, (OrderBy, OrderByLitestar)):
statement = self._order_by(statement, filter_.field_name, sort_desc=filter_.sort_order == "desc")
elif isinstance(filter_, SearchFilter):
elif isinstance(filter_, (SearchFilter, SearchFilterLitestar)):
statement = self._filter_by_like(
statement,
filter_.field_name,
value=filter_.value,
ignore_case=bool(filter_.ignore_case),
)
elif isinstance(filter_, NotInSearchFilter):
elif isinstance(filter_, (NotInSearchFilter, NotInSearchFilterLitestar)):
statement = self._filter_by_not_like(
statement,
filter_.field_name,
Expand Down
72 changes: 48 additions & 24 deletions advanced_alchemy/repository/_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,29 @@
from advanced_alchemy.repository._util import get_instrumented_attr, wrap_sqlalchemy_exception
from advanced_alchemy.repository.typing import ModelT

# pyright: reportMissingImports=false
try:
from litestar.repository.filters import BeforeAfter as BeforeAfterLitestar
from litestar.repository.filters import CollectionFilter as CollectionFilterLitestar
from litestar.repository.filters import FilterTypes as FilterTypesLitestar
from litestar.repository.filters import LimitOffset as LimitOffsetLitestar
from litestar.repository.filters import NotInCollectionFilter as NotInCollectionFilterLitestar
from litestar.repository.filters import NotInSearchFilter as NotInSearchFilterLitestar
from litestar.repository.filters import OnBeforeAfter as OnBeforeAfterLitestar
from litestar.repository.filters import OrderBy as OrderByLitestar
from litestar.repository.filters import SearchFilter as SearchFilterLitestar
except ImportError:
from advanced_alchemy.filters import BeforeAfter as BeforeAfterLitestar
from advanced_alchemy.filters import CollectionFilter as CollectionFilterLitestar
from advanced_alchemy.filters import FilterTypes as FilterTypesLitestar
from advanced_alchemy.filters import LimitOffset as LimitOffsetLitestar
from advanced_alchemy.filters import NotInCollectionFilter as NotInCollectionFilterLitestar
from advanced_alchemy.filters import NotInSearchFilter as NotInSearchFilterLitestar
from advanced_alchemy.filters import OnBeforeAfter as OnBeforeAfterLitestar
from advanced_alchemy.filters import OrderBy as OrderByLitestar
from advanced_alchemy.filters import SearchFilter as SearchFilterLitestar


if TYPE_CHECKING:
from collections import abc
from datetime import datetime
Expand Down Expand Up @@ -300,7 +323,11 @@ def delete_many(
def _get_insertmanyvalues_max_parameters(self, chunk_size: int | None = None) -> int:
return chunk_size if chunk_size is not None else DEFAULT_INSERTMANYVALUES_MAX_PARAMETERS

def exists(self, *filters: FilterTypes | ColumnElement[bool], **kwargs: Any) -> bool:
def exists(
self,
*filters: FilterTypes | FilterTypesLitestar | ColumnElement[bool],
**kwargs: Any,
) -> bool:
"""Return true if the object specified by ``kwargs`` exists.
Args:
Expand Down Expand Up @@ -542,7 +569,7 @@ def get_or_upsert(

def count(
self,
*filters: FilterTypes | ColumnElement[bool],
*filters: FilterTypes | FilterTypesLitestar | ColumnElement[bool],
statement: Select[tuple[ModelT]] | StatementLambdaElement | None = None,
**kwargs: Any,
) -> int:
Expand Down Expand Up @@ -676,7 +703,7 @@ def _get_update_many_statement(model_type: type[ModelT], supports_returning: boo

def list_and_count(
self,
*filters: FilterTypes | ColumnElement[bool],
*filters: FilterTypes | FilterTypesLitestar | ColumnElement[bool],
auto_expunge: bool | None = None,
statement: Select[tuple[ModelT]] | StatementLambdaElement | None = None,
force_basic_query_mode: bool | None = None,
Expand Down Expand Up @@ -730,7 +757,7 @@ def _refresh(

def _list_and_count_window(
self,
*filters: FilterTypes | ColumnElement[bool],
*filters: FilterTypes | FilterTypesLitestar | ColumnElement[bool],
auto_expunge: bool | None = None,
statement: Select[tuple[ModelT]] | StatementLambdaElement | None = None,
**kwargs: Any,
Expand Down Expand Up @@ -766,7 +793,7 @@ def _list_and_count_window(

def _list_and_count_basic(
self,
*filters: FilterTypes | ColumnElement[bool],
*filters: FilterTypes | FilterTypesLitestar | ColumnElement[bool],
auto_expunge: bool | None = None,
statement: Select[tuple[ModelT]] | StatementLambdaElement | None = None,
**kwargs: Any,
Expand Down Expand Up @@ -852,17 +879,14 @@ def upsert(
return instance

def _supports_merge_operations(self, force_disable_merge: bool = False) -> bool:
return bool(
return (
(
(
self._dialect.server_version_info is not None
and self._dialect.server_version_info[0] >= POSTGRES_VERSION_SUPPORTING_MERGE
and self._dialect.name == "postgresql"
)
or self._dialect.name == "oracle"
self._dialect.server_version_info is not None
and self._dialect.server_version_info[0] >= POSTGRES_VERSION_SUPPORTING_MERGE
and self._dialect.name == "postgresql"
)
and not force_disable_merge,
)
or self._dialect.name == "oracle"
) and not force_disable_merge

def _get_merge_stmt(
self,
Expand Down Expand Up @@ -932,7 +956,7 @@ def upsert_many(

def list(
self,
*filters: FilterTypes | ColumnElement[bool],
*filters: FilterTypes | FilterTypesLitestar | ColumnElement[bool],
auto_expunge: bool | None = None,
statement: Select[tuple[ModelT]] | StatementLambdaElement | None = None,
**kwargs: Any,
Expand Down Expand Up @@ -1035,7 +1059,7 @@ def _apply_limit_offset_pagination(

def _apply_filters(
self,
*filters: FilterTypes | ColumnElement[bool],
*filters: FilterTypes | FilterTypesLitestar | ColumnElement[bool],
apply_pagination: bool = True,
statement: StatementLambdaElement,
) -> StatementLambdaElement:
Expand All @@ -1053,38 +1077,38 @@ def _apply_filters(
The select with filters applied.
"""
for filter_ in filters:
if isinstance(filter_, LimitOffset):
if isinstance(filter_, (LimitOffset, LimitOffsetLitestar)):
if apply_pagination:
statement = self._apply_limit_offset_pagination(filter_.limit, filter_.offset, statement=statement)
elif isinstance(filter_, BeforeAfter):
elif isinstance(filter_, (BeforeAfter, BeforeAfterLitestar)):
statement = self._filter_on_datetime_field(
field_name=filter_.field_name,
before=filter_.before,
after=filter_.after,
statement=statement,
)
elif isinstance(filter_, OnBeforeAfter):
elif isinstance(filter_, (OnBeforeAfter, OnBeforeAfterLitestar)):
statement = self._filter_on_datetime_field(
field_name=filter_.field_name,
on_or_before=filter_.on_or_before,
on_or_after=filter_.on_or_after,
statement=statement,
)

elif isinstance(filter_, NotInCollectionFilter):
elif isinstance(filter_, (NotInCollectionFilter, NotInCollectionFilterLitestar)):
statement = self._filter_not_in_collection(filter_.field_name, filter_.values, statement=statement)
elif isinstance(filter_, CollectionFilter):
elif isinstance(filter_, (CollectionFilter, CollectionFilterLitestar)):
statement = self._filter_in_collection(filter_.field_name, filter_.values, statement=statement)
elif isinstance(filter_, OrderBy):
elif isinstance(filter_, (OrderBy, OrderByLitestar)):
statement = self._order_by(statement, filter_.field_name, sort_desc=filter_.sort_order == "desc")
elif isinstance(filter_, SearchFilter):
elif isinstance(filter_, (SearchFilter, SearchFilterLitestar)):
statement = self._filter_by_like(
statement,
filter_.field_name,
value=filter_.value,
ignore_case=bool(filter_.ignore_case),
)
elif isinstance(filter_, NotInSearchFilter):
elif isinstance(filter_, (NotInSearchFilter, NotInSearchFilterLitestar)):
statement = self._filter_by_not_like(
statement,
filter_.field_name,
Expand Down
10 changes: 5 additions & 5 deletions pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions sonar-project.properties
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@ sonar.sources=advanced_alchemy
sonar.tests=tests
sonar.coverage.exclusions=\
**/__init__.py
sonar.cpd.exclusions=\
advanced_alchemy/repository/_sync.py
4 changes: 2 additions & 2 deletions tests/integration/test_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,7 +905,7 @@ async def test_repo_filter_on_before_after(author_repo: AuthorRepository) -> Non
on_or_after=None,
)
existing_obj = await maybe_async(
author_repo.list(*[before_filter, OrderBy(field_name="created_at", sort_order="desc")]), # type: ignore
author_repo.list(*[before_filter, OrderBy(field_name="created_at", sort_order="desc")]),
)
assert existing_obj[0].name == "Agatha Christie"

Expand All @@ -915,7 +915,7 @@ async def test_repo_filter_on_before_after(author_repo: AuthorRepository) -> Non
on_or_before=None,
)
existing_obj = await maybe_async(
author_repo.list(*[after_filter, OrderBy(field_name="created_at", sort_order="desc")]), # type: ignore
author_repo.list(*[after_filter, OrderBy(field_name="created_at", sort_order="desc")]),
)
assert existing_obj[0].name == "Agatha Christie"

Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,7 +664,7 @@ async def test_sqlalchemy_repo_list_with_not_in_collection_filter(
async def test_sqlalchemy_repo_unknown_filter_type_raises(mock_repo: SQLAlchemyAsyncRepository) -> None:
"""Test that repo raises exception if list receives unknown filter type."""
with pytest.raises(RepositoryError):
await maybe_async(mock_repo.list("not a filter")) # type:ignore[arg-type]
await maybe_async(mock_repo.list("not a filter"))


async def test_sqlalchemy_repo_update(
Expand Down

0 comments on commit 1753709

Please sign in to comment.