Skip to content

Commit

Permalink
tests
Browse files Browse the repository at this point in the history
  • Loading branch information
amyasnikov committed Oct 22, 2024
1 parent dbac934 commit 905bdd3
Show file tree
Hide file tree
Showing 11 changed files with 109 additions and 15 deletions.
2 changes: 1 addition & 1 deletion requirements/base.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
deepdiff>=6.2.0,<7
dimi >=1.2.0,< 2
dimi >=1.3.0,< 2
django-bootstrap5 >=24.2,<25
dulwich # Core NetBox "optional" requirement
jq>=1.4.0,<2
Expand Down
1 change: 0 additions & 1 deletion validity/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ def pollers_info(custom_pollers: Annotated[list[PollerInfo], "validity_settings.
] + custom_pollers


import validity.choices # noqa
import validity.pollers.factory # noqa
from validity.scripts import ApplyWorker, CombineWorker, Launcher, SplitWorker, Task # noqa

Expand Down
2 changes: 1 addition & 1 deletion validity/forms/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from validity import di, models
from validity.choices import ExplanationVerbosityChoices
from validity.netbox_changes import FieldSet
from ..utils.misc import LazyIterator
from validity.utils.misc import LazyIterator
from .fields import DynamicModelChoicePropertyField, DynamicModelMultipleChoicePropertyField
from .mixins import PollerCleanMixin, SubformMixin
from .widgets import PrettyJSONWidget
Expand Down
2 changes: 1 addition & 1 deletion validity/pollers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .base import CustomPoller, Poller
from .base import BasePoller, CustomPoller
from .cli import NetmikoPoller
from .http import RequestsPoller
from .netconf import ScrapliNetconfPoller
4 changes: 2 additions & 2 deletions validity/pollers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from validity.models import Command, VDevice


class Poller(ABC):
class BasePoller(ABC):
host_param_name: str

def __init__(self, credentials: dict, commands: Collection["Command"]) -> None:
Expand All @@ -28,7 +28,7 @@ def get_credentials(self, device: "VDevice"):
return self.credentials | {self.host_param_name: str(ip.address.ip)}


class ThreadPoller(Poller):
class ThreadPoller(BasePoller):
"""
Polls devices one by one using threads
"""
Expand Down
6 changes: 3 additions & 3 deletions validity/pollers/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from validity import di
from validity.settings import PollerInfo
from validity.utils.misc import partialcls
from .base import Poller, ThreadPoller
from .base import BasePoller, ThreadPoller


if TYPE_CHECKING:
Expand All @@ -31,13 +31,13 @@ def __init__(self, pollers_info: Annotated[list[PollerInfo], "pollers_info"]):
class PollerFactory:
def __init__(
self,
poller_map: Annotated[dict[str, type[Poller]], "PollerChoices.classes"],
poller_map: Annotated[dict[str, type[BasePoller]], "PollerChoices.classes"],
max_threads: Annotated[int, "validity_settings.polling_threads"],
) -> None:
self.poller_map = poller_map
self.max_threads = max_threads

def __call__(self, connection_type: str, credentials: dict, commands: Sequence["Command"]) -> Poller:
def __call__(self, connection_type: str, credentials: dict, commands: Sequence["Command"]) -> BasePoller:
if poller_cls := self.poller_map.get(connection_type):
if issubclass(poller_cls, ThreadPoller):
poller_cls = partialcls(poller_cls, thread_workers=self.max_threads)
Expand Down
8 changes: 4 additions & 4 deletions validity/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pydantic import BaseModel, ConfigDict, Field, field_validator

from validity import di
from validity.pollers import Poller
from validity.pollers import BasePoller


class ScriptTimeouts(BaseModel):
Expand All @@ -19,18 +19,18 @@ class ScriptTimeouts(BaseModel):
class PollerInfo(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)

klass: type[Poller]
klass: type[BasePoller]
name: str = Field(pattern="[a-z_]+")
verbose_name: str = Field(default="", validate_default=True)
color: str = Field(pattern="[a-z-]+")
command_types: list[Literal["CLI", "netconf", "json_api", "custom"]]

@field_validator("verbose_name")
@classmethod
def validate_verbose_name(cls, value):
def validate_verbose_name(cls, value, info):
if value:
return value
return " ".join(part.title() for part in value.split("_"))
return " ".join(part.title() for part in info.data["name"].split("_"))


class ValiditySettings(BaseModel):
Expand Down
52 changes: 52 additions & 0 deletions validity/tests/test_custom_pollers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from http import HTTPStatus
from typing import Any
from unittest.mock import Mock

import pytest
from factories import CommandFactory, PollerFactory

from validity.dependencies import validity_settings
from validity.forms import PollerForm
from validity.models.polling import Command
from validity.pollers import CustomPoller
from validity.settings import PollerInfo, ValiditySettings


class MyCustomPoller(CustomPoller):
host_param_name = "ip_address"
driver_factory = Mock

def poll_one_command(self, driver: Any, command: Command) -> str:
return "output"


@pytest.fixture
def custom_poller(db, di):
settings = ValiditySettings(
custom_pollers=[PollerInfo(klass=MyCustomPoller, name="cupo", color="red", command_types=["custom"])]
)
override = di.override({validity_settings: lambda: settings})
override.__enter__()
yield PollerFactory(connection_type="cupo")
override.__exit__(None, None, None)


def test_custom_poller_model(custom_poller, di):
poller = PollerFactory(connection_type="cupo")
poller.commands.set([CommandFactory(type="custom")])
backend = poller.get_backend()
assert isinstance(backend, MyCustomPoller)
assert poller.get_connection_type_color() == "red"
poller.validate_commands(poller.commands.all(), di["PollerChoices"].command_types, poller.connection_type)


def test_custom_poller_api(custom_poller, admin_client):
resp = admin_client.get(f"/api/plugins/validity/pollers/{custom_poller.pk}/")
assert resp.status_code == HTTPStatus.OK
assert resp.json()["connection_type"] == "cupo"


def test_custom_poller_form(custom_poller):
form = PollerForm()
form_choices = {choice[0] for choice in form["connection_type"].field.choices}
assert "cupo" in form_choices
15 changes: 15 additions & 0 deletions validity/tests/test_models/test_poller.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import pytest
from factories import PollerFactory

from validity.pollers import NetmikoPoller, RequestsPoller, ScrapliNetconfPoller


@pytest.mark.parametrize(
"connection_type, poller_class",
[("netmiko", NetmikoPoller), ("requests", RequestsPoller), ("scrapli_netconf", ScrapliNetconfPoller)],
)
@pytest.mark.django_db
def test_get_backend(connection_type, poller_class):
poller = PollerFactory(connection_type=connection_type)
backend = poller.get_backend()
assert isinstance(backend, poller_class)
19 changes: 18 additions & 1 deletion validity/tests/test_pollers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@

import pytest

from validity.pollers import NetmikoPoller
from validity.pollers import NetmikoPoller, RequestsPoller
from validity.pollers.factory import PollerChoices
from validity.pollers.http import HttpDriver
from validity.settings import PollerInfo


class TestNetmikoPoller:
Expand Down Expand Up @@ -84,3 +86,18 @@ def test_http_driver():
auth=None,
)
assert result == requests.request.return_value.content.decode.return_value


