Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New restapi v1 entrypoint workflow yaml validation #673

Open
wants to merge 17 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
fadfea6
feat(restapi): Add errors for entrypoint workflow yaml validation
andrewhand Oct 24, 2024
cb27d9d
feat(restapi): Added controller for entrypoint workflow yaml validation
andrewhand Oct 31, 2024
0bc7de8
feat(restapi): Added schema for entrypoint workflow yaml validation
andrewhand Oct 31, 2024
d8c11c2
feat(restapi): Added the service layer for entrypoint workflow yaml v…
andrewhand Oct 31, 2024
5a24b92
chore(restapi): Moved building task engine dict from worfkflow servic…
andrewhand Oct 31, 2024
f55950d
chore(restapi): Added doc strings
andrewhand Oct 31, 2024
8086726
test(restapi): Added tests and fixes
andrewhand Nov 14, 2024
23174f8
feat(restapi): Adding entrypoint worflow yaml validation to entrypoin…
andrewhand Nov 14, 2024
35c81d1
fix(restapi): Fixed a bug with plugin id aquistion
andrewhand Nov 14, 2024
25b6b73
feat(restapi): Minor fixes to entrypoint validation in entrypoint end…
andrewhand Nov 15, 2024
a65b650
chore(restapi): removing unneeded line of code
andrewhand Nov 15, 2024
bdbe844
feat(restapi): pushing for debug sesh
andrewhand Nov 20, 2024
1622ded
chore(restapi): removed unneed service from endpoint
andrewhand Nov 20, 2024
5128579
feat(restapi): Moved validate service and helpers to shared folder
andrewhand Nov 22, 2024
d7c3758
Merge branch 'dev' of https://github.com/usnistgov/dioptra into resta…
jtsextonMITRE Feb 6, 2025
5ab64aa
feat(restapi): add client piece and update tests
jtsextonMITRE Feb 6, 2025
3c88fbe
chore: fix spelling error
jtsextonMITRE Feb 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion src/dioptra/client/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@
# ACCESS THE FULL CC BY 4.0 LICENSE HERE:
# https://creativecommons.org/licenses/by/4.0/legalcode
from pathlib import Path
from typing import ClassVar, Final, TypeVar
from typing import ClassVar, Final, TypeVar, Any

from .base import CollectionClient, IllegalArgumentError

T = TypeVar("T")

JOB_FILES_DOWNLOAD: Final[str] = "jobFilesDownload"
ENTRYPOINT_VALIDATION: Final[str] = "entrypointValidate"


class WorkflowsCollectionClient(CollectionClient[T]):
Expand Down Expand Up @@ -86,3 +87,19 @@ def download_job_files(
return self._session.download(
self.url, JOB_FILES_DOWNLOAD, output_path=job_files_path, params=params
)

def validate_entrypoint(
self,
task_graph: str,
plugins: list[int],
entrypoint_parameters: list[dict[str, Any]],
):
payload = {
"taskGraph" : task_graph,
"plugins": plugins,
"parameters": entrypoint_parameters,
}

return self._session.post(
self.url, ENTRYPOINT_VALIDATION, json_=payload
)
19 changes: 19 additions & 0 deletions src/dioptra/restapi/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from flask_restx import Api
from structlog.stdlib import BoundLogger

from dioptra.task_engine.issues import ValidationIssue

LOGGER: BoundLogger = structlog.stdlib.get_logger()


Expand Down Expand Up @@ -324,6 +326,14 @@ def __init__(self):
)


class EntrpointWorkflowYamlValidationError(DioptraError):
"""The entrypoint worklfow yaml is invalid."""

def __init__(self, issues: list[ValidationIssue]):
super().__init__("The entrypoint worklfow yaml is invalid.")
self.issues = issues


# User Errors
class NoCurrentUserError(DioptraError):
"""There is no currently logged-in user."""
Expand Down Expand Up @@ -475,3 +485,12 @@ def handle_mlflow_error(error: MLFlowError):
def handle_base_error(error: DioptraError):
log.debug(error.to_message())
return error_result(error, http.HTTPStatus.BAD_REQUEST, {})

@api.errorhandler(DioptraError)
def handle_entrypoint_workflow_yaml_validation_error(error: EntrpointWorkflowYamlValidationError):
log.debug(error.to_message())
return error_result(
error,
http.HTTPStatus.UNPROCESSABLE_ENTITY,
{"issues": [{"type": str(issue.type), "severity": str(issue.severity), "message": issue.message} for issue in error.args[0]]}
)
60 changes: 60 additions & 0 deletions src/dioptra/restapi/v1/entrypoints/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@
from dioptra.restapi.v1.plugins.service import PluginIdsService
from dioptra.restapi.v1.queues.service import RESOURCE_TYPE as QUEUE_RESOURCE_TYPE
from dioptra.restapi.v1.queues.service import QueueIdsService
from dioptra.restapi.v1.shared.entrypoint_validate_service import (
EntrypointValidateService,
)
from dioptra.restapi.v1.shared.search_parser import construct_sql_query_filters

