diff --git a/api/openapi.generated.yml b/api/openapi.generated.yml index 9bfd3359c..9ff0658b1 100644 --- a/api/openapi.generated.yml +++ b/api/openapi.generated.yml @@ -900,6 +900,10 @@ components: enum: - opportunity_id - opportunity_number + - opportunity_title + - post_date + - close_date + - agency_code description: The field to sort the response by sort_direction: description: Whether to sort the response ascending or descending diff --git a/api/src/api/opportunities_v0_1/opportunity_schemas.py b/api/src/api/opportunities_v0_1/opportunity_schemas.py index 09b4cf72f..f73d75917 100644 --- a/api/src/api/opportunities_v0_1/opportunity_schemas.py +++ b/api/src/api/opportunities_v0_1/opportunity_schemas.py @@ -283,7 +283,15 @@ class OpportunitySearchRequestSchema(Schema): pagination = fields.Nested( generate_pagination_schema( - "OpportunityPaginationSchema", ["opportunity_id", "opportunity_number"] + "OpportunityPaginationSchema", + [ + "opportunity_id", + "opportunity_number", + "opportunity_title", + "post_date", + "close_date", + "agency_code", + ], ), required=True, ) diff --git a/api/src/services/opportunities_v0_1/search_opportunities.py b/api/src/services/opportunities_v0_1/search_opportunities.py index 443739477..7808fd279 100644 --- a/api/src/services/opportunities_v0_1/search_opportunities.py +++ b/api/src/services/opportunities_v0_1/search_opportunities.py @@ -1,9 +1,9 @@ import logging -from typing import Sequence, Tuple +from typing import Any, Sequence, Tuple from pydantic import BaseModel, Field -from sqlalchemy import Select, asc, desc, or_, select -from sqlalchemy.orm import noload, selectinload +from sqlalchemy import Select, asc, desc, nulls_last, or_, select +from sqlalchemy.orm import InstrumentedAttribute, noload, selectinload import src.adapters.db as db from src.db.models.opportunity_models import ( @@ -35,9 +35,19 @@ class SearchOpportunityParams(BaseModel): filters: SearchOpportunityFilters | None = Field(default=None) -def _add_query_filters( - stmt: Select[tuple[Opportunity]], query: str | None -) -> Select[tuple[Opportunity]]: +def _join_stmt_to_current_summary(stmt: Select[tuple[Any]]) -> Select[tuple[Any]]: + # Utility method to add this join to a select statement as we do this in a few places + # + # We need to add joins so that the where/order_by clauses + # can query against the tables that are relevant for these filters + return stmt.join(CurrentOpportunitySummary).join( + OpportunitySummary, + CurrentOpportunitySummary.opportunity_summary_id + == OpportunitySummary.opportunity_summary_id, + ) + + +def _add_query_filters(stmt: Select[tuple[Any]], query: str | None) -> Select[tuple[Any]]: if query is None or len(query) == 0: return stmt @@ -47,20 +57,11 @@ def _add_query_filters( def _add_filters( - stmt: Select[tuple[Opportunity]], filters: SearchOpportunityFilters | None -) -> Select[tuple[Opportunity]]: + stmt: Select[tuple[Any]], filters: SearchOpportunityFilters | None +) -> Select[tuple[Any]]: if filters is None: return stmt - # We need to add joins so that the where clauses - # can query against the tables that are relevant for these filters - # Current + Opportunity Summary are always needed so just add them here - stmt = stmt.join(CurrentOpportunitySummary).join( - OpportunitySummary, - CurrentOpportunitySummary.opportunity_summary_id - == OpportunitySummary.opportunity_summary_id, - ) - if filters.opportunity_status is not None: one_of_opportunity_statuses = filters.opportunity_status.get("one_of") @@ -113,20 +114,107 @@ def _add_filters( return stmt +def _add_order_by( + stmt: Select[tuple[Opportunity]], pagination: PaginationParams +) -> Select[tuple[Opportunity]]: + # This generates an order by command like: + # + # ORDER BY opportunity.opportunity_id DESC NULLS LAST + + # This determines whether we use ascending or descending when building the query + sort_fn = asc if pagination.is_ascending else desc + + match pagination.order_by: + case "opportunity_id": + field: InstrumentedAttribute = Opportunity.opportunity_id + case "opportunity_number": + field = Opportunity.opportunity_number + case "opportunity_title": + field = Opportunity.opportunity_title + case "post_date": + field = OpportunitySummary.post_date + # Need to add joins to the query stmt to order by field from opportunity summary + stmt = _join_stmt_to_current_summary(stmt) + case "close_date": + field = OpportunitySummary.close_date + # Need to add joins to the query stmt to order by field from opportunity summary + stmt = _join_stmt_to_current_summary(stmt) + case "agency_code": + field = Opportunity.agency + case _: + # If this exception happens, it means our API schema + # allows for values we don't have implemented. This + # means we can't determine how to sort / need to correct + # the mismatch. + msg = f"Unconfigured sort_by parameter {pagination.order_by} provided, cannot determine how to sort." + raise Exception(msg) + + # Any values that are null will automatically be sorted to the end + return stmt.order_by(nulls_last(sort_fn(field))) + + def search_opportunities( db_session: db.Session, raw_search_params: dict ) -> Tuple[Sequence[Opportunity], PaginationInfo]: search_params = SearchOpportunityParams.model_validate(raw_search_params) - sort_fn = asc if search_params.pagination.is_ascending else desc - stmt = ( - select(Opportunity) - # TODO - when we want to sort by non-opportunity table fields we'll need to change this - .order_by(sort_fn(getattr(Opportunity, search_params.pagination.order_by))).where( + """ + We create an inner query which handles all of the filtering and returns + a set of opportunity IDs for the outer query to filter against. This query + ends up looking like (varying based on exact filters): + + SELECT DISTINCT + opportunity.opportunity_id + FROM opportunity + JOIN current_opportunity_summary ON opportunity.opportunity_id = current_opportunity_summary.opportunity_id + JOIN opportunity_summary ON current_opportunity_summary.opportunity_summary_id = opportunity_summary.opportunity_summary_id + JOIN link_opportunity_summary_funding_instrument ON opportunity_summary.opportunity_summary_id = link_opportunity_summary_funding_instrument.opportunity_summary_id + JOIN link_opportunity_summary_funding_category ON opportunity_summary.opportunity_summary_id = link_opportunity_summary_funding_category.opportunity_summary_id + JOIN link_opportunity_summary_applicant_type ON opportunity_summary.opportunity_summary_id = link_opportunity_summary_applicant_type.opportunity_summary_id + WHERE + opportunity.is_draft IS FALSE + AND(EXISTS ( + SELECT + 1 FROM current_opportunity_summary + WHERE + opportunity.opportunity_id = current_opportunity_summary.opportunity_id)) + AND link_opportunity_summary_funding_instrument.funding_instrument_id IN(1, 2, 3, 4)) + """ + inner_stmt = ( + select(Opportunity.opportunity_id).where( Opportunity.is_draft.is_(False) ) # Only ever return non-drafts # Filter anything without a current opportunity summary .where(Opportunity.current_opportunity_summary != None) # noqa: E711 + # Distinct the opportunity IDs returned so that the outer query + # has fewer results to query against + .distinct() + ) + + # Current + Opportunity Summary are always needed so just add them here + inner_stmt = _join_stmt_to_current_summary(inner_stmt) + inner_stmt = _add_query_filters(inner_stmt, search_params.query) + inner_stmt = _add_filters(inner_stmt, search_params.filters) + + # + # + """ + The outer query handles sorting and filters against the inner query described above. + This ends up looking like (joins to current opportunity if ordering by other fields): + + SELECT + opportunity.opportunity_id, + opportunity.opportunity_title, + -- and so on for the opportunity table fields + FROM opportunity + WHERE + opportunity.opportunity_id in ( /* the above subquery */ ) + ORDER BY + opportunity.opportunity_id DESC NULLS LAST + LIMIT 25 OFFSET 100 + """ + stmt = ( + select(Opportunity).where(Opportunity.opportunity_id.in_(inner_stmt)) # selectinload makes it so all relationships are loaded and attached to the Opportunity # records that we end up fetching. It emits a separate "select * from table where opportunity_id in (x, y ,z)" # for each relationship. This is used instead of joinedload as it ends up more performant for complex models @@ -134,16 +222,9 @@ def search_opportunities( # # See: https://docs.sqlalchemy.org/en/20/orm/queryguide/relationships.html#what-kind-of-loading-to-use .options(selectinload("*"), noload(Opportunity.all_opportunity_summaries)) - # Distinct is necessary as the joins may add duplicate rows when multiple one-to-many relationships match - # While SQLAlchemy will unique those rows for us, the SQL query still ends up with far less than the limit - # we specify as that is done outside of the DB. - # By having distinct, we do that ourselves in the query so that the limit we specify will be the actual amount - # of records returned (assuming there at least that number to return) - .distinct() ) - stmt = _add_query_filters(stmt, search_params.query) - stmt = _add_filters(stmt, search_params.filters) + stmt = _add_order_by(stmt, search_params.pagination) paginator: Paginator[Opportunity] = Paginator( Opportunity, stmt, db_session, page_size=search_params.pagination.page_size diff --git a/api/tests/conftest.py b/api/tests/conftest.py index 1163a4316..91ad2705d 100644 --- a/api/tests/conftest.py +++ b/api/tests/conftest.py @@ -12,6 +12,7 @@ import tests.src.db.models.factories as factories from src.db import models from src.db.models.lookup.sync_lookup_values import sync_lookup_values +from src.db.models.opportunity_models import Opportunity from src.util.local import load_local_env_vars from tests.lib import db_testing @@ -190,3 +191,67 @@ def mock_s3_bucket_resource(mock_s3): @pytest.fixture def mock_s3_bucket(mock_s3_bucket_resource): yield mock_s3_bucket_resource.name + + +#################### +# Class-based testing +#################### + + +class BaseTestClass: + """ + A base class to derive a test class from. This lets + us have a set of fixtures with a scope greater than + an individual test, but that need to be more granular than + session scoping. + + Useful for avoiding repetition in setup of tests which + can be clearer or provide better performance. + + See: https://docs.pytest.org/en/7.1.x/how-to/fixtures.html#fixture-scopes + + For example: + + class TestExampleClass(BaseTestClass): + + @pytest.fixture(scope="class") + def setup_data(db_session): + # note that the db_session here would be the one created in this class + # as it will pull from the class scope instead + + examples = ExampleFactory.create_batch(size=100) + """ + + @pytest.fixture(scope="class") + def db_session(self, db_client, monkeypatch_class): + # Note this shadows the db_session fixture for tests in this class + with db_client.get_session() as db_session: + yield db_session + + @pytest.fixture(scope="class") + def enable_factory_create(self, monkeypatch_class, db_session): + """ + Allows the create method of factories to be called. By default, the create + throws an exception to prevent accidental creation of database objects for tests + that do not need persistence. This fixture only allows the create method to be + called for the current class of tests. Each test that needs to call Factory.create should pull in + this fixture. + """ + monkeypatch_class.setattr(factories, "_db_session", db_session) + + @pytest.fixture(scope="class") + def truncate_opportunities(self, db_session): + """ + Use this fixture when you want to truncate the opportunity table + and handle deleting all related records. + + As this is at the class scope, this will only run once for a given + class implementation. + """ + + opportunities = db_session.query(Opportunity).all() + for opp in opportunities: + db_session.delete(opp) + + # Force the deletes to the DB + db_session.commit() diff --git a/api/tests/src/api/opportunities_v0_1/test_opportunity_route.py b/api/tests/src/api/opportunities_v0_1/test_opportunity_route.py index 32290acf9..28d47cbf6 100644 --- a/api/tests/src/api/opportunities_v0_1/test_opportunity_route.py +++ b/api/tests/src/api/opportunities_v0_1/test_opportunity_route.py @@ -1,4 +1,5 @@ import dataclasses +from datetime import date from enum import IntEnum import pytest @@ -14,7 +15,7 @@ OpportunityAssistanceListing, OpportunitySummary, ) -from tests.src.db.models import factories +from tests.conftest import BaseTestClass from tests.src.db.models.factories import ( CurrentOpportunitySummaryFactory, LinkOpportunitySummaryApplicantTypeFactory, @@ -250,18 +251,9 @@ def validate_assistance_listings( get_search_request(page_offset=100, page_size=5), SearchExpectedValues(total_pages=2, total_records=10, response_record_count=0), ), - # Sorting - ( - get_search_request(order_by="opportunity_id", sort_direction="ascending"), - SearchExpectedValues(total_pages=2, total_records=10, response_record_count=5), - ), - ( - get_search_request(order_by="opportunity_number", sort_direction="descending"), - SearchExpectedValues(total_pages=2, total_records=10, response_record_count=5), - ), ], ) -def test_opportunity_search_paging_and_sorting_200( +def test_opportunity_search_paging_200( client, api_auth_token, enable_factory_create, @@ -269,7 +261,7 @@ def test_opportunity_search_paging_and_sorting_200( expected_values, truncate_opportunities, ): - # This test is just focused on testing the sorting and pagination + # This test is just focused on testing the pagination OpportunityFactory.create_batch(size=10) resp = client.post( @@ -282,6 +274,150 @@ def test_opportunity_search_paging_and_sorting_200( validate_search_pagination(search_response, search_request, expected_values) +def setup_pagination_scenario( + opportunity_id: int, + opportunity_number: str, + opportunity_title: str, + post_date: date, + close_date: date | None, + agency: str, +) -> None: + opportunity = OpportunityFactory.create( + opportunity_id=opportunity_id, + opportunity_number=opportunity_number, + opportunity_title=opportunity_title, + agency=agency, + no_current_summary=True, + ) + + opportunity_summary = OpportunitySummaryFactory.create( + opportunity=opportunity, post_date=post_date, close_date=close_date + ) + + CurrentOpportunitySummaryFactory.create( + opportunity=opportunity, + opportunity_summary=opportunity_summary, + ) + + +class TestSearchPagination(BaseTestClass): + @pytest.fixture(scope="class") + def setup_scenarios(self, truncate_opportunities, enable_factory_create): + setup_pagination_scenario( + opportunity_id=1, + opportunity_number="dddd", + opportunity_title="zzzz", + post_date=date(2024, 3, 1), + close_date=None, + agency="mmmm", + ) + setup_pagination_scenario( + opportunity_id=2, + opportunity_number="eeee", + opportunity_title="yyyy", + post_date=date(2024, 2, 1), + close_date=date(2024, 12, 1), + agency="nnnn", + ) + setup_pagination_scenario( + opportunity_id=3, + opportunity_number="aaaa", + opportunity_title="wwww", + post_date=date(2024, 5, 1), + close_date=date(2024, 11, 1), + agency="llll", + ) + setup_pagination_scenario( + opportunity_id=4, + opportunity_number="bbbb", + opportunity_title="uuuu", + post_date=date(2024, 4, 1), + close_date=date(2024, 10, 1), + agency="kkkk", + ) + setup_pagination_scenario( + opportunity_id=5, + opportunity_number="cccc", + opportunity_title="xxxx", + post_date=date(2024, 1, 1), + close_date=date(2024, 9, 1), + agency="oooo", + ) + + @pytest.mark.parametrize( + "search_request,expected_order", + [ + ### These various scenarios are setup so that the order will be different depending on the field + ### See the set values in the above setup method + # Opportunity ID + ( + get_search_request(order_by="opportunity_id", sort_direction="ascending"), + [1, 2, 3, 4, 5], + ), + ( + get_search_request(order_by="opportunity_id", sort_direction="descending"), + [5, 4, 3, 2, 1], + ), + # Opportunity number + ( + get_search_request(order_by="opportunity_number", sort_direction="ascending"), + [3, 4, 5, 1, 2], + ), + ( + get_search_request(order_by="opportunity_number", sort_direction="descending"), + [2, 1, 5, 4, 3], + ), + # Opportunity title + ( + get_search_request(order_by="opportunity_title", sort_direction="ascending"), + [4, 3, 5, 2, 1], + ), + ( + get_search_request(order_by="opportunity_title", sort_direction="descending"), + [1, 2, 5, 3, 4], + ), + # Post date + (get_search_request(order_by="post_date", sort_direction="ascending"), [5, 2, 1, 4, 3]), + ( + get_search_request(order_by="post_date", sort_direction="descending"), + [3, 4, 1, 2, 5], + ), + # Close date + # note that opportunity id 1's value is null which always goes to the end regardless of direction + ( + get_search_request(order_by="close_date", sort_direction="ascending"), + [5, 4, 3, 2, 1], + ), + ( + get_search_request(order_by="close_date", sort_direction="descending"), + [2, 3, 4, 5, 1], + ), + # Agency + ( + get_search_request(order_by="agency_code", sort_direction="ascending"), + [4, 3, 1, 2, 5], + ), + ( + get_search_request(order_by="agency_code", sort_direction="descending"), + [5, 2, 1, 3, 4], + ), + ], + ) + def test_opportunity_sorting_200( + self, client, api_auth_token, search_request, expected_order, setup_scenarios + ): + resp = client.post( + "/v0.1/opportunities/search", json=search_request, headers={"X-Auth": api_auth_token} + ) + + search_response = resp.get_json() + assert resp.status_code == 200 + + returned_opportunity_ids = [record["opportunity_id"] for record in search_response["data"]] + + assert returned_opportunity_ids == expected_order + + ##################################### # Search opportunities tests (Scenarios) ##################################### @@ -406,7 +542,7 @@ def setup_opportunity( [ { "field": "pagination.order_by", - "message": "Value must be one of: opportunity_id, opportunity_number", + "message": "Value must be one of: opportunity_id, opportunity_number, opportunity_title, post_date, close_date, agency_code", "type": "invalid_choice", }, { @@ -478,7 +614,7 @@ def test_opportunity_search_invalid_request_422( assert response_data == expected_response_data -class TestSearchScenarios: +class TestSearchScenarios(BaseTestClass): """ Group the scenario tests in a class for performance. As the setup for these tests is slow, but can be shared across all of them, initialize them once @@ -486,26 +622,7 @@ class TestSearchScenarios: """ @pytest.fixture(scope="class") - def db_session(self, db_client, monkeypatch_class): - # Note this shadows the db_session fixture for tests in this class - with db_client.get_session() as db_session: - # Set the factories DB session. This is what would normally be done with - # the "enable_factory_create" fixture, but for the class level - monkeypatch_class.setattr(factories, "_db_session", db_session) - - yield db_session - - @pytest.fixture(scope="class") - def truncate_db(self, db_session): - opportunities = db_session.query(Opportunity).all() - for opp in opportunities: - db_session.delete(opp) - - # Force the deletes to the DB - db_session.commit() - - @pytest.fixture(scope="class") - def setup_scenarios(self, truncate_db): + def setup_scenarios(self, truncate_opportunities, enable_factory_create): # Won't be returned ever because it's a draft opportunity setup_opportunity(Scenario.DRAFT_OPPORTUNITY, is_draft=True)