diff --git a/.bumpversion.cfg b/.bumpversion.cfg index cf78fce08..0d98c440a 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 3.0.0rc1 +current_version = 3.1.1 commit = False tag = False parse = (?P\d+)\.(?P\d+)\.(?P\d+)(rc(?P\d+))? diff --git a/.github/workflows/run-codspeed-tests.yml b/.github/workflows/run-codspeed-tests.yml index fd8e88b7f..29d7c1cb5 100644 --- a/.github/workflows/run-codspeed-tests.yml +++ b/.github/workflows/run-codspeed-tests.yml @@ -11,7 +11,7 @@ jobs: name: Run benchmarks runs-on: ubuntu-latest container: - image: python:3.11 + image: python:3.13 options: --privileged services: postgres: @@ -61,7 +61,7 @@ jobs: - uses: CodSpeedHQ/action@v3 with: - run: CACHE_URI=redis://redis DATABASE_URI=postgresql://$POSTGRES_USER:$POSTGRES_PASSWORD@$POSTGRES_HOST/$POSTGRES_DB pytest test/unit_tests --codspeed + run: CACHE_URI=redis://redis DATABASE_URI=postgresql+psycopg://$POSTGRES_USER:$POSTGRES_PASSWORD@$POSTGRES_HOST/$POSTGRES_DB pytest test/unit_tests --codspeed token: ${{ secrets.CODSPEED_TOKEN }} env: POSTGRES_DB: orchestrator-core-test diff --git a/.github/workflows/run-unit-tests.yml b/.github/workflows/run-unit-tests.yml index 8f44f2d32..54b7ddad8 100644 --- a/.github/workflows/run-unit-tests.yml +++ b/.github/workflows/run-unit-tests.yml @@ -51,7 +51,7 @@ jobs: env: FLIT_ROOT_INSTALL: 1 - name: Run Unit tests - run: CACHE_URI=redis://redis DATABASE_URI=postgresql://$POSTGRES_USER:$POSTGRES_PASSWORD@$POSTGRES_HOST/$POSTGRES_DB pytest --cov-branch --cov=orchestrator --cov-report=xml --ignore=test --ignore=orchestrator/devtools --ignore=examples --ignore=docs --ignore=orchestrator/vendor + run: CACHE_URI=redis://redis DATABASE_URI=postgresql+psycopg://$POSTGRES_USER:$POSTGRES_PASSWORD@$POSTGRES_HOST/$POSTGRES_DB pytest --cov-branch --cov=orchestrator --cov-report=xml --ignore=test --ignore=orchestrator/devtools --ignore=examples --ignore=docs --ignore=orchestrator/vendor env: POSTGRES_DB: orchestrator-core-test POSTGRES_USER: nwa diff --git a/mkdocs.yml b/mkdocs.yml index 4706dc638..34d812203 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -190,8 +190,8 @@ nav: - Callbacks: reference-docs/workflows/callbacks.md - Websockets: reference-docs/websockets.md - Migration guides: - - 2.0: migration-guide/2.0.md - - 3.0: migration-guide/3.0.md + - 2.x: migration-guide/2.0.md + - 3.x: migration-guide/3.0.md - Workshops: # - Beginner: diff --git a/orchestrator/__init__.py b/orchestrator/__init__.py index 31b8255bc..7a8b5a08e 100644 --- a/orchestrator/__init__.py +++ b/orchestrator/__init__.py @@ -13,7 +13,7 @@ """This is the orchestrator workflow engine.""" -__version__ = "3.0.0rc1" +__version__ = "3.1.1" from orchestrator.app import OrchestratorCore from orchestrator.settings import app_settings diff --git a/orchestrator/api/api_v1/endpoints/settings.py b/orchestrator/api/api_v1/endpoints/settings.py index d0bc59b6e..6a89b0625 100644 --- a/orchestrator/api/api_v1/endpoints/settings.py +++ b/orchestrator/api/api_v1/endpoints/settings.py @@ -28,6 +28,7 @@ from orchestrator.settings import ExecutorType, app_settings from orchestrator.utils.json import json_dumps from orchestrator.utils.redis import delete_keys_matching_pattern +from orchestrator.utils.redis_client import create_redis_asyncio_client from orchestrator.websocket import WS_CHANNELS, broadcast_invalidate_cache, websocket_manager router = APIRouter() @@ -41,7 +42,7 @@ @router.delete("/cache/{name}") async def clear_cache(name: str) -> int | None: - cache: AIORedis = AIORedis.from_url(str(app_settings.CACHE_URI)) + cache: AIORedis = create_redis_asyncio_client(app_settings.CACHE_URI) if name not in CACHE_FLUSH_OPTIONS: raise_status(HTTPStatus.BAD_REQUEST, "Invalid cache name") diff --git a/orchestrator/cli/generator/generator/migration.py b/orchestrator/cli/generator/generator/migration.py index 17da8705a..55a8739c5 100644 --- a/orchestrator/cli/generator/generator/migration.py +++ b/orchestrator/cli/generator/generator/migration.py @@ -31,13 +31,16 @@ sort_product_blocks_by_dependencies, ) from orchestrator.cli.generator.generator.settings import product_generator_settings as settings +from orchestrator.settings import convert_database_uri logger = structlog.getLogger(__name__) def create_migration_file(message: str, head: str) -> Path | None: - if not environ.get("DATABASE_URI"): - environ.update({"DATABASE_URI": "postgresql://nwa:nwa@localhost/orchestrator-core"}) + if environ.get("DATABASE_URI"): + environ.update({"DATABASE_URI": convert_database_uri(environ["DATABASE_URI"])}) + else: + environ.update({"DATABASE_URI": "postgresql+psycopg://nwa:nwa@localhost/orchestrator-core"}) if not environ.get("PYTHONPATH"): environ.update({"PYTHONPATH": "."}) logger.info( diff --git a/orchestrator/distlock/managers/redis_distlock_manager.py b/orchestrator/distlock/managers/redis_distlock_manager.py index bbaad5586..6f050a212 100644 --- a/orchestrator/distlock/managers/redis_distlock_manager.py +++ b/orchestrator/distlock/managers/redis_distlock_manager.py @@ -20,6 +20,7 @@ from structlog import get_logger from orchestrator.settings import app_settings +from orchestrator.utils.redis_client import create_redis_asyncio_client, create_redis_client logger = get_logger(__name__) @@ -37,7 +38,7 @@ def __init__(self, redis_address: RedisDsn): self.redis_address = redis_address async def connect_redis(self) -> None: - self.redis_conn = AIORedis.from_url(str(self.redis_address)) + self.redis_conn = create_redis_asyncio_client(self.redis_address) async def disconnect_redis(self) -> None: if self.redis_conn: @@ -78,7 +79,7 @@ async def release_lock(self, lock: Lock) -> None: def release_sync(self, lock: Lock) -> None: redis_conn: Redis | None = None try: - redis_conn = Redis.from_url(str(app_settings.CACHE_URI)) + redis_conn = create_redis_client(app_settings.CACHE_URI) sync_lock: SyncLock = SyncLock( redis=redis_conn, name=lock.name, # type: ignore diff --git a/orchestrator/graphql/resolvers/settings.py b/orchestrator/graphql/resolvers/settings.py index 26df0f380..629755f23 100644 --- a/orchestrator/graphql/resolvers/settings.py +++ b/orchestrator/graphql/resolvers/settings.py @@ -21,6 +21,7 @@ from orchestrator.services.settings import get_engine_settings, get_engine_settings_for_update, post_update_to_slack from orchestrator.settings import ExecutorType, app_settings from orchestrator.utils.redis import delete_keys_matching_pattern +from orchestrator.utils.redis_client import create_redis_asyncio_client logger = structlog.get_logger(__name__) @@ -57,7 +58,7 @@ def resolve_settings(info: OrchestratorInfo) -> StatusType: # Mutations async def clear_cache(info: OrchestratorInfo, name: str) -> CacheClearSuccess | Error: - cache: AIORedis = AIORedis.from_url(str(app_settings.CACHE_URI)) + cache: AIORedis = create_redis_asyncio_client(app_settings.CACHE_URI) if name not in CACHE_FLUSH_OPTIONS: return Error(message="Invalid cache name") diff --git a/orchestrator/graphql/schemas/product.py b/orchestrator/graphql/schemas/product.py index 6679b5e59..0a679536f 100644 --- a/orchestrator/graphql/schemas/product.py +++ b/orchestrator/graphql/schemas/product.py @@ -1,11 +1,11 @@ -from typing import TYPE_CHECKING, Annotated +from typing import TYPE_CHECKING, Annotated, Iterable import strawberry from strawberry import UNSET from strawberry.federation.schema_directives import Key from oauth2_lib.strawberry import authenticated_field -from orchestrator.db import ProductTable +from orchestrator.db import ProductBlockTable, ProductTable from orchestrator.domain.base import ProductModel from orchestrator.graphql.pagination import Connection from orchestrator.graphql.schemas.fixed_input import FixedInput @@ -51,6 +51,23 @@ async def subscriptions( filter_by_with_related_subscriptions = (filter_by or []) + [GraphqlFilter(field="product", value=self.name)] return await resolve_subscriptions(info, filter_by_with_related_subscriptions, sort_by, first, after) + @strawberry.field(description="Returns list of all nested productblock names") # type: ignore + async def all_pb_names(self) -> list[str]: + + model = get_original_model(self, ProductTable) + + def get_all_pb_names(product_blocks: list[ProductBlockTable]) -> Iterable[str]: + for product_block in product_blocks: + yield product_block.name + + if product_block.depends_on: + yield from get_all_pb_names(product_block.depends_on) + + names: list[str] = list(get_all_pb_names(model.product_blocks)) + names.sort() + + return names + @strawberry.field(description="Return product blocks") # type: ignore async def product_blocks(self) -> list[Annotated["ProductBlock", strawberry.lazy(".product_block")]]: from orchestrator.graphql.schemas.product_block import ProductBlock diff --git a/orchestrator/migrations/helpers.py b/orchestrator/migrations/helpers.py index cf20f54fb..a3f366d5b 100644 --- a/orchestrator/migrations/helpers.py +++ b/orchestrator/migrations/helpers.py @@ -880,10 +880,10 @@ def delete_product(conn: sa.engine.Connection, name: str) -> None: RETURNING product_id ), deleted_p_pb AS ( - DELETE FROM product_product_blocks WHERE product_id IN (SELECT product_id FROM deleted_p) + DELETE FROM product_product_blocks WHERE product_id = ANY(SELECT product_id FROM deleted_p) ), deleted_pb_rt AS ( - DELETE FROM products_workflows WHERE product_id IN (SELECT product_id FROM deleted_p) + DELETE FROM products_workflows WHERE product_id = ANY(SELECT product_id FROM deleted_p) ) SELECT * from deleted_p; """ @@ -911,10 +911,10 @@ def delete_product_block(conn: sa.engine.Connection, name: str) -> None: RETURNING product_block_id ), deleted_p_pb AS ( - DELETE FROM product_product_blocks WHERE product_block_id IN (SELECT product_block_id FROM deleted_pb) + DELETE FROM product_product_blocks WHERE product_block_id =ANY(SELECT product_block_id FROM deleted_pb) ), deleted_pb_rt AS ( - DELETE FROM product_block_resource_types WHERE product_block_id IN (SELECT product_block_id FROM deleted_pb) + DELETE FROM product_block_resource_types WHERE product_block_id =ANY(SELECT product_block_id FROM deleted_pb) ) SELECT * from deleted_pb; """ @@ -968,7 +968,7 @@ def delete_resource_type(conn: sa.engine.Connection, resource_type: str) -> None RETURNING resource_type_id ), deleted_pb_rt AS ( - DELETE FROM product_block_resource_types WHERE resource_type_id IN (SELECT resource_type_id FROM deleted_pb) + DELETE FROM product_block_resource_types WHERE resource_type_id =ANY(SELECT resource_type_id FROM deleted_pb) ) SELECT * from deleted_pb; """ diff --git a/orchestrator/settings.py b/orchestrator/settings.py index 27e54ca15..2e0a4ef73 100644 --- a/orchestrator/settings.py +++ b/orchestrator/settings.py @@ -13,16 +13,21 @@ import secrets import string +import warnings from pathlib import Path from typing import Literal -from pydantic import PostgresDsn, RedisDsn +from pydantic import Field, NonNegativeInt, PostgresDsn, RedisDsn from pydantic_settings import BaseSettings from oauth2_lib.settings import oauth2lib_settings from pydantic_forms.types import strEnum +class OrchestratorDeprecationWarning(DeprecationWarning): + pass + + class ExecutorType(strEnum): WORKER = "celery" THREADPOOL = "threadpool" @@ -49,7 +54,7 @@ class AppSettings(BaseSettings): EXECUTOR: str = ExecutorType.THREADPOOL WORKFLOWS_SWAGGER_HOST: str = "localhost" WORKFLOWS_GUI_URI: str = "http://localhost:3000" - DATABASE_URI: PostgresDsn = "postgresql://nwa:nwa@localhost/orchestrator-core" # type: ignore + DATABASE_URI: PostgresDsn = "postgresql+psycopg://nwa:nwa@localhost/orchestrator-core" # type: ignore MAX_WORKERS: int = 5 MAIL_SERVER: str = "localhost" MAIL_PORT: int = 25 @@ -57,6 +62,9 @@ class AppSettings(BaseSettings): CACHE_URI: RedisDsn = "redis://localhost:6379/0" # type: ignore CACHE_DOMAIN_MODELS: bool = False CACHE_HMAC_SECRET: str | None = None # HMAC signing key, used when pickling results in the cache + REDIS_RETRY_COUNT: NonNegativeInt = Field( + 2, description="Number of retries for redis connection errors/timeouts, 0 to disable" + ) # More info: https://redis-py.readthedocs.io/en/stable/retry.html ENABLE_DISTLOCK_MANAGER: bool = True DISTLOCK_BACKEND: str = "memory" CC_NOC: int = 0 @@ -85,6 +93,22 @@ class AppSettings(BaseSettings): VALIDATE_OUT_OF_SYNC_SUBSCRIPTIONS: bool = False FILTER_BY_MODE: Literal["partial", "exact"] = "exact" + def __init__(self) -> None: + super(AppSettings, self).__init__() + self.DATABASE_URI = PostgresDsn(convert_database_uri(str(self.DATABASE_URI))) + + +def convert_database_uri(db_uri: str) -> str: + if db_uri.startswith(("postgresql://", "postgresql+psycopg2://")): + db_uri = "postgresql+psycopg" + db_uri[db_uri.find("://") :] + warnings.filterwarnings("always", category=OrchestratorDeprecationWarning) + warnings.warn( + "DATABASE_URI converted to postgresql+psycopg:// format, please update your enviroment variable", + OrchestratorDeprecationWarning, + stacklevel=2, + ) + return db_uri + app_settings = AppSettings() diff --git a/orchestrator/utils/redis.py b/orchestrator/utils/redis.py index ca66d17f2..31740b038 100644 --- a/orchestrator/utils/redis.py +++ b/orchestrator/utils/redis.py @@ -17,22 +17,22 @@ from typing import Any, Callable from uuid import UUID -import redis.exceptions from anyio import CancelScope, get_cancelled_exc_class -from redis import Redis from redis.asyncio import Redis as AIORedis from redis.asyncio.client import Pipeline, PubSub -from redis.asyncio.retry import Retry -from redis.backoff import EqualJitterBackoff from structlog import get_logger from orchestrator.services.subscriptions import _generate_etag from orchestrator.settings import app_settings from orchestrator.utils.json import PY_JSON_TYPES, json_dumps, json_loads +from orchestrator.utils.redis_client import ( + create_redis_asyncio_client, + create_redis_client, +) logger = get_logger(__name__) -cache = Redis.from_url(str(app_settings.CACHE_URI)) +cache = create_redis_client(app_settings.CACHE_URI) ONE_WEEK = 3600 * 24 * 7 @@ -136,12 +136,7 @@ class RedisBroadcast: client: AIORedis def __init__(self, redis_url: str): - self.client = AIORedis.from_url( - redis_url, - retry_on_error=[redis.exceptions.ConnectionError], - retry_on_timeout=True, - retry=Retry(EqualJitterBackoff(base=0.05), 2), - ) + self.client = create_redis_asyncio_client(redis_url) self.redis_url = redis_url @asynccontextmanager diff --git a/orchestrator/utils/redis_client.py b/orchestrator/utils/redis_client.py new file mode 100644 index 000000000..a797ae9ac --- /dev/null +++ b/orchestrator/utils/redis_client.py @@ -0,0 +1,35 @@ +import redis.asyncio +import redis.client +import redis.exceptions +from pydantic import RedisDsn +from redis import Redis +from redis.asyncio import Redis as AIORedis +from redis.asyncio.retry import Retry as AIORetry +from redis.backoff import EqualJitterBackoff +from redis.retry import Retry + +from orchestrator.settings import app_settings + +REDIS_RETRY_ON_ERROR = [redis.exceptions.ConnectionError] +REDIS_RETRY_ON_TIMEOUT = True +REDIS_RETRY_BACKOFF = EqualJitterBackoff(base=0.05) + + +def create_redis_client(redis_url: str | RedisDsn) -> redis.client.Redis: + """Create sync Redis client for the given Redis DSN with retry handling for connection errors and timeouts.""" + return Redis.from_url( + str(redis_url), + retry_on_error=REDIS_RETRY_ON_ERROR, # type: ignore[arg-type] + retry_on_timeout=REDIS_RETRY_ON_TIMEOUT, + retry=Retry(REDIS_RETRY_BACKOFF, app_settings.REDIS_RETRY_COUNT), + ) + + +def create_redis_asyncio_client(redis_url: str | RedisDsn) -> redis.asyncio.client.Redis: + """Create async Redis client for the given Redis DSN with retry handling for connection errors and timeouts.""" + return AIORedis.from_url( + str(redis_url), + retry_on_error=REDIS_RETRY_ON_ERROR, # type: ignore[arg-type] + retry_on_timeout=REDIS_RETRY_ON_TIMEOUT, + retry=AIORetry(REDIS_RETRY_BACKOFF, app_settings.REDIS_RETRY_COUNT), + ) diff --git a/pyproject.toml b/pyproject.toml index f21bfaafb..b649bd09e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,9 +47,9 @@ dependencies = [ "itsdangerous", "Jinja2==3.1.5", "orjson==3.10.15", - "psycopg2-binary==2.9.10", - "pydantic[email]~=2.8.2", - "pydantic-settings~=2.7.1", + "psycopg[binary]==3.2.5", + "pydantic[email]~=2.10.6", + "pydantic-settings~=2.8.0", "python-dateutil==2.8.2", "python-rapidjson>=1.18,<1.21", "pytz==2025.1", @@ -59,13 +59,13 @@ dependencies = [ "SQLAlchemy==2.0.38", "SQLAlchemy-Utils==0.41.2", "structlog", - "typer==0.15.1", - "uvicorn[standard]~=0.32.0", + "typer==0.15.2", + "uvicorn[standard]~=0.34.0", "nwa-stdlib~=1.9.0", "oauth2-lib~=2.4.0", "tabulate==0.9.0", "strawberry-graphql>=0.246.2", - "pydantic-forms~=1.3.0", + "pydantic-forms~=1.4.0", ] description-file = "README.md" @@ -89,7 +89,7 @@ test = [ "jsonref", "mypy==1.9", "pyinstrument", - "pytest==8.3.4", + "pytest==8.3.5", "pytest-asyncio==0.21.2", "pytest-codspeed", "pytest-cov", diff --git a/test/unit_tests/api/test_subscriptions.py b/test/unit_tests/api/test_subscriptions.py index 54a03f281..157b5dfcb 100644 --- a/test/unit_tests/api/test_subscriptions.py +++ b/test/unit_tests/api/test_subscriptions.py @@ -5,7 +5,6 @@ from uuid import uuid4 import pytest -from redis.client import Redis from nwastdlib.url import URL from orchestrator import app_settings @@ -34,6 +33,7 @@ from orchestrator.targets import Target from orchestrator.utils.json import json_dumps, json_loads from orchestrator.utils.redis import to_redis +from orchestrator.utils.redis_client import create_redis_client from orchestrator.workflow import ProcessStatus from test.unit_tests.config import ( IMS_CIRCUIT_ID, @@ -734,17 +734,23 @@ def search(keyword): assert not failed, f"Could not find '{subscription_id}' by all keywords; {succeeded=} {failed=}" -def test_subscription_detail_with_domain_model(test_client, generic_subscription_1): +def test_subscription_detail_with_domain_model(test_client, generic_subscription_1, benchmark): # test with a subscription that has domain model and without - response = test_client.get(URL("api/subscriptions/domain-model") / generic_subscription_1) + @benchmark + def response(): + return test_client.get(URL("api/subscriptions/domain-model") / generic_subscription_1) + assert response.status_code == HTTPStatus.OK # Check hierarchy assert response.json()["pb_1"]["rt_1"] == "Value1" -def test_subscription_detail_with_domain_model_does_not_exist(test_client, generic_subscription_1): +def test_subscription_detail_with_domain_model_does_not_exist(test_client, generic_subscription_1, benchmark): # test with a subscription that has domain model and without - response = test_client.get(URL("api/subscriptions/domain-model") / uuid4()) + @benchmark + def response(): + return test_client.get(URL("api/subscriptions/domain-model") / uuid4()) + assert response.status_code == HTTPStatus.NOT_FOUND @@ -774,7 +780,7 @@ def test_subscription_detail_with_domain_model_if_none_match(test_client, generi @pytest.mark.skipif( not getenv("AIOCACHE_DISABLE", "0") == "0", reason="AIOCACHE must be enabled for this test to do anything" ) -def test_subscription_detail_with_domain_model_cache(test_client, generic_subscription_1): +def test_subscription_detail_with_domain_model_cache(test_client, generic_subscription_1, benchmark): # test with a subscription that has domain model and without subscription = SubscriptionModel.from_subscription(generic_subscription_1) extended_model = build_extended_domain_model(subscription) @@ -784,9 +790,11 @@ def test_subscription_detail_with_domain_model_cache(test_client, generic_subscr to_redis(extended_model) - response = test_client.get(URL("api/subscriptions/domain-model") / generic_subscription_1) + @benchmark + def response(): + return test_client.get(URL("api/subscriptions/domain-model") / generic_subscription_1) - cache = Redis.from_url(str(app_settings.CACHE_URI)) + cache = create_redis_client(app_settings.CACHE_URI) result = cache.get(f"orchestrator:domain:{generic_subscription_1}") cached_model = json_dumps(json_loads(result)) cached_etag = cache.get(f"orchestrator:domain:etag:{generic_subscription_1}") diff --git a/test/unit_tests/conftest.py b/test/unit_tests/conftest.py index 2473625d4..beeaf88e1 100644 --- a/test/unit_tests/conftest.py +++ b/test/unit_tests/conftest.py @@ -11,7 +11,6 @@ from alembic import command from alembic.config import Config from pydantic import BaseModel as PydanticBaseModel -from redis import Redis from sqlalchemy import create_engine, select, text from sqlalchemy.engine.url import make_url from sqlalchemy.orm.scoping import scoped_session @@ -36,6 +35,7 @@ from orchestrator.settings import app_settings from orchestrator.types import SubscriptionLifecycle from orchestrator.utils.json import json_dumps +from orchestrator.utils.redis_client import create_redis_client from pydantic_forms.core import FormPage from test.unit_tests.fixtures.processes import mocked_processes, mocked_processes_resumeall, test_workflow # noqa: F401 from test.unit_tests.fixtures.products.product_blocks.product_block_list_nested import ( # noqa: F401 @@ -134,6 +134,21 @@ CUSTOMER_ID: str = "2f47f65a-0911-e511-80d0-005056956c1a" +CLI_OPT_MONITOR_SQLALCHEMY = "--monitor-sqlalchemy" + + +def pytest_addoption(parser): + """Define custom pytest commandline options.""" + parser.addoption( + CLI_OPT_MONITOR_SQLALCHEMY, + action="store_true", + default=False, + help=( + "When set, activate query monitoring for tests instrumented with monitor_sqlalchemy. " + "Note that this has a certain overhead on execution time." + ), + ) + def run_migrations(db_uri: str) -> None: """Configure the alembic context and run the migrations. @@ -174,7 +189,7 @@ def db_uri(worker_id): Database uri to be used in the test thread """ - database_uri = os.environ.get("DATABASE_URI", "postgresql://nwa:nwa@localhost/orchestrator-core-test") + database_uri = os.environ.get("DATABASE_URI", "postgresql+psycopg://nwa:nwa@localhost/orchestrator-core-test") if worker_id == "master": # pytest is being run without any workers return database_uri @@ -205,9 +220,9 @@ def database(db_uri): url.database = "postgres" engine = create_engine(url) with closing(engine.connect()) as conn: - conn.execute(text("COMMIT;")) - conn.execute(text(f'DROP DATABASE IF EXISTS "{db_to_create}";')) - conn.execute(text("COMMIT;")) + conn.commit() + conn.execution_options(isolation_level="AUTOCOMMIT").execute(text(f'DROP DATABASE IF EXISTS "{db_to_create}";')) + conn.commit() conn.execute(text(f'CREATE DATABASE "{db_to_create}";')) run_migrations(db_uri) @@ -218,8 +233,10 @@ def database(db_uri): finally: db.wrapped_database.engine.dispose() with closing(engine.connect()) as conn: - conn.execute(text("COMMIT;")) - conn.execute(text(f'DROP DATABASE IF EXISTS "{db_to_create}";')) + conn.commit() + conn.execution_options(isolation_level="AUTOCOMMIT").execute( + text(f'DROP DATABASE IF EXISTS "{db_to_create}";') + ) @pytest.fixture(autouse=True) @@ -430,6 +447,65 @@ def generic_product_block_3(generic_resource_type_2): return pb +@pytest.fixture +def generic_referencing_product_block_1(generic_resource_type_1, generic_root_product_block_1): + pb = ProductBlockTable( + name="PB_1", + description="Generic Referencing Product Block 1", + tag="PB1", + status="active", + resource_types=[generic_resource_type_1], + created_at=datetime.datetime.fromisoformat("2023-05-24T00:00:00+00:00"), + depends_on_block_relations=[generic_root_product_block_1], + in_use_by_block_relations=[], + ) + db.session.add(pb) + db.session.commit() + return pb + + +@pytest.fixture +def generic_root_product_block_1(generic_resource_type_3): + pb = ProductBlockTable( + name="PB_Root_1", + description="Generic Root Product Block 1", + tag="PBR1", + status="active", + resource_types=[generic_resource_type_3], + created_at=datetime.datetime.fromisoformat("2023-05-24T00:00:00+00:00"), + in_use_by_block_relations=[], + depends_on_block_relations=[], + ) + db.session.add(pb) + db.session.commit() + return pb + + +@pytest.fixture +def generic_product_block_chain(generic_resource_type_3): + + pb_2 = ProductBlockTable( + name="PB_Chained_2", + description="Generic Product Block 2", + tag="PB2", + status="active", + resource_types=[generic_resource_type_3], + created_at=datetime.datetime.fromisoformat("2023-05-24T00:00:00+00:00"), + ) + pb_1 = ProductBlockTable( + name="PB_Chained_1", + description="Generic Product Block 1", + tag="PB1", + status="active", + resource_types=[generic_resource_type_3], + created_at=datetime.datetime.fromisoformat("2023-05-24T00:00:00+00:00"), + depends_on=[pb_2], + ) + db.session.add_all([pb_1, pb_2]) + db.session.commit() + return pb_1, pb_2 + + @pytest.fixture def generic_product_1(generic_product_block_1, generic_product_block_2): workflow = db.session.scalar(select(WorkflowTable).where(WorkflowTable.name == "modify_note")) @@ -480,6 +556,22 @@ def generic_product_3(generic_product_block_2): return p +@pytest.fixture +def generic_product_4(generic_product_block_chain): + pb_1, pb_2 = generic_product_block_chain + p = ProductTable( + name="Product 4", + description="Generic Product Four", + product_type="Generic", + status="active", + tag="GEN3", + product_blocks=[pb_1], + ) + db.session.add(p) + db.session.commit() + return p + + @pytest.fixture def generic_product_block_type_1(generic_product_block_1): class GenericProductBlockOneInactive(ProductBlockModel, product_block_name="PB_1"): @@ -644,7 +736,7 @@ def cache_fixture(monkeypatch): """Fixture to enable domain model caching and cleanup keys added to the list.""" with monkeypatch.context() as m: m.setattr(app_settings, "CACHE_DOMAIN_MODELS", True) - cache = Redis.from_url(str(app_settings.CACHE_URI)) + cache = create_redis_client(app_settings.CACHE_URI) # Clear cache before using this fixture cache.flushdb() @@ -669,10 +761,12 @@ def refresh_subscriptions_search_view(): @pytest.fixture -def monitor_sqlalchemy(): +def monitor_sqlalchemy(pytestconfig, request, capsys): """Can be used to inspect the number of sqlalchemy queries made by part of the code. - Usage: include as fixture, wrap code to measure in context manager, run pytest with option `-s` for stdout + Usage: include this fixture, it returns a context manager. Wrap this around the code you want to inspect. + The inspection is disabled unless you explicitly enable it. + To enable it pass the cli option --monitor-sqlalchemy (see CLI_OPT_MONITOR_SQLALCHEMY). Example: def mytest(monitor_sqlalchemy): @@ -685,20 +779,27 @@ def mytest(monitor_sqlalchemy): """ from orchestrator.db.listeners import disable_listeners, monitor_sqlalchemy_queries - monitor_sqlalchemy_queries() - @contextlib.contextmanager - def context(): + def monitor_queries(): + monitor_sqlalchemy_queries() before = db.session.connection().info.copy() yield after = db.session.connection().info.copy() + disable_listeners() estimated_queries = after["queries_completed"] - before.get("queries_completed", 0) estimated_query_time = after["query_time_spent"] - before.get("query_time_spent", 0.0) - print(f"{estimated_queries:3d} sqlalchemy queries in {estimated_query_time:.2f}s") - yield context + with capsys.disabled(): + print(f"\n{request.node.nodeid} performed {estimated_queries} queries in {estimated_query_time:.2f}s\n") - disable_listeners() + @contextlib.contextmanager + def noop(): + yield + + if pytestconfig.getoption(CLI_OPT_MONITOR_SQLALCHEMY): + yield monitor_queries + else: + yield noop diff --git a/test/unit_tests/domain/test_base_performance.py b/test/unit_tests/domain/test_base_performance.py index 633831e00..b90cc7cf0 100644 --- a/test/unit_tests/domain/test_base_performance.py +++ b/test/unit_tests/domain/test_base_performance.py @@ -1,8 +1,9 @@ from uuid import UUID, uuid4 import pytest +from sqlalchemy import func, select -from orchestrator.db import db +from orchestrator.db import SubscriptionTable, db from orchestrator.domain import SubscriptionModel from orchestrator.types import SubscriptionLifecycle from test.unit_tests.fixtures.products.product_blocks.product_block_one import DummyEnum @@ -79,7 +80,9 @@ def subscription_with_100_horizontal_blocks(create_horizontal_subscription): @pytest.mark.benchmark -def test_subscription_model_horizontal_references(subscription_with_100_horizontal_blocks, test_product_type_one): +def test_subscription_model_horizontal_references( + subscription_with_100_horizontal_blocks, test_product_type_one, monitor_sqlalchemy +): # Note: fixtures only execute once per benchmark and are excluded from the measurement # given @@ -90,8 +93,8 @@ def test_subscription_model_horizontal_references(subscription_with_100_horizont # when - # Include the `monitor_sqlalchemy` fixture and use it as a context manager to see the number of real queries - subscription = ProductTypeOneForTest.from_subscription(subscription_id) + with monitor_sqlalchemy(): # Context does nothing unless you set CLI_OPT_MONITOR_SQLALCHEMY + subscription = ProductTypeOneForTest.from_subscription(subscription_id) # then assert len(subscription.block.sub_block_list) == 100 @@ -103,7 +106,9 @@ def subscription_with_10_vertical_blocks(create_vertical_subscription): @pytest.mark.benchmark -def test_subscription_model_vertical_references(subscription_with_10_vertical_blocks, test_product_type_one_nested): +def test_subscription_model_vertical_references( + subscription_with_10_vertical_blocks, test_product_type_one_nested, monitor_sqlalchemy +): # Note: fixtures only execute once per benchmark and are excluded from the measurement # given @@ -114,8 +119,8 @@ def test_subscription_model_vertical_references(subscription_with_10_vertical_bl # when - # Include the `monitor_sqlalchemy` fixture and use it as a context manager to see the number of real queries - subscription = ProductTypeOneNestedForTest.from_subscription(subscription_id) + with monitor_sqlalchemy(): # Context does nothing unless you set CLI_OPT_MONITOR_SQLALCHEMY + subscription = ProductTypeOneNestedForTest.from_subscription(subscription_id) # then assert subscription.block is not None @@ -123,3 +128,33 @@ def test_subscription_model_vertical_references(subscription_with_10_vertical_bl assert subscription.block.sub_block.sub_block is not None assert subscription.block.sub_block.sub_block.sub_block is not None # no need to check all x levels + + +@pytest.mark.benchmark +def test_subscription_model_vertical_references_save(create_vertical_subscription, monitor_sqlalchemy): + # when + with monitor_sqlalchemy(): + subscription_id = create_vertical_subscription(size=5) + + # then + + # Checks that the subscription was created, without too much overhead + query_check_created = ( + select(func.count()).select_from(SubscriptionTable).where(SubscriptionTable.subscription_id == subscription_id) + ) + assert db.session.scalar(query_check_created) == 1 + + +@pytest.mark.benchmark +def test_subscription_model_horizontal_references_save(create_horizontal_subscription, monitor_sqlalchemy): + # when + with monitor_sqlalchemy(): + subscription_id = create_horizontal_subscription(size=10) + + # then + + # Checks that the subscription was created, without too much overhead + query_check_created = ( + select(func.count()).select_from(SubscriptionTable).where(SubscriptionTable.subscription_id == subscription_id) + ) + assert db.session.scalar(query_check_created) == 1 diff --git a/test/unit_tests/forms/test_display_subscription.py b/test/unit_tests/forms/test_display_subscription.py index f1299e34e..866b2f82a 100644 --- a/test/unit_tests/forms/test_display_subscription.py +++ b/test/unit_tests/forms/test_display_subscription.py @@ -59,11 +59,7 @@ class Form(FormPage): "type": "string", }, "summary": { - "allOf": [ - { - "$ref": "#/$defs/MigrationSummaryValue", - }, - ], + "$ref": "#/$defs/MigrationSummaryValue", "format": "summary", "default": None, "type": "string", diff --git a/test/unit_tests/graphql/test_product.py b/test/unit_tests/graphql/test_product.py index d984c0b6a..1ea1b27ab 100644 --- a/test/unit_tests/graphql/test_product.py +++ b/test/unit_tests/graphql/test_product.py @@ -78,6 +78,36 @@ def get_product_query( ).encode("utf-8") +def get_all_product_names_query( + filter_by: list[str] | None = None, +) -> bytes: + query = """ +query ProductQuery($filterBy: [GraphqlFilter!]) { + products(filterBy: $filterBy) { + page { + allPbNames + } + pageInfo { + endCursor + hasNextPage + hasPreviousPage + startCursor + totalItems + } + } +} + """ + return json.dumps( + { + "operationName": "ProductQuery", + "query": query, + "variables": { + "filterBy": filter_by if filter_by else [], + }, + } + ).encode("utf-8") + + def get_products_with_related_subscriptions_query( first: int = 10, after: int = 0, @@ -196,6 +226,20 @@ def test_product_query( } +def test_all_product_block_names(test_client, generic_product_4): + filter_by = {"filter_by": {"field": "name", "value": "Product 4"}} + data = get_all_product_names_query(**filter_by) + response: Response = test_client.post("/api/graphql", content=data, headers={"Content-Type": "application/json"}) + + assert HTTPStatus.OK == response.status_code + result = response.json() + products_data = result["data"]["products"] + products = products_data["page"] + names = products[0]["allPbNames"] + + assert len(names) == 2 + + def test_product_has_previous_page(test_client, generic_product_1, generic_product_2, generic_product_3): data = get_product_query(after=1, sort_by=[{"field": "name", "order": "ASC"}]) response: Response = test_client.post("/api/graphql", content=data, headers={"Content-Type": "application/json"}) diff --git a/test/unit_tests/graphql/test_settings.py b/test/unit_tests/graphql/test_settings.py index 2082f9fdf..2158eb1d2 100644 --- a/test/unit_tests/graphql/test_settings.py +++ b/test/unit_tests/graphql/test_settings.py @@ -2,10 +2,9 @@ from hashlib import md5 from http import HTTPStatus -from redis import Redis - from orchestrator import app_settings from orchestrator.utils.redis import ONE_WEEK +from orchestrator.utils.redis_client import create_redis_client from test.unit_tests.config import GRAPHQL_ENDPOINT, GRAPHQL_HEADERS @@ -106,7 +105,7 @@ def test_clear_cache_mutation_fails_auth(test_client, monkeypatch): def test_success_clear_cache(test_client, cache_fixture): - cache = Redis.from_url(str(app_settings.CACHE_URI)) + cache = create_redis_client(app_settings.CACHE_URI) key = "some_model_uuid" test_data = {key: {"data": [1, 2, 3]}} diff --git a/test/unit_tests/graphql/test_subscription.py b/test/unit_tests/graphql/test_subscription.py index a88e34c23..7f299cfc1 100644 --- a/test/unit_tests/graphql/test_subscription.py +++ b/test/unit_tests/graphql/test_subscription.py @@ -60,19 +60,27 @@ def build_complex_query(subscription_id): ).encode("utf-8") -def test_single_simple_subscription(fastapi_app_graphql, test_client, product_sub_list_union_subscription_1): +def test_single_simple_subscription(fastapi_app_graphql, test_client, product_sub_list_union_subscription_1, benchmark): test_query = build_simple_query(subscription_id=product_sub_list_union_subscription_1) - response = test_client.post("/api/graphql", content=test_query, headers={"Content-Type": "application/json"}) + + @benchmark + def response(): + return test_client.post("/api/graphql", content=test_query, headers={"Content-Type": "application/json"}) + assert response.status_code == HTTPStatus.OK assert response.json() == {"data": {"subscription": {"insync": True, "status": "ACTIVE"}}} def test_single_complex_subscription( - fastapi_app_graphql, test_client, product_sub_list_union_subscription_1, test_product_type_sub_list_union + fastapi_app_graphql, test_client, product_sub_list_union_subscription_1, test_product_type_sub_list_union, benchmark ): _, _, ProductSubListUnion = test_product_type_sub_list_union test_query = build_complex_query(subscription_id=product_sub_list_union_subscription_1) - response = test_client.post("/api/graphql", content=test_query, headers={"Content-Type": "application/json"}) + + @benchmark + def response(): + return test_client.post("/api/graphql", content=test_query, headers={"Content-Type": "application/json"}) + assert response.status_code == HTTPStatus.OK assert response.json() == { "data": { @@ -86,8 +94,12 @@ def test_single_complex_subscription( } -def test_subscription_does_not_exist(fastapi_app_graphql, test_client): +def test_subscription_does_not_exist(fastapi_app_graphql, test_client, benchmark): test_query = build_simple_query(uuid4()) - response = test_client.post("/api/graphql", content=test_query, headers={"Content-Type": "application/json"}) + + @benchmark + def response(): + return test_client.post("/api/graphql", content=test_query, headers={"Content-Type": "application/json"}) + assert response.status_code == HTTPStatus.OK assert response.json() == {"data": {"subscription": None}} diff --git a/test/unit_tests/graphql/test_subscriptions.py b/test/unit_tests/graphql/test_subscriptions.py index 6d1d0ed86..878b58528 100644 --- a/test/unit_tests/graphql/test_subscriptions.py +++ b/test/unit_tests/graphql/test_subscriptions.py @@ -413,12 +413,15 @@ def get_subscriptions_with_metadata_and_schema_query( ).encode("utf-8") -def test_subscriptions_single_page(test_client, product_type_1_subscriptions_factory): +def test_subscriptions_single_page(test_client, product_type_1_subscriptions_factory, benchmark): # when product_type_1_subscriptions_factory(4) data = get_subscriptions_query() - response = test_client.post("/api/graphql", content=data, headers={"Content-Type": "application/json"}) + + @benchmark + def response(): + return test_client.post("/api/graphql", content=data, headers={"Content-Type": "application/json"}) # then @@ -770,6 +773,34 @@ def test_subscriptions_range_filtering_on_start_date(test_client, product_type_1 assert higher_than_date <= subscription["startDate"] <= lower_than_date +def test_subscriptions_with_exact_filter_by(test_client, product_type_1_subscriptions_factory): + # when + + product_type_1_subscriptions_factory(20) + + with patch.object(app_settings, "FILTER_BY_MODE", "exact"): + data = get_subscriptions_query(filter_by=[{"field": "description", "value": "Subscription 1"}]) + response = test_client.post("/api/graphql", content=data, headers={"Content-Type": "application/json"}) + + # then + + assert HTTPStatus.OK == response.status_code + result = response.json() + subscriptions_data = result["data"]["subscriptions"] + subscriptions = subscriptions_data["page"] + pageinfo = subscriptions_data["pageInfo"] + + assert len(subscriptions) == 1 + assert "errors" not in result + assert pageinfo == { + "hasPreviousPage": False, + "hasNextPage": False, + "startCursor": 0, + "endCursor": 0, + "totalItems": 1, + } + + def test_subscriptions_range_filtering_on_type(test_client, product_type_1_subscriptions_factory): # when @@ -1048,6 +1079,7 @@ def test_single_subscription_with_depends_on_subscriptions( sub_one_subscription_1, sub_two_subscription_1, product_sub_list_union_subscription_1, + benchmark, ): # when @@ -1057,7 +1089,10 @@ def test_single_subscription_with_depends_on_subscriptions( subscription_id = str(product_sub_list_union_subscription_1) data = get_subscriptions_query_with_relations(query_string=subscription_id) - response = test_client.post("/api/graphql", content=data, headers={"Content-Type": "application/json"}) + + @benchmark + def response(): + return test_client.post("/api/graphql", content=data, headers={"Content-Type": "application/json"}) expected_depends_on_ids = { str(subscription.subscription_id) for subscription in [sub_one_subscription_1, sub_two_subscription_1] @@ -1155,6 +1190,7 @@ def test_single_subscription_schema( sub_one_subscription_1, sub_two_subscription_1, product_sub_list_union_subscription_1, + benchmark, ): # when @@ -1163,7 +1199,11 @@ def test_single_subscription_schema( data = get_subscriptions_product_block_json_schema_query( filter_by=[{"field": "subscriptionId", "value": subscription_id}] ) - response = test_client.post("/api/graphql", content=data, headers={"Content-Type": "application/json"}) + + @benchmark + def response(): + return test_client.post("/api/graphql", content=data, headers={"Content-Type": "application/json"}) + # then assert HTTPStatus.OK == response.status_code @@ -1295,7 +1335,7 @@ def test_single_subscription_schema( "customer_id": {"title": "Customer Id", "type": "string"}, "subscription_id": {"format": "uuid", "title": "Subscription Id", "type": "string"}, "description": {"default": "Initial subscription", "title": "Description", "type": "string"}, - "status": {"allOf": [{"$ref": "#/$defs/SubscriptionLifecycle"}], "default": "initial"}, + "status": {"$ref": "#/$defs/SubscriptionLifecycle", "default": "initial"}, "insync": {"default": False, "title": "Insync", "type": "boolean"}, "start_date": { "anyOf": [{"format": "date-time", "type": "string"}, {"type": "null"}],