Skip to content

Commit

Permalink
old tests fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
amyasnikov committed Jan 11, 2025
1 parent f12ba55 commit 2f7fedd
Show file tree
Hide file tree
Showing 9 changed files with 38 additions and 39 deletions.
11 changes: 4 additions & 7 deletions validity/scripts/keeper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ class JobKeeper:
"""

job: Job
error_callbacks: list[Callable[[Job], None]] = field(default_factory=list)
error_callback: Callable[["JobKeeper", Exception], None] = lambda *_: None # noqa: E731
logger: Logger = field(default_factory=lambda: di["Logger"])
auto_terminate: bool = True

def __enter__(self):
self.job.start()
Expand All @@ -26,23 +27,19 @@ def __enter__(self):
def __exit__(self, exc_type, exc, tb):
if exc_type:
self.terminate_errored_job(exc)
elif self.job.status == JobStatusChoices.STATUS_RUNNING:
elif self.job.status == JobStatusChoices.STATUS_RUNNING and self.auto_terminate:
self.terminate_job()
self.logger.flush()

def _exec_callbacks(self):
for callback in self.error_callbacks:
callback(self)

def terminate_errored_job(self, error: Exception) -> None:
self._exec_callbacks()
if isinstance(error, AbortScript):
self.logger.messages.extend(error.logs)
self.logger.failure(str(error))
status = error.status
else:
self.logger.log_exception(error)
status = JobStatusChoices.STATUS_ERRORED
self.error_callback(self, error)
self.terminate_job(status=status, error=repr(error))

def terminate_job(
Expand Down
4 changes: 2 additions & 2 deletions validity/scripts/runtests/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,12 @@ def abort_if_apply_errors(self, job_extractor: JobExtractor) -> None:
raise AbortScript("ApplyWorkerError", status=JobStatusChoices.STATUS_ERRORED, logs=error_logs)

def get_job_keeper(self, job: Job) -> JobKeeper:
def error_callback(keeper):
def error_callback(keeper, error):
keeper.logger.info("Database changes have been reverted")
self.testresult_queryset.filter(report_id=keeper.job.object_id).raw_delete()

keeper = self.jobkeeper_factory(job)
keeper.error_callbacks = [error_callback]
keeper.error_callback = error_callback
return keeper

def __call__(self, params: FullRunTestsParams) -> Any:
Expand Down
3 changes: 2 additions & 1 deletion validity/scripts/runtests/split.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass, field
from functools import partial
from itertools import chain, cycle, groupby, repeat
from typing import Any, Callable, Collection, Iterable, Protocol

Expand All @@ -25,7 +26,7 @@ def __call__(
@di.dependency(scope=Singleton)
@dataclass(repr=False)
class SplitWorker:
jobkeeper_factory: Callable[[Job], JobKeeper] = JobKeeper
jobkeeper_factory: Callable[[Job], JobKeeper] = partial(JobKeeper, auto_terminate=False)
datasource_sync_fn: Callable[[Iterable[VDataSource], Q], None] = datasource_sync
backup_fn: BackupFn = bulk_backup
device_batch_size: int = 2000
Expand Down
13 changes: 7 additions & 6 deletions validity/tests/test_scripts/runtests/test_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from validity.scripts.data_models import TestResultRatio as ResultRatio
from validity.scripts.runtests.apply import ApplyWorker, DeviceTestIterator
from validity.scripts.runtests.apply import TestExecutor as TExecutor
from validity.utils.logger import Logger


NS_1 = """
Expand Down Expand Up @@ -48,7 +49,7 @@ def func3(): pass
)
@pytest.mark.django_db
def test_nameset_functions(nameset_texts, extracted_fn_names, warning_calls):
script = TExecutor(1, 2, 10)
script = TExecutor(Logger(), 2, 10)
namesets = [NameSetDBFactory(definitions=d) for d in nameset_texts]
functions = script.nameset_functions(namesets)
assert extracted_fn_names == functions.keys()
Expand All @@ -73,7 +74,7 @@ def test_nameset_functions(nameset_texts, extracted_fn_names, warning_calls):
)
@pytest.mark.django_db
def test_builtins_are_available_in_nameset(definitions):
script = TExecutor(10, 20, 30, extra_globals=DEFAULT_NAMESET)
script = TExecutor(Logger(), 20, 30, extra_globals=DEFAULT_NAMESET)
namesets = [NameSetDBFactory(definitions=definitions)]
functions = script.nameset_functions(namesets)
functions["func"]()
Expand All @@ -89,7 +90,7 @@ def test_run_tests_for_device():
tests[0].run.return_value = True, []
tests[1].run.return_value = False, [("some", "explanation")]
tests[2].run.side_effect = EvalError("some test error")
executor = TExecutor(10, explanation_verbosity=2, report_id=30)
executor = TExecutor(Logger(), explanation_verbosity=2, report_id=30)
results = [
{
"passed": r.passed,
Expand Down Expand Up @@ -132,6 +133,7 @@ def apply_worker():
job_extractor_factory = Mock()
job_extractor_factory.return_value.parent.job.result.slices = [None, {1: [1, 2, 3]}]
return ApplyWorker(
logger=Logger(),
testresult_queryset=Mock(),
test_executor_factory=Mock(return_value=executor),
result_batch_size=100,
Expand Down Expand Up @@ -160,8 +162,7 @@ def test_applyworker_success(full_runtests_params, apply_worker):


@dataclass
class MockLogger:
script_id: str
class MockLogger(Logger):
messages: list = field(default_factory=list, init=False)

def log_exception(self, m):
Expand All @@ -171,6 +172,6 @@ def log_exception(self, m):
@pytest.mark.django_db
def test_applyworker_exception(full_runtests_params, apply_worker):
apply_worker.test_executor_factory = Mock(side_effect=ValueError("some error"))
apply_worker.logger_factory = MockLogger
apply_worker.logger = MockLogger()
result = apply_worker(params=full_runtests_params, worker_id=1)
assert result == ExecutionResult(test_stat=ResultRatio(passed=0, total=0), log=["some error"], errored=True)
4 changes: 2 additions & 2 deletions validity/tests/test_scripts/runtests/test_combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from validity.scripts.data_models import TestResultRatio as ResultRatio
from validity.scripts.exceptions import AbortScript
from validity.scripts.runtests.combine import CombineWorker
from validity.utils.logger import Logger


@pytest.fixture
Expand Down Expand Up @@ -38,9 +39,8 @@ def job_extractor(messages):
# but according to netbox4.0 strange behaviour reverse() finally causes it
@pytest.mark.django_db
def test_compose_logs(worker, messages, job_extractor):
logger = worker.log_factory()
time = messages[0].time
logs = worker.compose_logs(logger, job_extractor, report_id=10)
logs = worker.compose_logs(Logger(), job_extractor, report_id=10)
assert len(logs) == 6
assert logs[:5] == messages
last_msg = replace(logs[-1], time=time)
Expand Down
7 changes: 2 additions & 5 deletions validity/tests/test_scripts/runtests/test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from validity.scripts.data_models import Message, SplitResult
from validity.scripts.runtests.split import SplitWorker
from validity.utils.logger import Logger


@pytest.fixture
Expand Down Expand Up @@ -50,9 +51,7 @@ def split_worker():
def test_distribute_work(split_worker, selectors, worker_num, runtests_params, expected_result, devices):
runtests_params.workers_num = worker_num
runtests_params.selectors = [s.pk for s in selectors]
result = split_worker.distribute_work(
runtests_params, split_worker.log_factory(), runtests_params.get_device_filter()
)
result = split_worker.distribute_work(runtests_params, Logger(), runtests_params.get_device_filter())
assert result == expected_result


Expand Down Expand Up @@ -103,13 +102,11 @@ def test_call(selectors, devices, runtests_params, monkeypatch):
status="info",
message="Running the tests for *2 devices*",
time=datetime.datetime(2000, 1, 1, 0, 0),
script_id=None,
),
Message(
status="info",
message="Distributing the work among 2 workers. Each worker handles 1 device(s) in average",
time=datetime.datetime(2000, 1, 1, 0, 0),
script_id=None,
),
],
slices=[{1: [1]}, {2: [2]}],
Expand Down
15 changes: 10 additions & 5 deletions validity/tests/test_scripts/test_launcher.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import uuid
from dataclasses import asdict
from unittest.mock import Mock

import pytest
Expand All @@ -13,8 +12,8 @@


class ConcreteScriptParams(ScriptParams):
def with_job_info(self, job: Job):
return FullParams(**asdict(self) | {"job": job})
def _full_cls(self):
return FullParams


class FullParams:
Expand All @@ -28,7 +27,13 @@ def __eq__(self, other):
@pytest.fixture
def launcher(db):
report = ComplianceReport.objects.create()
return Launcher(job_name="test_launcher", job_object_factory=lambda: report, rq_queue=Mock(), tasks=[])
return Launcher(
job_name="test_launcher",
job_object_factory=lambda _: report,
rq_queue=Mock(),
tasks=[],
worker_count_fn=lambda _: 3,
)


@pytest.fixture
Expand All @@ -45,7 +50,7 @@ def task_func(): ...
params.schedule_at = schedule_at
launcher.tasks = [Task(task_func, job_timeout=60)]
job = launcher(params)
assert isinstance(job, Job) and job.object == launcher.job_object_factory()
assert isinstance(job, Job) and job.object == launcher.job_object_factory(None)
enqueue_fn = getattr(launcher.rq_queue, "enqueue_at" if schedule_at else "enqueue")
enqueue_fn.assert_called_once()
enqueue_kwargs = enqueue_fn.call_args.kwargs
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from validity.scripts.logger import Logger
from validity.utils.logger import Logger


@pytest.fixture
Expand Down Expand Up @@ -29,7 +29,7 @@ def test_logger(error_with_traceback):
"status": "failure",
"message": (
"Unhandled error occured: `<class 'ValueError'>: error`\n```\n "
'File "/plugin/validity/validity/tests/test_scripts/test_logger.py", '
f'File "{__file__}", '
"""line 10, in error_with_traceback\n raise ValueError("error")\n\n```"""
),
},
Expand Down
16 changes: 7 additions & 9 deletions validity/tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,18 +218,16 @@ def test_get(self, admin_client):
assert resp.status_code == HTTPStatus.OK

@pytest.mark.parametrize(
"form_data, status_code, worker_count",
"form_data, status_code, has_workers",
[
({}, HTTPStatus.FOUND, 1),
({}, HTTPStatus.OK, 0),
({"devices": [1, 2]}, HTTPStatus.OK, 1), # devices do not exist
({}, HTTPStatus.FOUND, True),
({}, HTTPStatus.OK, False),
({"devices": [1, 2]}, HTTPStatus.OK, True), # devices do not exist
],
)
def test_post(self, admin_client, di, form_data, status_code, worker_count):
launcher = Mock(**{"rq_queue.name": "queue_1", "return_value.pk": 1})
with di.override(
{dependencies.runtests_launcher: lambda: launcher, dependencies.runtests_worker_count: lambda: worker_count}
):
def test_post(self, admin_client, di, form_data, status_code, has_workers):
launcher = Mock(**{"has_workers": has_workers, "return_value.pk": 1})
with di.override({dependencies.runtests_launcher: lambda: launcher}):
result = admin_client.post(self.url, form_data)
assert result.status_code == status_code
if status_code == HTTPStatus.FOUND: # if form is valid
Expand Down

0 comments on commit 2f7fedd

Please sign in to comment.