diff --git a/validity/scripts/keeper.py b/validity/scripts/keeper.py index a5b7405..3b7d410 100644 --- a/validity/scripts/keeper.py +++ b/validity/scripts/keeper.py @@ -16,8 +16,9 @@ class JobKeeper: """ job: Job - error_callbacks: list[Callable[[Job], None]] = field(default_factory=list) + error_callback: Callable[["JobKeeper", Exception], None] = lambda *_: None # noqa: E731 logger: Logger = field(default_factory=lambda: di["Logger"]) + auto_terminate: bool = True def __enter__(self): self.job.start() @@ -26,16 +27,11 @@ def __enter__(self): def __exit__(self, exc_type, exc, tb): if exc_type: self.terminate_errored_job(exc) - elif self.job.status == JobStatusChoices.STATUS_RUNNING: + elif self.job.status == JobStatusChoices.STATUS_RUNNING and self.auto_terminate: self.terminate_job() self.logger.flush() - def _exec_callbacks(self): - for callback in self.error_callbacks: - callback(self) - def terminate_errored_job(self, error: Exception) -> None: - self._exec_callbacks() if isinstance(error, AbortScript): self.logger.messages.extend(error.logs) self.logger.failure(str(error)) @@ -43,6 +39,7 @@ def terminate_errored_job(self, error: Exception) -> None: else: self.logger.log_exception(error) status = JobStatusChoices.STATUS_ERRORED + self.error_callback(self, error) self.terminate_job(status=status, error=repr(error)) def terminate_job( diff --git a/validity/scripts/runtests/combine.py b/validity/scripts/runtests/combine.py index 0066caa..278639e 100644 --- a/validity/scripts/runtests/combine.py +++ b/validity/scripts/runtests/combine.py @@ -80,12 +80,12 @@ def abort_if_apply_errors(self, job_extractor: JobExtractor) -> None: raise AbortScript("ApplyWorkerError", status=JobStatusChoices.STATUS_ERRORED, logs=error_logs) def get_job_keeper(self, job: Job) -> JobKeeper: - def error_callback(keeper): + def error_callback(keeper, error): keeper.logger.info("Database changes have been reverted") self.testresult_queryset.filter(report_id=keeper.job.object_id).raw_delete() keeper = self.jobkeeper_factory(job) - keeper.error_callbacks = [error_callback] + keeper.error_callback = error_callback return keeper def __call__(self, params: FullRunTestsParams) -> Any: diff --git a/validity/scripts/runtests/split.py b/validity/scripts/runtests/split.py index 14d3c91..7e5bb02 100644 --- a/validity/scripts/runtests/split.py +++ b/validity/scripts/runtests/split.py @@ -1,4 +1,5 @@ from dataclasses import dataclass, field +from functools import partial from itertools import chain, cycle, groupby, repeat from typing import Any, Callable, Collection, Iterable, Protocol @@ -25,7 +26,7 @@ def __call__( @di.dependency(scope=Singleton) @dataclass(repr=False) class SplitWorker: - jobkeeper_factory: Callable[[Job], JobKeeper] = JobKeeper + jobkeeper_factory: Callable[[Job], JobKeeper] = partial(JobKeeper, auto_terminate=False) datasource_sync_fn: Callable[[Iterable[VDataSource], Q], None] = datasource_sync backup_fn: BackupFn = bulk_backup device_batch_size: int = 2000 diff --git a/validity/tests/test_scripts/runtests/test_apply.py b/validity/tests/test_scripts/runtests/test_apply.py index b6b4612..a83d405 100644 --- a/validity/tests/test_scripts/runtests/test_apply.py +++ b/validity/tests/test_scripts/runtests/test_apply.py @@ -11,6 +11,7 @@ from validity.scripts.data_models import TestResultRatio as ResultRatio from validity.scripts.runtests.apply import ApplyWorker, DeviceTestIterator from validity.scripts.runtests.apply import TestExecutor as TExecutor +from validity.utils.logger import Logger NS_1 = """ @@ -48,7 +49,7 @@ def func3(): pass ) @pytest.mark.django_db def test_nameset_functions(nameset_texts, extracted_fn_names, warning_calls): - script = TExecutor(1, 2, 10) + script = TExecutor(Logger(), 2, 10) namesets = [NameSetDBFactory(definitions=d) for d in nameset_texts] functions = script.nameset_functions(namesets) assert extracted_fn_names == functions.keys() @@ -73,7 +74,7 @@ def test_nameset_functions(nameset_texts, extracted_fn_names, warning_calls): ) @pytest.mark.django_db def test_builtins_are_available_in_nameset(definitions): - script = TExecutor(10, 20, 30, extra_globals=DEFAULT_NAMESET) + script = TExecutor(Logger(), 20, 30, extra_globals=DEFAULT_NAMESET) namesets = [NameSetDBFactory(definitions=definitions)] functions = script.nameset_functions(namesets) functions["func"]() @@ -89,7 +90,7 @@ def test_run_tests_for_device(): tests[0].run.return_value = True, [] tests[1].run.return_value = False, [("some", "explanation")] tests[2].run.side_effect = EvalError("some test error") - executor = TExecutor(10, explanation_verbosity=2, report_id=30) + executor = TExecutor(Logger(), explanation_verbosity=2, report_id=30) results = [ { "passed": r.passed, @@ -132,6 +133,7 @@ def apply_worker(): job_extractor_factory = Mock() job_extractor_factory.return_value.parent.job.result.slices = [None, {1: [1, 2, 3]}] return ApplyWorker( + logger=Logger(), testresult_queryset=Mock(), test_executor_factory=Mock(return_value=executor), result_batch_size=100, @@ -160,8 +162,7 @@ def test_applyworker_success(full_runtests_params, apply_worker): @dataclass -class MockLogger: - script_id: str +class MockLogger(Logger): messages: list = field(default_factory=list, init=False) def log_exception(self, m): @@ -171,6 +172,6 @@ def log_exception(self, m): @pytest.mark.django_db def test_applyworker_exception(full_runtests_params, apply_worker): apply_worker.test_executor_factory = Mock(side_effect=ValueError("some error")) - apply_worker.logger_factory = MockLogger + apply_worker.logger = MockLogger() result = apply_worker(params=full_runtests_params, worker_id=1) assert result == ExecutionResult(test_stat=ResultRatio(passed=0, total=0), log=["some error"], errored=True) diff --git a/validity/tests/test_scripts/runtests/test_combine.py b/validity/tests/test_scripts/runtests/test_combine.py index 09d20c8..77d3ff3 100644 --- a/validity/tests/test_scripts/runtests/test_combine.py +++ b/validity/tests/test_scripts/runtests/test_combine.py @@ -9,6 +9,7 @@ from validity.scripts.data_models import TestResultRatio as ResultRatio from validity.scripts.exceptions import AbortScript from validity.scripts.runtests.combine import CombineWorker +from validity.utils.logger import Logger @pytest.fixture @@ -38,9 +39,8 @@ def job_extractor(messages): # but according to netbox4.0 strange behaviour reverse() finally causes it @pytest.mark.django_db def test_compose_logs(worker, messages, job_extractor): - logger = worker.log_factory() time = messages[0].time - logs = worker.compose_logs(logger, job_extractor, report_id=10) + logs = worker.compose_logs(Logger(), job_extractor, report_id=10) assert len(logs) == 6 assert logs[:5] == messages last_msg = replace(logs[-1], time=time) diff --git a/validity/tests/test_scripts/runtests/test_split.py b/validity/tests/test_scripts/runtests/test_split.py index 359d071..0a56b5c 100644 --- a/validity/tests/test_scripts/runtests/test_split.py +++ b/validity/tests/test_scripts/runtests/test_split.py @@ -10,6 +10,7 @@ from validity.scripts.data_models import Message, SplitResult from validity.scripts.runtests.split import SplitWorker +from validity.utils.logger import Logger @pytest.fixture @@ -50,9 +51,7 @@ def split_worker(): def test_distribute_work(split_worker, selectors, worker_num, runtests_params, expected_result, devices): runtests_params.workers_num = worker_num runtests_params.selectors = [s.pk for s in selectors] - result = split_worker.distribute_work( - runtests_params, split_worker.log_factory(), runtests_params.get_device_filter() - ) + result = split_worker.distribute_work(runtests_params, Logger(), runtests_params.get_device_filter()) assert result == expected_result @@ -103,13 +102,11 @@ def test_call(selectors, devices, runtests_params, monkeypatch): status="info", message="Running the tests for *2 devices*", time=datetime.datetime(2000, 1, 1, 0, 0), - script_id=None, ), Message( status="info", message="Distributing the work among 2 workers. Each worker handles 1 device(s) in average", time=datetime.datetime(2000, 1, 1, 0, 0), - script_id=None, ), ], slices=[{1: [1]}, {2: [2]}], diff --git a/validity/tests/test_scripts/test_launcher.py b/validity/tests/test_scripts/test_launcher.py index dddf39d..a52c1e4 100644 --- a/validity/tests/test_scripts/test_launcher.py +++ b/validity/tests/test_scripts/test_launcher.py @@ -1,5 +1,4 @@ import uuid -from dataclasses import asdict from unittest.mock import Mock import pytest @@ -13,8 +12,8 @@ class ConcreteScriptParams(ScriptParams): - def with_job_info(self, job: Job): - return FullParams(**asdict(self) | {"job": job}) + def _full_cls(self): + return FullParams class FullParams: @@ -28,7 +27,13 @@ def __eq__(self, other): @pytest.fixture def launcher(db): report = ComplianceReport.objects.create() - return Launcher(job_name="test_launcher", job_object_factory=lambda: report, rq_queue=Mock(), tasks=[]) + return Launcher( + job_name="test_launcher", + job_object_factory=lambda _: report, + rq_queue=Mock(), + tasks=[], + worker_count_fn=lambda _: 3, + ) @pytest.fixture @@ -45,7 +50,7 @@ def task_func(): ... params.schedule_at = schedule_at launcher.tasks = [Task(task_func, job_timeout=60)] job = launcher(params) - assert isinstance(job, Job) and job.object == launcher.job_object_factory() + assert isinstance(job, Job) and job.object == launcher.job_object_factory(None) enqueue_fn = getattr(launcher.rq_queue, "enqueue_at" if schedule_at else "enqueue") enqueue_fn.assert_called_once() enqueue_kwargs = enqueue_fn.call_args.kwargs diff --git a/validity/tests/test_scripts/test_logger.py b/validity/tests/test_utils/test_logger.py similarity index 88% rename from validity/tests/test_scripts/test_logger.py rename to validity/tests/test_utils/test_logger.py index 4c6bb5e..41ebd00 100644 --- a/validity/tests/test_scripts/test_logger.py +++ b/validity/tests/test_utils/test_logger.py @@ -1,6 +1,6 @@ import pytest -from validity.scripts.logger import Logger +from validity.utils.logger import Logger @pytest.fixture @@ -29,7 +29,7 @@ def test_logger(error_with_traceback): "status": "failure", "message": ( "Unhandled error occured: `: error`\n```\n " - 'File "/plugin/validity/validity/tests/test_scripts/test_logger.py", ' + f'File "{__file__}", ' """line 10, in error_with_traceback\n raise ValueError("error")\n\n```""" ), }, diff --git a/validity/tests/test_views.py b/validity/tests/test_views.py index aa981a0..faef344 100644 --- a/validity/tests/test_views.py +++ b/validity/tests/test_views.py @@ -218,18 +218,16 @@ def test_get(self, admin_client): assert resp.status_code == HTTPStatus.OK @pytest.mark.parametrize( - "form_data, status_code, worker_count", + "form_data, status_code, has_workers", [ - ({}, HTTPStatus.FOUND, 1), - ({}, HTTPStatus.OK, 0), - ({"devices": [1, 2]}, HTTPStatus.OK, 1), # devices do not exist + ({}, HTTPStatus.FOUND, True), + ({}, HTTPStatus.OK, False), + ({"devices": [1, 2]}, HTTPStatus.OK, True), # devices do not exist ], ) - def test_post(self, admin_client, di, form_data, status_code, worker_count): - launcher = Mock(**{"rq_queue.name": "queue_1", "return_value.pk": 1}) - with di.override( - {dependencies.runtests_launcher: lambda: launcher, dependencies.runtests_worker_count: lambda: worker_count} - ): + def test_post(self, admin_client, di, form_data, status_code, has_workers): + launcher = Mock(**{"has_workers": has_workers, "return_value.pk": 1}) + with di.override({dependencies.runtests_launcher: lambda: launcher}): result = admin_client.post(self.url, form_data) assert result.status_code == status_code if status_code == HTTPStatus.FOUND: # if form is valid