diff --git a/fastapi_jsonapi/data_layers/filtering/sqlalchemy.py b/fastapi_jsonapi/data_layers/filtering/sqlalchemy.py index 27aea356..a768f5bb 100644 --- a/fastapi_jsonapi/data_layers/filtering/sqlalchemy.py +++ b/fastapi_jsonapi/data_layers/filtering/sqlalchemy.py @@ -26,7 +26,7 @@ from fastapi_jsonapi.data_typing import TypeModel, TypeSchema from fastapi_jsonapi.exceptions import InvalidFilters, InvalidType from fastapi_jsonapi.exceptions.json_api import HTTPException -from fastapi_jsonapi.schema import get_model_field, get_relationships +from fastapi_jsonapi.schema import JSONAPISchemaIntrospectionError, get_model_field, get_relationships log = logging.getLogger(__name__) @@ -44,7 +44,7 @@ class RelationshipFilteringInfo(BaseModel): target_schema: Type[TypeSchema] model: Type[TypeModel] aliased_model: AliasedClass - column: InstrumentedAttribute + join_column: InstrumentedAttribute class Config: arbitrary_types_allowed = True @@ -288,7 +288,10 @@ def get_model_column( schema: Type[TypeSchema], field_name: str, ) -> InstrumentedAttribute: - model_field = get_model_field(schema, field_name) + try: + model_field = get_model_field(schema, field_name) + except JSONAPISchemaIntrospectionError as e: + raise InvalidFilters(str(e)) try: return getattr(model, model_field) @@ -327,8 +330,9 @@ def gather_relationships_info( model: Type[TypeModel], schema: Type[TypeSchema], relationship_path: List[str], - collected_info: dict, + collected_info: dict[RelationshipPath, RelationshipFilteringInfo], target_relationship_idx: int = 0, + prev_aliased_model: Optional[Any] = None, ) -> dict[RelationshipPath, RelationshipFilteringInfo]: is_last_relationship = target_relationship_idx == len(relationship_path) - 1 target_relationship_path = RELATIONSHIP_SPLITTER.join( @@ -342,25 +346,36 @@ def gather_relationships_info( target_schema = schema.__fields__[target_relationship_name].type_ target_model = getattr(model, target_relationship_name).property.mapper.class_ - target_column = get_model_column( - model, - schema, - target_relationship_name, - ) + + if prev_aliased_model: + join_column = get_model_column( + model=prev_aliased_model, + schema=schema, + field_name=target_relationship_name, + ) + else: + join_column = get_model_column( + model, + schema, + target_relationship_name, + ) + + aliased_model = aliased(target_model) collected_info[target_relationship_path] = RelationshipFilteringInfo( target_schema=target_schema, model=target_model, - aliased_model=aliased(target_model), - column=target_column, + aliased_model=aliased_model, + join_column=join_column, ) if not is_last_relationship: return gather_relationships_info( - target_model, - target_schema, - relationship_path, - collected_info, - target_relationship_idx + 1, + model=target_model, + schema=target_schema, + relationship_path=relationship_path, + collected_info=collected_info, + target_relationship_idx=target_relationship_idx + 1, + prev_aliased_model=aliased_model, ) return collected_info @@ -553,5 +568,5 @@ def create_filters_and_joins( target_schema=schema, relationships_info=relationships_info, ) - joins = [(info.aliased_model, info.column) for info in relationships_info.values()] + joins = [(info.aliased_model, info.join_column) for info in relationships_info.values()] return expressions, joins diff --git a/fastapi_jsonapi/schema.py b/fastapi_jsonapi/schema.py index ca1b62ee..d9bf31d6 100644 --- a/fastapi_jsonapi/schema.py +++ b/fastapi_jsonapi/schema.py @@ -122,6 +122,10 @@ class JSONAPIResultDetailSchema(BaseJSONAPIResultSchema): ] +class JSONAPISchemaIntrospectionError(Exception): + pass + + def get_model_field(schema: Type["TypeSchema"], field: str) -> str: """ Get the model field of a schema field. @@ -145,7 +149,7 @@ class ComputerSchema(pydantic_base): schema=schema.__name__, field=field, ) - raise Exception(msg) + raise JSONAPISchemaIntrospectionError(msg) return field diff --git a/tests/fixtures/app.py b/tests/fixtures/app.py index eadb8fe7..1bf87887 100644 --- a/tests/fixtures/app.py +++ b/tests/fixtures/app.py @@ -1,11 +1,13 @@ from pathlib import Path -from typing import Type +from typing import Optional, Type import pytest from fastapi import APIRouter, FastAPI +from pydantic import BaseModel from fastapi_jsonapi import RoutersJSONAPI, init from fastapi_jsonapi.atomic import AtomicOperations +from fastapi_jsonapi.data_typing import TypeModel from fastapi_jsonapi.views.detail_view import DetailViewBase from fastapi_jsonapi.views.list_view import ListViewBase from tests.fixtures.views import ( @@ -13,9 +15,13 @@ ListViewBaseGeneric, ) from tests.models import ( + Alpha, + Beta, Child, Computer, CustomUUIDItem, + Delta, + Gamma, Parent, ParentToChildAssociation, Post, @@ -25,6 +31,8 @@ UserBio, ) from tests.schemas import ( + AlphaSchema, + BetaSchema, ChildInSchema, ChildPatchSchema, ChildSchema, @@ -32,6 +40,8 @@ ComputerPatchSchema, ComputerSchema, CustomUUIDItemSchema, + DeltaSchema, + GammaSchema, ParentPatchSchema, ParentSchema, ParentToChildAssociationSchema, @@ -245,3 +255,74 @@ def build_app_custom( app.include_router(atomic.router, prefix="") init(app) return app + + +def build_alphabet_app() -> FastAPI: + return build_custom_app_by_schemas( + [ + ResourceInfoDTO( + path="/alpha", + resource_type="alpha", + model=Alpha, + schema_=AlphaSchema, + ), + ResourceInfoDTO( + path="/beta", + resource_type="beta", + model=Beta, + schema_=BetaSchema, + ), + ResourceInfoDTO( + path="/gamma", + resource_type="gamma", + model=Gamma, + schema_=GammaSchema, + ), + ResourceInfoDTO( + path="/delta", + resource_type="delta", + model=Delta, + schema_=DeltaSchema, + ), + ], + ) + + +class ResourceInfoDTO(BaseModel): + path: str + resource_type: str + model: Type[TypeModel] + schema_: Type[BaseModel] + schema_in_patch: Optional[BaseModel] = None + schema_in_post: Optional[BaseModel] = None + class_list: Type[ListViewBase] = ListViewBaseGeneric + class_detail: Type[DetailViewBase] = DetailViewBaseGeneric + + class Config: + arbitrary_types_allowed = True + + +def build_custom_app_by_schemas(resources_info: list[ResourceInfoDTO]): + router: APIRouter = APIRouter() + + for info in resources_info: + RoutersJSONAPI( + router=router, + path=info.path, + tags=["Misc"], + class_list=info.class_list, + class_detail=info.class_detail, + schema=info.schema_, + resource_type=info.resource_type, + schema_in_patch=info.schema_in_patch, + schema_in_post=info.schema_in_post, + model=info.model, + ) + + app = build_app_plain() + app.include_router(router, prefix="") + + atomic = AtomicOperations() + app.include_router(atomic.router, prefix="") + init(app) + return app diff --git a/tests/models.py b/tests/models.py index 5eaea38e..baf4cd84 100644 --- a/tests/models.py +++ b/tests/models.py @@ -312,3 +312,81 @@ class SelfRelationship(Base): class ContainsTimestamp(Base): id = Column(Integer, primary_key=True) timestamp = Column(DateTime(True), nullable=False) + + +class Alpha(Base): + __tablename__ = "alpha" + + id = Column(Integer, primary_key=True, autoincrement=True) + beta_id = Column( + Integer, + ForeignKey("beta.id"), + nullable=False, + index=True, + ) + beta = relationship("Beta", back_populates="alphas") + gamma_id = Column(Integer, ForeignKey("gamma.id"), nullable=False) + gamma: "Gamma" = relationship("Gamma") + + +class BetaGammaBinding(Base): + __tablename__ = "beta_gamma_binding" + + id: int = Column(Integer, primary_key=True) + beta_id: int = Column(ForeignKey("beta.id", ondelete="CASCADE"), nullable=False) + gamma_id: int = Column(ForeignKey("gamma.id", ondelete="CASCADE"), nullable=False) + + +class Beta(Base): + __tablename__ = "beta" + + id = Column(Integer, primary_key=True, autoincrement=True) + gammas: List["Gamma"] = relationship( + "Gamma", + secondary="beta_gamma_binding", + back_populates="betas", + lazy="noload", + ) + alphas = relationship("Alpha") + deltas: List["Delta"] = relationship( + "Delta", + secondary="beta_delta_binding", + lazy="noload", + ) + + +class Gamma(Base): + __tablename__ = "gamma" + + id = Column(Integer, primary_key=True, autoincrement=True) + betas: List["Beta"] = relationship( + "Beta", + secondary="beta_gamma_binding", + back_populates="gammas", + lazy="raise", + ) + delta_id: int = Column( + Integer, + ForeignKey("delta.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + alpha = relationship("Alpha") + delta: "Delta" = relationship("Delta") + + +class BetaDeltaBinding(Base): + __tablename__ = "beta_delta_binding" + + id: int = Column(Integer, primary_key=True) + beta_id: int = Column(ForeignKey("beta.id", ondelete="CASCADE"), nullable=False) + delta_id: int = Column(ForeignKey("delta.id", ondelete="CASCADE"), nullable=False) + + +class Delta(Base): + __tablename__ = "delta" + + id = Column(Integer, primary_key=True, autoincrement=True) + name = Column(String) + gammas: List["Gamma"] = relationship("Gamma", back_populates="delta", lazy="noload") + betas: List["Beta"] = relationship("Beta", secondary="beta_delta_binding", back_populates="deltas", lazy="noload") diff --git a/tests/schemas.py b/tests/schemas.py index ba5824f3..afad1c3c 100644 --- a/tests/schemas.py +++ b/tests/schemas.py @@ -412,3 +412,72 @@ class SelfRelationshipSchema(BaseModel): class CustomUserAttributesSchema(UserBaseSchema): spam: str eggs: str + + +class AlphaSchema(BaseModel): + beta: Optional["BetaSchema"] = Field( + relationship=RelationshipInfo( + resource_type="beta", + ), + ) + gamma: Optional["GammaSchema"] = Field( + relationship=RelationshipInfo( + resource_type="gamma", + ), + ) + + +class BetaSchema(BaseModel): + alphas: Optional["AlphaSchema"] = Field( + relationship=RelationshipInfo( + resource_type="alpha", + ), + ) + gammas: Optional["GammaSchema"] = Field( + None, + relationship=RelationshipInfo( + resource_type="gamma", + many=True, + ), + ) + deltas: Optional["DeltaSchema"] = Field( + None, + relationship=RelationshipInfo( + resource_type="delta", + many=True, + ), + ) + + +class GammaSchema(BaseModel): + betas: Optional["BetaSchema"] = Field( + None, + relationship=RelationshipInfo( + resource_type="beta", + many=True, + ), + ) + delta: Optional["DeltaSchema"] = Field( + None, + relationship=RelationshipInfo( + resource_type="Delta", + ), + ) + + +class DeltaSchema(BaseModel): + name: str + gammas: Optional["GammaSchema"] = Field( + None, + relationship=RelationshipInfo( + resource_type="gamma", + many=True, + ), + ) + betas: Optional["BetaSchema"] = Field( + None, + relationship=RelationshipInfo( + resource_type="beta", + many=True, + ), + ) diff --git a/tests/test_api/test_api_sqla_with_includes.py b/tests/test_api/test_api_sqla_with_includes.py index ec182ff4..d3c31d92 100644 --- a/tests/test_api/test_api_sqla_with_includes.py +++ b/tests/test_api/test_api_sqla_with_includes.py @@ -19,13 +19,17 @@ from fastapi_jsonapi.views.view_base import ViewBase from tests.common import is_postgres_tests -from tests.fixtures.app import build_app_custom +from tests.fixtures.app import build_alphabet_app, build_app_custom from tests.fixtures.entities import build_workplace, create_user from tests.misc.utils import fake from tests.models import ( + Alpha, + Beta, Computer, ContainsTimestamp, CustomUUIDItem, + Delta, + Gamma, Post, PostComment, SelfRelationship, @@ -2740,6 +2744,64 @@ async def test_join_by_relationships_works_correctly_with_many_filters_for_one_f "meta": {"count": 0, "totalPages": 1}, } + async def test_join_by_relationships_for_one_model_by_different_join_chains( + self, + async_session: AsyncSession, + ): + app = build_alphabet_app() + + delta_1 = Delta(name="delta_1") + delta_1.betas = [beta_1 := Beta()] + + gamma_1 = Gamma(delta=delta_1) + gamma_1.betas = [beta_1] + + delta_2 = Delta(name="delta_2") + gamma_2 = Gamma(delta=delta_2) + + alpha_1 = Alpha(beta=beta_1, gamma=gamma_2) + + async_session.add_all( + [ + delta_1, + delta_2, + beta_1, + gamma_1, + gamma_2, + alpha_1, + ], + ) + await async_session.commit() + + async with AsyncClient(app=app, base_url="http://test") as client: + params = { + "filter": json.dumps( + [ + { + "name": "beta.gammas.delta.name", + "op": "eq", + "val": delta_1.name, + }, + { + "name": "gamma.delta.name", + "op": "eq", + "val": delta_2.name, + }, + ], + ), + } + + resource_type = "alpha" + url = app.url_path_for(f"get_{resource_type}_list") + response = await client.get(url, params=params) + + assert response.status_code == status.HTTP_200_OK, response.text + assert response.json() == { + "data": [{"attributes": {}, "id": str(alpha_1.id), "type": "alpha"}], + "jsonapi": {"version": "1.0"}, + "meta": {"count": 1, "totalPages": 1}, + } + ASCENDING = "" DESCENDING = "-" @@ -2809,4 +2871,36 @@ async def test_sort( } +class TestFilteringErrors: + async def test_incorrect_field_name( + self, + app: FastAPI, + client: AsyncClient, + ): + url = app.url_path_for("get_user_list") + params = { + "filter": json.dumps( + [ + { + "name": "fake_field_name", + "op": "eq", + "val": "", + }, + ], + ), + } + response = await client.get(url, params=params) + assert response.status_code == status.HTTP_400_BAD_REQUEST, response.text + assert response.json() == { + "errors": [ + { + "detail": "UserSchema has no attribute fake_field_name", + "source": {"parameter": "filters"}, + "status_code": 400, + "title": "Invalid filters querystring parameter.", + }, + ], + } + + # todo: test errors