Skip to content

feat: multi-plugins #144

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/dev-guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ specify which ``tool`` subtable it would be checking:


available_plugins = [
*plugins.list_from_entry_points(),
*plugins.list_plugins_from_entry_points(),
plugins.PluginWrapper("your-tool", your_plugin),
]
validator = api.Validator(available_plugins)
Expand Down
4 changes: 2 additions & 2 deletions src/validate_pyproject/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,9 @@ def __init__(
self._extra_validations = tuple(extra_validations)

if plugins is ALL_PLUGINS:
from .plugins import list_from_entry_points
from .plugins import list_plugins_from_entry_points

plugins = list_from_entry_points()
plugins = list_plugins_from_entry_points()

self._plugins = (*plugins, *extra_plugins)

Expand Down
3 changes: 1 addition & 2 deletions src/validate_pyproject/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@
from . import _tomllib as tomllib
from .api import Validator
from .errors import ValidationError
from .plugins import PluginWrapper
from .plugins import list_from_entry_points as list_plugins_from_entry_points
from .plugins import PluginWrapper, list_plugins_from_entry_points
from .remote import RemotePlugin, load_store

_logger = logging.getLogger(__package__)
Expand Down
44 changes: 29 additions & 15 deletions src/validate_pyproject/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@
else:
Protocol = object

ENTRYPOINT_GROUP = "validate_pyproject.tool_schema"


class PluginProtocol(Protocol):
@property
Expand All @@ -56,7 +54,7 @@ def fragment(self) -> str:

class PluginWrapper:
def __init__(self, tool: str, load_fn: "Plugin"):
self._tool = tool
self._tool, _, self._fragment = tool.partition("#")
self._load_fn = load_fn

@property
Expand All @@ -73,7 +71,7 @@ def schema(self) -> "Schema":

@property
def fragment(self) -> str:
return ""
return self._fragment

@property
def help_text(self) -> str:
Expand All @@ -90,12 +88,13 @@ def __repr__(self) -> str:
_: PluginProtocol = typing.cast(PluginWrapper, None)


def iterate_entry_points(group: str = ENTRYPOINT_GROUP) -> Iterable[EntryPoint]:
def iterate_entry_points(group: str) -> Iterable[EntryPoint]:
"""Produces a generator yielding an EntryPoint object for each plugin registered
via ``setuptools`` `entry point`_ mechanism.

This method can be used in conjunction with :obj:`load_from_entry_point` to filter
the plugins before actually loading them.
This method can be used in conjunction with :obj:`load_from_entry_point` to
filter the plugins before actually loading them. The entry points are not
deduplicated, but they are sorted.
"""
entries = entry_points()
if hasattr(entries, "select"): # pragma: no cover
Expand All @@ -110,8 +109,7 @@ def iterate_entry_points(group: str = ENTRYPOINT_GROUP) -> Iterable[EntryPoint]:
# TODO: Once Python 3.10 becomes the oldest version supported, this fallback and
# conditional statement can be removed.
entries_ = (plugin for plugin in entries.get(group, []))
deduplicated = {e.name: e for e in sorted(entries_, key=lambda e: e.name)}
return list(deduplicated.values())
return sorted(entries_, key=lambda e: e.name)


def load_from_entry_point(entry_point: EntryPoint) -> PluginWrapper:
Expand All @@ -123,23 +121,39 @@ def load_from_entry_point(entry_point: EntryPoint) -> PluginWrapper:
raise ErrorLoadingPlugin(entry_point=entry_point) from ex


def list_from_entry_points(
group: str = ENTRYPOINT_GROUP,
def load_multi_entry_point(entry_point: EntryPoint) -> List[PluginWrapper]:
"""Carefully load the plugin, raising a meaningful message in case of errors"""
try:
dict_plugins = entry_point.load()
return [PluginWrapper(k, v) for k, v in dict_plugins().items()]
except Exception as ex:
raise ErrorLoadingPlugin(entry_point=entry_point) from ex


def list_plugins_from_entry_points(
filtering: Callable[[EntryPoint], bool] = lambda _: True,
) -> List[PluginWrapper]:
"""Produces a list of plugin objects for each plugin registered
via ``setuptools`` `entry point`_ mechanism.

Args:
group: name of the setuptools' entry point group where plugins is being
registered
filtering: function returning a boolean deciding if the entry point should be
loaded and included (or not) in the final list. A ``True`` return means the
plugin should be included.
"""
return [
load_from_entry_point(e) for e in iterate_entry_points(group) if filtering(e)
eps = [
load_from_entry_point(e)
for e in iterate_entry_points("validate_pyproject.tool_schema")
if filtering(e)
]
eps += [
ep
for e in iterate_entry_points("validate_pyproject.multi_schema")
for ep in load_multi_entry_point(e)
if filtering(e)
]
dedup = {e.tool: e for e in sorted(eps, key=lambda e: e.tool)}
return list(dedup.values())


class ErrorLoadingPlugin(RuntimeError):
Expand Down
3 changes: 1 addition & 2 deletions src/validate_pyproject/pre_compile/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
from typing import Any, Dict, List, Mapping, NamedTuple, Sequence

from .. import cli
from ..plugins import PluginWrapper
from ..plugins import list_from_entry_points as list_plugins_from_entry_points
from ..plugins import PluginWrapper, list_plugins_from_entry_points
from ..remote import RemotePlugin, load_store
from . import pre_compile

Expand Down
2 changes: 1 addition & 1 deletion src/validate_pyproject/repo_review.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def repo_review_checks() -> Dict[str, VPP001]:

def repo_review_families(pyproject: Dict[str, Any]) -> Dict[str, Dict[str, str]]:
has_distutils = "distutils" in pyproject.get("tool", {})
plugin_names = (ep.name for ep in plugins.iterate_entry_points())
plugin_names = (ep.tool for ep in plugins.list_plugins_from_entry_points())
plugin_list = (
f"`[tool.{n}]`" for n in plugin_names if n != "distutils" or has_distutils
)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_load_plugin():

class TestRegistry:
def test_with_plugins(self):
plg = plugins.list_from_entry_points()
plg = plugins.list_plugins_from_entry_points()
registry = api.SchemaRegistry(plg)
main_schema = registry[registry.main]
project = main_schema["properties"]["project"]
Expand Down Expand Up @@ -112,7 +112,7 @@ def test_invalid(self):
# ---

def plugin(self, tool):
plg = plugins.list_from_entry_points(filtering=lambda e: e.name == tool)
plg = plugins.list_plugins_from_entry_points(filtering=lambda e: e.name == tool)
return plg[0]

TOOLS = ("distutils", "setuptools")
Expand Down
2 changes: 1 addition & 1 deletion tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_custom_plugins(self, capsys):


def parse_args(args):
plg = plugins.list_from_entry_points()
plg = plugins.list_plugins_from_entry_points()
return cli.parse_args(args, plg)


Expand Down
47 changes: 39 additions & 8 deletions tests/test_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
# The original PyScaffold license can be found in 'NOTICE.txt'

import sys
from types import ModuleType
from typing import Any, List

import pytest

from validate_pyproject import plugins
from validate_pyproject.plugins import ENTRYPOINT_GROUP, ErrorLoadingPlugin
from validate_pyproject.plugins import ErrorLoadingPlugin

EXISTING = (
"setuptools",
Expand All @@ -17,16 +19,16 @@

if sys.version_info[:2] >= (3, 8):
# TODO: Import directly (no need for conditional) when `python_requires = >= 3.8`
from importlib.metadata import EntryPoint # pragma: no cover
from importlib import metadata # pragma: no cover
else:
from importlib_metadata import EntryPoint # pragma: no cover
import importlib_metadata as metadata # pragma: no cover


def test_load_from_entry_point__error():
# This module does not exist, so Python will have some trouble loading it
# EntryPoint(name, value, group)
# metadata.EntryPoint(name, value, group)
entry = "mypkg.SOOOOO___fake___:activate"
fake = EntryPoint("fake", entry, ENTRYPOINT_GROUP)
fake = metadata.EntryPoint("fake", entry, "validate_pyproject.tool_schema")
with pytest.raises(ErrorLoadingPlugin):
plugins.load_from_entry_point(fake)

Expand All @@ -36,7 +38,7 @@ def is_entry_point(ep):


def test_iterate_entry_points():
plugin_iter = plugins.iterate_entry_points()
plugin_iter = plugins.iterate_entry_points("validate_pyproject.tool_schema")
assert hasattr(plugin_iter, "__iter__")
pluging_list = list(plugin_iter)
assert all(is_entry_point(e) for e in pluging_list)
Expand All @@ -47,14 +49,14 @@ def test_iterate_entry_points():

def test_list_from_entry_points():
# Should return a list with all the plugins registered in the entrypoints
pluging_list = plugins.list_from_entry_points()
pluging_list = plugins.list_plugins_from_entry_points()
orig_len = len(pluging_list)
plugin_names = " ".join(e.tool for e in pluging_list)
for example in EXISTING:
assert example in plugin_names

# a filtering function can be passed to avoid loading plugins that are not needed
pluging_list = plugins.list_from_entry_points(
pluging_list = plugins.list_plugins_from_entry_points(
filtering=lambda e: e.name != "setuptools"
)
plugin_names = " ".join(e.tool for e in pluging_list)
Expand All @@ -76,3 +78,32 @@ def _fn2(_):

pw = plugins.PluginWrapper("name", _fn2)
assert pw.help_text == "Help for `name`"


def loader(name: str) -> Any:
return {"example": "thing"}


def dynamic_ep():
return {"some#fragment": loader}


class Select(list):
def select(self, group: str) -> List[str]:
return list(self) if group == "validate_pyproject.multi_schema" else []


def test_process_checks(monkeypatch: pytest.MonkeyPatch) -> None:
ep = metadata.EntryPoint(
name="_",
group="validate_pyproject.multi_schema",
value="test_module:dynamic_ep",
)
sys.modules["test_module"] = ModuleType("test_module")
sys.modules["test_module"].dynamic_ep = dynamic_ep # type: ignore[attr-defined]
sys.modules["test_module"].loader = loader # type: ignore[attr-defined]
monkeypatch.setattr(plugins, "entry_points", lambda: Select([ep]))
eps = plugins.list_plugins_from_entry_points()
(ep,) = eps
assert ep.tool == "some"
assert ep.fragment == "fragment"