LOGGER: BoundLogger = structlog.stdlib.get_logger()
Expand All @@ -66,6 +69,7 @@ class EntrypointService(object):
def __init__(
self,
entrypoint_name_service: EntrypointNameService,
entrypoint_validate_service: EntrypointValidateService,
plugin_ids_service: PluginIdsService,
queue_ids_service: QueueIdsService,
group_id_service: GroupIdService,
Expand All @@ -81,6 +85,7 @@ def __init__(
group_id_service: A GroupIdService object.
"""
self._entrypoint_name_service = entrypoint_name_service
self._entrypoint_validate_service = entrypoint_validate_service
self._plugin_ids_service = plugin_ids_service
self._queue_ids_service = queue_ids_service
self._group_id_service = group_id_service
Expand Down Expand Up @@ -112,6 +117,8 @@ def create(
Raises:
EntityExistsError: If a entrypoint with the given name already exists.
"""
# from dioptra.restapi.v1.workflows.service import EntrypointValidateService

log: BoundLogger = kwargs.get("log", LOGGER.new())

duplicate = self._entrypoint_name_service.get(name, group_id=group_id, log=log)
Expand All @@ -136,6 +143,12 @@ def create(
for i, param in enumerate(parameters)
]

self._entrypoint_validate_service.validate(
task_graph=task_graph,
plugin_ids=plugin_ids,
entrypoint_parameters=parameters,
)

new_entrypoint = models.EntryPoint(
name=name,
description=description,
Expand Down Expand Up @@ -311,6 +324,7 @@ class EntrypointIdService(object):
def __init__(
self,
entrypoint_name_service: EntrypointNameService,
entrypoint_validate_service: EntrypointValidateService,
queue_ids_service: QueueIdsService,
) -> None:
"""Initialize the entrypoint service.
Expand All @@ -322,6 +336,7 @@ def __init__(
queue_ids_service: A QueueIdsService object.
"""
self._entrypoint_name_service = entrypoint_name_service
self._entrypoint_validate_service = entrypoint_validate_service
self._queue_ids_service = queue_ids_service

def get(
Expand Down Expand Up @@ -459,6 +474,13 @@ def modify(
for i, param in enumerate(parameters)
]

plugin_ids = [entrypoint_plugin_file.plugin.resource_id for entrypoint_plugin_file in entrypoint.entry_point_plugin_files]
self._entrypoint_validate_service.validate(
task_graph=task_graph,
plugin_ids=plugin_ids,
entrypoint_parameters=parameters,
)

new_entrypoint = models.EntryPoint(
name=name,
description=description,
Expand Down Expand Up @@ -537,6 +559,7 @@ class EntrypointIdPluginsService(object):
def __init__(
self,
entrypoint_id_service: EntrypointIdService,
entrypoint_validate_service: EntrypointValidateService,
plugin_ids_service: PluginIdsService,
queue_ids_service: QueueIdsService,
) -> None:
Expand All @@ -550,6 +573,7 @@ def __init__(
queue_ids_service: A QueueIdsService object.
"""
self._entrypoint_id_service = entrypoint_id_service
self._entrypoint_validate_service = entrypoint_validate_service
self._plugin_ids_service = plugin_ids_service
self._queue_ids_service = queue_ids_service

Expand Down Expand Up @@ -621,6 +645,23 @@ def append(
)
for param in entrypoint.parameters
]

parameters = [
{
"name": param.name,
"default_value": param.default_value,
"parameter_type": param.parameter_type,
"parameter_number": param.parameter_number,
}
for param in entrypoint.parameters
]
plugin_ids = [entrypoint_plugin_file.plugin.resource_id for entrypoint_plugin_file in entrypoint.entry_point_plugin_files]
self._entrypoint_validate_service.validate(
task_graph=entrypoint.task_graph,
plugin_ids=plugin_ids,
entrypoint_parameters=parameters,
)

new_entrypoint = models.EntryPoint(
name=entrypoint.name,
description=entrypoint.description,
Expand Down Expand Up @@ -687,6 +728,7 @@ class EntrypointIdPluginsIdService(object):
def __init__(
self,
entrypoint_id_service: EntrypointIdService,
entrypoint_validate_service: EntrypointValidateService,
queue_ids_service: QueueIdsService,
) -> None:
"""Initialize the entrypoint service.
Expand All @@ -698,6 +740,7 @@ def __init__(
queue_ids_service: A QueueIdsService object.
"""
self._entrypoint_id_service = entrypoint_id_service
self._entrypoint_validate_service = entrypoint_validate_service
self._queue_ids_service = queue_ids_service

def get(
Expand Down Expand Up @@ -792,6 +835,23 @@ def delete(
)
for param in entrypoint.parameters
]

parameters = [
{
"name": param.name,
"default_value": param.default_value,
"parameter_type": param.parameter_type,
"parameter_number": param.parameter_number,
}
for param in entrypoint.parameters
]
plugin_ids = [entrypoint_plugin_file.plugin.resource_id for entrypoint_plugin_file in entrypoint.entry_point_plugin_files]
self._entrypoint_validate_service.validate(
task_graph=entrypoint.task_graph,
plugin_ids=plugin_ids,
entrypoint_parameters=parameters,
)

new_entrypoint = models.EntryPoint(
name=entrypoint.name,
description=entrypoint.description,
Expand Down
117 changes: 117 additions & 0 deletions src/dioptra/restapi/v1/shared/build_task_engine_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from pathlib import Path
from typing import Any, cast

import structlog
import yaml
from structlog.stdlib import BoundLogger

from dioptra.restapi.db import models
from dioptra.task_engine.type_registry import BUILTIN_TYPES

# from .type_coercions import (
# BOOLEAN_PARAM_TYPE,
# FLOAT_PARAM_TYPE,
# INTEGER_PARAM_TYPE,
# STRING_PARAM_TYPE,
# coerce_to_type,
# )

LOGGER: BoundLogger = structlog.stdlib.get_logger()

# EXPLICIT_GLOBAL_TYPES: Final[set[str]] = {
# STRING_PARAM_TYPE,
# BOOLEAN_PARAM_TYPE,
# INTEGER_PARAM_TYPE,
# FLOAT_PARAM_TYPE,
# }
# YAML_FILE_ENCODING: Final[str] = "utf-8"
# YAML_EXPORT_SETTINGS: Final[dict[str, Any]] = {
# "indent": 2,
# "sort_keys": False,
# "encoding": YAML_FILE_ENCODING,
# }


def build_task_engine_dict(
plugins: list[Any],
parameters: dict[str, Any],
task_graph: str,
) -> dict[str, Any]:
"""Build a dictionary representation of a task engine YAML file.

Args:
plugins: The entrypoint's plugin files.
parameters: The entrypoint parameteres.
task_graph: The task graph of the entrypoint.

Returns:
The task engine dictionary.
"""
tasks: dict[str, Any] = {}
parameter_types: dict[str, Any] = {}
for plugin in plugins:
for plugin_file in plugin['plugin_files']:
for task in plugin_file.tasks:
input_parameters = task.input_parameters
output_parameters = task.output_parameters
tasks[task.plugin_task_name] = {
"plugin": _build_plugin_field(plugin['plugin'], plugin_file, task),
}
if input_parameters:
tasks[task.plugin_task_name]["inputs"] = _build_task_inputs(
input_parameters
)
if output_parameters:
tasks[task.plugin_task_name]["outputs"] = _build_task_outputs(
output_parameters
)
for param in input_parameters + output_parameters:
name = param.parameter_type.name
if name not in BUILTIN_TYPES:
parameter_types[name] = param.parameter_type.structure

task_engine_dict = {
"types": parameter_types,
"parameters": parameters,
"tasks": tasks,
"graph": cast(dict[str, Any], yaml.safe_load(task_graph)),
}
return task_engine_dict


def _build_plugin_field(
plugin: models.Plugin, plugin_file: models.PluginFile, task: models.PluginTask
) -> str:
if plugin_file.filename == "__init__.py":
# Omit filename from plugin import path if it is an __init__.py file.
module_parts = [Path(x).stem for x in Path(plugin_file.filename).parts[:-1]]

else:
module_parts = [Path(x).stem for x in Path(plugin_file.filename).parts]

return ".".join([plugin.name, *module_parts, task.plugin_task_name])


def _build_task_inputs(
input_parameters: list[models.PluginTaskInputParameter],
) -> list[dict[str, Any]]:
return [
{
"name": input_param.name,
"type": input_param.parameter_type.name,
"required": input_param.required,
}
for input_param in input_parameters
]


def _build_task_outputs(
output_parameters: list[models.PluginTaskOutputParameter],
) -> list[dict[str, Any]] | dict[str, Any]:
if len(output_parameters) == 1:
return {output_parameters[0].name: output_parameters[0].parameter_type.name}

return [
{output_param.name: output_param.parameter_type.name}
for output_param in output_parameters
]
Loading
Loading