Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix partial class #116

Merged
merged 3 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions validity/models/test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import ast
from functools import partial
from typing import Any, Callable

from django.core.exceptions import ValidationError
Expand All @@ -9,6 +8,7 @@
from validity.choices import SeverityChoices
from validity.compliance.eval import ExplanationalEval
from validity.managers import ComplianceTestQS
from validity.utils.misc import partialcls
from .base import BaseModel, DataSourceMixin


Expand All @@ -23,7 +23,7 @@ class ComplianceTest(DataSourceMixin, BaseModel):

clone_fields = ("expression", "selectors", "severity", "data_source", "data_file")
text_db_field_name = "expression"
evaluator_cls = partial(ExplanationalEval, load_defaults=True)
evaluator_cls = partialcls(ExplanationalEval, load_defaults=True)

objects = ComplianceTestQS.as_manager()

Expand Down
4 changes: 2 additions & 2 deletions validity/pollers/factory.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from functools import partial
from typing import TYPE_CHECKING, Annotated, Sequence

from dimi import Singleton

from validity import di
from validity.utils.misc import partialcls
from .base import DevicePoller, ThreadPoller


Expand All @@ -24,6 +24,6 @@ def __init__(
def __call__(self, connection_type: str, credentials: dict, commands: Sequence["Command"]) -> DevicePoller:
if poller_cls := self.poller_map.get(connection_type):
if issubclass(poller_cls, ThreadPoller):
poller_cls = partial(poller_cls, thread_workers=self.max_threads)
poller_cls = partialcls(poller_cls, thread_workers=self.max_threads)
return poller_cls(credentials=credentials, commands=commands)
raise KeyError("No poller exists for this connection type", connection_type)
19 changes: 12 additions & 7 deletions validity/tables.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import datetime
import itertools
from functools import partial
from functools import partialmethod

from dcim.models import Device
from dcim.tables import DeviceTable
Expand Down Expand Up @@ -276,16 +276,21 @@ def get_paginate_by(self, request, max_paginate_by) -> int:
except (KeyError, ValueError):
return max_paginate_by // 2

def configure(self, request, max_paginate_by=None, orphans=None):
def get_page_lengths(self):
return (max_paginate_by // 2, max_paginate_by)
def get_paginator_class(self, max_paginate_by, orphans):
return type(
"CustomPaginator",
(EnhancedPaginator,),
{
"get_page_lengths": lambda self: (max_paginate_by // 2, max_paginate_by),
"__init__": partialmethod(EnhancedPaginator.__init__, orphans=orphans),
},
)

def configure(self, request, max_paginate_by=None, orphans=None):
super().configure(request)
if max_paginate_by and orphans:
paginator_class = type("CustomPaginator", (EnhancedPaginator,), {"get_page_lengths": get_page_lengths})
paginator_class = partial(paginator_class, orphans=orphans)
paginate_by = self.get_paginate_by(request, max_paginate_by)
paginate = {"paginator_class": paginator_class, "per_page": paginate_by}
paginate = {"paginator_class": self.get_paginator_class(max_paginate_by, orphans), "per_page": paginate_by}
RequestConfig(request, paginate).configure(self)


Expand Down
15 changes: 14 additions & 1 deletion validity/tests/test_utils/test_misc.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import operator
from contextlib import nullcontext
from dataclasses import dataclass

import pytest

from validity.utils.misc import reraise
from validity.utils.misc import partialcls, reraise
from validity.utils.version import NetboxVersion


Expand Down Expand Up @@ -64,3 +65,15 @@ def test_netbox_version(obj1, obj2, compare_results):
operators = [operator.lt, operator.le, operator.eq, operator.ge, operator.gt]
for op, expected_result in zip(operators, compare_results):
assert op(obj1, obj2) is expected_result


def test_partialcls():
@dataclass
class A:
a: int
b: int

A2 = partialcls(A, b=10)
assert A2(5) == A(5, 10)
assert A2(a=3, b=4) == A(3, 4)
assert type(A2(1)) is A
14 changes: 14 additions & 0 deletions validity/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,17 @@ def batched(iterable: Iterable, n: int, container: type = list):
if not batch:
return
yield batch


def partialcls(cls, *args, **kwargs):
"""
Returns partial class with args and kwargs applied to __init__.
All original class attributes are preserved. When called, returns original class instance
"""

def __new__(_, *new_args, **new_kwargs):
new_args = args + new_args
new_kwargs = kwargs | new_kwargs
return cls(*new_args, **new_kwargs)

return type(cls.__name__, (cls,), {"__new__": __new__})
4 changes: 2 additions & 2 deletions validity/views/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from functools import partial
from typing import Any, Dict

from django.db.models import Model
Expand All @@ -11,6 +10,7 @@
from utilities.views import ViewTab

from validity import filtersets, forms, models, tables
from validity.utils.misc import partialcls


class ObjectPermissionRequiredMixin(_ObjectPermissionRequiredMixin):
Expand Down Expand Up @@ -79,7 +79,7 @@ class TestResultBaseView(ObjectPermissionRequiredMixin, SingleTableMixin, Filter
tab = ViewTab("Test Results", badge=lambda obj: obj.results.count())
model = models.ComplianceTestResult
filterset_class = filtersets.ComplianceTestResultFilterSet
filterform_class = partial(forms.TestResultFilterForm, add_m2m_placeholder=True)
filterform_class = partialcls(forms.TestResultFilterForm, add_m2m_placeholder=True)
table_class = tables.ComplianceResultTable
permission_required = "validity.view_compliancetestresult"
queryset = models.ComplianceTestResult.objects.select_related("test", "device")
Expand Down
Loading