From 8af6b644b38c69429fcd0bec7456220be4f1572d Mon Sep 17 00:00:00 2001 From: Anton Date: Sun, 8 Sep 2024 22:20:04 +0200 Subject: [PATCH] Remove NB scripts, introduce distributed RunTests (#102) * all steps and launcher are ready * introduce dimi, refactor dependencies * wip introducing dimi * introduce dimi * everything except template is ready * gui for running tests * remove two-phase and rollback * run tests button * almost working * fix work splitup * fully working runtests * report table and filterset * remove old scripts * delete scripts migration * api for running tests * fix log debug * netbox version compatibility * final tests * disable fail fast * add system levels * backports to support 3.7 * stop running tests for 3.6 * page-header -> header * remove unneeded template parts * adjust runtests template --- .github/workflows/ci.yml | 6 +- development/.env.example | 1 + development/Dockerfile | 1 + development/docker-compose.yaml | 12 +- development/start.sh | 11 + pyproject.toml | 1 + requirements/base.txt | 1 + validity/__init__.py | 15 +- validity/api/helpers.py | 11 + validity/api/serializers.py | 43 +++- validity/api/views.py | 52 ++++- validity/dependencies.py | 64 ++++++ validity/filtersets.py | 13 +- validity/forms/__init__.py | 2 + validity/forms/fields.py | 30 +++ validity/forms/filterset.py | 18 +- validity/forms/general.py | 73 ++++++- validity/managers.py | 42 ++-- validity/migrations/0010_squashed_initial.py | 2 +- validity/migrations/0011_delete_scripts.py | 24 +++ validity/models/polling.py | 13 +- validity/models/report.py | 5 + validity/models/selector.py | 3 + validity/models/test.py | 3 + validity/netbox_changes/current.py | 1 + validity/netbox_changes/old.py | 1 + validity/netbox_changes/oldest.py | 2 + validity/pollers/__init__.py | 3 +- validity/pollers/factory.py | 27 +-- validity/scripts/__init__.py | 4 + validity/scripts/data_models.py | 160 ++++++++++++++ validity/scripts/exceptions.py | 12 ++ validity/scripts/install/validity_scripts.py | 10 - validity/scripts/launch.py | 62 ++++++ validity/scripts/logger.py | 47 +++++ validity/scripts/parent_jobs.py | 38 ++++ validity/scripts/run_tests.py | 193 ----------------- validity/scripts/runtests/__init__.py | 3 + validity/scripts/runtests/apply.py | 169 +++++++++++++++ validity/scripts/runtests/base.py | 43 ++++ validity/scripts/runtests/combine.py | 93 ++++++++ validity/scripts/runtests/split.py | 104 +++++++++ validity/scripts/script_data.py | 128 ----------- validity/scripts/variables.py | 13 -- validity/settings.py | 30 +++ validity/signals.py | 9 + validity/tables.py | 31 ++- validity/template_content.py | 12 +- .../templates/validity/compliancereport.html | 12 ++ .../validity/compliancetestresult.html | 4 + validity/templates/validity/inc/fieldset.html | 11 + .../templates/validity/scripts/result.html | 43 ++++ .../validity/scripts/result_htmx.html | 36 ++++ validity/templates/validity/scripts/run.html | 44 ++++ validity/templatetags/validity.py | 17 ++ validity/tests/conftest.py | 28 +-- validity/tests/factories.py | 29 +++ validity/tests/test_api.py | 20 ++ validity/tests/test_managers.py | 60 +++--- validity/tests/test_scripts/conftest.py | 17 +- .../tests/test_scripts/runtests/test_apply.py | 175 +++++++++++++++ .../test_scripts/runtests/test_combine.py | 96 +++++++++ .../tests/test_scripts/runtests/test_split.py | 112 ++++++++++ .../tests/test_scripts/test_data_models.py | 83 ++++++++ validity/tests/test_scripts/test_launcher.py | 68 ++++++ validity/tests/test_scripts/test_logger.py | 36 ++++ .../tests/test_scripts/test_parent_jobs.py | 32 +++ validity/tests/test_scripts/test_run_tests.py | 199 ------------------ .../tests/test_scripts/test_script_data.py | 33 --- validity/tests/test_views.py | 38 +++- validity/urls.py | 5 +- validity/utils/orm.py | 4 + validity/views/__init__.py | 2 +- validity/views/report.py | 15 +- validity/views/script.py | 90 ++++++++ validity/views/test.py | 12 +- 76 files changed, 2223 insertions(+), 739 deletions(-) create mode 100755 development/start.sh create mode 100644 validity/dependencies.py create mode 100644 validity/migrations/0011_delete_scripts.py create mode 100644 validity/scripts/data_models.py create mode 100644 validity/scripts/exceptions.py delete mode 100644 validity/scripts/install/validity_scripts.py create mode 100644 validity/scripts/launch.py create mode 100644 validity/scripts/logger.py create mode 100644 validity/scripts/parent_jobs.py delete mode 100644 validity/scripts/run_tests.py create mode 100644 validity/scripts/runtests/__init__.py create mode 100644 validity/scripts/runtests/apply.py create mode 100644 validity/scripts/runtests/base.py create mode 100644 validity/scripts/runtests/combine.py create mode 100644 validity/scripts/runtests/split.py delete mode 100644 validity/scripts/script_data.py delete mode 100644 validity/scripts/variables.py create mode 100644 validity/settings.py create mode 100644 validity/signals.py create mode 100644 validity/templates/validity/inc/fieldset.html create mode 100644 validity/templates/validity/scripts/result.html create mode 100644 validity/templates/validity/scripts/result_htmx.html create mode 100644 validity/templates/validity/scripts/run.html create mode 100644 validity/tests/test_scripts/runtests/test_apply.py create mode 100644 validity/tests/test_scripts/runtests/test_combine.py create mode 100644 validity/tests/test_scripts/runtests/test_split.py create mode 100644 validity/tests/test_scripts/test_data_models.py create mode 100644 validity/tests/test_scripts/test_launcher.py create mode 100644 validity/tests/test_scripts/test_logger.py create mode 100644 validity/tests/test_scripts/test_parent_jobs.py delete mode 100644 validity/tests/test_scripts/test_run_tests.py delete mode 100644 validity/tests/test_scripts/test_script_data.py create mode 100644 validity/views/script.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ee55909..53d3c1f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -21,8 +21,9 @@ jobs: test: runs-on: ubuntu-latest strategy: + fail-fast: false matrix: - netbox_version: [v3.6.9, v3.7.8, v4.0.7] + netbox_version: [v3.7.8, v4.0.11] steps: - name: Checkout uses: actions/checkout@v3 @@ -58,8 +59,9 @@ jobs: test_migrations: runs-on: ubuntu-latest strategy: + fail-fast: false matrix: - netbox_version: [v3.6.9, v3.7.8, v4.0.2] + netbox_version: [v3.7.8, v4.0.11] steps: - name: Checkout uses: actions/checkout@v3 diff --git a/development/.env.example b/development/.env.example index 4e95107..d1a458b 100644 --- a/development/.env.example +++ b/development/.env.example @@ -6,3 +6,4 @@ REDIS_PASSWORD=redis SECRET_KEY=SOME_ARBITRARY_LONG_ENOUGH_DJANGO_SECRET_KEY_STRING COMPOSE_PROJECT_NAME=validity DEBUGWEB=0 +DEBUGWORKER=0 diff --git a/development/Dockerfile b/development/Dockerfile index cc1b7ea..dfbfeb1 100644 --- a/development/Dockerfile +++ b/development/Dockerfile @@ -16,6 +16,7 @@ RUN mkdir -p /opt/netbox \ # Install Validity COPY . /plugin/validity +COPY ./development/start.sh /opt/netbox/netbox/ RUN pip install --editable /plugin/validity[dev] WORKDIR /opt/netbox/netbox/ diff --git a/development/docker-compose.yaml b/development/docker-compose.yaml index fc6967c..a83b80a 100644 --- a/development/docker-compose.yaml +++ b/development/docker-compose.yaml @@ -5,7 +5,9 @@ services: dockerfile: ./development/Dockerfile args: NETBOX_VERSION: ${NETBOX_VERSION} - command: sh -c "python manage.py rqworker" + command: ./start.sh $DEBUGWORKER manage.py rqworker + ports: + - "5679:5678" depends_on: - postgres - redis @@ -24,13 +26,7 @@ services: netbox: <<: *worker - command: > - bash -c " - if [[ $DEBUGWEB == 1 ]]; then - python -m debugpy --listen 0.0.0.0:5678 manage.py runserver 0.0.0.0:8000; - else - python manage.py runserver 0.0.0.0:8000; - fi" + command: ./start.sh $DEBUGWEB manage.py runserver 0.0.0.0:8000 ports: - "8000:8000" - "5678:5678" diff --git a/development/start.sh b/development/start.sh new file mode 100755 index 0000000..dd2c2a3 --- /dev/null +++ b/development/start.sh @@ -0,0 +1,11 @@ +#!/bin/bash + +DEBUG=$1 +shift + +if [[ $DEBUG == 1 ]]; then + echo "!!! DEBUGGING IS ENABLED !!!" + python -m debugpy --listen 0.0.0.0:5678 $@ +else + python $@ +fi diff --git a/pyproject.toml b/pyproject.toml index 4bbdb90..cba5d5f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,6 +60,7 @@ branch = true omit = [ "validity/tests/*", "validity/migrations/*", + "validity/dependencies.py", ] source = ["validity"] diff --git a/requirements/base.txt b/requirements/base.txt index 6efa658..545240e 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -1,4 +1,5 @@ deepdiff>=6.2.0,<7 +dimi < 2 django-bootstrap5 >=24.2,<25 dulwich # Core NetBox "optional" requirement jq>=1.4.0,<2 diff --git a/validity/__init__.py b/validity/__init__.py index d29b10a..f0e9b0f 100644 --- a/validity/__init__.py +++ b/validity/__init__.py @@ -1,8 +1,7 @@ import logging -from django.conf import settings as django_settings +from dimi import Container from netbox.settings import VERSION -from pydantic import BaseModel, Field from validity.utils.version import NetboxVersion @@ -31,7 +30,7 @@ class NetBoxValidityConfig(PluginConfig): netbox_version = NetboxVersion(VERSION) def ready(self): - import validity.data_backends + from validity import data_backends, dependencies, signals return super().ready() @@ -39,12 +38,4 @@ def ready(self): config = NetBoxValidityConfig -class ValiditySettings(BaseModel): - store_last_results: int = Field(default=5, gt=0, lt=1001) - store_reports: int = Field(default=5, gt=0, lt=1001) - sleep_between_tests: float = 0 - result_batch_size: int = Field(default=500, ge=1) - polling_threads: int = Field(default=500, ge=1) - - -settings = ValiditySettings.model_validate(django_settings.PLUGINS_CONFIG.get("validity", {})) +di = Container() diff --git a/validity/api/helpers.py b/validity/api/helpers.py index 564931a..7c9e798 100644 --- a/validity/api/helpers.py +++ b/validity/api/helpers.py @@ -5,6 +5,7 @@ from django.core.exceptions import ValidationError from django.db.models import ManyToManyField from netbox.api.serializers import WritableNestedSerializer +from rest_framework.relations import PrimaryKeyRelatedField from rest_framework.serializers import JSONField, ModelSerializer from validity import NetboxVersion @@ -94,3 +95,13 @@ def validate(self, attrs): ] raise ValidationError({instance.subform_json_field: errors}) return attrs + + +class PrimaryKeyField(PrimaryKeyRelatedField): + """ + Returns primary key only instead of the whole model instance + """ + + def to_internal_value(self, data): + obj = super().to_internal_value(data) + return obj.pk diff --git a/validity/api/serializers.py b/validity/api/serializers.py index d56c664..bb3a26e 100644 --- a/validity/api/serializers.py +++ b/validity/api/serializers.py @@ -1,4 +1,6 @@ from core.api.nested_serializers import NestedDataFileSerializer, NestedDataSourceSerializer +from core.api.serializers import JobSerializer +from core.models import DataSource from dcim.api.nested_serializers import ( NestedDeviceSerializer, NestedDeviceTypeSerializer, @@ -7,7 +9,9 @@ NestedPlatformSerializer, NestedSiteSerializer, ) -from dcim.models import DeviceType, Location, Manufacturer, Platform, Site +from dcim.models import Device, DeviceType, Location, Manufacturer, Platform, Site +from django.utils import timezone +from django.utils.translation import gettext_lazy as _ from extras.api.nested_serializers import NestedTagSerializer from extras.models import Tag from netbox.api.fields import SerializedPKRelatedField @@ -18,7 +22,15 @@ from tenancy.models import Tenant from validity import config, models -from .helpers import EncryptedDictField, FieldsMixin, ListQPMixin, SubformValidationMixin, nested_factory +from validity.choices import ExplanationVerbosityChoices +from .helpers import ( + EncryptedDictField, + FieldsMixin, + ListQPMixin, + PrimaryKeyField, + SubformValidationMixin, + nested_factory, +) class ComplianceSelectorSerializer(NetBoxModelSerializer): @@ -370,3 +382,30 @@ def to_representation(self, instance): if name_filter := self.get_list_param("name"): instance = [item for item in instance if item.name in set(name_filter)] return super().to_representation(instance) + + +class RunTestsSerializer(serializers.Serializer): + sync_datasources = serializers.BooleanField(required=False) + selectors = PrimaryKeyField( + many=True, + required=False, + queryset=models.ComplianceSelector.objects.all(), + ) + devices = PrimaryKeyField(many=True, required=False, queryset=Device.objects.all()) + test_tags = PrimaryKeyField(many=True, required=False, queryset=Tag.objects.all()) + explanation_verbosity = serializers.ChoiceField( + choices=ExplanationVerbosityChoices.choices, required=False, default=ExplanationVerbosityChoices.maximum + ) + overriding_datasource = PrimaryKeyField(required=False, queryset=DataSource.objects.all()) + workers_num = serializers.IntegerField(min_value=1, default=1) + schedule_at = serializers.DateTimeField(required=False, allow_null=True) + schedule_interval = serializers.IntegerField(required=False, allow_null=True) + + def validate_schedule_at(self, value): + if value and value < timezone.now(): + raise serializers.ValidationError(_("Scheduled time must be in the future.")) + return value + + +class ScriptResultSerializer(serializers.Serializer): + result = JobSerializer(read_only=True) diff --git a/validity/api/views.py b/validity/api/views.py index 22c0c2d..5e2251c 100644 --- a/validity/api/views.py +++ b/validity/api/views.py @@ -1,17 +1,39 @@ -from drf_spectacular.utils import OpenApiParameter, extend_schema -from netbox.api.viewsets import NetBoxModelViewSet +from http import HTTPStatus +from typing import Annotated, Any + +from drf_spectacular.utils import OpenApiParameter, extend_schema, extend_schema_view +from netbox.api.viewsets import NetBoxModelViewSet, NetBoxReadOnlyModelViewSet +from rest_framework.decorators import action from rest_framework.exceptions import NotFound from rest_framework.generics import ListAPIView from rest_framework.response import Response +from rest_framework.serializers import Serializer from rest_framework.views import APIView -from validity import filtersets, models +from validity import di, filtersets, models from validity.choices import SeverityChoices +from validity.scripts import Launcher, RunTestsParams, ScriptParams from . import serializers -class ReadOnlyNetboxViewSet(NetBoxModelViewSet): - http_method_names = ["get", "head", "options", "trace"] +class RunMixin: + run_serializer_class: type[Serializer] + params_class: type[ScriptParams] + launcher: Launcher + + def get_params(self, serializer, request): + return self.params_class(**serializer.validated_data, request=request) + + def get_result_data(self, job, request): + serializer = serializers.ScriptResultSerializer({"result": job}, context={"request": request}) + return serializer.data + + def run(self, request): + serializer = self.run_serializer_class(data=request.data) + if not serializer.is_valid(): + return Response(status=HTTPStatus.BAD_REQUEST, data=serializer.errors) + job = self.launcher(self.get_params(serializer, request)) + return Response(self.get_result_data(job, request)) class ComplianceSelectorViewSet(NetBoxModelViewSet): @@ -29,15 +51,27 @@ class ComplianceSelectorViewSet(NetBoxModelViewSet): filterset_class = filtersets.ComplianceSelectorFilterSet -class ComplianceTestViewSet(NetBoxModelViewSet): +@extend_schema_view(run=extend_schema(request=serializers.RunTestsSerializer)) +class ComplianceTestViewSet(RunMixin, NetBoxModelViewSet): queryset = models.ComplianceTest.objects.select_related("data_source", "data_file").prefetch_related( "selectors", "tags" ) serializer_class = serializers.ComplianceTestSerializer filterset_class = filtersets.ComplianceTestFilterSet + run_serializer_class = serializers.RunTestsSerializer + params_class = RunTestsParams + + @di.inject + def __init__(self, launcher: Annotated[Launcher, "runtests_launcher"], **kwargs: Any) -> None: + self.launcher = launcher + super().__init__(**kwargs) + + @action(detail=False, methods=["post"], url_path="run") + def run(self, request): + return super().run(request) -class ComplianceTestResultViewSet(ReadOnlyNetboxViewSet): +class ComplianceTestResultViewSet(NetBoxReadOnlyModelViewSet): queryset = models.ComplianceTestResult.objects.select_related("device", "test", "report") serializer_class = serializers.ComplianceTestResultSerializer filterset_class = filtersets.ComplianceTestResultFilterSet @@ -55,10 +89,10 @@ class NameSetViewSet(NetBoxModelViewSet): filterset_class = filtersets.NameSetFilterSet -class ComplianceReportViewSet(NetBoxModelViewSet): +class ComplianceReportViewSet(NetBoxReadOnlyModelViewSet): queryset = models.ComplianceReport.objects.annotate_result_stats().count_devices_and_tests() serializer_class = serializers.ComplianceReportSerializer - http_method_names = ["get", "head", "options", "trace", "delete"] + filterset_class = filtersets.ComplianceReportFilterSet class PollerViewSet(NetBoxModelViewSet): diff --git a/validity/dependencies.py b/validity/dependencies.py new file mode 100644 index 0000000..71f1552 --- /dev/null +++ b/validity/dependencies.py @@ -0,0 +1,64 @@ +from typing import Annotated + +import django_rq +from dimi.scopes import Singleton +from django.conf import LazySettings, settings +from utilities.rqworker import get_workers_for_queue + +from validity import di +from validity.choices import ConnectionTypeChoices +from validity.pollers import NetmikoPoller, RequestsPoller, ScrapliNetconfPoller +from validity.settings import ValiditySettings +from validity.utils.misc import null_request + + +@di.dependency +def django_settings(): + return settings + + +@di.dependency(scope=Singleton) +def validity_settings(django_settings: Annotated[LazySettings, django_settings]): + return ValiditySettings.model_validate(django_settings.PLUGINS_CONFIG.get("validity", {})) + + +@di.dependency(scope=Singleton) +def poller_map(): + return { + ConnectionTypeChoices.netmiko: NetmikoPoller, + ConnectionTypeChoices.requests: RequestsPoller, + ConnectionTypeChoices.scrapli_netconf: ScrapliNetconfPoller, + } + + +from validity.scripts import ApplyWorker, CombineWorker, Launcher, SplitWorker, Task # noqa + + +@di.dependency +def runtests_worker_count(vsettings: Annotated[ValiditySettings, validity_settings]) -> int: + return get_workers_for_queue(vsettings.runtests_queue) + + +@di.dependency(scope=Singleton) +def runtests_launcher( + vsettings: Annotated[ValiditySettings, validity_settings], + split_worker: Annotated[SplitWorker, ...], + apply_worker: Annotated[ApplyWorker, ...], + combine_worker: Annotated[CombineWorker, ...], +): + from validity.models import ComplianceReport + + return Launcher( + job_name="RunTests", + job_object_factory=null_request()(ComplianceReport.objects.create), + rq_queue=django_rq.get_queue(vsettings.runtests_queue), + tasks=[ + Task(split_worker, job_timeout=vsettings.script_timeouts.runtests_split), + Task( + apply_worker, + job_timeout=vsettings.script_timeouts.runtests_apply, + multi_workers=True, + ), + Task(combine_worker, job_timeout=vsettings.script_timeouts.runtests_combine), + ], + ) diff --git a/validity/filtersets.py b/validity/filtersets.py index 182eac7..01a02eb 100644 --- a/validity/filtersets.py +++ b/validity/filtersets.py @@ -2,12 +2,14 @@ from functools import reduce from typing import Sequence +from core.choices import JobStatusChoices +from core.models import Job from dcim.filtersets import DeviceFilterSet from dcim.models import Device, DeviceRole, DeviceType, Location, Manufacturer, Platform, Site from django.db.models import Q from django_filters import BooleanFilter, ChoiceFilter, ModelMultipleChoiceFilter from extras.models import Tag -from netbox.filtersets import NetBoxModelFilterSet +from netbox.filtersets import ChangeLoggedModelFilterSet, NetBoxModelFilterSet from tenancy.models import Tenant from validity import models @@ -112,6 +114,15 @@ class DeviceReportFilterSet(DeviceFilterSet): compliance_passed = BooleanFilter() +class ComplianceReportFilterSet(ChangeLoggedModelFilterSet): + job_status = ChoiceFilter(field_name="jobs__status", choices=JobStatusChoices) + job_id = ModelMultipleChoiceFilter(field_name="jobs", queryset=Job.objects.all()) + + class Meta: + model = models.ComplianceReport + fields = ("id", "job_id", "job_status", "created") + + class PollerFilterSet(SearchMixin, NetBoxModelFilterSet): class Meta: model = models.Poller diff --git a/validity/forms/__init__.py b/validity/forms/__init__.py index a6a1f53..f67024c 100644 --- a/validity/forms/__init__.py +++ b/validity/forms/__init__.py @@ -1,5 +1,6 @@ from .filterset import ( CommandFilterForm, + ComplianceReportFilerForm, ComplianceSelectorFilterForm, ComplianceTestFilterForm, ComplianceTestResultFilterForm, @@ -18,5 +19,6 @@ ComplianceTestForm, NameSetForm, PollerForm, + RunTestsForm, SerializerForm, ) diff --git a/validity/forms/fields.py b/validity/forms/fields.py index b4f9065..3b78170 100644 --- a/validity/forms/fields.py +++ b/validity/forms/fields.py @@ -1,6 +1,9 @@ +import operator +from abc import ABC, abstractmethod from typing import Any from django.forms import ChoiceField, JSONField +from utilities.forms.fields import DynamicModelChoiceField, DynamicModelMultipleChoiceField from validity.fields import EncryptedDict from .widgets import SelectWithPlaceholder @@ -27,3 +30,30 @@ def __init__(self, *, placeholder: str | None = None, **kwargs) -> None: kwargs["choices"] = (("", placeholder),) + tuple(kwargs["choices"]) kwargs["widget"] = SelectWithPlaceholder() super().__init__(**kwargs) + + +class ModelPropertyMixin(ABC): + """ + Supplies model's field (property) instead of model itself + """ + + def __init__(self, *args, property_name: str = "pk", **kwargs): + super().__init__(*args, **kwargs) + self.property_name = property_name + + def clean(self, value): + val = super().clean(value) + return self.extract_property(val) if val is not None else None + + @abstractmethod + def extract_property(self, value): ... + + +class DynamicModelChoicePropertyField(ModelPropertyMixin, DynamicModelChoiceField): + def extract_property(self, value): + return operator.attrgetter(self.property_name)(value) + + +class DynamicModelMultipleChoicePropertyField(ModelPropertyMixin, DynamicModelMultipleChoiceField): + def extract_property(self, value): + return [operator.attrgetter(self.property_name)(item) for item in value] diff --git a/validity/forms/filterset.py b/validity/forms/filterset.py index fc83b2f..1be8a17 100644 --- a/validity/forms/filterset.py +++ b/validity/forms/filterset.py @@ -1,12 +1,14 @@ -from core.models import DataSource +from core.choices import JobStatusChoices +from core.models import DataSource, Job from dcim.models import Device, DeviceRole, DeviceType, Location, Manufacturer, Platform, Site -from django.forms import CharField, Form, NullBooleanField, Select +from django.forms import CharField, DateTimeField, Form, NullBooleanField, Select from django.utils.translation import gettext_lazy as _ from extras.models import Tag from netbox.forms import NetBoxModelFilterSetForm from tenancy.models import Tenant -from utilities.forms import BOOLEAN_WITH_BLANK_CHOICES +from utilities.forms import BOOLEAN_WITH_BLANK_CHOICES, FilterForm from utilities.forms.fields import DynamicModelMultipleChoiceField +from utilities.forms.widgets import DateTimePicker from validity import models from validity.choices import ( @@ -18,7 +20,7 @@ ExtractionMethodChoices, SeverityChoices, ) -from validity.netbox_changes import FieldSet +from validity.netbox_changes import FieldSet, SavedFiltersMixin from .fields import PlaceholderChoiceField from .mixins import AddM2MPlaceholderFormMixin, ExcludeMixin @@ -160,6 +162,14 @@ class ComplianceTestFilterForm(NetBoxModelFilterSetForm): ) +class ComplianceReportFilerForm(SavedFiltersMixin, FilterForm): + model = models.ComplianceReport + job_id = DynamicModelMultipleChoiceField(required=False, label=_("Job ID"), queryset=Job.objects.all()) + job_status = PlaceholderChoiceField(required=False, label=_("Job Status"), choices=JobStatusChoices) + created__lte = DateTimeField(required=False, widget=DateTimePicker(), label=_("Created Before")) + created__gte = DateTimeField(required=False, widget=DateTimePicker(), label=_("Created After")) + + class PollerFilterForm(NetBoxModelFilterSetForm): model = models.Poller name = CharField(required=False) diff --git a/validity/forms/general.py b/validity/forms/general.py index cd6e372..6e715d4 100644 --- a/validity/forms/general.py +++ b/validity/forms/general.py @@ -1,7 +1,9 @@ from core.forms.mixins import SyncedDataMixin -from dcim.models import DeviceType, Location, Manufacturer, Platform, Site -from django.forms import CharField, ChoiceField, Select, Textarea, ValidationError +from core.models import DataSource +from dcim.models import Device, DeviceType, Location, Manufacturer, Platform, Site +from django.forms import BooleanField, CharField, ChoiceField, IntegerField, Select, Textarea, ValidationError from django.utils.translation import gettext_lazy as _ +from extras.forms import ScriptForm from extras.models import Tag from netbox.forms import NetBoxModelForm from tenancy.models import Tenant @@ -10,8 +12,9 @@ from utilities.forms.widgets import HTMXSelect from validity import models -from validity.choices import ConnectionTypeChoices +from validity.choices import ConnectionTypeChoices, ExplanationVerbosityChoices from validity.netbox_changes import FieldSet +from .fields import DynamicModelChoicePropertyField, DynamicModelMultipleChoicePropertyField from .mixins import SubformMixin from .widgets import PrettyJSONWidget @@ -161,3 +164,67 @@ class Meta: model = models.Command fields = ("name", "label", "type", "retrieves_config", "serializer", "tags") widgets = {"type": HTMXSelect()} + + +class RunTestsForm(ScriptForm): + sync_datasources = BooleanField( + required=False, + label=_("Sync Data Sources"), + help_text=_("Sync all referenced Data Sources"), + ) + selectors = DynamicModelMultipleChoicePropertyField( + queryset=models.ComplianceSelector.objects.all(), + required=False, + label=_("Specific Selectors"), + help_text=_("Run the tests only for specific selectors"), + ) + devices = DynamicModelMultipleChoicePropertyField( + queryset=Device.objects.all(), + required=False, + label=_("Specific Devices"), + help_text=_("Run the tests only for specific devices"), + ) + test_tags = DynamicModelMultipleChoicePropertyField( + queryset=Tag.objects.all(), + required=False, + label=_("Specific Test Tags"), + help_text=_("Run the tests which contain specific tags only"), + ) + explanation_verbosity = ChoiceField( + choices=ExplanationVerbosityChoices.choices, + initial=ExplanationVerbosityChoices.maximum, + help_text=_("Explanation Verbosity Level"), + required=False, + ) + overriding_datasource = DynamicModelChoicePropertyField( + queryset=DataSource.objects.all(), + required=False, + label=_("Override DataSource"), + help_text=_("Find all devices state/config data in this Data Source instead of bound ones"), + ) + workers_num = IntegerField( + min_value=1, + initial=1, + required=False, + label=_("Number of Workers"), + help_text=_("Speed up tests execution by splitting the work among multiple RQ workers"), + ) + _commit = None # remove this Field from ScriptForm + + fieldsets = ( + FieldSet( + "sync_datasources", + "selectors", + "devices", + "test_tags", + "explanation_verbosity", + "workers_num", + "overriding_datasource", + name=_("Main Parameters"), + ), + FieldSet("_schedule_at", "_interval", name=_("Postponed Execution")), + ) + + def clean(self): + schedule_at = self.cleaned_data.get("_schedule_at") + return super().clean() | {"_schedule_at": schedule_at} diff --git a/validity/managers.py b/validity/managers.py index 7994f36..5225c32 100644 --- a/validity/managers.py +++ b/validity/managers.py @@ -1,6 +1,8 @@ from functools import partialmethod from itertools import chain +from core.models import Job +from django.contrib.contenttypes.models import ContentType from django.contrib.postgres.aggregates import ArrayAgg from django.db.models import ( BigIntegerField, @@ -10,6 +12,7 @@ ExpressionWrapper, F, FloatField, + ManyToManyField, Prefetch, Q, Value, @@ -19,8 +22,8 @@ from django.db.models.functions import Cast from netbox.models import RestrictedQuerySet -from validity import settings from validity.choices import DeviceGroupByChoices, SeverityChoices +from validity.settings import ValiditySettingsMixin from validity.utils.orm import CustomPrefetchMixin, SetAttributesMixin @@ -47,7 +50,7 @@ def annotate_latest_count(self): ) -class ComplianceTestResultQS(RestrictedQuerySet): +class ComplianceTestResultQS(ValiditySettingsMixin, RestrictedQuerySet): def only_latest(self, exclude: bool = False) -> "ComplianceTestResultQS": qs = self.order_by("test__pk", "device__pk", "-created").distinct("test__pk", "device__pk") if exclude: @@ -62,9 +65,8 @@ def last_more_than(self, than: int) -> "ComplianceTestResultQS": def count_devices_and_tests(self): return self.aggregate(device_count=Count("devices", distinct=True), test_count=Count("tests", distinct=True)) - def delete_old(self, _settings=settings): - del_count = self.filter(report=None).last_more_than(_settings.store_last_results)._raw_delete(self.db) - return (del_count, {"validity.ComplianceTestResult": del_count}) + def raw_delete(self): + return self._raw_delete(self.db) def percentage(field1: str, field2: str) -> Case: @@ -90,7 +92,7 @@ def annotate_paths(self): return self.annotate_config_path().annotate_command_path() -class ComplianceReportQS(RestrictedQuerySet): +class ComplianceReportQS(ValiditySettingsMixin, RestrictedQuerySet): def annotate_result_stats(self, groupby_field: DeviceGroupByChoices | None = None): qs = self if groupby_field: @@ -118,15 +120,21 @@ def count_devices_and_tests(self): device_count=Count("results__device", distinct=True), test_count=Count("results__test", distinct=True) ) - def delete_old(self, _settings=settings): - from validity.models import ComplianceTestResult + def delete_old(self): + from validity.models import ComplianceReport, ComplianceTestResult - old_reports = list(self.order_by("-created").values_list("pk", flat=True)[_settings.store_reports :]) - deleted_results = ComplianceTestResult.objects.filter(report__pk__in=old_reports)._raw_delete(self.db) + old_reports = list(self.order_by("-created").values_list("pk", flat=True)[self.v_settings.store_reports :]) + deleted_results = ComplianceTestResult.objects.filter(report__pk__in=old_reports).raw_delete() + report_content_type = ContentType.objects.get_for_model(ComplianceReport) + deleted_jobs = Job.objects.filter(object_id__in=old_reports, object_type=report_content_type).delete() deleted_reports, _ = self.filter(pk__in=old_reports).delete() return ( - deleted_results + deleted_reports, - {"validity.ComplianceTestResult": deleted_results, "validity.ComplianceReport": deleted_reports}, + deleted_results + deleted_reports + deleted_reports, + { + "validity.ComplianceTestResult": deleted_results, + "validity.ComplianceReport": deleted_reports, + "core.Job": deleted_jobs, + }, ) @@ -267,3 +275,13 @@ def bind_attributes(self, instance): instance.path = path super().bind_attributes(instance) self._aux_attributes = initial_attrs + + +class ComplianceSelectorQS(RestrictedQuerySet): + def prefetch_filters(self): + filter_fields = ( + field.name + for field in self.model._meta.get_fields() + if isinstance(field, ManyToManyField) and field.name.endswith("_filter") + ) + return self.prefetch_related(*filter_fields) diff --git a/validity/migrations/0010_squashed_initial.py b/validity/migrations/0010_squashed_initial.py index 0e46680..9dd30aa 100644 --- a/validity/migrations/0010_squashed_initial.py +++ b/validity/migrations/0010_squashed_initial.py @@ -289,6 +289,7 @@ class Migration(migrations.Migration): options={ "abstract": False, "ordering": ("name",), + "permissions": [("run", "Can run compliance test")], }, bases=(validity.models.base.URLMixin, models.Model), ), @@ -512,5 +513,4 @@ class Migration(migrations.Migration): ), migrations.RunPython(create_cf, delete_cf), migrations.RunPython(create_polling_datasource, delete_polling_datasource), - migrations.RunPython(setup_scripts, delete_scripts), ] diff --git a/validity/migrations/0011_delete_scripts.py b/validity/migrations/0011_delete_scripts.py new file mode 100644 index 0000000..884f94e --- /dev/null +++ b/validity/migrations/0011_delete_scripts.py @@ -0,0 +1,24 @@ +from django.db import migrations + + +DATASOURCE_NAME = "validity_scripts" + + +def delete_scripts(apps, schema_editor): + """ + Delete DataSource and ScriptModule used for validity v1/v2 + """ + DataSource = apps.get_model("core", "DataSource") + ScriptModule = apps.get_model("extras", "ScriptModule") + db_alias = schema_editor.connection.alias + ScriptModule.objects.using(db_alias).filter(data_source__name=DATASOURCE_NAME).delete() + DataSource.objects.using(db_alias).filter(name=DATASOURCE_NAME).delete() + + +class Migration(migrations.Migration): + dependencies = [ + ("validity", "0010_squashed_initial"), + ] + operations = [ + migrations.RunPython(delete_scripts, migrations.RunPython.noop), + ] diff --git a/validity/models/polling.py b/validity/models/polling.py index 2761dd9..d9a21e6 100644 --- a/validity/models/polling.py +++ b/validity/models/polling.py @@ -1,5 +1,5 @@ from functools import cached_property -from typing import Collection +from typing import TYPE_CHECKING, Annotated, Collection from dcim.models import Device from django.core.exceptions import ValidationError @@ -7,15 +7,19 @@ from django.db import models from django.utils.translation import gettext_lazy as _ +from validity import di from validity.choices import CommandTypeChoices, ConnectionTypeChoices from validity.fields import EncryptedDictField from validity.managers import CommandQS, PollerQS -from validity.pollers import get_poller from validity.subforms import CLICommandForm, JSONAPICommandForm, NetconfCommandForm from .base import BaseModel, SubformMixin from .serializer import Serializer +if TYPE_CHECKING: + from validity.pollers.factory import PollerFactory + + class Command(SubformMixin, BaseModel): name = models.CharField(_("Name"), max_length=255, unique=True) label = models.CharField( @@ -110,8 +114,9 @@ def config_command(self) -> Command | None: """ return next((cmd for cmd in self.commands.all() if cmd.retrieves_config), None) - def get_backend(self): - return get_poller(self.connection_type, self.credentials, self.commands.all()) + @di.inject + def get_backend(self, poller_factory: Annotated["PollerFactory", ...]): + return poller_factory(self.connection_type, self.credentials, self.commands.all()) @staticmethod def validate_commands(connection_type: str, commands: Collection[Command]): diff --git a/validity/models/report.py b/validity/models/report.py index 0386383..ee9e611 100644 --- a/validity/models/report.py +++ b/validity/models/report.py @@ -1,3 +1,5 @@ +from core.models import Job +from django.contrib.contenttypes.fields import GenericRelation from netbox.models import ChangeLoggingMixin from validity.managers import ComplianceReportQS @@ -5,7 +7,10 @@ class ComplianceReport(ChangeLoggingMixin, BaseReadOnlyModel): + jobs = GenericRelation(Job, content_type_field="object_type") + objects = ComplianceReportQS.as_manager() + run_view = "plugins:validity:compliancetest_run" class Meta: ordering = ("-created",) diff --git a/validity/models/selector.py b/validity/models/selector.py index f49e602..b252b83 100644 --- a/validity/models/selector.py +++ b/validity/models/selector.py @@ -13,6 +13,7 @@ from validity.choices import BoolOperationChoices, DynamicPairsChoices from validity.compliance.dynamic_pairs import DynamicPairNameFilter, dpf_factory +from validity.managers import ComplianceSelectorQS from validity.utils.misc import reraise from .base import BaseModel from .device import VDevice @@ -42,6 +43,8 @@ class ComplianceSelector(BaseModel): ) dp_tag_prefix = models.CharField(_("Dynamic Pair Tag Prefix"), max_length=255, blank=True) + objects = ComplianceSelectorQS.as_manager() + clone_fields = ( "filter_operation", "name_filter", diff --git a/validity/models/test.py b/validity/models/test.py index e124842..668b9be 100644 --- a/validity/models/test.py +++ b/validity/models/test.py @@ -29,6 +29,9 @@ class ComplianceTest(DataSourceMixin, BaseModel): class Meta: ordering = ("name",) + permissions = [ + ("run", "Can run compliance test"), + ] def clean(self): super().clean() diff --git a/validity/netbox_changes/current.py b/validity/netbox_changes/current.py index d08529b..d6533fb 100644 --- a/validity/netbox_changes/current.py +++ b/validity/netbox_changes/current.py @@ -10,6 +10,7 @@ PluginTemplateExtension = __locate("netbox.plugins.PluginTemplateExtension") CF_OBJ_TYPE = "related_object_type" CF_CONTENT_TYPES = "object_types" +htmx_partial = __locate("utilities.htmx.htmx_partial") class BootstrapMixin: diff --git a/validity/netbox_changes/old.py b/validity/netbox_changes/old.py index a575d22..1bb0e05 100644 --- a/validity/netbox_changes/old.py +++ b/validity/netbox_changes/old.py @@ -7,3 +7,4 @@ enqueue_object = __locate("extras.events.enqueue_object") events_queue = __locate("netbox.context.events_queue") EventRulesMixin = __locate("netbox.models.features.EventRulesMixin") +SavedFiltersMixin = __locate("netbox.forms.mixins.SavedFiltersMixin") diff --git a/validity/netbox_changes/oldest.py b/validity/netbox_changes/oldest.py index c626c3a..afba381 100644 --- a/validity/netbox_changes/oldest.py +++ b/validity/netbox_changes/oldest.py @@ -6,9 +6,11 @@ events_queue = __locate("netbox.context.webhooks_queue") EventRulesMixin = __locate("netbox.models.features.WebhooksMixin") BootstrapMixin = __locate("utilities.forms.BootstrapMixin") +SavedFiltersMixin = __locate("extras.forms.mixins.SavedFiltersMixin") plugins = __locate("extras.plugins") ButtonColorChoices = __locate("utilities.choices.ButtonColorChoices") PluginTemplateExtension = __locate("extras.plugins.PluginTemplateExtension") +htmx_partial = __locate("utilities.htmx.is_htmx") class FieldSet: diff --git a/validity/pollers/__init__.py b/validity/pollers/__init__.py index 5eb0350..890f508 100644 --- a/validity/pollers/__init__.py +++ b/validity/pollers/__init__.py @@ -1,2 +1,3 @@ from .cli import NetmikoPoller -from .factory import get_poller +from .http import RequestsPoller +from .netconf import ScrapliNetconfPoller diff --git a/validity/pollers/factory.py b/validity/pollers/factory.py index 4aee467..a4083a3 100644 --- a/validity/pollers/factory.py +++ b/validity/pollers/factory.py @@ -1,20 +1,23 @@ from functools import partial -from typing import TYPE_CHECKING, Sequence +from typing import TYPE_CHECKING, Annotated, Sequence -from validity import settings -from validity.choices import ConnectionTypeChoices +from dimi import Singleton + +from validity import di from .base import DevicePoller, ThreadPoller -from .cli import NetmikoPoller -from .http import RequestsPoller -from .netconf import ScrapliNetconfPoller if TYPE_CHECKING: from validity.models import Command +@di.dependency(scope=Singleton) class PollerFactory: - def __init__(self, poller_map: dict, max_threads: int) -> None: + def __init__( + self, + poller_map: Annotated[dict, "poller_map"], + max_threads: Annotated[int, "validity_settings.polling_threads"], + ) -> None: self.poller_map = poller_map self.max_threads = max_threads @@ -24,13 +27,3 @@ def __call__(self, connection_type: str, credentials: dict, commands: Sequence[" poller_cls = partial(poller_cls, thread_workers=self.max_threads) return poller_cls(credentials=credentials, commands=commands) raise KeyError("No poller exists for this connection type", connection_type) - - -get_poller = PollerFactory( - poller_map={ - ConnectionTypeChoices.netmiko: NetmikoPoller, - ConnectionTypeChoices.requests: RequestsPoller, - ConnectionTypeChoices.scrapli_netconf: ScrapliNetconfPoller, - }, - max_threads=settings.polling_threads, -) diff --git a/validity/scripts/__init__.py b/validity/scripts/__init__.py index e69de29..902a651 100644 --- a/validity/scripts/__init__.py +++ b/validity/scripts/__init__.py @@ -0,0 +1,4 @@ +from .data_models import RunTestsParams, ScriptParams, Task +from .launch import Launcher +from .logger import Logger +from .runtests import ApplyWorker, CombineWorker, SplitWorker diff --git a/validity/scripts/data_models.py b/validity/scripts/data_models.py new file mode 100644 index 0000000..462eb73 --- /dev/null +++ b/validity/scripts/data_models.py @@ -0,0 +1,160 @@ +import datetime +import operator +from abc import ABC, abstractmethod +from dataclasses import asdict, dataclass, field +from functools import reduce +from typing import Callable, ClassVar, Literal +from uuid import UUID + +from core.models import Job +from django.contrib.auth import get_user_model +from django.db.models import Q, QuerySet +from django.utils import timezone +from pydantic import ConfigDict, Field, field_validator +from pydantic.dataclasses import dataclass as py_dataclass +from rq import Callback + +from validity.models import ComplianceSelector + + +@dataclass(slots=True, frozen=True) +class Message: + status: Literal["debug", "info", "failure", "warning", "success", "default"] + message: str + time: datetime.datetime = field(default_factory=lambda: timezone.now()) + script_id: str | None = None + + @property + def serialized(self) -> dict: + msg = self.message + if self.script_id: + msg = f"{self.script_id}, {msg}" + return {"status": self.status, "message": msg, "time": self.time.isoformat()} + + +@dataclass(slots=True) +class SplitResult: + log: list[Message] + slices: list[dict[int, list[int]]] + + +@dataclass(slots=True, frozen=True) +class TestResultRatio: + passed: int + total: int + + def __add__(self, other): + return type(self)(self.passed + other.passed, self.total + other.total) + + @property + def serialized(self): + return asdict(self) + + +@dataclass(slots=True) +class ExecutionResult: + test_stat: TestResultRatio + log: list[Message] + errored: bool = False + + +@dataclass +class RequestInfo: + """ + Pickleable substitution for Django's HttpRequest + """ + + id: UUID + user_id: int + + user_queryset: ClassVar[QuerySet] = get_user_model().objects.all() + + @classmethod + def from_http_request(cls, request): + return cls(id=request.id, user_id=request.user.pk) + + def get_user(self): + return self.user_queryset.get(pk=self.user_id) + + +@py_dataclass(kw_only=True, config=ConfigDict(arbitrary_types_allowed=True, populate_by_name=True)) +class ScriptParams(ABC): + request: RequestInfo + schedule_at: datetime.datetime | None = Field(default=None, validation_alias="_schedule_at") + schedule_interval: int | None = Field(default=None, validation_alias="_interval") + workers_num: int = 1 + + @field_validator("request", mode="before") + @classmethod + def coerce_request_info(cls, value): + if not isinstance(value, (RequestInfo, dict)): + value = RequestInfo.from_http_request(value) + return value + + @abstractmethod + def with_job_info(self, job: Job) -> "FullScriptParams": ... + + +@py_dataclass(kw_only=True) +class FullScriptParams(ScriptParams): + job_id: int + report_id: int + + job_queryset: ClassVar[QuerySet[Job]] = Job.objects.all() + + def get_job(self): + return self.job_queryset.get(pk=self.job_id) + + +@py_dataclass(kw_only=True) +class RunTestsParams(ScriptParams): + sync_datasources: bool = False + selectors: list[int] = field(default_factory=list) + devices: list[int] = field(default_factory=list) + test_tags: list[int] = field(default_factory=list) + explanation_verbosity: int = 2 + overriding_datasource: int | None = None + + @property + def selector_qs(self) -> QuerySet[ComplianceSelector]: + qs = ( + ComplianceSelector.objects.filter(pk__in=self.selectors) + if self.selectors + else ComplianceSelector.objects.all() + ) + if self.test_tags: + qs = qs.filter(tests__tags__pk__in=self.test_tags).distinct() + return qs + + def get_device_filter(self) -> Q: + selectors = self.selector_qs + if not selectors.exists(): + return Q(pk__in=[]) + filtr = reduce(operator.or_, (selector.filter for selector in selectors.prefetch_filters())) + if self.devices: + filtr &= Q(pk__in=self.devices) + return filtr + + def with_job_info(self, job: Job) -> "FullRunTestsParams": + return FullRunTestsParams(**asdict(self) | {"job_id": job.pk, "report_id": job.object_id}) + + +@py_dataclass(kw_only=True) +class FullRunTestsParams(FullScriptParams, RunTestsParams): + pass + + +@dataclass +class Task: + """ + Represents all the kwargs that can be passed to rq.Queue.enqueue + """ + + func: Callable + job_timeout: int | str + on_failure: Callback | None = None + multi_workers: bool = False + + @property + def as_kwargs(self): + return {"f": self.func, "job_timeout": self.job_timeout, "on_failure": self.on_failure} diff --git a/validity/scripts/exceptions.py b/validity/scripts/exceptions.py new file mode 100644 index 0000000..4d5fc57 --- /dev/null +++ b/validity/scripts/exceptions.py @@ -0,0 +1,12 @@ +from typing import Sequence + +from core.choices import JobStatusChoices + +from .data_models import Message + + +class AbortScript(Exception): + def __init__(self, *args, status: str = JobStatusChoices.STATUS_FAILED, logs: Sequence[Message] = ()) -> None: + self.status = status + self.logs = logs + super().__init__(*args) diff --git a/validity/scripts/install/validity_scripts.py b/validity/scripts/install/validity_scripts.py deleted file mode 100644 index 064b6a7..0000000 --- a/validity/scripts/install/validity_scripts.py +++ /dev/null @@ -1,10 +0,0 @@ -from extras.scripts import Script - -from validity.scripts.run_tests import RunTestsScript - - -class RunTests(RunTestsScript, Script): - pass - - -name = "Validity Scripts" diff --git a/validity/scripts/launch.py b/validity/scripts/launch.py new file mode 100644 index 0000000..1ef944c --- /dev/null +++ b/validity/scripts/launch.py @@ -0,0 +1,62 @@ +import datetime +import uuid +from dataclasses import dataclass +from functools import partial +from typing import Callable + +from core.choices import JobStatusChoices +from core.models import Job +from django.contrib.auth.models import AbstractBaseUser +from django.contrib.contenttypes.models import ContentType +from django.db.models import Model +from rq import Queue + +from .data_models import FullScriptParams, ScriptParams, Task + + +@dataclass +class Launcher: + job_name: str + job_object_factory: Callable[[], Model] + rq_queue: Queue + tasks: list[Task] + + def create_netbox_job( + self, schedule_at: datetime.datetime | None, interval: int | None, user: AbstractBaseUser + ) -> Job: + status = JobStatusChoices.STATUS_SCHEDULED if schedule_at else JobStatusChoices.STATUS_PENDING + obj = self.job_object_factory() + content_type = ContentType.objects.get_for_model(type(obj)) + return Job.objects.create( + object_type=content_type, + object_id=obj.pk, + name=self.job_name, + status=status, + scheduled=schedule_at, + interval=interval, + user=user, + job_id=uuid.uuid4(), + ) + + def enqueue(self, params: FullScriptParams, rq_job_id: uuid.UUID) -> None: + prev_job = None + for task_idx, task in enumerate(self.tasks): + enqueue_fn = ( + partial(self.rq_queue.enqueue_at, params.schedule_at) + if params.schedule_at and task_idx == 0 + else self.rq_queue.enqueue + ) + task_kwargs = task.as_kwargs | {"depends_on": prev_job, "params": params} + if task_idx == len(self.tasks) - 1: + task_kwargs["job_id"] = str(rq_job_id) # job id of the last task matches with the job id from the DB + prev_job = ( + [enqueue_fn(**task_kwargs, worker_id=worker_id) for worker_id in range(params.workers_num)] + if task.multi_workers + else enqueue_fn(**task_kwargs) + ) + + def __call__(self, params: ScriptParams) -> Job: + nb_job = self.create_netbox_job(params.schedule_at, params.schedule_interval, params.request.get_user()) + full_params = params.with_job_info(nb_job) + self.enqueue(full_params, nb_job.job_id) + return nb_job diff --git a/validity/scripts/logger.py b/validity/scripts/logger.py new file mode 100644 index 0000000..98baccf --- /dev/null +++ b/validity/scripts/logger.py @@ -0,0 +1,47 @@ +import logging +import traceback as tb +from functools import partialmethod +from typing import Literal + +from extras.choices import LogLevelChoices + +from .data_models import Message + + +logger = logging.getLogger("validity.scripts") + + +class Logger: + """ + Collects the logs in the format of NetBox Custom Script + """ + + SYSTEM_LEVELS = { + "debug": logging.DEBUG, + "default": logging.INFO, + "success": logging.INFO, + "info": logging.INFO, + "warning": logging.WARNING, + "failure": logging.ERROR, + } + + def __init__(self, script_id: str | None = None) -> None: + self.messages = [] + self.script_id = script_id + + def _log(self, message: str, level: Literal["debug", "info", "failure", "warning", "success", "default"]): + msg = Message(level, message, script_id=self.script_id) + self.messages.append(msg) + logger.log(self.SYSTEM_LEVELS[level], message) + + debug = partialmethod(_log, level="debug") + success = partialmethod(_log, level=LogLevelChoices.LOG_SUCCESS) + info = partialmethod(_log, level=LogLevelChoices.LOG_INFO) + warning = partialmethod(_log, level=LogLevelChoices.LOG_WARNING) + failure = partialmethod(_log, level=LogLevelChoices.LOG_FAILURE) + + def log_exception(self, exc_value, exc_type=None, exc_traceback=None): + exc_traceback = exc_traceback or exc_value.__traceback__ + exc_type = exc_type or type(exc_value) + stacktrace = "".join(tb.format_tb(exc_traceback)) + self.failure(f"Unhandled error occured: `{exc_type}: {exc_value}`\n```\n{stacktrace}\n```") diff --git a/validity/scripts/parent_jobs.py b/validity/scripts/parent_jobs.py new file mode 100644 index 0000000..1e162e3 --- /dev/null +++ b/validity/scripts/parent_jobs.py @@ -0,0 +1,38 @@ +from dataclasses import dataclass, field +from functools import cached_property + +from rq.job import Job, get_current_job + + +@dataclass +class JobExtractor: + nesting_level: int = 0 + _job: Job | None = field(default_factory=get_current_job) + + @property + def nesting_name(self) -> str: + if self.nesting_level == 0: + return "Current" + if self.nesting_level == 1: + return "Parent" + return f"x{self.nesting_level} Parent" + + @property + def job(self) -> Job: + if self._job is None: + raise ValueError(f"{self.nesting_name} Job must not be None") + return self._job + + @cached_property + def parents(self) -> list["JobExtractor"]: + result = [self._get_parent(dep) for dep in self.job.fetch_dependencies()] + if result: + self.__dict__["parent"] = result[0] + return result + + @cached_property + def parent(self) -> "JobExtractor": + return self._get_parent(self.job.dependency) + + def _get_parent(self, dependency: Job | None) -> "JobExtractor": + return type(self)(nesting_level=self.nesting_level + 1, _job=dependency) diff --git a/validity/scripts/run_tests.py b/validity/scripts/run_tests.py deleted file mode 100644 index fa95f1f..0000000 --- a/validity/scripts/run_tests.py +++ /dev/null @@ -1,193 +0,0 @@ -import time -from itertools import chain -from typing import Any, Callable, Generator, Iterable - -import yaml -from core.models import DataSource -from dcim.models import Device -from django.db.models import Prefetch, QuerySet -from django.utils.translation import gettext as __ -from extras.choices import ObjectChangeActionChoices -from extras.models import Tag -from extras.scripts import BooleanVar, MultiObjectVar, ObjectVar - -import validity -from validity.choices import ExplanationVerbosityChoices -from validity.compliance.exceptions import EvalError, SerializationError -from validity.models import ( - ComplianceReport, - ComplianceSelector, - ComplianceTest, - ComplianceTestResult, - NameSet, - VDataSource, - VDevice, -) -from validity.netbox_changes import enqueue_object, events_queue -from validity.utils.misc import datasource_sync, null_request -from .script_data import RunTestsScriptData, ScriptDataMixin -from .variables import VerbosityVar - - -class RunTestsScript(ScriptDataMixin[RunTestsScriptData]): - _sleep_between_tests = validity.settings.sleep_between_tests - _result_batch_size = validity.settings.result_batch_size - - sync_datasources = BooleanVar( - required=False, - default=False, - label=__("Sync Data Sources"), - description=__("Sync all referenced Data Sources"), - ) - make_report = BooleanVar(default=True, label=__("Make Compliance Report")) - selectors = MultiObjectVar( - model=ComplianceSelector, - required=False, - label=__("Specific Selectors"), - description=__("Run the tests only for specific selectors"), - ) - devices = MultiObjectVar( - model=Device, - required=False, - label=__("Specific Devices"), - description=__("Run the tests only for specific devices"), - ) - test_tags = MultiObjectVar( - model=Tag, - required=False, - label=__("Specific Test Tags"), - description=__("Run the tests which contain specific tags only"), - ) - explanation_verbosity = VerbosityVar( - choices=ExplanationVerbosityChoices.choices, - default=ExplanationVerbosityChoices.maximum, - label=__("Explanation Verbosity Level"), - required=False, - ) - override_datasource = ObjectVar( - model=DataSource, - required=False, - label=__("Override DataSource"), - description=__("Find all devices state/config data in this Data Source instead of bound ones"), - ) - - class Meta: - name = __("Run Compliance Tests") - description = __("Execute compliance tests and save the results") - - def __init__(self, datasource_sync_fn: Callable = datasource_sync): - super().__init__() - self.datasource_sync_fn = datasource_sync_fn - self._nameset_functions = {} - self.global_namesets = NameSet.objects.filter(_global=True) - self.results_count = 0 - self.results_passed = 0 - - def nameset_functions(self, namesets: Iterable[NameSet]) -> dict[str, Callable]: - result = {} - for nameset in chain(namesets, self.global_namesets): - if nameset.name not in self._nameset_functions: - try: - new_functions = nameset.extract() - except Exception as e: - self.log_warning(f"Cannot extract code from nameset {nameset}, {type(e).__name__}: {e}") - new_functions = {} - self._nameset_functions[nameset.name] = new_functions - result |= self._nameset_functions[nameset.name] - return result - - def run_test(self, device: VDevice, test: ComplianceTest) -> tuple[bool, list[tuple[Any, Any]]]: - functions = self.nameset_functions(test.namesets.all()) - return test.run(device, functions, verbosity=self.script_data.explanation_verbosity) - - def run_tests_for_device( - self, - tests_qs: QuerySet[ComplianceTest], - device: VDevice, - report: ComplianceReport | None, - ) -> Generator[ComplianceTestResult, None, None]: - for test in tests_qs: - explanation = [] - try: - device.state # noqa: B018 - passed, explanation = self.run_test(device, test) - except EvalError as exc: - self.log_failure(f"Failed to execute test **{test}** for device **{device}**, `{exc}`") - passed = False - explanation.append((str(exc), None)) - self.results_count += 1 - self.results_passed += int(passed) - yield ComplianceTestResult( - test=test, - device=device, - passed=passed, - explanation=explanation, - report=report, - dynamic_pair=device.dynamic_pair, - ) - time.sleep(self._sleep_between_tests) - - def get_device_qs(self, selector: ComplianceSelector) -> QuerySet[VDevice]: - device_qs = selector.devices.select_related().prefetch_serializer().prefetch_poller() - if self.script_data.override_datasource: - device_qs = device_qs.set_datasource(self.script_data.override_datasource.obj) - else: - device_qs = device_qs.prefetch_datasource() - if self.script_data.devices: - device_qs = device_qs.filter(pk__in=self.script_data.devices) - return device_qs - - def run_tests_for_selector( - self, selector: ComplianceSelector, report: ComplianceReport | None - ) -> Generator[ComplianceTestResult, None, None]: - for device in self.get_device_qs(selector): - try: - yield from self.run_tests_for_device(selector.tests.all(), device, report) - except SerializationError as e: - self.log_failure(f"`{e}`, ignoring all tests for *{device}*") - continue - - def fire_report_webhook(self, report_id: int) -> None: - report = ComplianceReport.objects.filter(pk=report_id).annotate_result_stats().count_devices_and_tests().first() - queue = events_queue.get() - enqueue_object(queue, report, self.request.user, self.request.id, ObjectChangeActionChoices.ACTION_CREATE) - - def save_to_db(self, results: Iterable[ComplianceTestResult], report: ComplianceReport | None) -> None: - ComplianceTestResult.objects.bulk_create(results, batch_size=self._result_batch_size) - ComplianceTestResult.objects.delete_old() - if report: - ComplianceReport.objects.delete_old() - - def get_selectors(self) -> QuerySet[ComplianceSelector]: - selectors = self.script_data.selectors.queryset - test_qs = ComplianceTest.objects.all() - if self.script_data.test_tags: - test_qs = test_qs.filter(tags__pk__in=self.script_data.test_tags).distinct() - selectors = selectors.filter(tests__tags__pk__in=self.script_data.test_tags).distinct() - return selectors.prefetch_related(Prefetch("tests", test_qs.prefetch_related("namesets"))) - - def datasources_to_sync(self) -> Iterable[VDataSource]: - if self.script_data.override_datasource: - return [self.script_data.override_datasource.obj] - datasource_ids = ( - VDevice.objects.filter(self.script_data.device_filter) - .annotate_datasource_id() - .values_list("data_source_id", flat=True) - .distinct() - ) - return VDataSource.objects.filter(pk__in=datasource_ids) - - def run(self, data, commit): - self.script_data = self.script_data_cls(data) - selectors = self.get_selectors() - if self.script_data.sync_datasources: - self.datasource_sync_fn(self.datasources_to_sync(), device_filter=self.script_data.device_filter) - with null_request(): - report = ComplianceReport.objects.create() if self.script_data.make_report else None - results = chain.from_iterable(self.run_tests_for_selector(selector, report) for selector in selectors) - self.save_to_db(results, report) - output = {"results": {"all": self.results_count, "passed": self.results_passed}} - if report: - self.log_info(f"See [Compliance Report]({report.get_absolute_url()}) for detailed statistics") - self.fire_report_webhook(report.pk) - return yaml.dump(output, sort_keys=False) diff --git a/validity/scripts/runtests/__init__.py b/validity/scripts/runtests/__init__.py new file mode 100644 index 0000000..2e5d159 --- /dev/null +++ b/validity/scripts/runtests/__init__.py @@ -0,0 +1,3 @@ +from .apply import ApplyWorker +from .combine import CombineWorker +from .split import SplitWorker diff --git a/validity/scripts/runtests/apply.py b/validity/scripts/runtests/apply.py new file mode 100644 index 0000000..1167579 --- /dev/null +++ b/validity/scripts/runtests/apply.py @@ -0,0 +1,169 @@ +from dataclasses import dataclass, field +from functools import cached_property +from itertools import chain +from typing import Annotated, Any, Callable, Iterable, Iterator + +from dimi import Singleton +from django.db.models import Prefetch, QuerySet + +from validity import di +from validity.compliance.exceptions import EvalError, SerializationError +from validity.models import ComplianceSelector, ComplianceTest, ComplianceTestResult, NameSet, VDataSource, VDevice +from ..data_models import ExecutionResult, FullRunTestsParams, TestResultRatio +from ..logger import Logger +from ..parent_jobs import JobExtractor + + +class TestExecutor: + """ + Executes all the tests for specified subset of devices + """ + + def __init__(self, worker_id: int, explanation_verbosity: int, report_id: int) -> None: + self.explanation_verbosity = explanation_verbosity + self.report_id = report_id + self.log = Logger(script_id=f"Worker {worker_id}") + self.results_count = 0 + self.results_passed = 0 + self._nameset_functions = {} + self.global_namesets = NameSet.objects.filter(_global=True) + + def nameset_functions(self, namesets: Iterable[NameSet]) -> dict[str, Callable]: + result = {} + for nameset in chain(namesets, self.global_namesets): + if nameset.name not in self._nameset_functions: + try: + new_functions = nameset.extract() + except Exception as e: + self.log.warning(f"Cannot extract code from nameset {nameset}, {type(e).__name__}: {e}") + new_functions = {} + self._nameset_functions[nameset.name] = new_functions + result |= self._nameset_functions[nameset.name] + return result + + def run_test(self, device: VDevice, test: ComplianceTest) -> tuple[bool, list[tuple[Any, Any]]]: + functions = self.nameset_functions(test.namesets.all()) + return test.run(device, functions, verbosity=self.explanation_verbosity) + + def run_tests_for_device( + self, + tests_qs: QuerySet[ComplianceTest], + device: VDevice, + ) -> Iterator[ComplianceTestResult]: + for test in tests_qs: + try: + device.state # noqa: B018 + passed, explanation = self.run_test(device, test) + except EvalError as exc: + self.log.failure(f"Failed to execute test **{test}** for device **{device}**, `{exc}`") + passed = False + explanation = [(str(exc), None)] + self.results_count += 1 + self.results_passed += int(passed) + yield ComplianceTestResult( + test=test, + device=device, + passed=passed, + explanation=explanation, + report_id=self.report_id, + dynamic_pair=device.dynamic_pair, + ) + + def __call__(self, devices: QuerySet[VDevice], tests: QuerySet[ComplianceTest]) -> Iterator[ComplianceTestResult]: + for device in devices: + try: + yield from self.run_tests_for_device(tests, device) + except SerializationError as e: + self.log.failure(f"`{e}`, ignoring all tests for *{device}*") + continue + + +class DeviceTestIterator: + """ + Generates pairs of (devices, tests) where each test has to be executed on each of the corresponding devices + """ + + def __init__( + self, selector_devices: dict[int, list[int]], test_tags: list[int], overriding_datasource_id: int | None + ): + self.selector_devices = selector_devices + self.test_tags = test_tags + self.overriding_datasource_id = overriding_datasource_id + self.all_selectors = self._get_selectors().in_bulk() + + def __iter__(self): + return self + + def __next__(self) -> tuple[QuerySet[VDevice], QuerySet[ComplianceTest]]: + if not self.selector_devices: + raise StopIteration + selector_id, device_ids = self.selector_devices.popitem() + selector = self.all_selectors[selector_id] + devices = self._get_device_qs(selector, device_ids) + return devices, selector.tests.all() + + @cached_property + def overriding_datasource(self) -> VDataSource | None: + if self.overriding_datasource_id: + return VDataSource.objects.get(pk=self.overriding_datasource_id) + + def _get_selectors(self): + selectors = ComplianceSelector.objects.all() + test_qs = ComplianceTest.objects.all() + if self.test_tags: + test_qs = test_qs.filter(tags__pk__in=self.test_tags).distinct() + return selectors.prefetch_related(Prefetch("tests", test_qs.prefetch_related("namesets"))) + + def _get_device_qs(self, selector: ComplianceSelector, device_ids: list[int]) -> QuerySet[VDevice]: + device_qs = selector.devices.select_related().prefetch_serializer().prefetch_poller() + if self.overriding_datasource: + device_qs = device_qs.set_datasource(self.overriding_datasource) + else: + device_qs = device_qs.prefetch_datasource() + device_qs = device_qs.filter(pk__in=device_ids) + return device_qs + + +@di.dependency(scope=Singleton) +@dataclass(repr=False, kw_only=True) +class ApplyWorker: + """ + Provides a function to execute specified tests, save the results to DB and return ExecutionResult + """ + + test_executor_cls: type[TestExecutor] = TestExecutor + logger_factory: Callable[[str], Logger] = Logger + device_test_gen: type[DeviceTestIterator] = DeviceTestIterator + result_batch_size: Annotated[int, "validity_settings.result_batch_size"] + job_extractor_factory: Callable[[], JobExtractor] = JobExtractor + testresult_queryset: QuerySet[ComplianceTestResult] = field(default_factory=ComplianceTestResult.objects.all) + + def __call__(self, *, params: FullRunTestsParams, worker_id: int) -> ExecutionResult: + try: + executor = self.test_executor_cls(worker_id, params.explanation_verbosity, params.report_id) + test_results = self.get_test_results(params, worker_id, executor) + self.save_results_to_db(test_results) + return ExecutionResult( + TestResultRatio(executor.results_passed, executor.results_count), executor.log.messages + ) + except Exception as err: + logger = self.logger_factory(f"Worker {worker_id}") + logger.log_exception(err) + return ExecutionResult(test_stat=TestResultRatio(0, 0), log=logger.messages, errored=True) + + def get_test_results( + self, params: FullRunTestsParams, worker_id: int, executor: TestExecutor + ) -> Iterator[ComplianceTestResult]: + selector_devices = self.get_selector_devices(worker_id) + test_results = ( + executor(devices, tests) + for devices, tests in self.device_test_gen(selector_devices, params.test_tags, params.overriding_datasource) + ) + return chain.from_iterable(test_results) + + def get_selector_devices(self, worker_id: int) -> dict[int, list[int]]: + job_extractor = self.job_extractor_factory() + return job_extractor.parent.job.result.slices[worker_id] + + def save_results_to_db(self, results: Iterable[ComplianceTestResult]) -> None: + self.testresult_queryset.bulk_create(results, batch_size=self.result_batch_size) diff --git a/validity/scripts/runtests/base.py b/validity/scripts/runtests/base.py new file mode 100644 index 0000000..d71b646 --- /dev/null +++ b/validity/scripts/runtests/base.py @@ -0,0 +1,43 @@ +from contextlib import contextmanager +from dataclasses import dataclass, field + +from core.choices import JobStatusChoices +from core.models import Job +from django.db.models import QuerySet + +from validity.models import ComplianceTestResult +from ..exceptions import AbortScript + + +@dataclass(repr=False, kw_only=True) +class TerminateMixin: + testresult_queryset: QuerySet[ComplianceTestResult] = field(default_factory=ComplianceTestResult.objects.all) + + def terminate_job(self, job: Job, status: str, error: str | None = None, logs=None, output=None): + logs = logs or [] + job.data = {"log": [log.serialized for log in logs], "output": output} + job.terminate(status, error) + + def terminate_errored_job(self, job: Job, error: Exception): + logger = self.log_factory() + if isinstance(error, AbortScript): + logger.messages.extend(error.logs) + logger.failure(str(error)) + status = error.status + else: + logger.log_exception(error) + status = JobStatusChoices.STATUS_ERRORED + logger.info("Database changes have been reverted") + self.revert_db_changes(job) + self.terminate_job(job, status=status, error=repr(error), logs=logger.messages) + + def revert_db_changes(self, job: Job) -> None: + self.testresult_queryset.filter(report_id=job.object_id).raw_delete() + + @contextmanager + def terminate_job_on_error(self, job: Job): + try: + yield + except Exception as err: + self.terminate_errored_job(job, err) + raise diff --git a/validity/scripts/runtests/combine.py b/validity/scripts/runtests/combine.py new file mode 100644 index 0000000..03a270d --- /dev/null +++ b/validity/scripts/runtests/combine.py @@ -0,0 +1,93 @@ +import datetime +import operator +from dataclasses import dataclass, field +from functools import reduce +from itertools import chain +from typing import Annotated, Any, Callable + +from core.choices import JobStatusChoices +from core.models import Job +from dimi import Singleton +from django.db.models import QuerySet +from django.http import HttpRequest +from django.urls import reverse +from extras.choices import ObjectChangeActionChoices + +from validity import di +from validity.models import ComplianceReport +from validity.netbox_changes import enqueue_object, events_queue +from ..data_models import FullRunTestsParams, Message, TestResultRatio +from ..exceptions import AbortScript +from ..launch import Launcher +from ..logger import Logger +from ..parent_jobs import JobExtractor +from .base import TerminateMixin + + +def enqueue(report, request, action): + queue = events_queue.get() + enqueue_object(queue, report, request.get_user(), request.id, action) + events_queue.set(queue) + + +@di.dependency(scope=Singleton) +@dataclass(repr=False, kw_only=True) +class CombineWorker(TerminateMixin): + log_factory: Callable[[], Logger] = Logger + job_extractor_factory: Callable[[], JobExtractor] = JobExtractor + enqueue_func: Callable[[ComplianceReport, HttpRequest, str], None] = enqueue + report_queryset: QuerySet[ComplianceReport] = field( + default_factory=ComplianceReport.objects.annotate_result_stats().count_devices_and_tests + ) + + def fire_report_webhook(self, report_id: int, request: HttpRequest) -> None: + report = self.report_queryset.get(pk=report_id) + self.enqueue_func(report, request, ObjectChangeActionChoices.ACTION_CREATE) + + def count_test_stats(self, job_extractor: JobExtractor) -> TestResultRatio: + result_ratios = (parent.job.result.test_stat for parent in job_extractor.parents) + return reduce(operator.add, result_ratios) + + def collect_logs(self, logger: Logger, job_extractor: JobExtractor) -> list[Message]: + parent_logs = chain.from_iterable(extractor.job.result.log for extractor in job_extractor.parents) + grandparent_logs = job_extractor.parent.parent.job.result.log + return [*grandparent_logs, *parent_logs, *logger.messages] + + def compose_logs(self, logger, job_extractor, report_id): + report_url = reverse("plugins:validity:compliancereport", kwargs={"pk": report_id}) + logger.success(f"Job succeeded. See [Compliance Report]({report_url}) for detailed statistics") + return self.collect_logs(logger, job_extractor) + + def terminate_succeeded_job(self, job: Job, test_stats: TestResultRatio, logs: list[Message]): + job.data = {"log": [log.serialized for log in logs], "output": {"statistics": test_stats.serialized}} + job.terminate() + + @di.inject + def schedule_next_job( + self, params: FullRunTestsParams, job: Job, launcher: Annotated[Launcher, "runtests_launcher"] + ) -> None: + if params.schedule_interval: + params.schedule_at = job.started + datetime.timedelta(minutes=params.schedule_interval) + launcher(params) + + def abort_if_apply_errors(self, job_extractor: JobExtractor) -> None: + error_logs = list( + chain.from_iterable( + extractor.job.result.log for extractor in job_extractor.parents if extractor.job.result.errored + ) + ) + if error_logs: + raise AbortScript("ApplyWorkerError", status=JobStatusChoices.STATUS_ERRORED, logs=error_logs) + + def __call__(self, params: FullRunTestsParams) -> Any: + netbox_job = params.get_job() + with self.terminate_job_on_error(netbox_job): + job_extractor = self.job_extractor_factory() + self.abort_if_apply_errors(job_extractor) + self.fire_report_webhook(params.report_id, params.request) + self.schedule_next_job(params, netbox_job) + logs = self.compose_logs(self.log_factory(), job_extractor, params.report_id) + test_stats = self.count_test_stats(job_extractor) + self.terminate_job( + netbox_job, JobStatusChoices.STATUS_COMPLETED, logs=logs, output={"statistics": test_stats.serialized} + ) diff --git a/validity/scripts/runtests/split.py b/validity/scripts/runtests/split.py new file mode 100644 index 0000000..287233f --- /dev/null +++ b/validity/scripts/runtests/split.py @@ -0,0 +1,104 @@ +from dataclasses import dataclass, field +from itertools import chain, cycle, groupby, repeat +from typing import Callable, Iterable + +from dimi import Singleton +from django.db.models import Q, QuerySet + +from validity import di +from validity.models import ComplianceSelector, VDataSource, VDevice +from validity.utils.misc import batched, datasource_sync +from ..data_models import FullRunTestsParams, SplitResult +from ..exceptions import AbortScript +from ..logger import Logger +from .base import TerminateMixin + + +@di.dependency(scope=Singleton) +@dataclass(repr=False) +class SplitWorker(TerminateMixin): + log_factory: Callable[[], Logger] = Logger + datasource_sync_fn: Callable[[Iterable[VDataSource], Q], None] = datasource_sync + device_batch_size: int = 2000 + datasource_queryset: QuerySet[VDataSource] = field(default_factory=VDataSource.objects.all) + device_queryset: QuerySet[VDevice] = field(default_factory=VDevice.objects.all) + + def datasources_to_sync(self, overriding_datasource: int | None, device_filter: Q) -> Iterable[VDataSource]: + if overriding_datasource: + return [self.datasource_queryset.get(pk=overriding_datasource)] + datasource_ids = ( + self.device_queryset.filter(device_filter) + .annotate_datasource_id() + .values_list("data_source_id", flat=True) + .distinct() + ) + return self.datasource_queryset.filter(pk__in=datasource_ids) + + def sync_datasources(self, overriding_datasource: int | None, device_filter: Q): + datasources = self.datasources_to_sync(overriding_datasource, device_filter) + self.datasource_sync_fn(datasources, device_filter) + + def _work_slices( + self, selector_qs: QuerySet[ComplianceSelector], specific_devices: list[int], devices_per_worker: int + ): + def get_device_ids(selector): + qs = selector.devices.filter(pk__in=specific_devices) if specific_devices else selector.devices + return qs.order_by("pk").values_list("pk", flat=True).iterator(chunk_size=self.device_batch_size) + + selector_device = chain.from_iterable( + zip(repeat(selector.pk), get_device_ids(selector)) for selector in selector_qs + ) + for batch in batched(selector_device, devices_per_worker, tuple): + yield { + selector: device_ids + for selector, grouped_pairs in groupby(batch, key=lambda pair: pair[0]) + if (device_ids := [dev_id for _, dev_id in grouped_pairs]) + } + + def _eliminate_leftover(self, slices): + leftover = slices.pop() + for slice in cycle(slices): + if not leftover: + break + selector, devices = leftover.popitem() + slice.setdefault(selector, []) + slice[selector].extend(devices) + + def distribute_work( + self, params: FullRunTestsParams, logger: Logger, device_filter: Q + ) -> list[dict[int, list[int]]]: + """ + Split all the devices under test into N slices where N is the number of workers + Returns list of {selector_id: [device_id_1, device_id_2, ...]} + """ + device_count = self.device_queryset.filter(device_filter).count() + if not (devices_per_worker := device_count // params.workers_num): + raise AbortScript( + f"The number of workers ({params.workers_num}) " + f"cannot be larger than the number of devices ({device_count})" + ) + logger.info(f"Running the tests for *{device_count} devices*") + if params.workers_num > 1: + logger.info( + f"Distributing the work among {params.workers_num} workers. " + f"Each worker handles {devices_per_worker} device(s) in average" + ) + + slices = [*self._work_slices(params.selector_qs, params.devices, devices_per_worker)] + + # distribute the leftover among other slices + if len(slices) > params.workers_num: + self._eliminate_leftover(slices) + return slices + + def __call__(self, params: FullRunTestsParams) -> SplitResult: + job = params.get_job() + with self.terminate_job_on_error(job): + job.start() + job.object_type.model_class().objects.delete_old() + logger = self.log_factory() + device_filter = params.get_device_filter() + if params.sync_datasources: + self.sync_datasources(params.overriding_datasource, device_filter) + slices = self.distribute_work(params, logger, device_filter) + return SplitResult(log=logger.messages, slices=slices) diff --git a/validity/scripts/script_data.py b/validity/scripts/script_data.py deleted file mode 100644 index 26371c0..0000000 --- a/validity/scripts/script_data.py +++ /dev/null @@ -1,128 +0,0 @@ -import operator -from functools import cached_property, reduce -from typing import Generic, TypeVar, get_args - -from django.db.models import Model, Q, QuerySet -from django.utils.functional import classproperty -from extras.models import Tag - -from validity import models - - -class DBObject(int): - def __new__(cls, value, model): - return super().__new__(cls, value) - - def __init__(self, value, model): - self.model = model - super().__init__() - - @cached_property - def obj(self): - return self.model.objects.filter(pk=self).first() - - -class QuerySetObject(list): - def __init__(self, iterable, model=None): - self.model = model - super().__init__(iterable) - - -class AllQuerySetObject(QuerySetObject): - """ - Defaults to "all" if empty - """ - - @property - def queryset(self): - if not self: - return self.model.objects.all() - return self.model.objects.filter(pk__in=iter(self)) - - -class EmptyQuerySetObject(QuerySetObject): - """ - Defaults to "none" if empty - """ - - @property - def queryset(self): - if not self: - return self.model.objects.none() - return self.model.objects.filter(pk__in=iter(self)) - - -class DBField: - def __init__(self, model, object_cls, default=None) -> None: - self.model = model - self.object_cls = object_cls - self.attr_name = None - if default is not None and not isinstance(default, object_cls): - default = object_cls(default, model) - self.default = default - - def __set_name__(self, parent_cls, attr_name): - self.attr_name = attr_name - - def __get__(self, instance, type_): - return instance.__dict__.get(self.attr_name, self.default) - - def __set__(self, instance, value): - if value is not None: - value = self.object_cls(value, self.model) - instance.__dict__[self.attr_name] = value - - -class ScriptData: - def from_queryset(self, queryset: QuerySet) -> list[int]: - """ - Extract primary keys from queryset - """ - return list(queryset.values_list("pk", flat=True)) - - def __init__(self, data) -> None: - for k, v in data.items(): - if isinstance(v, QuerySet): - v = self.from_queryset(v) - elif isinstance(v, Model): - v = v.pk - setattr(self, k, v) - - -_ScriptData = TypeVar("_ScriptData", bound=ScriptData) - - -class ScriptDataMixin(Generic[_ScriptData]): - """ - Mixin for Script. Allows to define script data cls in class definition and later use it. - Example: - self.script_data = self.script_data_cls(data) - """ - - script_data: _ScriptData - - @classproperty - def script_data_cls(cls) -> type[_ScriptData]: - for base_classes in cls.__orig_bases__: - if (args := get_args(base_classes)) and issubclass(args[0], ScriptData): - return args[0] - raise AttributeError(f"No ScriptData definition found for {cls.__name__}") - - -class RunTestsScriptData(ScriptData): - sync_datasources = False - make_report = True - selectors = DBField(models.ComplianceSelector, AllQuerySetObject, default=[]) - devices = DBField(models.VDevice, AllQuerySetObject, default=[]) - test_tags = DBField(Tag, EmptyQuerySetObject, default=[]) - explanation_verbosity = 2 - override_datasource = DBField(models.VDataSource, DBObject, default=None) - - @cached_property - def device_filter(self) -> Q: - filtr = Q() - if self.selectors: - filtr &= reduce(operator.or_, (qs.filter for qs in self.selectors.queryset)) - if self.devices: - filtr &= reduce(operator.or_, (Q(pk=pk) for pk in self.devices)) - return filtr diff --git a/validity/scripts/variables.py b/validity/scripts/variables.py deleted file mode 100644 index 5bd26cf..0000000 --- a/validity/scripts/variables.py +++ /dev/null @@ -1,13 +0,0 @@ -from extras.scripts import ChoiceVar - -from validity.forms.fields import IntegerChoiceField - - -class NoNullChoiceVar(ChoiceVar): - def __init__(self, choices, *args, **kwargs): - super().__init__(choices, *args, **kwargs) - self.field_attrs["choices"] = choices - - -class VerbosityVar(NoNullChoiceVar): - form_field = IntegerChoiceField diff --git a/validity/settings.py b/validity/settings.py new file mode 100644 index 0000000..e3981d7 --- /dev/null +++ b/validity/settings.py @@ -0,0 +1,30 @@ +from typing import Annotated + +from pydantic import BaseModel, Field + +from validity import di + + +class ScriptTimeouts(BaseModel): + """ + Timeout syntax complies with rq timeout format + """ + + runtests_split: int | str = "15m" + runtests_apply: int | str = "30m" + runtests_combine: int | str = "15m" + + +class ValiditySettings(BaseModel): + store_reports: int = Field(default=5, gt=0, lt=1001) + result_batch_size: int = Field(default=500, ge=1) + polling_threads: int = Field(default=500, ge=1) + runtests_queue: str = "default" + script_timeouts: ScriptTimeouts = ScriptTimeouts() + + +class ValiditySettingsMixin: + @property + @di.inject + def v_settings(self, _settings: Annotated[ValiditySettings, "validity_settings"]) -> ValiditySettings: + return _settings diff --git a/validity/signals.py b/validity/signals.py new file mode 100644 index 0000000..d3b38a2 --- /dev/null +++ b/validity/signals.py @@ -0,0 +1,9 @@ +from django.db.models.signals import pre_delete +from django.dispatch import receiver + +from validity.models import ComplianceReport + + +@receiver(pre_delete, sender=ComplianceReport) +def delete_bound_jobs(sender, instance, **kwargs): + instance.jobs.all().delete() diff --git a/validity/tables.py b/validity/tables.py index 0162f8b..8300e03 100644 --- a/validity/tables.py +++ b/validity/tables.py @@ -1,3 +1,4 @@ +import datetime import itertools from functools import partial @@ -9,13 +10,12 @@ from django.utils.safestring import mark_safe from django.utils.translation import gettext_lazy as _ from django_tables2 import Column, RequestConfig, Table, TemplateColumn -from netbox.tables import BooleanColumn as BooleanColumn -from netbox.tables import ChoiceFieldColumn, ManyToManyColumn, NetBoxTable -from netbox.tables.columns import ActionsColumn, LinkedCountColumn +from netbox.tables import BaseTable, BooleanColumn, ChoiceFieldColumn, ManyToManyColumn, NetBoxTable +from netbox.tables.columns import ActionsColumn, LinkedCountColumn, MarkdownColumn from utilities.paginator import EnhancedPaginator from validity import models -from validity.templatetags.validity import colorful_percentage +from validity.templatetags.validity import colorful_percentage, isodatetime class SelectorTable(NetBoxTable): @@ -173,6 +173,7 @@ def get_table_attr(obj, attr_name): class ComplianceReportTable(NetBoxTable): id = Column(linkify=True) + job_status = ChoiceFieldColumn(verbose_name=_("Job Status"), accessor="jobs__first__status") groupby_value = Column( verbose_name=_("GroupBy Value"), linkify=lambda record: reverse(record["viewname"], kwargs={"pk": record["groupby_pk"]}), @@ -190,6 +191,7 @@ class Meta(NetBoxTable.Meta): model = models.ComplianceReport fields = ( "id", + "job_status", "groupby_value", "device_count", "test_count", @@ -201,6 +203,11 @@ class Meta(NetBoxTable.Meta): ) default_columns = fields + def render_job_status(self, column, bound_column, record, value): + record = record.jobs.first() + bound_column.name = "status" + return column.render(record, bound_column, value) + class DeviceReportM2MColumn(ManyToManyColumn): def __init__(self, *args, badge_color: str = "", **kwargs): @@ -279,3 +286,19 @@ def get_page_lengths(self): paginate_by = self.get_paginate_by(request, max_paginate_by) paginate = {"paginator_class": paginator_class, "per_page": paginate_by} RequestConfig(request, paginate).configure(self) + + +class ScriptResultTable(BaseTable): + index = Column(verbose_name=_("Line"), empty_values=()) + time = Column(verbose_name=_("Time")) + status = TemplateColumn( + template_code="""{% load log_levels %}{% log_level record.status %}""", verbose_name=_("Level") + ) + message = MarkdownColumn(verbose_name=_("Message")) + + class Meta(BaseTable.Meta): + empty_text = _("No results found") + fields = ("index", "time", "status", "message") + + def render_time(self, value): + return isodatetime(datetime.datetime.fromisoformat(value)) diff --git a/validity/template_content.py b/validity/template_content.py index 3f6502a..fb4bcd2 100644 --- a/validity/template_content.py +++ b/validity/template_content.py @@ -1,4 +1,5 @@ import yaml +from django.urls import reverse from django.utils.translation import gettext_lazy as _ from tenancy.models import Tenant @@ -47,4 +48,13 @@ def right_page(self): ) -template_extensions = [DataSourceTenantExtension, PollingInfoExtension] +class ComplianceTestExtension(PluginTemplateExtension): + model = "validity.compliancetest" + + def list_buttons(self): + run_tests_url = reverse("plugins:validity:compliancetest_run") + icon = '' + return f'{icon} Run Tests' + + +template_extensions = [DataSourceTenantExtension, PollingInfoExtension, ComplianceTestExtension] diff --git a/validity/templates/validity/compliancereport.html b/validity/templates/validity/compliancereport.html index 11c6c0d..ff08c1f 100644 --- a/validity/templates/validity/compliancereport.html +++ b/validity/templates/validity/compliancereport.html @@ -26,6 +26,18 @@
Compliance Report
Created {{ object.created | date:"Y-m-d G:i:s" }} + {% with job=object.jobs.first %} + + Job + {% if job %}{{ job | linkify }} | {{ job | colored_choice:"status" }}{% else %}—{% endif %} + + {% if job.error %} + + Job Error + {{ job.error }} + + {% endif %} + {% endwith %} Devices involved {{ object.device_count }} diff --git a/validity/templates/validity/compliancetestresult.html b/validity/templates/validity/compliancetestresult.html index c5bfbe5..6ad64e0 100644 --- a/validity/templates/validity/compliancetestresult.html +++ b/validity/templates/validity/compliancetestresult.html @@ -51,6 +51,10 @@
Compliance Test Results
Created {{ object.created | date:"Y-m-d G:i:s" }} + + Report + {{ object.report | linkify | placeholder }} + diff --git a/validity/templates/validity/inc/fieldset.html b/validity/templates/validity/inc/fieldset.html new file mode 100644 index 0000000..99d7053 --- /dev/null +++ b/validity/templates/validity/inc/fieldset.html @@ -0,0 +1,11 @@ +{% load form_helpers %} +
+
+
{{ group }}
+
+ {% for name in fields %} + {% with field=form|getfield:name %} + {% render_field field %} + {% endwith %} + {% endfor %} +
diff --git a/validity/templates/validity/scripts/result.html b/validity/templates/validity/scripts/result.html new file mode 100644 index 0000000..cb83547 --- /dev/null +++ b/validity/templates/validity/scripts/result.html @@ -0,0 +1,43 @@ +{% extends 'extras/script_result.html' %} +{% load validity %} +{% block title %}Script Execution Result{% endblock title %} + +{% block header %} + +{% endblock header %} + +{% block content-wrapper %} +{% block content %} +
+
+
+ {% block htmx-template %} + {% include 'validity/scripts/result_htmx.html' %} + {% endblock htmx-template %} +
+
+
+ +{% endblock content %} +{% endblock content-wrapper %} diff --git a/validity/templates/validity/scripts/result_htmx.html b/validity/templates/validity/scripts/result_htmx.html new file mode 100644 index 0000000..612defd --- /dev/null +++ b/validity/templates/validity/scripts/result_htmx.html @@ -0,0 +1,36 @@ +{% load helpers %} +{% load log_levels %} +{% load i18n %} +{% load validity %} + +
+

+ {% if job.started %} + Started: {{ job.started|isodatetime }} + {% elif job.scheduled %} + Scheduled: {{ job.scheduled|isodatetime }} + {% else %} + Created: {{ job.created|isodatetime }} + {% endif %} + {% if job.completed %} + Duration: {{ job.duration }} + {% endif %} + {% badge job.get_status_display job.get_status_color %} +

+ {% if job.completed %} +
+
+
Log
+ {% include 'inc/table.html' %} +
+
+ +
+
Output
+
{{ job.data | get_key:"output" | yaml }}
+
+ + {% elif job.started %} + {% include 'extras/inc/result_pending.html' %} + {% endif %} +
diff --git a/validity/templates/validity/scripts/run.html b/validity/templates/validity/scripts/run.html new file mode 100644 index 0000000..cf710f6 --- /dev/null +++ b/validity/templates/validity/scripts/run.html @@ -0,0 +1,44 @@ +{% extends 'generic/object.html' %} +{% load helpers %} +{% load log_levels %} +{% load validity %} + +{% block title %}Run Compliance Tests{% endblock %} +{% block breadcrumbs %} + +{% endblock breadcrumbs %} +{% block subtitle %} +
+ Execute all or particular Compliance Tests against specified scope of Devices and generate the Report +
+{% endblock subtitle %} +{% block controls %}{% endblock %} +{% block object_identifier %}{% endblock object_identifier %} +{% block tabs %}{% endblock tabs %} + + +{% block content %} +
+
+ {% if not perms.extras.run_compliancetest %} +
+ You do not have permission to run Compliance Tests. +
+ {% endif %} +
+ {% csrf_token %} +
+ {% for fieldset in form.fieldsets %} + {% render_fieldset form fieldset %} + {% endfor %} +
+
+ Cancel + +
+
+
+
+{% endblock content %} diff --git a/validity/templatetags/validity.py b/validity/templatetags/validity.py index a1113fa..b5c2211 100644 --- a/validity/templatetags/validity.py +++ b/validity/templatetags/validity.py @@ -1,10 +1,12 @@ from typing import Any from django import template +from django.contrib.humanize.templatetags.humanize import naturaltime from django.db.models import Model from django.http.request import HttpRequest from django.utils.html import format_html from django.utils.safestring import mark_safe +from django.utils.timezone import localtime from django.utils.translation import gettext_lazy as _ from utilities.templatetags.builtins.filters import linkify, placeholder @@ -91,3 +93,18 @@ def bg(): @register.filter def nb_version(): return config.netbox_version + + +@register.inclusion_tag("validity/inc/fieldset.html") +def render_fieldset(form, fieldset): + # backport of the native render_fieldset appeared in 4.0 + name, items = fieldset if config.netbox_version < "4.0.0" else (fieldset.name, fieldset.items) + return {"group": name, "fields": items, "form": form} + + +@register.filter() +def isodatetime(value, spec="seconds"): + # backport of the native isodatetime in 4.0 + value = localtime(value) if value.tzinfo else value + text = f"{value.date().isoformat()} {value.time().isoformat(spec)}" + return mark_safe(f'{text}') diff --git a/validity/tests/conftest.py b/validity/tests/conftest.py index 490772e..cb22dd9 100644 --- a/validity/tests/conftest.py +++ b/validity/tests/conftest.py @@ -4,13 +4,13 @@ from core.models import DataSource from dcim.models import Device, DeviceType, Manufacturer from django.contrib.contenttypes.models import ContentType -from extras.models import CustomField, ScriptModule +from extras.models import CustomField from graphene_django.utils.testing import graphql_query from tenancy.models import Tenant import validity import validity.scripts -from validity.models import Poller, Serializer, VDataSource +from validity.models import Poller, Serializer from validity.utils.orm import CustomFieldBuilder @@ -77,28 +77,14 @@ def create_custom_fields(db): ) -@pytest.fixture -def setup_runtests_script(): - ds_path = Path(validity.scripts.__file__).parent.resolve() / "install" - datasource = VDataSource.objects.create( - name="validity_scripts", - type="local", - source_url=f"file://{ds_path}", - ) - datasource.sync() - module = ScriptModule( - data_source=datasource, - data_file=datasource.datafiles.get(path="validity_scripts.py"), - file_root="scripts", - auto_sync_enabled=True, - ) - module.clean() - module.save() - - @pytest.fixture def gql_query(admin_client): def func(*args, **kwargs): return graphql_query(*args, **kwargs, client=admin_client, graphql_url="/graphql/") return func + + +@pytest.fixture +def di(): + return validity.di diff --git a/validity/tests/factories.py b/validity/tests/factories.py index def7712..c0065a2 100644 --- a/validity/tests/factories.py +++ b/validity/tests/factories.py @@ -1,8 +1,12 @@ import datetime +import uuid import django import factory +from core.models import Job from dcim.models import DeviceRole, DeviceType, Location, Manufacturer, Platform, Site +from django.contrib.auth import get_user_model +from django.contrib.contenttypes.models import ContentType from extras.models import Tag from factory.django import DjangoModelFactory from tenancy.models import Tenant @@ -232,6 +236,31 @@ class Meta: model = models.Poller +class UserFactory(DjangoModelFactory): + email = "su@admin.com" + username = "su" + password = factory.PostGenerationMethodCall("set_password", "admin") + + is_superuser = True + is_staff = True + is_active = True + + class Meta: + model = get_user_model() + + +class RunTestsJobFactory(DjangoModelFactory): + name = "RunTests" + object = factory.SubFactory(ReportFactory) + object_id = factory.SelfAttribute("object.pk") + object_type = factory.LazyAttribute(lambda obj: ContentType.objects.get_for_model(type(obj.object))) + user = factory.SubFactory(UserFactory) + job_id = factory.LazyFunction(uuid.uuid4) + + class Meta: + model = Job + + _NOT_DEFINED = object() diff --git a/validity/tests/test_api.py b/validity/tests/test_api.py index 20e9a27..f0e4e2b 100644 --- a/validity/tests/test_api.py +++ b/validity/tests/test_api.py @@ -1,4 +1,5 @@ from http import HTTPStatus +from unittest.mock import Mock import pytest from base import ApiGetTest, ApiPostGetTest @@ -14,6 +15,7 @@ ManufacturerFactory, PlatformFactory, ReportFactory, + RunTestsJobFactory, SelectorFactory, SerializerDBFactory, SiteFactory, @@ -22,6 +24,7 @@ state_item, ) +from validity import dependencies from validity.models import VDevice @@ -202,3 +205,20 @@ def test_report_devices(admin_client): for device in results: assert len(device["results"]) == 1 assert device["results_count"] == 1 + + +@pytest.mark.parametrize( + "post_body, status_code", + [ + ({}, HTTPStatus.OK), + ({"devices": [1, 2]}, HTTPStatus.BAD_REQUEST), # devices do not exist + ({"schedule_interval": 1, "sync_datasources": True, "explanation_verbosity": 2}, HTTPStatus.OK), + ], +) +def test_run_tests(admin_client, di, post_body, status_code): + launcher = Mock(return_value=RunTestsJobFactory()) + with di.override({dependencies.runtests_launcher: lambda: launcher}): + resp = admin_client.post("/api/plugins/validity/tests/run/", post_body, content_type="application/json") + assert resp.status_code == status_code + if resp.status_code == HTTPStatus.OK: + launcher.assert_called_once() diff --git a/validity/tests/test_managers.py b/validity/tests/test_managers.py index c6bb022..09fb63e 100644 --- a/validity/tests/test_managers.py +++ b/validity/tests/test_managers.py @@ -1,47 +1,22 @@ -from itertools import product from unittest.mock import Mock import pytest -from factories import CommandFactory, CompTestDBFactory, DataSourceFactory, DeviceFactory +from factories import CommandFactory, DataSourceFactory, DeviceFactory -from validity.models import Command, ComplianceReport, ComplianceTestResult, VDevice - - -@pytest.mark.parametrize("store_results", [3, 2, 1]) -@pytest.mark.django_db -def test_delete_old_results(store_results): - report = ComplianceReport.objects.create() - device1 = DeviceFactory() - device2 = DeviceFactory() - test1 = CompTestDBFactory() - test2 = CompTestDBFactory() - report_results = [ - ComplianceTestResult.objects.create(passed=True, device=device1, test=test1, explanation=[], report=report).pk, - ComplianceTestResult.objects.create(passed=True, device=device2, test=test2, explanation=[], report=report).pk, - ] - result_per_devtest = 5 - for test, device in product([test1, test2], [device1, device2]): - for i in range(result_per_devtest): - ComplianceTestResult.objects.create(passed=True, device=device, test=test, explanation=i) - - assert ComplianceTestResult.objects.count() == 4 * result_per_devtest + len(report_results) - ComplianceTestResult.objects.delete_old(_settings=Mock(store_last_results=store_results)) - assert ComplianceTestResult.objects.count() == 4 * store_results + len(report_results) - assert ComplianceTestResult.objects.filter(pk__in=report_results).count() == len(report_results) - for test, device in product([test1, test2], [device1, device2]): - assert [ - *ComplianceTestResult.objects.filter(report=None, test=test, device=device) - .order_by("created") - .values_list("explanation", flat=True) - ] == [*range(result_per_devtest - store_results, result_per_devtest)] +from validity.dependencies import validity_settings +from validity.managers import ComplianceSelectorQS +from validity.models import Command, ComplianceReport, ComplianceSelector, VDevice +from validity.settings import ValiditySettings @pytest.mark.parametrize("store_reports", [3, 2, 1]) @pytest.mark.django_db -def test_delete_old_reports(store_reports): +def test_delete_old_reports(store_reports, di): reports = [ComplianceReport.objects.create() for _ in range(10)] - ComplianceReport.objects.delete_old(_settings=Mock(store_reports=store_reports)) - assert list(ComplianceReport.objects.order_by("created")) == reports[-store_reports:] + settings = ValiditySettings(store_reports=store_reports) + with di.override({validity_settings: lambda: settings}): + ComplianceReport.objects.delete_old() + assert list(ComplianceReport.objects.order_by("created")) == reports[-store_reports:] @pytest.mark.django_db @@ -67,3 +42,18 @@ def test_set_attribute(): assert device.attr1 == "val1" and device.attr2 == "val2" for device in device_qs.filter(name__startswith="d"): assert device.attr1 == "val1" and device.attr2 == "val2" + + +def test_prefetch_filters(monkeypatch): + monkeypatch.setattr(ComplianceSelectorQS, "prefetch_related", Mock()) + ComplianceSelector.objects.all().prefetch_filters() + ComplianceSelectorQS.prefetch_related.assert_called_once() + assert set(ComplianceSelectorQS.prefetch_related.call_args.args) == { + "tag_filter", + "manufacturer_filter", + "type_filter", + "platform_filter", + "location_filter", + "site_filter", + "tenant_filter", + } diff --git a/validity/tests/test_scripts/conftest.py b/validity/tests/test_scripts/conftest.py index 055026b..a0035e6 100644 --- a/validity/tests/test_scripts/conftest.py +++ b/validity/tests/test_scripts/conftest.py @@ -1,10 +1,17 @@ -from unittest.mock import Mock +import uuid import pytest -from extras.scripts import Script +from factories import RunTestsJobFactory + +from validity.scripts.data_models import RequestInfo, RunTestsParams + + +@pytest.fixture +def runtests_params(): + return RunTestsParams(request=RequestInfo(id=uuid.uuid4(), user_id=1)) @pytest.fixture -def mock_script_logging(monkeypatch): - for log_func in ["log_debug", "log_info", "log_failure", "log_success", "log_warning"]: - monkeypatch.setattr(Script, log_func, Mock()) +def full_runtests_params(runtests_params): + job = RunTestsJobFactory() + return runtests_params.with_job_info(job) diff --git a/validity/tests/test_scripts/runtests/test_apply.py b/validity/tests/test_scripts/runtests/test_apply.py new file mode 100644 index 0000000..6e537cf --- /dev/null +++ b/validity/tests/test_scripts/runtests/test_apply.py @@ -0,0 +1,175 @@ +from dataclasses import dataclass, field +from unittest.mock import Mock + +import pytest +from factories import CompTestDBFactory, CompTestResultFactory, DeviceFactory, NameSetDBFactory, SelectorFactory + +from validity.compliance.exceptions import EvalError +from validity.models import ComplianceTest +from validity.scripts.data_models import ExecutionResult +from validity.scripts.data_models import TestResultRatio as ResultRatio +from validity.scripts.runtests.apply import ApplyWorker, DeviceTestIterator +from validity.scripts.runtests.apply import TestExecutor as TExecutor + + +NS_1 = """ +__all__ = ["func1", "var", "func2"] + +def func1(var): pass + +var = 1234 + +def func2(var): pass +""" + +NS_2 = """ +from collections import Counter +import itertools + +__all__ = ["func3", "non_existing_func", "Counter", "itertools"] + +def func3(): pass +""" + +NS_3 = "some wrong syntax" + + +@pytest.mark.parametrize( + "nameset_texts, extracted_fn_names, warning_calls", + [ + pytest.param(["", ""], set(), 0, id="empty"), + pytest.param([NS_1], {"func1", "func2"}, 0, id="NS_1"), + pytest.param([NS_2], {"func3", "Counter"}, 0, id="NS_2"), + pytest.param([NS_1, NS_2], {"func1", "func2", "func3", "Counter"}, 0, id="NS_1, NS_2"), + pytest.param([NS_3], set(), 1, id="NS_3"), + pytest.param([NS_3, NS_1, NS_3], {"func1", "func2"}, 2, id="NS3, NS_1, NS_3"), + ], +) +@pytest.mark.django_db +def test_nameset_functions(nameset_texts, extracted_fn_names, warning_calls): + script = TExecutor(1, 2, 10) + namesets = [NameSetDBFactory(definitions=d) for d in nameset_texts] + functions = script.nameset_functions(namesets) + assert extracted_fn_names == functions.keys() + assert len(script.log.messages) == warning_calls + for fn_name, fn in functions.items(): + assert fn_name == fn.__name__ + assert callable(fn) + + +FUNC = """ +__all__ = ['func'] +{} +""" + + +@pytest.mark.parametrize( + "definitions", + [ + pytest.param(FUNC.format("def func(): return max(1, 10)"), id="max"), + pytest.param(FUNC.format('def func(): return jq.first(".data", {"data": [1,2,3]})'), id="jq"), + ], +) +@pytest.mark.django_db +def test_builtins_are_available_in_nameset(definitions): + script = TExecutor(10, 20, 30) + namesets = [NameSetDBFactory(definitions=definitions)] + functions = script.nameset_functions(namesets) + functions["func"]() + + +@pytest.mark.django_db +def test_run_tests_for_device(): + device = DeviceFactory() + device.__dict__["state"] = {"a": "b"} + device.__dict__["dynamic_pair"] = DeviceFactory(name="dynpair") + namesets = Mock(**{"all.return_value": []}) + tests = [Mock(namesets=namesets, spec=ComplianceTest, _state=Mock(db="default")) for _ in range(3)] + tests[0].run.return_value = True, [] + tests[1].run.return_value = False, [("some", "explanation")] + tests[2].run.side_effect = EvalError("some test error") + executor = TExecutor(10, explanation_verbosity=2, report_id=30) + results = [ + { + "passed": r.passed, + "explanation": r.explanation, + "report_id": r.report_id, + "dynamic_pair": r.dynamic_pair.name, + } + for r in executor.run_tests_for_device(tests, device) + ] + assert results == [ + {"passed": True, "explanation": [], "report_id": 30, "dynamic_pair": "dynpair"}, + {"passed": False, "explanation": [("some", "explanation")], "report_id": 30, "dynamic_pair": "dynpair"}, + {"passed": False, "explanation": [("some test error", None)], "report_id": 30, "dynamic_pair": "dynpair"}, + ] + for test in tests: + test.run.assert_called_once_with(device, {}, verbosity=2) + assert executor.results_passed == 1 + assert executor.results_count == 3 + + +@pytest.mark.django_db +def test_devicetest_iterator(): + devices = [DeviceFactory() for _ in range(3)] + selectors = [SelectorFactory(), SelectorFactory()] + tests = [CompTestDBFactory() for _ in range(5)] + selectors[0].tests.set(tests[:2]) + selectors[1].tests.set(tests[2:]) + selector_devices = {selectors[0].pk: [devices[0].pk, devices[1].pk], selectors[1].pk: [d.pk for d in devices[2:]]} + iterator = DeviceTestIterator(selector_devices, [], None) + iter_values = [(list(d.order_by("pk")), list(t.order_by("pk"))) for d, t in iterator] + assert iter_values[::-1] == [(devices[:2], tests[:2]), (devices[2:], tests[2:])] + + +@pytest.fixture +def apply_worker(): + test_results = CompTestResultFactory.build_batch(size=3) + executor = Mock(results_passed=5, results_count=10, return_value=test_results) + executor.log.messages = ["log1", "log2"] + device_test_gen = Mock(return_value=[(["device1"], ["test1"]), (["device2"], ["test2"])]) + job_extractor_factory = Mock() + job_extractor_factory.return_value.parent.job.result.slices = [None, {1: [1, 2, 3]}] + return ApplyWorker( + testresult_queryset=Mock(), + test_executor_cls=Mock(return_value=executor), + result_batch_size=100, + job_extractor_factory=job_extractor_factory, + device_test_gen=device_test_gen, + ) + + +@pytest.mark.django_db +def test_applyworker_success(full_runtests_params, apply_worker): + full_runtests_params.overriding_datasource = 10 + full_runtests_params.test_tags = [555] + device_test_gen = apply_worker.device_test_gen + executor = apply_worker.test_executor_cls.return_value + test_results = executor.return_value + result = apply_worker(params=full_runtests_params, worker_id=1) + assert result == ExecutionResult(test_stat=ResultRatio(passed=5, total=10), log=["log1", "log2"]) + device_test_gen.assert_called_once_with( + {1: [1, 2, 3]}, full_runtests_params.test_tags, full_runtests_params.overriding_datasource + ) + apply_worker.testresult_queryset.bulk_create.assert_called_once() + assert list(apply_worker.testresult_queryset.bulk_create.call_args.args[0]) == test_results * len( + device_test_gen.return_value + ) + assert executor.call_count == len(device_test_gen.return_value) + + +@dataclass +class MockLogger: + script_id: str + messages: list = field(default_factory=list, init=False) + + def log_exception(self, m): + self.messages.append(str(m)) + + +@pytest.mark.django_db +def test_applyworker_exception(full_runtests_params, apply_worker): + apply_worker.test_executor_cls = Mock(side_effect=ValueError("some error")) + apply_worker.logger_factory = MockLogger + result = apply_worker(params=full_runtests_params, worker_id=1) + assert result == ExecutionResult(test_stat=ResultRatio(passed=0, total=0), log=["some error"], errored=True) diff --git a/validity/tests/test_scripts/runtests/test_combine.py b/validity/tests/test_scripts/runtests/test_combine.py new file mode 100644 index 0000000..a0e15da --- /dev/null +++ b/validity/tests/test_scripts/runtests/test_combine.py @@ -0,0 +1,96 @@ +import datetime +from dataclasses import replace +from unittest.mock import Mock + +import pytest +from django.utils import timezone + +from validity.scripts.data_models import ExecutionResult, Message +from validity.scripts.data_models import TestResultRatio as ResultRatio +from validity.scripts.exceptions import AbortScript +from validity.scripts.runtests.combine import CombineWorker + + +@pytest.fixture +def worker(): + return CombineWorker( + testresult_queryset=Mock(), job_extractor_factory=Mock(), enqueue_func=Mock(), report_queryset=Mock() + ) + + +@pytest.fixture +def messages(): + time = datetime.datetime(2000, 1, 1) + return [Message(status="info", message=f"m-{i}", time=time) for i in range(5)] + + +@pytest.fixture +def job_extractor(messages): + extractor = Mock() + extractor.parents = [Mock(), Mock()] + extractor.parent.parent.job.result.log = messages[:1] + extractor.parents[0].job.result = ExecutionResult(test_stat=ResultRatio(2, 2), log=messages[1:3]) + extractor.parents[1].job.result = ExecutionResult(test_stat=ResultRatio(1, 5), log=messages[3:]) + return extractor + + +# the test itself does not require db access, +# but according to netbox4.0 strange behaviour reverse() finally causes it +@pytest.mark.django_db +def test_compose_logs(worker, messages, job_extractor): + logger = worker.log_factory() + time = messages[0].time + logs = worker.compose_logs(logger, job_extractor, report_id=10) + assert len(logs) == 6 + assert logs[:5] == messages + last_msg = replace(logs[-1], time=time) + assert last_msg == Message( + status="success", + message="Job succeeded. See [Compliance Report](/plugins/validity/reports/10/) for detailed statistics", + time=time, + ) + + +@pytest.mark.django_db +def test_call_abort(worker, full_runtests_params, job_extractor, monkeypatch): + job_extractor.parents[1].job.result.errored = True + monkeypatch.setattr(timezone, "now", lambda: datetime.datetime(2020, 1, 1)) + worker.job_extractor_factory = lambda: job_extractor + with pytest.raises(AbortScript): + worker(full_runtests_params) + job = full_runtests_params.get_job() + assert job.status == "errored" + assert job.data == { + "log": [ + {"message": "m-3", "status": "info", "time": "2000-01-01T00:00:00"}, + {"message": "m-4", "status": "info", "time": "2000-01-01T00:00:00"}, + {"message": "ApplyWorkerError", "status": "failure", "time": "2020-01-01T00:00:00"}, + {"message": "Database changes have been reverted", "status": "info", "time": "2020-01-01T00:00:00"}, + ], + "output": None, + } + assert job.error == "AbortScript('ApplyWorkerError')" + + +@pytest.mark.django_db(transaction=True, reset_sequences=True) +def test_successful_call(worker, full_runtests_params, job_extractor, monkeypatch, messages): + monkeypatch.setattr(timezone, "now", lambda: datetime.datetime(2020, 1, 1)) + job = full_runtests_params.get_job() + worker.job_extractor_factory = lambda: job_extractor + worker.report_queryset.get.return_value = job.object + worker(full_runtests_params) + job.refresh_from_db() + assert job.status == "completed" + assert job.data == { + "log": [ + *[m.serialized for m in messages], + { + "time": "2020-01-01T00:00:00", + "status": "success", + "message": "Job succeeded. See [Compliance Report](/plugins/validity/reports/1/) for detailed statistics", + }, + ], + "output": {"statistics": {"total": 7, "passed": 3}}, + } + assert job.error == "" + worker.enqueue_func.assert_called_once_with(job.object, full_runtests_params.request, "create") diff --git a/validity/tests/test_scripts/runtests/test_split.py b/validity/tests/test_scripts/runtests/test_split.py new file mode 100644 index 0000000..9a75725 --- /dev/null +++ b/validity/tests/test_scripts/runtests/test_split.py @@ -0,0 +1,112 @@ +import datetime +from unittest.mock import Mock + +import factory +import pytest +from django.db.models import Q +from django.db.models.signals import post_save +from django.utils import timezone +from factories import DataSourceFactory, DeviceFactory, RunTestsJobFactory, SelectorFactory, TenantFactory + +from validity.scripts.data_models import Message, SplitResult +from validity.scripts.runtests.split import SplitWorker + + +@pytest.fixture +@factory.django.mute_signals(post_save) +def selectors(db): + s1 = SelectorFactory(name_filter="g1-.*") + s2 = SelectorFactory(name_filter="g2-.*") + return s1, s2 + + +@pytest.fixture +@factory.django.mute_signals(post_save) +def devices(device_num): + for i in range(device_num // 2): + DeviceFactory(name=f"g1-{i}") + for i in range(device_num // 2, device_num): + DeviceFactory(name=f"g2-{i}") + + +@pytest.fixture +def split_worker(): + return SplitWorker() + + +@pytest.mark.parametrize( + "worker_num, device_num, expected_result", + [ + (1, 6, [{1: [1, 2, 3], 2: [4, 5, 6]}]), + (2, 6, [{1: [1, 2, 3]}, {2: [4, 5, 6]}]), + (3, 6, [{1: [1, 2]}, {1: [3], 2: [4]}, {2: [5, 6]}]), + (4, 6, [{1: [1], 2: [6]}, {1: [2]}, {1: [3]}, {2: [4]}, {2: [5]}]), + (2, 3, [{1: [1], 2: [3]}, {2: [2]}]), + (5, 9, [{1: [1], 2: [9]}, {1: [2]}, {1: [3]}, {1: [4]}, {2: [5]}, {2: [6]}, {2: [7]}, {2: [8]}]), + ], +) +@pytest.mark.django_db(transaction=True, reset_sequences=True) +@factory.django.mute_signals(post_save) +def test_distribute_work(split_worker, selectors, worker_num, runtests_params, expected_result, devices): + runtests_params.workers_num = worker_num + runtests_params.selectors = [s.pk for s in selectors] + result = split_worker.distribute_work( + runtests_params, split_worker.log_factory(), runtests_params.get_device_filter() + ) + assert result == expected_result + + +@pytest.mark.parametrize("overriding_datasource", [None, DataSourceFactory]) +@pytest.mark.django_db +def test_sync_datasources(create_custom_fields, overriding_datasource): + if overriding_datasource: + overriding_datasource = overriding_datasource() + ds1 = DataSourceFactory() + ds2 = DataSourceFactory() + DataSourceFactory() + DeviceFactory(name="d1", tenant=TenantFactory(custom_field_data={"data_source": ds1.pk})) + DeviceFactory(name="d2", tenant=TenantFactory(custom_field_data={"data_source": ds2.pk})) + DeviceFactory() + + worker = SplitWorker(datasource_sync_fn=Mock()) + overriding_pk = overriding_datasource.pk if overriding_datasource else None + worker.sync_datasources(overriding_datasource=overriding_pk, device_filter=Q(name__in=["d1", "d2"])) + worker.datasource_sync_fn.assert_called_once() + datasources, device_filter = worker.datasource_sync_fn.call_args.args + assert device_filter == Q(name__in=["d1", "d2"]) + expected_result = [overriding_datasource] if overriding_datasource else [ds1, ds2] + assert list(datasources) == expected_result + + +@pytest.mark.parametrize("device_num", [2]) +@pytest.mark.django_db(transaction=True, reset_sequences=True) +def test_call(selectors, devices, runtests_params, monkeypatch): + time = timezone.datetime(year=2000, month=1, day=1) + monkeypatch.setattr(timezone, "now", lambda: time) + job = RunTestsJobFactory() + runtests_params = runtests_params.with_job_info(job) + runtests_params.workers_num = 2 + runtests_params.sync_datasources = True + runtests_params.selectors = [s.pk for s in selectors] + worker = SplitWorker(datasource_sync_fn=Mock()) + result = worker(runtests_params) + assert result == SplitResult( + log=[ + Message( + status="info", + message="Running the tests for *2 devices*", + time=datetime.datetime(2000, 1, 1, 0, 0), + script_id=None, + ), + Message( + status="info", + message="Distributing the work among 2 workers. Each worker handles 1 device(s) in average", + time=datetime.datetime(2000, 1, 1, 0, 0), + script_id=None, + ), + ], + slices=[{1: [1]}, {2: [2]}], + ) + worker.datasource_sync_fn.assert_called_once() + job.refresh_from_db() + assert job.status == "running" diff --git a/validity/tests/test_scripts/test_data_models.py b/validity/tests/test_scripts/test_data_models.py new file mode 100644 index 0000000..07d9a81 --- /dev/null +++ b/validity/tests/test_scripts/test_data_models.py @@ -0,0 +1,83 @@ +import datetime +from zoneinfo import ZoneInfo + +import pytest +from dcim.models import Device +from django.db.models import Q +from factories import CompTestDBFactory, DeviceFactory, SelectorFactory, TagFactory + +from validity.scripts import data_models + + +def test_serialized_message(): + time = datetime.datetime(year=2000, month=2, day=3, hour=14, minute=10, second=35, tzinfo=ZoneInfo("UTC")) + msg = data_models.Message(status="info", message="hello", time=time) + assert msg.serialized == {"status": "info", "message": "hello", "time": "2000-02-03T14:10:35+00:00"} + + msg2 = data_models.Message(status="info", message="hello2", time=time, script_id="My Script") + assert msg2.serialized["message"] == "My Script, hello2" + + +def test_resultratio_sum(): + r1 = data_models.TestResultRatio(passed=1, total=5) + r2 = data_models.TestResultRatio(passed=3, total=3) + r3 = data_models.TestResultRatio(passed=0, total=1) + + assert r1 + r2 == data_models.TestResultRatio(4, 8) + assert r1 + r2 + r3 == data_models.TestResultRatio(4, 9) + + +class TestRunTestsParams: + @pytest.fixture + def selectors(self, db): + return [SelectorFactory() for _ in range(10)] + + def test_selector_qs(self, runtests_params, selectors): + assert list(runtests_params.selector_qs) == selectors + + runtests_params.selectors = [selectors[0].id, selectors[2].id] + assert list(runtests_params.selector_qs) == [selectors[0], selectors[2]] + + @pytest.mark.django_db(transaction=True, reset_sequences=True) + def test_selector_qs_with_tags(self, runtests_params, selectors): + tag1 = TagFactory() + tag2 = TagFactory() + test1 = CompTestDBFactory() + test1.tags.add(tag1) + test2 = CompTestDBFactory() + test2.tags.add(tag2) + + test1.selectors.set([selectors[0], selectors[1]]) + test2.selectors.set([selectors[1], selectors[2], selectors[3]]) + runtests_params.test_tags = [test1.pk, test2.pk] + assert list(runtests_params.selector_qs) == selectors[:4] + + runtests_params.selectors = [selectors[0].pk, selectors[6].pk] + + assert list(runtests_params.selector_qs) == [selectors[0]] + + @pytest.mark.django_db + def test_get_device_filter_empty_selectors(self, runtests_params): + assert runtests_params.get_device_filter() == Q(pk__in=[]) + + @pytest.mark.django_db + def test_get_device_filter_with_selectors(self, runtests_params, selectors): + selectors[0].name_filter = "g1-.*" + selectors[0].save() + selectors[1].name_filter = "g2-.*" + selectors[1].save() + d1 = DeviceFactory(name="g1-dev1") + DeviceFactory(name="g1-dev2") + DeviceFactory(name="g2-dev1") + d2 = DeviceFactory(name="some_device") + runtests_params.selectors = [selectors[0].pk, selectors[1].pk] + device_filter = runtests_params.get_device_filter() + assert {*Device.objects.filter(device_filter).values_list("name", flat=True)} == { + "g1-dev1", + "g1-dev2", + "g2-dev1", + } + + runtests_params.devices = [d1.pk, d2.pk] + device_filter = runtests_params.get_device_filter() + assert {*Device.objects.filter(device_filter).values_list("name", flat=True)} == {"g1-dev1"} diff --git a/validity/tests/test_scripts/test_launcher.py b/validity/tests/test_scripts/test_launcher.py new file mode 100644 index 0000000..dddf39d --- /dev/null +++ b/validity/tests/test_scripts/test_launcher.py @@ -0,0 +1,68 @@ +import uuid +from dataclasses import asdict +from unittest.mock import Mock + +import pytest +from core.models import Job +from django.utils import timezone +from factories import UserFactory + +from validity.models import ComplianceReport +from validity.scripts.data_models import RequestInfo, ScriptParams, Task +from validity.scripts.launch import Launcher + + +class ConcreteScriptParams(ScriptParams): + def with_job_info(self, job: Job): + return FullParams(**asdict(self) | {"job": job}) + + +class FullParams: + def __init__(self, **kwargs): + self.__dict__ = kwargs.copy() + + def __eq__(self, other): + return self.__dict__ == other.__dict__ + + +@pytest.fixture +def launcher(db): + report = ComplianceReport.objects.create() + return Launcher(job_name="test_launcher", job_object_factory=lambda: report, rq_queue=Mock(), tasks=[]) + + +@pytest.fixture +def params(db): + user = UserFactory() + return ConcreteScriptParams(request=RequestInfo(id=uuid.uuid4(), user_id=user.pk), workers_num=1) + + +@pytest.mark.parametrize("schedule_at", [None, timezone.now()]) +@pytest.mark.django_db +def test_launcher(launcher, params, schedule_at): + def task_func(): ... + + params.schedule_at = schedule_at + launcher.tasks = [Task(task_func, job_timeout=60)] + job = launcher(params) + assert isinstance(job, Job) and job.object == launcher.job_object_factory() + enqueue_fn = getattr(launcher.rq_queue, "enqueue_at" if schedule_at else "enqueue") + enqueue_fn.assert_called_once() + enqueue_kwargs = enqueue_fn.call_args.kwargs + assert enqueue_kwargs["job_id"] == str(job.job_id) + assert enqueue_kwargs["params"] == params.with_job_info(job) + assert enqueue_kwargs["f"] == task_func + assert enqueue_kwargs["job_timeout"] == 60 + assert enqueue_kwargs["depends_on"] is None + + +@pytest.mark.django_db +def test_multi_tasks(launcher, params): + def task_func_1(): ... + def task_func_2(): ... + + params.workers_num = 3 + launcher.tasks = [Task(task_func_1, job_timeout=10, multi_workers=True), Task(task_func_2, job_timeout=20)] + launcher(params) + assert launcher.rq_queue.enqueue.call_count == 4 + assert launcher.rq_queue.enqueue.call_args.kwargs["depends_on"] == [launcher.rq_queue.enqueue.return_value] * 3 diff --git a/validity/tests/test_scripts/test_logger.py b/validity/tests/test_scripts/test_logger.py new file mode 100644 index 0000000..4c6bb5e --- /dev/null +++ b/validity/tests/test_scripts/test_logger.py @@ -0,0 +1,36 @@ +import pytest + +from validity.scripts.logger import Logger + + +@pytest.fixture +def error_with_traceback(): + # moving this function up/down will cause the test to fail + try: + raise ValueError("error") + except ValueError as e: + return e + + +def test_logger(error_with_traceback): + logger = Logger() + logger.info("info-msg") + logger.failure("failure-msg") + logger.log_exception(error_with_traceback) + + assert len(logger.messages) == 3 + serialized_logs = [m.serialized for m in logger.messages] + for log in serialized_logs: + del log["time"] + assert serialized_logs == [ + {"status": "info", "message": "info-msg"}, + {"status": "failure", "message": "failure-msg"}, + { + "status": "failure", + "message": ( + "Unhandled error occured: `: error`\n```\n " + 'File "/plugin/validity/validity/tests/test_scripts/test_logger.py", ' + """line 10, in error_with_traceback\n raise ValueError("error")\n\n```""" + ), + }, + ] diff --git a/validity/tests/test_scripts/test_parent_jobs.py b/validity/tests/test_scripts/test_parent_jobs.py new file mode 100644 index 0000000..41f2fc4 --- /dev/null +++ b/validity/tests/test_scripts/test_parent_jobs.py @@ -0,0 +1,32 @@ +from unittest.mock import Mock + +import pytest + +from validity.scripts.parent_jobs import JobExtractor + + +class TestJobExtractor: + def test_job(self): + extractor = JobExtractor(_job=None) + with pytest.raises(ValueError): + extractor.job # noqa: B018 + extractor = JobExtractor(_job=Mock()) + assert extractor.job == extractor._job + + def test_parent(self): + extractor = JobExtractor(_job=Mock()) + assert extractor.parent.nesting_level == 1 + assert extractor.parent.parent.nesting_level == 2 + assert extractor.parent.job == extractor._job.dependency + + def test_parents(self): + extractor = JobExtractor(_job=Mock()) + extractor._job.fetch_dependencies.return_value = [10, 20, 30] + assert [p.job for p in extractor.parents] == [10, 20, 30] + assert extractor.parents[0].nesting_level == 1 + + def test_nesting_name(self): + extractor = JobExtractor(_job=Mock()) + assert extractor.nesting_name == "Current" + assert extractor.parent.nesting_name == "Parent" + assert extractor.parent.parent.nesting_name == "x2 Parent" diff --git a/validity/tests/test_scripts/test_run_tests.py b/validity/tests/test_scripts/test_run_tests.py deleted file mode 100644 index b241e27..0000000 --- a/validity/tests/test_scripts/test_run_tests.py +++ /dev/null @@ -1,199 +0,0 @@ -from collections import namedtuple -from unittest.mock import Mock -from uuid import uuid4 - -import pytest -from django.db.models import Q, QuerySet -from extras.scripts import Script -from factories import ( - CompTestDBFactory, - DataSourceFactory, - DeviceFactory, - NameSetDBFactory, - ReportFactory, - SelectorFactory, - TenantFactory, -) -from simpleeval import InvalidExpression - -from validity.compliance.exceptions import EvalError -from validity.models import ComplianceReport, ComplianceTestResult, VDevice -from validity.scripts import run_tests -from validity.scripts.run_tests import RunTestsScript as RunTestsMixin -from validity.utils.misc import null_request - - -class RunTestsScript(RunTestsMixin, Script): - pass - - -NS_1 = """ -__all__ = ["func1", "var", "func2"] - -def func1(var): pass - -var = 1234 - -def func2(var): pass -""" - -NS_2 = """ -from collections import Counter -import itertools - -__all__ = ["func3", "non_existing_func", "Counter", "itertools"] - -def func3(): pass -""" - -NS_3 = "some wrong syntax" - - -@pytest.mark.parametrize( - "nameset_texts, extracted_fn_names, warning_calls", - [ - pytest.param(["", ""], set(), 0, id="empty"), - pytest.param([NS_1], {"func1", "func2"}, 0, id="NS_1"), - pytest.param([NS_2], {"func3", "Counter"}, 0, id="NS_2"), - pytest.param([NS_1, NS_2], {"func1", "func2", "func3", "Counter"}, 0, id="NS_1, NS_2"), - pytest.param([NS_3], set(), 1, id="NS_3"), - pytest.param([NS_3, NS_1, NS_3], {"func1", "func2"}, 2, id="NS3, NS_1, NS_3"), - ], -) -@pytest.mark.django_db -def test_nameset_functions(nameset_texts, extracted_fn_names, warning_calls, mock_script_logging): - script = RunTestsScript() - namesets = [NameSetDBFactory(definitions=d) for d in nameset_texts] - functions = script.nameset_functions(namesets) - assert extracted_fn_names == functions.keys() - assert script.log_warning.call_count == warning_calls - for fn_name, fn in functions.items(): - assert fn_name == fn.__name__ - assert callable(fn) - - -FUNC = """ -__all__ = ['func'] -{} -""" - - -@pytest.mark.parametrize( - "definitions", - [ - pytest.param(FUNC.format("def func(): return max(1, 10)"), id="max"), - pytest.param(FUNC.format('def func(): return jq.first(".data", {"data": [1,2,3]})'), id="jq"), - ], -) -@pytest.mark.django_db -def test_builtins_are_available_in_nameset(definitions): - script = RunTestsScript() - namesets = [NameSetDBFactory(definitions=definitions)] - functions = script.nameset_functions(namesets) - functions["func"]() - - -@pytest.mark.parametrize( - "run_test_mock", - [ - Mock(return_value=(True, [("expla", "nation")])), - Mock(return_value=(False, [("1", "2"), ("3", "4")])), - Mock(side_effect=EvalError(orig_error=InvalidExpression())), - ], -) -def test_run_tests_for_device(mock_script_logging, run_test_mock, monkeypatch): - result_cls = namedtuple("MockResult", "passed explanation device test report dynamic_pair") - monkeypatch.setattr(run_tests, "ComplianceTestResult", result_cls) - script = RunTestsScript() - script._sleep_between_tests = 0 - monkeypatch.setattr(script, "run_test", run_test_mock) - tests = ["test1", "test2", "test3"] - device = Mock() - report = Mock() - results = list(script.run_tests_for_device(tests, device, report)) - assert len(results) == len(tests) - is_error = isinstance(run_test_mock.side_effect, Exception) - for test, result in zip(tests, results): - assert result.test == test - if is_error: - assert script.log_failure.call_count == len(tests) - assert result.passed is False - assert len(result.explanation) == 1 and result.explanation[0][1] is None - else: - assert result.passed == run_test_mock.return_value[0] - assert result.explanation == run_test_mock.return_value[1] - assert result.report == report - assert run_test_mock.call_count == len(tests) - - -def test_run_tests_for_selector(mock_script_logging, monkeypatch): - script = RunTestsScript() - devices = [Mock(name="device1"), Mock(name="device2")] - monkeypatch.setattr(script, "run_tests_for_device", Mock(return_value=range(3))) - monkeypatch.setattr(script, "get_device_qs", Mock(return_value=devices)) - selector = Mock() - report = Mock() - list(script.run_tests_for_selector(selector, report)) - assert script.run_tests_for_device.call_count == len(devices) - script.run_tests_for_device.assert_any_call(selector.tests.all(), devices[0], report) - script.run_tests_for_device.assert_any_call(selector.tests.all(), devices[1], report) - script.get_device_qs.assert_called_once_with(selector) - - -@pytest.mark.django_db -def test_webhook_without_ctx_is_not_fired(monkeypatch): - enq_obj = Mock() - monkeypatch.setattr(run_tests, "enqueue_object", enq_obj) - with null_request(): - ComplianceReport.objects.create() - enq_obj.assert_not_called() - - -@pytest.mark.django_db -def test_fire_report_webhook(monkeypatch): - enq_obj = Mock() - monkeypatch.setattr(run_tests, "enqueue_object", enq_obj) - script = RunTestsScript() - script.request = Mock(id=uuid4(), user=Mock(username="admin")) - report = ReportFactory() - script.fire_report_webhook(report.pk) - enq_obj.assert_called_once() - - -@pytest.mark.django_db -def test_datasources_to_sync(create_custom_fields): - script = RunTestsScript() - script.script_data = Mock() - assert script.datasources_to_sync() == [script.script_data.override_datasource.obj] - - device = DeviceFactory() - DeviceFactory() - datasource = DataSourceFactory() - DataSourceFactory() - device.tenant = TenantFactory(custom_field_data={"data_source": datasource.pk}) - device.save() - script.script_data = Mock(override_datasource=None, device_filter=Q(pk=device.pk)) - datasources_to_sync = script.datasources_to_sync() - assert isinstance(datasources_to_sync, QuerySet) - assert list(datasources_to_sync) == [datasource] - - -@pytest.mark.django_db -def test_full_run(monkeypatch): - DeviceFactory(name="device1") - DeviceFactory(name="device2") - selector = SelectorFactory(name_filter="device([0-9])", dynamic_pairs="NAME") - monkeypatch.setattr( - VDevice, - "config", - property(lambda self: {"key2": "somevalue"} if self.name == "device2" else {"key1": "somevalue"}), - ) - test = CompTestDBFactory( - expression='jq.first(".key1", device.config) == jq.first(".key2", device.dynamic_pair.config) != None' - ) - test.selectors.set([selector]) - script = RunTestsScript() - script.run(data={"make_report": False}, commit=True) - results = [*ComplianceTestResult.objects.order_by("device__name")] - assert len(results) == 2 - assert results[0].passed and not results[1].passed diff --git a/validity/tests/test_scripts/test_script_data.py b/validity/tests/test_scripts/test_script_data.py deleted file mode 100644 index 04e314d..0000000 --- a/validity/tests/test_scripts/test_script_data.py +++ /dev/null @@ -1,33 +0,0 @@ -import pytest -from django.db.models import Q -from factories import SelectorFactory - -from validity.models import ComplianceSelector -from validity.scripts.script_data import RunTestsScriptData - - -@pytest.fixture -def mkpatch_selector_filter(monkeypatch): - def get_filter(self): - nonlocal counter - counter += 1 - return Q(pk=counter) - - counter = 100 - monkeypatch.setattr(ComplianceSelector, "filter", property(get_filter)) - - -@pytest.mark.parametrize( - "selectors, devices, result", - [ - ([], [], Q()), - ([SelectorFactory, SelectorFactory], [], Q(pk=101) | Q(pk=102)), - ([], [10, 11, 12], Q(pk=10) | Q(pk=11) | Q(pk=12)), - ([SelectorFactory, SelectorFactory], [21, 22], (Q(pk=101) | Q(pk=102)) & (Q(pk=21) | Q(pk=22))), - ], -) -@pytest.mark.django_db -def test_get_filter(selectors, devices, result, mkpatch_selector_filter): - selectors = [s().pk for s in selectors] - script_data = RunTestsScriptData({"selectors": selectors, "devices": devices}) - assert script_data.device_filter == result diff --git a/validity/tests/test_views.py b/validity/tests/test_views.py index f972414..08406b1 100644 --- a/validity/tests/test_views.py +++ b/validity/tests/test_views.py @@ -1,5 +1,6 @@ import textwrap from http import HTTPStatus +from unittest.mock import Mock import pytest from base import ViewTest @@ -18,6 +19,7 @@ PlatformFactory, PollerFactory, ReportFactory, + RunTestsJobFactory, SelectorFactory, SerializerDBFactory, SerializerDSFactory, @@ -27,8 +29,9 @@ state_item, ) -from validity import models +from validity import dependencies, models from validity.compliance.state import State +from validity.scripts.data_models import RunTestsParams class TestDBNameSet(ViewTest): @@ -204,7 +207,34 @@ def test_datasource_devices(admin_client): assert resp.status_code == HTTPStatus.OK -@pytest.mark.django_db -def test_run_tests(admin_client, setup_runtests_script): - resp = admin_client.get("/plugins/validity/tests/run/", follow=True) +class TestRunTests: + url = "/plugins/validity/tests/run/" + + def test_get(self, admin_client): + resp = admin_client.get(self.url) + assert resp.status_code == HTTPStatus.OK + + @pytest.mark.parametrize( + "form_data, status_code, worker_count", + [ + ({}, HTTPStatus.FOUND, 1), + ({}, HTTPStatus.OK, 0), + ({"devices": [1, 2]}, HTTPStatus.OK, 1), # devices do not exist + ], + ) + def test_post(self, admin_client, di, form_data, status_code, worker_count): + launcher = Mock(**{"rq_queue.name": "queue_1", "return_value.pk": 1}) + with di.override( + {dependencies.runtests_launcher: lambda: launcher, dependencies.runtests_worker_count: lambda: worker_count} + ): + result = admin_client.post(self.url, form_data) + assert result.status_code == status_code + if status_code == HTTPStatus.FOUND: # if form is valid + launcher.assert_called_once() + assert isinstance(launcher.call_args.args[0], RunTestsParams) + + +def test_testresult(admin_client): + job = RunTestsJobFactory() + resp = admin_client.get(f"/plugins/validity/scripts/results/{job.pk}/") assert resp.status_code == HTTPStatus.OK diff --git a/validity/urls.py b/validity/urls.py index 44fcabb..405687b 100644 --- a/validity/urls.py +++ b/validity/urls.py @@ -10,7 +10,7 @@ path("selectors/delete/", views.ComplianceSelectorBulkDeleteView.as_view(), name="complianceselector_bulk_delete"), path("selectors//", include(get_model_urls("validity", "complianceselector"))), path("tests/", views.ComplianceTestListView.as_view(), name="compliancetest_list"), - path("tests/run/", views.run_tests, name="compliancetest_run"), + path("tests/run/", views.RunTestsView.as_view(), name="compliancetest_run"), path("tests/add/", views.ComplianceTestEditView.as_view(), name="compliancetest_add"), path("tests/delete/", views.ComplianceTestBulkDeleteView.as_view(), name="compliancetest_bulk_delete"), path("tests//", include(get_model_urls("validity", "compliancetest"))), @@ -26,6 +26,8 @@ path("namesets//", include(get_model_urls("validity", "nameset"))), path("reports/", views.ComplianceReportListView.as_view(), name="compliancereport_list"), path("reports//", include(get_model_urls("validity", "compliancereport"))), + # hack to display NetBox Job view without an error + path("reports//", views.ComplianceReportView.as_view(), name="compliancereport_jobs"), path("pollers/", views.PollerListView.as_view(), name="poller_list"), path("pollers/add/", views.PollerEditView.as_view(), name="poller_add"), path("pollers/delete/", views.PollerBulkDeleteView.as_view(), name="poller_bulk_delete"), @@ -34,4 +36,5 @@ path("commands/add/", views.CommandEditView.as_view(), name="command_add"), path("commands/delete/", views.CommandBulkDeleteView.as_view(), name="command_bulk_delete"), path("commands//", include(get_model_urls("validity", "command"))), + path("scripts/results//", views.ScriptResultView.as_view(), name="script_result"), ] diff --git a/validity/utils/orm.py b/validity/utils/orm.py index 5b20d75..e1e9bad 100644 --- a/validity/utils/orm.py +++ b/validity/utils/orm.py @@ -200,6 +200,10 @@ def model_to_proxy(model: Model, proxy_type: type[M]) -> M: @dataclass class CustomFieldBuilder: + """ + Facilitates CustomField creation procedure + """ + cf_model: type content_type_model: type db_alias: str = "" diff --git a/validity/views/__init__.py b/validity/views/__init__.py index ba69f1f..258582c 100644 --- a/validity/views/__init__.py +++ b/validity/views/__init__.py @@ -4,6 +4,7 @@ from .nameset import NameSetBulkDeleteView, NameSetDeleteView, NameSetEditView, NameSetListView, NameSetView from .poller import PollerBulkDeleteView, PollerDeleteView, PollerEditView, PollerListView, PollerView from .report import ComplianceReportListView, ComplianceReportView +from .script import RunTestsView, ScriptResultView from .selector import ( ComplianceSelectorBulkDeleteView, ComplianceSelectorDeleteView, @@ -24,6 +25,5 @@ ComplianceTestEditView, ComplianceTestListView, ComplianceTestView, - run_tests, ) from .test_result import ComplianceResultListView, ComplianceResultView diff --git a/validity/views/report.py b/validity/views/report.py index 3e75674..d557078 100644 --- a/validity/views/report.py +++ b/validity/views/report.py @@ -1,5 +1,5 @@ import functools -from typing import Any, Dict, Iterable, Iterator +from typing import Any, Iterable, Iterator from django.db.models.query import QuerySet from django.shortcuts import get_object_or_404 @@ -14,8 +14,15 @@ class ComplianceReportListView(generic.ObjectListView): - queryset = models.ComplianceReport.objects.annotate_result_stats().count_devices_and_tests().order_by("-created") + queryset = ( + models.ComplianceReport.objects.prefetch_related("jobs") + .annotate_result_stats() + .count_devices_and_tests() + .order_by("-created") + ) table = tables.ComplianceReportTable + filterset = filtersets.ComplianceReportFilterSet + filterset_form = forms.ComplianceReportFilerForm def get_table(self, data, request, bulk_actions=True): table = super().get_table(data, request, bulk_actions) @@ -30,7 +37,7 @@ class ComplianceReportView(generic.ObjectView): def get_table(self, groupby_qs): table = tables.ComplianceReportTable(data=groupby_qs) - table.exclude += ("id", "created", "test_count") + table.exclude += ("id", "created", "test_count", "job_status") return table def transform_groupby_qs(self, groupby_qs: Iterable[dict], groupby_field: DeviceGroupByChoices) -> Iterator[dict]: @@ -94,7 +101,7 @@ def get_table(self, **kwargs): table.configure(self.request) return table - def get_context_data(self, **kwargs: Any) -> Dict[str, Any]: + def get_context_data(self, **kwargs: Any) -> dict[str, Any]: return super().get_context_data(**kwargs) | { "object": self.object, "tab": self.tab, diff --git a/validity/views/script.py b/validity/views/script.py new file mode 100644 index 0000000..379ccd7 --- /dev/null +++ b/validity/views/script.py @@ -0,0 +1,90 @@ +from typing import Annotated, Any + +from core.models import Job +from django.contrib import messages +from django.forms import Form +from django.http import HttpResponse, HttpResponseRedirect +from django.shortcuts import render +from django.urls import reverse +from django.utils.translation import gettext_lazy as _ +from django.views.generic.edit import FormView +from netbox.views.generic import ObjectView +from netbox.views.generic.mixins import TableMixin + +from validity import di +from validity.forms import RunTestsForm +from validity.netbox_changes import htmx_partial +from validity.scripts.data_models import RunTestsParams, ScriptParams +from validity.scripts.launch import Launcher +from validity.tables import ScriptResultTable + + +class RunScriptView(FormView): + template_name = "validity/scripts/run.html" + redirect_viewname = "plugins:validity:script_result" + params_class: type[ScriptParams] + empty_form_values = ("", None) + + # these params must be injected into __init__ + launcher: Launcher + worker_count: int + + def get_params(self, form: Form): + form_data = {field: value for field, value in form.cleaned_data.items() if value not in self.empty_form_values} + return self.params_class(request=self.request, **form_data) + + def get_success_url(self, job_id: int) -> str: + return reverse(self.redirect_viewname, kwargs={"pk": job_id}) + + def form_valid(self, form: Form) -> HttpResponse: + if self.worker_count == 0: + messages.error( + self.request, + _('Unable to run script: no running RQ worker found for the queue "{}"').format( + self.launcher.rq_queue.name + ), + ) + return self.render_to_response(self.get_context_data()) + job = self.launcher(self.get_params(form)) + return HttpResponseRedirect(self.get_success_url(job.pk)) + + +class RunTestsView(RunScriptView): + params_class = RunTestsParams + form_class = RunTestsForm + + @di.inject + def __init__( + self, + launcher: Annotated[Launcher, "runtests_launcher"], + worker_count: Annotated[int, "runtests_worker_count"], + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.launcher = launcher + self.worker_count = worker_count + + +class ScriptResultView(TableMixin, ObjectView): + queryset = Job.objects.filter(object_type__model="compliancereport", object_type__app_label="validity") + table_class = ScriptResultTable + template_name = "validity/scripts/result.html" + htmx_template_name = "validity/scripts/result_htmx.html" + + def get_table(self, job, request, bulk_actions=False): + logs = [entry | {"index": i} for i, entry in enumerate(job.data["log"], start=1)] + table = self.table_class(logs, user=request.user) + table.configure(request) + return table + + def get(self, request, **kwargs): + job = self.get_object(**kwargs) + table = self.get_table(job, request) if job.completed else None + context = {"job": job, "table": table} + if htmx_partial(request): + response = render(request, self.htmx_template_name, context) + if job.completed or not job.started: + response.status_code = 286 # cancel HTMX polling + return response + + return render(request, self.template_name, context) diff --git a/validity/views/test.py b/validity/views/test.py index 671e955..61afdd3 100644 --- a/validity/views/test.py +++ b/validity/views/test.py @@ -1,9 +1,8 @@ from django.db.models import Count, Q -from django.shortcuts import get_object_or_404, redirect from netbox.views import generic from utilities.views import register_model_view -from validity import config, filtersets, forms, models, tables +from validity import filtersets, forms, models, tables from .base import TableMixin, TestResultBaseView @@ -49,12 +48,3 @@ class ComplianceTestBulkDeleteView(generic.BulkDeleteView): class ComplianceTestEditView(generic.ObjectEditView): queryset = models.ComplianceTest.objects.all() form = forms.ComplianceTestForm - - -def run_tests(request): - if config.netbox_version < "4.0.0": - return redirect("extras:script", module="validity_scripts", name="RunTests") - from extras.models import Script - - script = get_object_or_404(Script, name="RunTests", module__data_file__path="validity_scripts.py") - return redirect("extras:script", pk=script.pk)