def test_poller_choices():
poller_choices = PollerChoices(
pollers_info=[
PollerInfo(klass=NetmikoPoller, name="some_poller", color="red", command_types=["CLI"]),
PollerInfo(
klass=RequestsPoller, name="p2", verbose_name="P2", color="green", command_types=["json_api", "custom"]
),
]
)
assert poller_choices.choices == [("some_poller", "Some Poller"), ("p2", "P2")]
assert poller_choices.colors == {"some_poller": "red", "p2": "green"}
assert poller_choices.classes == {"some_poller": NetmikoPoller, "p2": RequestsPoller}
assert poller_choices.command_types == {"some_poller": ["CLI"], "p2": ["json_api", "custom"]}
13 changes: 12 additions & 1 deletion validity/tests/test_utils/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import pytest

from validity.utils.misc import log_exceptions, partialcls, reraise
from validity.utils.misc import LazyIterator, log_exceptions, partialcls, reraise
from validity.utils.version import NetboxVersion


Expand Down Expand Up @@ -89,3 +89,14 @@ def test_log_exceptions():
with log_exceptions(logger, "info", log_traceback=True):
raise ValueError("qwerty")
logger.info.assert_called_once_with(msg="qwerty", exc_info=True)


def test_lazy_iterator():
part1 = [10, 20, 30]
part2 = lambda: [40, 50] # noqa
part3 = Mock(return_value=[60])
part4 = (70,)
iterator = LazyIterator(part1, part2, part3, part4)
part3.assert_not_called()
assert list(iterator) == [10, 20, 30, 40, 50, 60, 70]
assert list(iterator) == [10, 20, 30, 40, 50, 60, 70] # checking iterator is not exhausted

0 comments on commit 905bdd3

Please sign in to comment.