From cb179921d0be1318d348969c163f2bd826bc27f6 Mon Sep 17 00:00:00 2001 From: Keith Manville Date: Wed, 29 Jan 2025 10:34:24 -0500 Subject: [PATCH 01/12] build(deps): upgrade isort version in tox environment to 6.0.0 --- src/dioptra/restapi/v1/experiments/service.py | 4 +++- src/dioptra/restapi/v1/jobs/service.py | 8 ++++++-- .../architectures/tensorflow_layers/backbones.py | 4 +++- tox.ini | 2 +- 4 files changed, 13 insertions(+), 5 deletions(-) diff --git a/src/dioptra/restapi/v1/experiments/service.py b/src/dioptra/restapi/v1/experiments/service.py index df1d1f2ee..c85f4b6a3 100644 --- a/src/dioptra/restapi/v1/experiments/service.py +++ b/src/dioptra/restapi/v1/experiments/service.py @@ -37,7 +37,9 @@ from dioptra.restapi.v1.entrypoints.service import ( RESOURCE_TYPE as ENTRYPOINT_RESOURCE_TYPE, ) -from dioptra.restapi.v1.entrypoints.service import EntrypointIdsService +from dioptra.restapi.v1.entrypoints.service import ( + EntrypointIdsService, +) from dioptra.restapi.v1.groups.service import GroupIdService from dioptra.restapi.v1.shared.search_parser import construct_sql_query_filters diff --git a/src/dioptra/restapi/v1/jobs/service.py b/src/dioptra/restapi/v1/jobs/service.py index be90f77a5..ae1105336 100644 --- a/src/dioptra/restapi/v1/jobs/service.py +++ b/src/dioptra/restapi/v1/jobs/service.py @@ -43,11 +43,15 @@ from dioptra.restapi.v1.entrypoints.service import ( RESOURCE_TYPE as ENTRYPOINT_RESOURCE_TYPE, ) -from dioptra.restapi.v1.entrypoints.service import EntrypointIdService +from dioptra.restapi.v1.entrypoints.service import ( + EntrypointIdService, +) from dioptra.restapi.v1.experiments.service import ( RESOURCE_TYPE as EXPERIMENT_RESOURCE_TYPE, ) -from dioptra.restapi.v1.experiments.service import ExperimentIdService +from dioptra.restapi.v1.experiments.service import ( + ExperimentIdService, +) from dioptra.restapi.v1.groups.service import GroupIdService from dioptra.restapi.v1.queues.service import RESOURCE_TYPE as QUEUE_RESOURCE_TYPE from dioptra.restapi.v1.queues.service import QueueIdService diff --git a/src/dioptra/sdk/object_detection/architectures/tensorflow_layers/backbones.py b/src/dioptra/sdk/object_detection/architectures/tensorflow_layers/backbones.py index 9957d9734..69b7683a7 100644 --- a/src/dioptra/sdk/object_detection/architectures/tensorflow_layers/backbones.py +++ b/src/dioptra/sdk/object_detection/architectures/tensorflow_layers/backbones.py @@ -39,7 +39,9 @@ from tensorflow.keras.applications.efficientnet import ( preprocess_input as efficient_net_preprocess_input, ) - from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2 + from tensorflow.keras.applications.mobilenet_v2 import ( + MobileNetV2, + ) from tensorflow.keras.applications.mobilenet_v2 import ( preprocess_input as mobilenet_v2_preprocess_input, ) diff --git a/tox.ini b/tox.ini index 6df5e06f7..d2aa8498a 100644 --- a/tox.ini +++ b/tox.ini @@ -170,7 +170,7 @@ commands = gitlint {posargs:--config "{tox_root}{/}.gitlint"} [testenv:isort] deps = - isort>=5.6.0 + isort>=6.0.0 skip_install = true commands = isort {posargs:-c -v "{tox_root}{/}src{/}dioptra" "{tox_root}{/}task-plugins{/}dioptra_builtins"} From 3e3f1e504a708cf4963f6d80422125b6dae4442c Mon Sep 17 00:00:00 2001 From: Keith Manville Date: Wed, 29 Jan 2025 10:50:46 -0500 Subject: [PATCH 02/12] chore: resolve black formatting errors --- .../db/alembic/versions/6a343ff9bc2f_.py | 2 +- src/dioptra/task_engine/type_registry.py | 4 ++-- src/dioptra/task_engine/util.py | 2 +- src/dioptra/task_engine/validation.py | 22 +++++++++---------- .../dioptra_builtins/metrics/distance.py | 2 +- .../dioptra_builtins/metrics/performance.py | 2 +- 6 files changed, 17 insertions(+), 17 deletions(-) diff --git a/src/dioptra/restapi/db/alembic/versions/6a343ff9bc2f_.py b/src/dioptra/restapi/db/alembic/versions/6a343ff9bc2f_.py index 64377d285..04fd7f6b1 100644 --- a/src/dioptra/restapi/db/alembic/versions/6a343ff9bc2f_.py +++ b/src/dioptra/restapi/db/alembic/versions/6a343ff9bc2f_.py @@ -1,7 +1,7 @@ """empty message Revision ID: 6a343ff9bc2f -Revises: +Revises: Create Date: 2020-09-11 17:22:38.129031 """ diff --git a/src/dioptra/task_engine/type_registry.py b/src/dioptra/task_engine/type_registry.py index 6702d4159..ccbcc94fe 100644 --- a/src/dioptra/task_engine/type_registry.py +++ b/src/dioptra/task_engine/type_registry.py @@ -308,7 +308,7 @@ def build_type( def get_dependency_types( # noqa: C901 - type_def: Optional[Union[_TypeDefinition, str]] + type_def: Optional[Union[_TypeDefinition, str]], ) -> Iterator[str]: """ Search the given type definition and generate all references to other @@ -395,7 +395,7 @@ def get_sorted_types(type_defs: Mapping[str, _TypeDefinition]) -> list[str]: def build_type_registry( - type_defs: Mapping[str, _TypeDefinition] + type_defs: Mapping[str, _TypeDefinition], ) -> Mapping[str, types.Type]: """ Create a type registry from a set of type definitions. diff --git a/src/dioptra/task_engine/util.py b/src/dioptra/task_engine/util.py index 64a195f4e..33d3117f2 100644 --- a/src/dioptra/task_engine/util.py +++ b/src/dioptra/task_engine/util.py @@ -108,7 +108,7 @@ def step_get_plugin_short_name(step: Mapping[str, Any]) -> Optional[str]: def step_get_invocation_arg_specs( - step_def: Mapping[str, Any] + step_def: Mapping[str, Any], ) -> Union[tuple[list[Any], Mapping[str, Any]], tuple[None, None]]: """ Get invocation positional and keyword arg specs from the given step. This diff --git a/src/dioptra/task_engine/validation.py b/src/dioptra/task_engine/validation.py index d0fad69e2..a47435bc0 100644 --- a/src/dioptra/task_engine/validation.py +++ b/src/dioptra/task_engine/validation.py @@ -28,7 +28,7 @@ def _instance_path_to_description( # noqa: C901 - instance_path: Sequence[Union[int, str]] + instance_path: Sequence[Union[int, str]], ) -> str: """ Create a nice description of the location in an experiment description @@ -238,7 +238,7 @@ def _check_name_collisions(experiment_desc: Mapping[str, Any]) -> list[Validatio def _check_global_parameter_types( - experiment_desc: Mapping[str, Any] + experiment_desc: Mapping[str, Any], ) -> list[ValidationIssue]: """ Check whether all global parameter types are valid. @@ -296,7 +296,7 @@ def _check_global_parameter_types( def _check_type_definition_type_references( - experiment_desc: Mapping[str, Any] + experiment_desc: Mapping[str, Any], ) -> list[ValidationIssue]: """ Check for references to undefined types in all type definitions. @@ -332,7 +332,7 @@ def _check_type_definition_type_references( def _check_type_reference_cycle( - experiment_desc: Mapping[str, Any] + experiment_desc: Mapping[str, Any], ) -> list[ValidationIssue]: """ Check for a reference cycle among type definitions. @@ -367,7 +367,7 @@ def _check_type_reference_cycle( def _check_union_member_duplicates( - experiment_desc: Mapping[str, Any] + experiment_desc: Mapping[str, Any], ) -> list[ValidationIssue]: """ Check for union type definitions for which there is duplication in the @@ -422,7 +422,7 @@ def _check_union_member_duplicates( def _check_task_plugin_references( - experiment_desc: Mapping[str, Any] + experiment_desc: Mapping[str, Any], ) -> list[ValidationIssue]: """ Check whether all task plugin short names refer to known task plugins. @@ -453,7 +453,7 @@ def _check_task_plugin_references( def _check_task_plugin_pyplugs_coords( - experiment_desc: Mapping[str, Any] + experiment_desc: Mapping[str, Any], ) -> list[ValidationIssue]: """ Check task plugin IDs for validity. They must at minimum include a module @@ -492,7 +492,7 @@ def _check_task_plugin_pyplugs_coords( def _check_task_plugin_io_names( - experiment_desc: Mapping[str, Any] + experiment_desc: Mapping[str, Any], ) -> list[ValidationIssue]: """ Check task definitions for duplicate input and output names. @@ -557,7 +557,7 @@ def _check_task_plugin_io_names( def _check_task_plugin_io_types( - experiment_desc: Mapping[str, Any] + experiment_desc: Mapping[str, Any], ) -> list[ValidationIssue]: """ Check task definition input and output type names for validity: whether @@ -616,7 +616,7 @@ def _check_task_plugin_io_types( def _check_graph_references( # noqa: C901 - experiment_desc: Mapping[str, Any] + experiment_desc: Mapping[str, Any], ) -> list[ValidationIssue]: """ Scan for references within task invocations, check whether they are legal, @@ -702,7 +702,7 @@ def _check_graph_references( # noqa: C901 def _check_graph_dependencies( - experiment_desc: Mapping[str, Any] + experiment_desc: Mapping[str, Any], ) -> list[ValidationIssue]: """ Check explicitly declared dependencies for each step and ensure they refer diff --git a/task-plugins/dioptra_builtins/metrics/distance.py b/task-plugins/dioptra_builtins/metrics/distance.py index a737ba2df..906f326f2 100644 --- a/task-plugins/dioptra_builtins/metrics/distance.py +++ b/task-plugins/dioptra_builtins/metrics/distance.py @@ -40,7 +40,7 @@ @pyplugs.register def get_distance_metric_list( - request: List[Dict[str, str]] + request: List[Dict[str, str]], ) -> List[Tuple[str, Callable[..., np.ndarray]]]: """Gets multiple distance metric functions from the registry. diff --git a/task-plugins/dioptra_builtins/metrics/performance.py b/task-plugins/dioptra_builtins/metrics/performance.py index 29d6147b4..e12efd62d 100644 --- a/task-plugins/dioptra_builtins/metrics/performance.py +++ b/task-plugins/dioptra_builtins/metrics/performance.py @@ -41,7 +41,7 @@ @pyplugs.register def get_performance_metric_list( - request: List[Dict[str, str]] + request: List[Dict[str, str]], ) -> List[Tuple[str, Callable[..., float]]]: """Gets multiple performance metric functions from the registry. From dbd1f3e9b57b00151aa337dcd884e3b73bb84669 Mon Sep 17 00:00:00 2001 From: Keith Manville Date: Tue, 28 Jan 2025 15:40:31 -0500 Subject: [PATCH 03/12] feat(restapi): add snapshot id to draft modifications PUT schema This commit updates the schema to allow for changing a draft modifications base snapshot ID via the PUT endpoint. This feature is designed to enable the reconciliation of drafts by allowing the user to signal that their changes are based on a more recent snapshot. This is a breaking change to the REST API. The UI and Python client have both been updated to conform to the change. The ability to change the base snapshot ID is currently disabled via a check in drafts/service.py. It will be enabled with the implementation of the draft commit workflow feature. For now, the only valid value is the latest snapshot ID when the draft was created. Co-authored-by: henrychoy --- src/dioptra/client/drafts.py | 10 ++++- src/dioptra/client/entrypoints.py | 5 ++- src/dioptra/client/experiments.py | 5 ++- src/dioptra/client/models.py | 5 ++- src/dioptra/client/plugin_parameter_types.py | 6 ++- src/dioptra/client/plugins.py | 6 ++- src/dioptra/client/queues.py | 5 ++- src/dioptra/restapi/errors.py | 28 ++++++++++++++ .../restapi/v1/shared/drafts/controller.py | 31 ++++++++++++--- .../restapi/v1/shared/drafts/schema.py | 7 ++++ .../restapi/v1/shared/drafts/service.py | 38 ++++++++++++++++++- src/frontend/src/dialogs/QueueDraftDialog.vue | 2 +- src/frontend/src/services/dataApi.ts | 10 ++++- src/frontend/src/views/QueuesView.vue | 4 +- tests/unit/restapi/lib/asserts.py | 1 - tests/unit/restapi/lib/routines.py | 4 +- 16 files changed, 145 insertions(+), 22 deletions(-) diff --git a/src/dioptra/client/drafts.py b/src/dioptra/client/drafts.py index d87c5052f..e6dd5b3c1 100644 --- a/src/dioptra/client/drafts.py +++ b/src/dioptra/client/drafts.py @@ -373,11 +373,14 @@ def create(self, *resource_ids: str | int, **kwargs) -> T: json_=self._validate_fields(kwargs), ) - def modify(self, *resource_ids: str | int, **kwargs) -> T: + def modify( + self, *resource_ids: str | int, resource_snapshot_id: str | int, **kwargs + ) -> T: """Modify a resource modification draft. Args: *resource_ids: The parent resource ids that own the sub-collection. + resource_snapshot_id: The resource snapshot id the draft is based on. **kwargs: The draft fields to modify. Returns: @@ -389,7 +392,10 @@ def modify(self, *resource_ids: str | int, **kwargs) -> T: """ return self._session.put( self.build_sub_collection_url(*resource_ids), - json_=self._validate_fields(kwargs), + json_={ + "resourceSnapshot": int(resource_snapshot_id), + "resourceData": self._validate_fields(kwargs), + }, ) def delete(self, *resource_ids: str | int) -> T: diff --git a/src/dioptra/client/entrypoints.py b/src/dioptra/client/entrypoints.py index 23459fb78..3be19f8b3 100644 --- a/src/dioptra/client/entrypoints.py +++ b/src/dioptra/client/entrypoints.py @@ -350,7 +350,10 @@ def modify_resource_drafts(self) -> ModifyResourceDraftsSubCollectionClient[T]: # PUT /api/v1/entrypoints/1/draft client.entrypoints.modify_resource_drafts.modify( - 1, name="new-name", description="new-description" + 1, + resource_snapshot_id=1, + name="new-name", + description="new-description" ) # POST /api/v1/entrypoints/1/draft diff --git a/src/dioptra/client/experiments.py b/src/dioptra/client/experiments.py index bbc0d3b31..387996b71 100644 --- a/src/dioptra/client/experiments.py +++ b/src/dioptra/client/experiments.py @@ -509,7 +509,10 @@ def modify_resource_drafts(self) -> ModifyResourceDraftsSubCollectionClient[T]: # PUT /api/v1/experiments/1/draft client.experiments.modify_resource_drafts.modify( - 1, name="new-name", description="new-description" + 1, + resource_snapshot_id=1, + name="new-name", + description="new-description" ) # POST /api/v1/experiments/1/draft diff --git a/src/dioptra/client/models.py b/src/dioptra/client/models.py index 05d8dedcb..0b157ceec 100644 --- a/src/dioptra/client/models.py +++ b/src/dioptra/client/models.py @@ -255,7 +255,10 @@ def modify_resource_drafts(self) -> ModifyResourceDraftsSubCollectionClient[T]: # PUT /api/v1/models/1/draft client.models.modify_resource_drafts.modify( - 1, name="new-name", description="new-description" + 1, + resource_snapshot_id=1, + name="new-name", + description="new-description" ) # POST /api/v1/models/1/draft diff --git a/src/dioptra/client/plugin_parameter_types.py b/src/dioptra/client/plugin_parameter_types.py index 2ed83abfa..6190c8ac5 100644 --- a/src/dioptra/client/plugin_parameter_types.py +++ b/src/dioptra/client/plugin_parameter_types.py @@ -118,7 +118,11 @@ def modify_resource_drafts(self) -> ModifyResourceDraftsSubCollectionClient[T]: # PUT /api/v1/pluginParameterTypes/1/draft client.plugin_parameter_types.modify_resource_drafts.modify( - 1, name="new-name", description="new-description", structure=None + 1, + resource_snapshot_id=1, + name="new-name", + description="new-description", + structure=None ) # POST /api/v1/pluginParameterTypes/1/draft diff --git a/src/dioptra/client/plugins.py b/src/dioptra/client/plugins.py index 2960e7f99..3659ae8ca 100644 --- a/src/dioptra/client/plugins.py +++ b/src/dioptra/client/plugins.py @@ -160,6 +160,7 @@ def modify_resource_drafts(self) -> ModifyResourceDraftsSubCollectionClient[T]: client.plugins.files.modify_resource_drafts.modify( 1, 2, + resource_snapshot_id=1, filename="new_name.py", contents="", tasks=[], @@ -465,7 +466,10 @@ def modify_resource_drafts(self) -> ModifyResourceDraftsSubCollectionClient[T]: # PUT /api/v1/plugins/1/draft client.plugins.modify_resource_drafts.modify( - 1, name="new-name", description="new-description" + 1, + resource_snapshot_id=1, + name="new-name", + description="new-description" ) # POST /api/v1/plugins/1/draft diff --git a/src/dioptra/client/queues.py b/src/dioptra/client/queues.py index d7ece3af1..21b566606 100644 --- a/src/dioptra/client/queues.py +++ b/src/dioptra/client/queues.py @@ -113,7 +113,10 @@ def modify_resource_drafts(self) -> ModifyResourceDraftsSubCollectionClient[T]: # PUT /api/v1/queues/1/draft client.queues.modify_resource_drafts.modify( - 1, name="new-name", description="new-description" + 1, + resource_snapshot_id=1, + name="new-name", + description="new-description" ) # POST /api/v1/queues/1/draft diff --git a/src/dioptra/restapi/errors.py b/src/dioptra/restapi/errors.py index 1bb547ec9..ea6fff6cb 100644 --- a/src/dioptra/restapi/errors.py +++ b/src/dioptra/restapi/errors.py @@ -206,6 +206,20 @@ def __init__(self, type: str, id: int): self.resource_id = id +class InvalidDraftBaseResourceSnapshotError(DioptraError): + """The draft's base snapshot identifier is invalid.""" + + def __init__( + self, + message: str, + base_resource_snapshot_id: int, + provided_resource_snapshot_id: int, + ): + super().__init__(message) + self.base_resource_snapshot_id = base_resource_snapshot_id + self.provided_resource_snapshot_id = provided_resource_snapshot_id + + class SortParameterValidationError(DioptraError): """The sort parameters are not valid.""" @@ -418,6 +432,20 @@ def handle_draft_already_exists(error: DraftAlreadyExistsError): log.debug(error.to_message()) return error_result(error, http.HTTPStatus.BAD_REQUEST, {}) + @api.errorhandler(InvalidDraftBaseResourceSnapshotError) + def handle_invalid_draft_base_resource_snapshot( + error: InvalidDraftBaseResourceSnapshotError, + ): + log.debug(error.to_message()) + return error_result( + error, + http.HTTPStatus.BAD_REQUEST, + { + "base_resource_snapshot_id": error.base_resource_snapshot_id, + "provided_resource_snapshot_id": error.provided_resource_snapshot_id, + }, + ) + @api.errorhandler(LockError) def handle_lock_error(error: LockError): log.debug(error.to_message()) diff --git a/src/dioptra/restapi/v1/shared/drafts/controller.py b/src/dioptra/restapi/v1/shared/drafts/controller.py index 710560ac0..3b9ac9120 100644 --- a/src/dioptra/restapi/v1/shared/drafts/controller.py +++ b/src/dioptra/restapi/v1/shared/drafts/controller.py @@ -25,14 +25,19 @@ from flask_login import login_required from flask_restx import Namespace, Resource from injector import ClassAssistedBuilder, inject -from marshmallow import Schema +from marshmallow import Schema, fields from structlog.stdlib import BoundLogger from dioptra.restapi.db import models from dioptra.restapi.v1 import utils from dioptra.restapi.v1.schemas import IdStatusResponseSchema -from .schema import DraftGetQueryParameters, DraftPageSchema, DraftSchema +from .schema import ( + DraftGetQueryParameters, + DraftPageSchema, + DraftSchema, + ModifyDraftBaseSchema, +) from .service import ( ResourceDraftsIdService, ResourceDraftsService, @@ -233,6 +238,13 @@ def generate_resource_id_draft_endpoint( else: model_name = "Draft" + "".join(request_schema.__name__.rsplit("Schema", 1)) + class ModifyDraftSchema(ModifyDraftBaseSchema): + resourceData = fields.Nested( + request_schema, + attribute="resource_data", + metadata=dict(description="Draft resource data."), + ) + @api.route("//draft") @api.param("id", "ID for the resource.") class ResourcesIdDraftEndpoint(Resource): @@ -261,7 +273,7 @@ def get(self, id: int): ) @login_required - @accepts(schema=request_schema, model_name=model_name, api=api) + @accepts(schema=request_schema, model_name="Create" + model_name, api=api) @responds(schema=DraftSchema, api=api) def post(self, id: int): """Creates a Draft for this resource.""" @@ -277,7 +289,7 @@ def post(self, id: int): ) @login_required - @accepts(schema=request_schema, model_name=model_name, api=api) + @accepts(schema=ModifyDraftSchema, model_name="Modify" + model_name, api=api) @responds(schema=DraftSchema, api=api) def put(self, id: int): """Modifies the Draft for this resource.""" @@ -513,6 +525,13 @@ def generate_nested_resource_id_draft_endpoint( request_schema.__name__.rsplit("Schema", 1) ) + class ModifyDraftSchema(ModifyDraftBaseSchema): + resourceData = fields.Nested( + request_schema, + attribute="resource_data", + metadata=dict(description="Draft resource data."), + ) + @api.route(f"//{resource_route}//draft") @api.param("id", "ID for the resource.") class ResourcesIdDraftEndpoint(Resource): @@ -551,7 +570,7 @@ def get(self, id: int, **kwargs): ) @login_required - @accepts(schema=request_schema, model_name=model_name, api=api) + @accepts(schema=request_schema, model_name="Create" + model_name, api=api) @responds(schema=DraftSchema, api=api) def post(self, id: int, **kwargs): """Creates a Draft for this resource.""" @@ -577,7 +596,7 @@ def post(self, id: int, **kwargs): ) @login_required - @accepts(schema=request_schema, model_name=model_name, api=api) + @accepts(schema=ModifyDraftSchema, model_name="Modify" + model_name, api=api) @responds(schema=DraftSchema, api=api) def put(self, id: int, **kwargs): """Modifies the Draft for this resource.""" diff --git a/src/dioptra/restapi/v1/shared/drafts/schema.py b/src/dioptra/restapi/v1/shared/drafts/schema.py index e0fbae202..1cd62b5fa 100644 --- a/src/dioptra/restapi/v1/shared/drafts/schema.py +++ b/src/dioptra/restapi/v1/shared/drafts/schema.py @@ -108,6 +108,13 @@ class DraftSchema(Schema): ) +class ModifyDraftBaseSchema(Schema): + resourceSnapshot = fields.Integer( + attribute="resource_snapshot_id", + metadata=dict(description="ID of the resource snapshot this draft modifies."), + ) + + class DraftPageSchema(BasePageSchema): """The paged schema for the data stored in a Draft.""" diff --git a/src/dioptra/restapi/v1/shared/drafts/service.py b/src/dioptra/restapi/v1/shared/drafts/service.py index 099d11234..55f90e2b3 100644 --- a/src/dioptra/restapi/v1/shared/drafts/service.py +++ b/src/dioptra/restapi/v1/shared/drafts/service.py @@ -32,6 +32,7 @@ DraftAlreadyExistsError, DraftDoesNotExistError, EntityDoesNotExistError, + InvalidDraftBaseResourceSnapshotError, ) from dioptra.restapi.v1.groups.service import GroupIdService @@ -466,8 +467,43 @@ def modify( if draft is None: return None, num_other_drafts + # NOTE: This check disables the ability to change the base snapshot ID. + # It is scheduled to be removed as part of the draft commit workflow feature. + if draft.payload["resource_snapshot_id"] != payload["resource_snapshot_id"]: + raise InvalidDraftBaseResourceSnapshotError( + "The provided resource snapshot must match the base resource snapshot", + base_resource_snapshot_id=draft.payload["resource_snapshot_id"], + provided_resource_snapshot_id=payload["resource_snapshot_id"], + ) + + if draft.payload["resource_snapshot_id"] > payload["resource_snapshot_id"]: + raise InvalidDraftBaseResourceSnapshotError( + "The provided resource snapshot must be greater than or equal to " + "the base resource snapshot.", + base_resource_snapshot_id=draft.payload["resource_snapshot_id"], + provided_resource_snapshot_id=payload["resource_snapshot_id"], + ) + + snapshot_exists_stmt = ( + select(models.ResourceSnapshot) + .where( + models.ResourceSnapshot.resource_snapshot_id + == payload["resource_snapshot_id"] + and models.ResourceSnapshot.resource_type == draft.resource_type + ) + .exists() + .select() + ) + snapshot_exists = db.session.scalar(snapshot_exists_stmt) + if not snapshot_exists: + raise EntityDoesNotExistError( + draft.resource_type, + resource_snapshot_id=payload["resource_snapshot_id"], + ) + current_timestamp = datetime.datetime.now(tz=datetime.timezone.utc) - draft.payload["resource_data"] = payload + draft.payload["resource_snapshot_id"] = payload["resource_snapshot_id"] + draft.payload["resource_data"] = payload["resource_data"] draft.last_modified_on = current_timestamp if commit: diff --git a/src/frontend/src/dialogs/QueueDraftDialog.vue b/src/frontend/src/dialogs/QueueDraftDialog.vue index f59f24ac0..516b822a6 100644 --- a/src/frontend/src/dialogs/QueueDraftDialog.vue +++ b/src/frontend/src/dialogs/QueueDraftDialog.vue @@ -107,7 +107,7 @@ function emitAddOrEdit() { if(props.queueToDraft.hasDraft) { - emit('updateDraftLinkedToQueue', props.queueToDraft.id, name.value, description.value) + emit('updateDraftLinkedToQueue', props.queueToDraft.id, name.value, description.value, props.queueToDraft.snapshot) } else { emit('addQueue', name.value, description.value, props.queueToDraft.id) } diff --git a/src/frontend/src/services/dataApi.ts b/src/frontend/src/services/dataApi.ts index aad979a5f..1566748f0 100644 --- a/src/frontend/src/services/dataApi.ts +++ b/src/frontend/src/services/dataApi.ts @@ -257,8 +257,14 @@ export async function updateDraft(type: T, draftId: string, return await axios.put(`/api/${type}/drafts/${draftId}`, params) } -export async function updateDraftLinkedtoQueue(queueId: number, name: string, description: string) { - return await axios.put(`/api/queues/${queueId}/draft`, { name, description }) +export async function updateDraftLinkedtoQueue(queueId: number, name: string, description: string, snapshotId: number) { + return await axios.put(`/api/queues/${queueId}/draft`, { + resourceSnapshot: snapshotId, + resourceData: { + name, + description, + } + }) } export async function deleteItem(type: T, id: number) { diff --git a/src/frontend/src/views/QueuesView.vue b/src/frontend/src/views/QueuesView.vue index 752b223a2..ea4fe8ab5 100644 --- a/src/frontend/src/views/QueuesView.vue +++ b/src/frontend/src/views/QueuesView.vue @@ -167,9 +167,9 @@ } } - async function updateDraftLinkedToQueue(queueId, name, description) { + async function updateDraftLinkedToQueue(queueId, name, description, snapshotId) { try { - await api.updateDraftLinkedtoQueue(queueId, name, description) + await api.updateDraftLinkedtoQueue(queueId, name, description, snapshotId) notify.success(`Successfully updated '${name}'`) showDraftDialog.value = false } catch(err) { diff --git a/tests/unit/restapi/lib/asserts.py b/tests/unit/restapi/lib/asserts.py index 28839cc1c..2911eb45e 100644 --- a/tests/unit/restapi/lib/asserts.py +++ b/tests/unit/restapi/lib/asserts.py @@ -121,7 +121,6 @@ def assert_draft_response_contents_matches_expectations( Args: response: The actual response from the API. expected_contents: The expected response from the API. - existing_draft: If the draft is of an existing resource or not. Raises: AssertionError: If the API response does not match the expected response diff --git a/tests/unit/restapi/lib/routines.py b/tests/unit/restapi/lib/routines.py index 74cd0998f..40d537895 100644 --- a/tests/unit/restapi/lib/routines.py +++ b/tests/unit/restapi/lib/routines.py @@ -105,7 +105,9 @@ def run_existing_resource_drafts_tests( ) # Modify operation tests - response = client.modify(*resource_ids, **draft_mod).json() + response = client.modify( + *resource_ids, resource_snapshot_id=response["resourceSnapshot"], **draft_mod + ).json() asserts.assert_draft_response_contents_matches_expectations( response, draft_mod_expected ) From 693317a77ce9f2676bf7fd1e40e356fe2f4648dd Mon Sep 17 00:00:00 2001 From: henrychoy Date: Mon, 23 Dec 2024 11:24:51 -0500 Subject: [PATCH 04/12] feat(frontend): add panel to browse experiment history This commit adds the ability for user to browse Experiment history in the view/edit form. Clicking the toggle will bring out the right drawer containing the list of snapshots. Selecting a row loads the values into the form. Up/down keyboard navigation works once focus is on the table. The table can be sorted. The Queue dialog was also updated to use this SnapshotList component. --- src/frontend/src/App.vue | 46 +++- src/frontend/src/assets/main.css | 3 +- src/frontend/src/components/SnapshotList.vue | 106 +++++++++ .../src/components/TableComponent.vue | 87 +++++-- src/frontend/src/dialogs/DialogComponent.vue | 6 +- src/frontend/src/dialogs/LeaveFormDialog.vue | 4 +- src/frontend/src/dialogs/QueueDialog.vue | 221 ++++++------------ src/frontend/src/router/index.ts | 3 +- src/frontend/src/stores/LoginStore.ts | 4 +- src/frontend/src/views/CreateExperiment.vue | 66 +++++- src/frontend/src/views/CreateJob.vue | 2 +- 11 files changed, 350 insertions(+), 198 deletions(-) create mode 100644 src/frontend/src/components/SnapshotList.vue diff --git a/src/frontend/src/App.vue b/src/frontend/src/App.vue index ca1e8a32f..9f8e804b2 100644 --- a/src/frontend/src/App.vue +++ b/src/frontend/src/App.vue @@ -1,9 +1,28 @@ \ No newline at end of file diff --git a/src/frontend/src/assets/main.css b/src/frontend/src/assets/main.css index d33396368..50bfefca1 100644 --- a/src/frontend/src/assets/main.css +++ b/src/frontend/src/assets/main.css @@ -110,6 +110,7 @@ h1 { font-size: clamp(1rem, 10vw, 4rem) !important; line-height: 1 !important; font-weight: 400 !important; + margin-top: 0; } h2 { @@ -120,7 +121,7 @@ h2 { .field-label { width: 100px; - font-size: 1rem; + font-size: .6em; color: black; } diff --git a/src/frontend/src/components/SnapshotList.vue b/src/frontend/src/components/SnapshotList.vue new file mode 100644 index 000000000..3f9b5ff02 --- /dev/null +++ b/src/frontend/src/components/SnapshotList.vue @@ -0,0 +1,106 @@ + + + \ No newline at end of file diff --git a/src/frontend/src/components/TableComponent.vue b/src/frontend/src/components/TableComponent.vue index 737b89cfa..5c9cbd721 100644 --- a/src/frontend/src/components/TableComponent.vue +++ b/src/frontend/src/components/TableComponent.vue @@ -7,14 +7,16 @@ :filter="filter" selection="single" v-model:selected="selected" - row-key="id" - :class="`q-mt-lg ${isMobile ? '' : '' }`" + :row-key="props.rowKey" + :class="'q-mt-lg'" flat bordered dense v-model:pagination="pagination" @request="onRequest" - :rows-per-page-options="[5,10,15,20,25,50,0]" + :tabindex="props.disableSelect ? '' : '0'" + @keydown="keydown" + :rows-per-page-options="props.showAll ? [0] : [5,10,15,20,25,50,0]" > -
+
-
-
- - -
-
- - -
-
- - -
- + + + + + + + + +
- - - - - - - {{ - new Intl.DateTimeFormat('en-US', { - year: '2-digit', - month: '2-digit', - day: '2-digit', - hour: 'numeric', - minute: 'numeric', - hour12: true - }).format(new Date(snapshot.snapshotCreatedOn)) - }} - - - - - - +
\ No newline at end of file diff --git a/src/frontend/src/router/index.ts b/src/frontend/src/router/index.ts index b7d16b2bd..7a3b15ff9 100644 --- a/src/frontend/src/router/index.ts +++ b/src/frontend/src/router/index.ts @@ -54,7 +54,8 @@ const router = createRouter({ }, { path: '/experiments/:id', - component: () => import('../views/CreateExperiment.vue') + component: () => import('../views/CreateExperiment.vue'), + meta: { type: 'experiments' } }, { path: '/groups', diff --git a/src/frontend/src/stores/LoginStore.ts b/src/frontend/src/stores/LoginStore.ts index 8e0bcd6f7..42e6ec350 100644 --- a/src/frontend/src/stores/LoginStore.ts +++ b/src/frontend/src/stores/LoginStore.ts @@ -44,11 +44,13 @@ export const useLoginStore = defineStore('login', () => { files: {}, }) + const showRightDrawer = ref(false) + const selectedSnapshot = ref() // computed()'s are getters // function()'s are actions - return { loggedInUser, loggedInGroup, groups, users, savedForms }; + return { loggedInUser, loggedInGroup, groups, users, savedForms, showRightDrawer, selectedSnapshot }; }) \ No newline at end of file diff --git a/src/frontend/src/views/CreateExperiment.vue b/src/frontend/src/views/CreateExperiment.vue index 7a959594c..1bb1c231c 100644 --- a/src/frontend/src/views/CreateExperiment.vue +++ b/src/frontend/src/views/CreateExperiment.vue @@ -1,8 +1,25 @@ \ No newline at end of file diff --git a/src/frontend/src/views/PluginFiles.vue b/src/frontend/src/views/PluginFiles.vue index 7b1c1ba8f..48720af72 100644 --- a/src/frontend/src/views/PluginFiles.vue +++ b/src/frontend/src/views/PluginFiles.vue @@ -16,19 +16,6 @@ {{ props.row.tasks.length }} - - Create New Plugin File - - Create New Plugin File - - {{ props.row.group.name }}
- - Register a new Plugin Param Type - - Register a new Plugin Param Type - - + store.triggerPopup, (newVal) => { + if(newVal) { + showAddDialog.value = true + store.triggerPopup = false + } + }) + \ No newline at end of file diff --git a/src/frontend/src/views/PluginsView.vue b/src/frontend/src/views/PluginsView.vue index 0b9445215..03b3c2331 100644 --- a/src/frontend/src/views/PluginsView.vue +++ b/src/frontend/src/views/PluginsView.vue @@ -38,19 +38,7 @@ /> - - Register a new Plugin - - Register a new Plugin - - + store.triggerPopup, (newVal) => { + if(newVal) { + showPluginDialog.value = true + store.triggerPopup = false + } + }) + const plugins = ref([]) const editing = ref(false) diff --git a/src/frontend/src/views/QueuesView.vue b/src/frontend/src/views/QueuesView.vue index ea4fe8ab5..cdd0dd805 100644 --- a/src/frontend/src/views/QueuesView.vue +++ b/src/frontend/src/views/QueuesView.vue @@ -24,19 +24,7 @@ /> - - Register a new Queue - - Register a new Queue - - + store.triggerPopup, (newVal) => { + if(newVal) { + showQueueDialog.value = true + store.triggerPopup = false + } + }) + diff --git a/src/frontend/src/views/TagsView.vue b/src/frontend/src/views/TagsView.vue index 21125863f..f248cf542 100644 --- a/src/frontend/src/views/TagsView.vue +++ b/src/frontend/src/views/TagsView.vue @@ -17,19 +17,7 @@ - - Register a new Tag - - Register a new Tag - - + store.triggerPopup, (newVal) => { + if(newVal) { + showAddDialog.value = true + store.triggerPopup = false + } + }) + diff --git a/tests/unit/restapi/lib/mock_mlflow.py b/tests/unit/restapi/lib/mock_mlflow.py index f2e1835a4..a7fb7a851 100644 --- a/tests/unit/restapi/lib/mock_mlflow.py +++ b/tests/unit/restapi/lib/mock_mlflow.py @@ -49,9 +49,7 @@ def get_run(self, id: str) -> MockMlflowRun: # find the latest metric for each metric name if ( metric.key not in output_metrics - # use >= here since we append to the list in log_metric and want to make sure we - # return the latest metric available if they have the same timestamp - or metric.timestamp >= output_metrics[metric.key].timestamp + or metric.timestamp > output_metrics[metric.key].timestamp ): output_metrics[metric.key] = metric From c8a4c8d4e5e663fa8418a526bb0d18577a58fd44 Mon Sep 17 00:00:00 2001 From: "James K. Glasbrenner" Date: Thu, 30 Jan 2025 17:31:09 -0500 Subject: [PATCH 08/12] fix: add support for draft field name converter logic in DioptraClient This commit extends the `NewResourceDraftsSubCollectionClient` and `ModifyResourceDraftsSubCollectionClient` classes to accept a new `convert_field_names_fn` argument, which takes a callable that implements the `ConvertFieldNamesToCamelCaseProtocol` interface. A utility function `make_field_names_to_camel_case_converter` that builds a callable that uses a Python dictionary for the name conversions and checks for field name collisions during conversion is provided. A new conversion callable is added to the entrypoints collection client to handle converting task_graph to taskGraph. The existing integration tests for the entrypoints drafts functionality have been updated to use task_graph instead of taskGraph, and a new unit test has been added to validate that a user cannot pass task_graph and taskGraph as arguments to the entrypoint drafts clients at the same time. Closes #727 --- src/dioptra/client/base.py | 4 + src/dioptra/client/drafts.py | 70 ++++++++++++- src/dioptra/client/entrypoints.py | 10 ++ tests/unit/restapi/v1/test_entrypoint.py | 122 +++++++++++++++++++++-- 4 files changed, 190 insertions(+), 16 deletions(-) diff --git a/src/dioptra/client/base.py b/src/dioptra/client/base.py index d1141926d..2b511848b 100644 --- a/src/dioptra/client/base.py +++ b/src/dioptra/client/base.py @@ -32,6 +32,10 @@ class DioptraClientError(Exception): """Base class for client errors""" +class FieldNameCollisionError(DioptraClientError): + """Raised when two field names will collide after conversion to camel case.""" + + class FieldsValidationError(DioptraClientError): """Raised when one or more fields are invalid.""" diff --git a/src/dioptra/client/drafts.py b/src/dioptra/client/drafts.py index e6dd5b3c1..32b6541db 100644 --- a/src/dioptra/client/drafts.py +++ b/src/dioptra/client/drafts.py @@ -21,6 +21,7 @@ CollectionClient, DioptraClientError, DioptraSession, + FieldNameCollisionError, FieldsValidationError, SubCollectionClient, SubCollectionUrlError, @@ -33,11 +34,61 @@ class DraftFieldsValidationError(DioptraClientError): """Raised when one or more draft fields are invalid.""" +class ConvertFieldNamesToCamelCaseProtocol(Protocol): + def __call__(self, json_: dict[str, Any]) -> dict[str, Any]: + ... # fmt: skip + + class ValidateDraftFieldsProtocol(Protocol): def __call__(self, json_: dict[str, Any]) -> dict[str, Any]: ... # fmt: skip +def make_field_names_to_camel_case_converter( + name_mapping: dict[str, str], +) -> ConvertFieldNamesToCamelCaseProtocol: + """ + Create a function that uses a dictionary to convert field names passed to the client + to camel case. + + Args: + name_mapping: A dictionary that maps the draft field names to their camel case + equivalents. + + Returns: + A function for converting the draft fields names to camel case. + """ + + def convert_field_names(json_: dict[str, Any]) -> dict[str, Any]: + """Convert the names of the provided draft fields to camel case. + + Args: + json_: The draft fields to convert. + + Returns: + The draft fields with names converted to camel case. + """ + if not name_mapping: + return json_ + + converted_json_: dict[str, Any] = {} + conversions_seen: dict[str, str] = {} + + for key, value in json_.items(): + if (converted_key := name_mapping.get(key, key)) in conversions_seen: + raise FieldNameCollisionError( + f"Camel case conversion failed (reason: duplicates " + f"{conversions_seen[converted_key]} after conversion): {key}" + ) + + converted_json_[converted_key] = value + conversions_seen[converted_key] = key + + return converted_json_ + + return convert_field_names + + def make_draft_fields_validator( draft_fields: set[str], resource_name: str ) -> ValidateDraftFieldsProtocol: @@ -102,6 +153,7 @@ def __init__( validate_fields_fn: ValidateDraftFieldsProtocol, root_collection: CollectionClient[T], parent_sub_collections: list[SubCollectionClient[T]] | None = None, + convert_field_names_fn: ConvertFieldNamesToCamelCaseProtocol | None = None, ) -> None: """Initialize the NewResourceDraftsSubCollectionClient instance. @@ -122,6 +174,9 @@ def __init__( self._parent_sub_collections: list[SubCollectionClient[T]] = ( parent_sub_collections or [] ) + self._convert_field_names = ( + convert_field_names_fn or make_field_names_to_camel_case_converter({}) + ) def get( self, @@ -198,8 +253,7 @@ def create( DraftFieldsValidationError: If one or more draft fields are invalid or missing. """ - - if "group" in kwargs: + if "group" in (kwargs := self._convert_field_names(kwargs)): raise FieldsValidationError( "Invalid argument (reason: keyword is reserved): group" ) @@ -228,7 +282,7 @@ def modify(self, *resource_ids: str | int, draft_id: int, **kwargs) -> T: DraftFieldsValidationError: If one or more draft fields are invalid or missing. """ - if "draftId" in kwargs: + if "draftId" in (kwargs := self._convert_field_names(kwargs)): raise FieldsValidationError( "Invalid argument (reason: keyword is reserved): draftId" ) @@ -322,6 +376,7 @@ def __init__( validate_fields_fn: ValidateDraftFieldsProtocol, root_collection: CollectionClient[T], parent_sub_collections: list[SubCollectionClient[T]] | None = None, + convert_field_names_fn: ConvertFieldNamesToCamelCaseProtocol | None = None, ) -> None: """Initialize the ModifyResourceDraftsSubCollectionClient instance. @@ -342,6 +397,9 @@ def __init__( parent_sub_collections=parent_sub_collections, ) self._validate_fields = validate_fields_fn + self._convert_field_names = ( + convert_field_names_fn or make_field_names_to_camel_case_converter({}) + ) def get_by_id(self, *resource_ids: str | int) -> T: """Get a resource modification draft. @@ -370,7 +428,7 @@ def create(self, *resource_ids: str | int, **kwargs) -> T: """ return self._session.post( self.build_sub_collection_url(*resource_ids), - json_=self._validate_fields(kwargs), + json_=self._validate_fields(self._convert_field_names(kwargs)), ) def modify( @@ -394,7 +452,9 @@ def modify( self.build_sub_collection_url(*resource_ids), json_={ "resourceSnapshot": int(resource_snapshot_id), - "resourceData": self._validate_fields(kwargs), + "resourceData": self._validate_fields( + self._convert_field_names(kwargs) + ), }, ) diff --git a/src/dioptra/client/entrypoints.py b/src/dioptra/client/entrypoints.py index 3be19f8b3..7fbf38bcf 100644 --- a/src/dioptra/client/entrypoints.py +++ b/src/dioptra/client/entrypoints.py @@ -26,6 +26,7 @@ ModifyResourceDraftsSubCollectionClient, NewResourceDraftsSubCollectionClient, make_draft_fields_validator, + make_field_names_to_camel_case_converter, ) from .snapshots import SnapshotsSubCollectionClient from .tags import TagsSubCollectionClient @@ -38,6 +39,9 @@ "queues", "plugins", } +FIELD_NAMES_TO_CAMEL_CASE: Final[dict[str, str]] = { + "task_graph": "taskGraph", +} T = TypeVar("T") @@ -280,6 +284,9 @@ def __init__(self, session: DioptraSession[T]) -> None: resource_name=self.name, ), root_collection=self, + convert_field_names_fn=make_field_names_to_camel_case_converter( + name_mapping=FIELD_NAMES_TO_CAMEL_CASE + ), ) self._modify_resource_drafts = ModifyResourceDraftsSubCollectionClient[T]( session=session, @@ -288,6 +295,9 @@ def __init__(self, session: DioptraSession[T]) -> None: resource_name=self.name, ), root_collection=self, + convert_field_names_fn=make_field_names_to_camel_case_converter( + name_mapping=FIELD_NAMES_TO_CAMEL_CASE + ), ) self._snapshots = SnapshotsSubCollectionClient[T]( session=session, root_collection=self diff --git a/tests/unit/restapi/v1/test_entrypoint.py b/tests/unit/restapi/v1/test_entrypoint.py index 4bf089aea..792244d74 100644 --- a/tests/unit/restapi/v1/test_entrypoint.py +++ b/tests/unit/restapi/v1/test_entrypoint.py @@ -27,7 +27,7 @@ import pytest from flask_sqlalchemy import SQLAlchemy -from dioptra.client.base import DioptraResponseProtocol +from dioptra.client.base import DioptraResponseProtocol, FieldNameCollisionError from dioptra.client.client import DioptraClient from ..lib import helpers, routines @@ -849,7 +849,7 @@ def test_manage_existing_entrypoint_draft( draft = { "name": name, "description": description, - "taskGraph": task_graph, + "task_graph": task_graph, "parameters": parameters, "plugins": plugin_ids, "queues": queue_ids, @@ -857,7 +857,7 @@ def test_manage_existing_entrypoint_draft( draft_mod = { "name": new_name, "description": description, - "taskGraph": task_graph, + "task_graph": task_graph, "parameters": parameters, "plugins": plugin_ids, "queues": queue_ids, @@ -870,7 +870,14 @@ def test_manage_existing_entrypoint_draft( "resource_id": entrypoint["id"], "resource_snapshot_id": entrypoint["snapshot"], "num_other_drafts": 0, - "payload": draft, + "payload": { + "name": name, + "description": description, + "taskGraph": task_graph, + "parameters": parameters, + "plugins": plugin_ids, + "queues": queue_ids, + }, } draft_mod_expected = { "user_id": auth_account["id"], @@ -878,7 +885,14 @@ def test_manage_existing_entrypoint_draft( "resource_id": entrypoint["id"], "resource_snapshot_id": entrypoint["snapshot"], "num_other_drafts": 0, - "payload": draft_mod, + "payload": { + "name": new_name, + "description": description, + "taskGraph": task_graph, + "parameters": parameters, + "plugins": plugin_ids, + "queues": queue_ids, + }, } # Run routine: existing resource drafts tests @@ -915,7 +929,7 @@ def test_manage_new_entrypoint_drafts( "draft1": { "name": "entrypoint1", "description": "new entrypoint", - "taskGraph": "graph", + "task_graph": "graph", "parameters": [], "plugins": [], "queues": [], @@ -923,7 +937,7 @@ def test_manage_new_entrypoint_drafts( "draft2": { "name": "entrypoint2", "description": "entrypoint", - "taskGraph": "graph", + "task_graph": "graph", "parameters": [], "queues": [1, 3], "plugins": [2], @@ -932,7 +946,7 @@ def test_manage_new_entrypoint_drafts( draft1_mod = { "name": "draft1", "description": "new description", - "taskGraph": "graph", + "task_graph": "graph", "parameters": [], "plugins": [], "queues": [], @@ -942,17 +956,38 @@ def test_manage_new_entrypoint_drafts( draft1_expected = { "user_id": auth_account["id"], "group_id": group_id, - "payload": drafts["draft1"], + "payload": { + "name": "entrypoint1", + "description": "new entrypoint", + "taskGraph": "graph", + "parameters": [], + "plugins": [], + "queues": [], + }, } draft2_expected = { "user_id": auth_account["id"], "group_id": group_id, - "payload": drafts["draft2"], + "payload": { + "name": "entrypoint2", + "description": "entrypoint", + "taskGraph": "graph", + "parameters": [], + "queues": [1, 3], + "plugins": [2], + }, } draft1_mod_expected = { "user_id": auth_account["id"], "group_id": group_id, - "payload": draft1_mod, + "payload": { + "name": "draft1", + "description": "new description", + "taskGraph": "graph", + "parameters": [], + "plugins": [], + "queues": [], + }, } # Run routine: existing resource drafts tests @@ -967,6 +1002,71 @@ def test_manage_new_entrypoint_drafts( ) +def test_client_raises_error_on_field_name_collision( + dioptra_client: DioptraClient[DioptraResponseProtocol], + db: SQLAlchemy, + auth_account: dict[str, Any], + registered_entrypoints: dict[str, Any], +) -> None: + """ + Test that the client errors out if both task_graph and taskGraph are passed as + keyword arguments to either the create and modify draft resource sub-collection + clients. + + Given an authenticated user, this test validates the following sequence of actions: + + - The user prepares a payload for either the create or modify draft resource + sub-collection client. The payload contains both task_graph an taskGraph as keys. + - The user submits a create new resource request using the payload, which raises + a FieldNameCollisionError. + - The user submits a modify resource draft request using the payload, which raises + a FieldNameCollisionError. + """ + entrypoint = registered_entrypoints["entrypoint1"] + group_id = auth_account["groups"][0]["id"] + name = "draft" + description = "description" + task_graph = textwrap.dedent( + """# my entrypoint graph + graph: + message: + my_entrypoint: $name + """ + ) + parameters = [ + { + "name": "my_entrypoint_param", + "defaultValue": "my_value", + "parameterType": "string", + } + ] + plugin_ids = [plugin["id"] for plugin in entrypoint["plugins"]] + queue_ids = [queue["id"] for queue in entrypoint["queues"]] + + draft = { + "name": name, + "description": description, + "task_graph": task_graph, + "taskGraph": task_graph, + "parameters": parameters, + "plugins": plugin_ids, + "queues": queue_ids, + } + + with pytest.raises(FieldNameCollisionError): + dioptra_client.entrypoints.new_resource_drafts.create( + group_id=group_id, + **draft, + ) + + with pytest.raises(FieldNameCollisionError): + dioptra_client.entrypoints.modify_resource_drafts.modify( + entrypoint["id"], + resource_snapshot_id=entrypoint["snapshot"], + **draft, + ) + + def test_manage_entrypoint_snapshots( dioptra_client: DioptraClient[DioptraResponseProtocol], db: SQLAlchemy, From 44db9db4fcba5d961153a1e885610523012f8d82 Mon Sep 17 00:00:00 2001 From: Keith Manville Date: Fri, 31 Jan 2025 11:45:37 -0500 Subject: [PATCH 09/12] fix: revert undesired changes a26566b inadvertently undid changes from 17ded57 and e44c2f6. This commit reverts the changes back. --- .../config/nginx/http_restapi.conf | 1 + tests/unit/restapi/lib/mock_mlflow.py | 11 +++++++++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/cookiecutter-templates/cookiecutter-dioptra-deployment/{{cookiecutter.__project_slug}}/config/nginx/http_restapi.conf b/cookiecutter-templates/cookiecutter-dioptra-deployment/{{cookiecutter.__project_slug}}/config/nginx/http_restapi.conf index 6b3bbed71..e4f760d59 100644 --- a/cookiecutter-templates/cookiecutter-dioptra-deployment/{{cookiecutter.__project_slug}}/config/nginx/http_restapi.conf +++ b/cookiecutter-templates/cookiecutter-dioptra-deployment/{{cookiecutter.__project_slug}}/config/nginx/http_restapi.conf @@ -11,6 +11,7 @@ server { location / { root /frontend; + try_files $uri $uri/ /index.html; } location /api/ { diff --git a/tests/unit/restapi/lib/mock_mlflow.py b/tests/unit/restapi/lib/mock_mlflow.py index a7fb7a851..ff788be58 100644 --- a/tests/unit/restapi/lib/mock_mlflow.py +++ b/tests/unit/restapi/lib/mock_mlflow.py @@ -49,7 +49,9 @@ def get_run(self, id: str) -> MockMlflowRun: # find the latest metric for each metric name if ( metric.key not in output_metrics - or metric.timestamp > output_metrics[metric.key].timestamp + # use >= here since we append to the list in log_metric and want to make + # sure we return the most recent metric if they have the same timestamp + or metric.timestamp >= output_metrics[metric.key].timestamp ): output_metrics[metric.key] = metric @@ -59,7 +61,12 @@ def get_run(self, id: str) -> MockMlflowRun: return run def log_metric( - self, id: str, key: str, value: float, step: Optional[int] = None, timestamp: Optional[int] = None + self, + id: str, + key: str, + value: float, + step: Optional[int] = None, + timestamp: Optional[int] = None, ): if id not in active_runs: active_runs[id] = [] From 2eca30a3ae543aa672dac79a3cfb577d64ea427d Mon Sep 17 00:00:00 2001 From: etrapnell-nist Date: Tue, 21 Jan 2025 09:34:30 -0500 Subject: [PATCH 10/12] docs: add docs and tests for promptless dioptra deployments This commit documents ways to configure dioptra deployments without prompts including: - using all default values - overriding one or more default values via the command line - overriding one or more default values via a json file It also adds new tests to verify the generated deployment matches the expected configuration. --- .../getting-started/running-dioptra.rst | 41 ++- .../test_cruft_no_prompt.py | 300 ++++++++++++++++++ tox.ini | 1 + 3 files changed, 335 insertions(+), 7 deletions(-) create mode 100644 tests/cookiecutter_dioptra_deployment/test_cruft_no_prompt.py diff --git a/docs/source/getting-started/running-dioptra.rst b/docs/source/getting-started/running-dioptra.rst index 5e252c950..8c640c6d2 100644 --- a/docs/source/getting-started/running-dioptra.rst +++ b/docs/source/getting-started/running-dioptra.rst @@ -58,13 +58,13 @@ This will generate a setup that is appropriate for testing Dioptra on your perso source venv-deploy/bin/activate python -m pip install --upgrade pip cruft jinja2 - # Run cruft and set the template's variables - cruft create https://github.com/usnistgov/dioptra --checkout $DIOPTRA_BRANCH \ - --directory cookiecutter-templates/cookiecutter-dioptra-deployment +Next, run cruft to begin the deployment process. The following command will run cruft and use all of the default template values. If you wish to configure the deployment other than using the default values, see the :ref:`Applying the template ` section for detailed description of the template values and how to configure them. + +.. code:: sh -Cruft will now run and prompt you to configure the deployment. See the :ref:`Applying the template ` section for detailed description of each prompt. + cruft create https://github.com/usnistgov/dioptra --checkout $DIOPTRA_BRANCH \ + --directory cookiecutter-templates/cookiecutter-dioptra-deployment --no-input -We recommend identifying a location to store datasets you will want to use with Dioptra at this point and setting the ``datasets_directory`` variable accordingly. See the :ref:`Downloading the datasets ` section for more details. Once you have configured your deployment, continue following the instructions for initializing and starting your deployment below. @@ -96,7 +96,9 @@ Applying the template --------------------- Create the folder where you plan to keep the deployment folder and change into it so that it becomes your working directory. -Next, run cruft and apply the Dioptra Deployment template. +Next, run cruft and apply the Dioptra Deployment template. There are four different methods for configuring the deployment. + +1) Have cruft ask for each variable: .. code:: sh @@ -104,6 +106,31 @@ Next, run cruft and apply the Dioptra Deployment template. cruft create https://github.com/usnistgov/dioptra --checkout main \ --directory cookiecutter-templates/cookiecutter-dioptra-deployment +2) Have cruft automatically apply the default template values: + +.. code:: sh + + cruft create https://github.com/usnistgov/dioptra --checkout main \ + --directory cookiecutter-templates/cookiecutter-dioptra-deployment --no-input + +3) Have cruft use default values except for those specified: + +.. code:: sh + + cruft create https://github.com/usnistgov/dioptra --checkout main \ + --directory cookiecutter-templates/cookiecutter-dioptra-deployment --no-input \ + --extra-context '{"datasets_directory": "~/datasets"}' + +4) Have cruft use default values except for those given in a file: + +.. code:: sh + + cruft create https://github.com/usnistgov/dioptra --checkout main \ + --directory cookiecutter-templates/cookiecutter-dioptra-deployment --no-input \ + --extra-context-file overrides.json + +If you selected 1) above, you will now be asked to set the variables needed to customize the generated configuration files. + .. margin:: .. note:: @@ -113,7 +140,6 @@ Next, run cruft and apply the Dioptra Deployment template. If it has, remove it. To start over, re-run the ``cruft`` command. -You will now be asked to set the variables needed to customize the generated configuration files. In most cases you can just use the default value, but there are a few that you may need to customize. Below is a full list of the variables, their default values, and explanations of what they mean. @@ -173,6 +199,7 @@ Below is a full list of the variables, their default values, and explanations of If the provided path is not valid, the `init-deployment.sh` script will fail. More advanced configurations can be achieved by modifying the `docker-compose.override.yml` file. (default: ``""``) + See the :ref:`Downloading the datasets ` section for more details. Example ~~~~~~~ diff --git a/tests/cookiecutter_dioptra_deployment/test_cruft_no_prompt.py b/tests/cookiecutter_dioptra_deployment/test_cruft_no_prompt.py new file mode 100644 index 000000000..d0082d9a5 --- /dev/null +++ b/tests/cookiecutter_dioptra_deployment/test_cruft_no_prompt.py @@ -0,0 +1,300 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode + +import json +import os +import re +import shutil +import subprocess + +import pytest +import yaml + +TEMPLATE_REPO = "https://github.com/usnistgov/dioptra" +BRANCH = "dev" +TEST_DIR = "test_cruft_deploy" + +# Based on cookiecutter.json +EXPECTED_DEFAULTS = { + "deployment_name": "dioptra-deployment", # slugified deployment name + "container_tag": "dev", + "docker_compose_path": "docker compose", + "nginx_server_name": "dioptra.example.org", + "nginx_expose_ports_on_localhost_only": "True", + "postgres_container_tag": "15", + "pgadmin_default_email": "dioptra@example.com", + "num_tensorflow_cpu_workers": "1", + "num_tensorflow_gpu_workers": "0", + "num_pytorch_cpu_workers": "1", + "num_pytorch_gpu_workers": "0", +} + + +@pytest.fixture(scope="session") +def rendered_folder(): + """Set up the test dir and run cruft create""" + if os.path.exists(TEST_DIR): + shutil.rmtree(TEST_DIR) + os.makedirs(TEST_DIR) + + cruft_command = [ + "cruft", + "create", + TEMPLATE_REPO, + "--checkout", + BRANCH, + "--directory", + "cookiecutter-templates/cookiecutter-dioptra-deployment", + "--no-input", + ] + + try: + subprocess.run(cruft_command, cwd=TEST_DIR, check=True) + except subprocess.CalledProcessError as e: + print(f"Cruft command failed: {e}") + assert False + + rendered_folders = os.listdir(TEST_DIR) + assert ( + len(rendered_folders) == 1 + ), "The wrong number of deployment folders was generated." + + rendered_path = os.path.join(TEST_DIR, rendered_folders[0]) + yield rendered_path + + # Clean up after tests run + shutil.rmtree(TEST_DIR) + + +def test_deployment_name(rendered_folder): + deployment_folder_name = os.path.basename(rendered_folder) + assert deployment_folder_name == EXPECTED_DEFAULTS["deployment_name"], ( + f"Deployment folder name mismatch. " + f'Expected "{EXPECTED_DEFAULTS["deployment_name"]}", Got: "{deployment_folder_name}"' + ) + + +def test_dioptra_container_tags(rendered_folder): + """Test that the Dioptra containers have the right tag.""" + expected_tag = EXPECTED_DEFAULTS["container_tag"] + dioptra_prefix = "dioptra/" + docker_compose_file = os.path.join(rendered_folder, "docker-compose.yml") + + assert os.path.exists(docker_compose_file), f"File not found: {docker_compose_file}" + + with open(docker_compose_file, "r") as compose_file: + compose_data = yaml.safe_load(compose_file) + + for service_name, service_details in compose_data.get("services", {}).items(): + image = service_details.get("image", "") + if image.startswith(dioptra_prefix): + provider, service, tag = image.partition(":") + assert tag == expected_tag, ( + f"Service '{service_name}' has incorrect tag. " + f"Expected: '{expected_tag}', Found: '{tag}'" + ) + + +def test_docker_compose_path(rendered_folder): + """Test that the docker compose path was correctly set in init-deployment.sh""" + script_name = "init-deployment.sh" + script_path = os.path.join(rendered_folder, script_name) + expected_value = EXPECTED_DEFAULTS["docker_compose_path"] + + assert os.path.exists(script_path), f"{script_name} not found at: {script_path}" + + with open(script_path, "r") as script_file: + script_content = script_file.read() + + # Ensure cookiecutter placeholder value no longer exists + placeholder_value = " {{ cookiecutter.docker_compose_path }} " + assert ( + placeholder_value not in script_content + ), f"Placeholder {placeholder_value} should not be in {script_path}" + + # Get the docker_compose() function contents + compose_function_pattern = r"^docker_compose\(\) \{.*?^\}$" + match = re.search( + compose_function_pattern, script_content, re.DOTALL | re.MULTILINE + ) + assert match, f"docker_compose() function not found in {script_path}" + function_content = match.group() + + # Expected lines in the init-deployment.sh docker_compose() function + expected_usage_1 = f" if ! {expected_value} " + '"${@}"; then' + expected_usage_2 = f' "{expected_value}, exiting..."' + + assert ( + expected_usage_1 in function_content + ), f"Expected usage '{expected_usage_1} not found in docker_compose () function." + + assert ( + expected_usage_2 in function_content + ), f"Expected usage '{expected_usage_2} not found in docker_compose () function." + + +def test_nginx_expose_ports_on_localhost_only(rendered_folder): + """Test that the nginx service container's ports are assigned correctly.""" + # expected_prefix assuming nginx_expose_ports_on_localhost_only = True + expected_prefix = "127.0.0.1:" + compose_file_path = os.path.join(rendered_folder, "docker-compose.yml") + assert os.path.exists(compose_file_path), f"{compose_file_path} not found." + + with open(compose_file_path, "r") as compose_file: + compose_content = yaml.safe_load(compose_file) + + # Find the nginx service name in the docker compose services, otherwise None + nginx_service_name = next( + ( + service_name + for service_name in compose_content.get("services", {}) + if service_name.endswith("nginx") + ), + None, + ) + assert ( + nginx_service_name + ), f"No service found with a name ending in 'nginx' in {compose_file_path}." + + nginx_service = compose_content["services"][nginx_service_name] + assert nginx_service, f"Nginx service not found in {compose_file_path}" + + ports = nginx_service.get("ports", []) + assert ports, f"No ports defined for the Nginx service in {compose_file_path}" + + for port in ports: + assert port.startswith( + expected_prefix + ), f"Port '{port}' does not start with '{expected_prefix}'." + + +def test_nginx_server_name(rendered_folder): + """Test that the nginx server name propogates to the conf files""" + nginx_config_dir = os.path.join(rendered_folder, "config", "nginx") + expected_name = EXPECTED_DEFAULTS["nginx_server_name"] + conf_files = [ + "http_dbadmin.conf", + "http_minio.conf", + "http_mlflow.conf", + "http_restapi.conf", + "https_dbadmin.conf", + "https_minio.conf", + "https_mlflow.conf", + "https_restapi.conf", + ] + + assert os.path.exists( + nginx_config_dir + ), f"Nginx config directory not found: {nginx_config_dir}" + + for conf_file in conf_files: + conf_path = os.path.join(nginx_config_dir, conf_file) + assert os.path.exists(conf_path), f"'{conf_file}' does not exist" + + with open(conf_path, "r") as file: + conf_content = file.read() + + # Some files have multiple server blocks, get them all + server_blocks = re.findall(r"server\s*{.*?}", conf_content, re.DOTALL) + assert server_blocks, f"No 'server' blocks found in {conf_file}" + + for server_block in server_blocks: + server_name_search = re.search(r"server_name\s+(.+?);", server_block) + assert ( + server_name_search + ), f"Server block without 'server_name' found in {conf_file}" + + found_server_name = server_name_search.group(1).strip() + assert found_server_name == expected_name, ( + f"Expected server_name '{expected_name}' but found " + f"'{found_server_name}' in {conf_file}" + ) + + +def test_pgadmin_default_email(rendered_folder): + env_file_path = os.path.join( + rendered_folder, "envs", f"{EXPECTED_DEFAULTS['deployment_name']}-dbadmin.env" + ) + + assert os.path.exists(env_file_path), f"File not found: {env_file_path}" + with open(env_file_path, "r") as env_file: + content = env_file.read() + expected_line = ( + f"PGADMIN_DEFAULT_EMAIL={EXPECTED_DEFAULTS['pgadmin_default_email']}" + ) + assert expected_line in content, ( + f"Expected line not found in {env_file_path}. " + f"Expected: '{expected_line}" + ) + + +def test_postgres_container_tag(rendered_folder): + """Test that the postgres container has the right tag.""" + expected_tag = EXPECTED_DEFAULTS["postgres_container_tag"] + + docker_compose_file = os.path.join(rendered_folder, "docker-compose.yml") + assert os.path.exists(docker_compose_file), f"File not found: {docker_compose_file}" + + with open(docker_compose_file, "r") as compose_file: + compose_data = yaml.safe_load(compose_file) + + postgres_service = compose_data.get("services", {}).get("dioptra-deployment-db", {}) + image = postgres_service.get("image", "") + assert image.endswith(f":{expected_tag}"), ( + f"Service '{image}' has incorrect tag. " f"Expected: '{expected_tag}'" + ) + + +def test_worker_services(rendered_folder): + """Test that the number of worker containers matches the respective variables.""" + worker_names = { + "num_tensorflow_cpu_workers": "tfcpu", + "num_tensorflow_gpu_workers": "tfgpu", + "num_pytorch_cpu_workers": "pytorchcpu", + "num_pytorch_gpu_workers": "pytorchgpu", + } + + deployment_name = EXPECTED_DEFAULTS["deployment_name"] + docker_compose_path = os.path.join(rendered_folder, "docker-compose.yml") + + assert ( + docker_compose_path + ), f"Docker compose file not found at '{docker_compose_path}'" + + with open(docker_compose_path, "r") as file: + docker_compose = yaml.safe_load(file) + + services = docker_compose.get("services", {}) + + for worker_type, service_suffix in worker_names.items(): + expected_count = int(EXPECTED_DEFAULTS[worker_type]) + + for i in range(1, expected_count + 1): + # ':02d' to ensure that the number is expressed as 2 digits + service_name = f"{deployment_name}-{service_suffix}-{i:02d}" + assert ( + service_name in services + ), f"Missing service: {service_name} for {worker_type}" + + # Check if there is a worker index above what is expected + extra_service_name = ( + f"{deployment_name}-{service_suffix}-{expected_count +1:02d}" + ) + assert extra_service_name not in services, ( + f"Unexpected service: {extra_service_name} for {worker_type}; " + f"expected count of {expected_count}" + ) diff --git a/tox.ini b/tox.ini index d2aa8498a..d455c209d 100644 --- a/tox.ini +++ b/tox.ini @@ -87,6 +87,7 @@ deps = {[pytest]deps} binaryornot>=0.4.0 cookiecutter>=2.0.0,<2.2.0 + cruft>=2.16.0 pytest-cookies skip_install = true commands = python -m pytest {posargs:--template="{tox_root}{/}cookiecutter-templates{/}cookiecutter-dioptra-deployment" "{tox_root}{/}tests{/}cookiecutter_dioptra_deployment"} From 8e1688a037d8c0f6f1b775433a519be957a7ada0 Mon Sep 17 00:00:00 2001 From: Harold Booth Date: Mon, 10 Feb 2025 17:03:35 -0500 Subject: [PATCH 11/12] * docs: fixed missing footnote and expand risk definition --- README.md | 2 +- docs/source/overview/executive-summary.rst | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 4819330dd..80751642b 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ Dioptra is a software test platform for assessing the trustworthy characteristics of artificial intelligence (AI). Trustworthy AI is: valid and reliable, safe, secure and resilient, accountable and transparent, explainable and interpretable, privacy-enhanced, and fair - with harmful bias managed[^1]. -Dioptra supports the Measure function of the [NIST AI Risk Management Framework](https://nist.gov/itl/ai-risk-management-framework/) by providing functionality to assess, analyze, and track identified AI risks. +Dioptra supports the Measure function of the [NIST AI Risk Management Framework](https://nist.gov/itl/ai-risk-management-framework/) by providing functionality to assess, analyze, and track identified AI potential benefits and negative consequences. Dioptra provides a REST API, which can be controlled via an intuitive web interface, a Python client, or any REST client library of the user's choice for designing, managing, executing, and tracking experiments. Details are available in the project documentation available at . diff --git a/docs/source/overview/executive-summary.rst b/docs/source/overview/executive-summary.rst index 9e511aab7..cef7a8a42 100644 --- a/docs/source/overview/executive-summary.rst +++ b/docs/source/overview/executive-summary.rst @@ -16,7 +16,9 @@ .. https://creativecommons.org/licenses/by/4.0/legalcode -Dioptra is a software test platform for assessing the trustworthy characteristics of artificial intelligence (AI). Trustworthy AI is: valid and reliable, safe, secure and resilient, accountable and transparent, explainable and interpretable, privacy-enhanced, and fair - with harmful bias managed1. Dioptra supports the Measure function of the NIST AI Risk Management Framework by providing functionality to assess, analyze, and track identified AI risks. +Dioptra is a software test platform for assessing the trustworthy characteristics of artificial intelligence (AI). Trustworthy AI is: valid and reliable, safe, secure and resilient, accountable and transparent, explainable and interpretable, privacy-enhanced, and fair - with harmful bias managed\ [#f1]_\ . Dioptra supports the Measure function of the NIST AI Risk Management Framework by providing functionality to assess, analyze, and track identified AI potential benefits and negative consequences. + +.. [#f1] https://doi.org/10.6028/NIST.AI.100-1 Use Cases --------- From fca6f4b3206f4f061d3ad8673bcce2720f75e13c Mon Sep 17 00:00:00 2001 From: etrapnell-nist Date: Wed, 12 Feb 2025 15:29:17 -0500 Subject: [PATCH 12/12] docs: set datasets_directory during quickstart --- docs/source/getting-started/running-dioptra.rst | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/docs/source/getting-started/running-dioptra.rst b/docs/source/getting-started/running-dioptra.rst index 8c640c6d2..866e2550a 100644 --- a/docs/source/getting-started/running-dioptra.rst +++ b/docs/source/getting-started/running-dioptra.rst @@ -58,12 +58,15 @@ This will generate a setup that is appropriate for testing Dioptra on your perso source venv-deploy/bin/activate python -m pip install --upgrade pip cruft jinja2 -Next, run cruft to begin the deployment process. The following command will run cruft and use all of the default template values. If you wish to configure the deployment other than using the default values, see the :ref:`Applying the template ` section for detailed description of the template values and how to configure them. +Next, run cruft to begin the deployment process. The following command will run cruft and use all of the default template values except for the `datasets_directory`. If you wish to configure the deployment in a different manner, see the :ref:`Applying the template ` section for detailed description of the template values and how to configure them. + +We recommend identifying a location to store datasets you will want to use with Dioptra at this point and setting the `datasets_directory` variable accordingly. See the :ref:`Downloading the datasets ` section for more details. .. code:: sh - cruft create https://github.com/usnistgov/dioptra --checkout $DIOPTRA_BRANCH \ - --directory cookiecutter-templates/cookiecutter-dioptra-deployment --no-input + cruft create https://github.com/usnistgov/dioptra --checkout main \ + --directory cookiecutter-templates/cookiecutter-dioptra-deployment --no-input \ + --extra-context '{"datasets_directory": "/datasets"}' Once you have configured your deployment, continue following the instructions for initializing and starting your deployment below.