Skip to content

feat: multi-plugins with extra schemas #231

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Mar 13, 2025
36 changes: 36 additions & 0 deletions docs/dev-guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,42 @@ Also notice plugins are activated in a specific order, using Python's built-in
``sorted`` function.


Providing multiple schemas
--------------------------

A second system is provided for providing multiple schemas in a single plugin.
This is useful when a single plugin is responsible for multiple subtables
under the ``tool`` table, or if you need to provide multiple schemas for a
a single subtable.

To use this system, the plugin function, which does not take any arguments,
should return a dictionary with two keys: ``tools``, which is a dictionary of
tool names to schemas, and optionally ``schemas``, which is a list of schemas
that are not associated with any specific tool, but are loaded via ref's from
the other tools.

When using a :pep:`621`-compliant backend, the following can be add to your
``pyproject.toml`` file:

.. code-block:: toml

# in pyproject.toml
[project.entry-points."validate_pyproject.validate_pyproject.multi_schema"]
arbitrary = "your_package.your_module:your_plugin"

An example of the plugin structure needed for this system is shown below:

.. code-block:: python

def your_plugin(tool_name: str) -> dict:
return {
"tools": {"my-tool": my_schema},
"schemas": [my_extra_schema],
}

Fragments for schemas are also supported with this system; use ``#`` to split
the tool name and fragment path in the dictionary key.

.. _entry-point: https://setuptools.pypa.io/en/stable/userguide/entry_point.html#entry-points
.. _JSON Schema: https://json-schema.org/
.. _Python package: https://packaging.python.org/
Expand Down
21 changes: 12 additions & 9 deletions src/validate_pyproject/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from . import _tomllib as tomllib
from .api import Validator
from .errors import ValidationError
from .plugins import PluginWrapper
from .plugins import PluginProtocol, PluginWrapper
from .plugins import list_from_entry_points as list_plugins_from_entry_points
from .remote import RemotePlugin, load_store

Expand Down Expand Up @@ -124,7 +124,7 @@ class CliParams(NamedTuple):
dump_json: bool = False


def __meta__(plugins: Sequence[PluginWrapper]) -> Dict[str, dict]:
def __meta__(plugins: Sequence[PluginProtocol]) -> Dict[str, dict]:
"""'Hyper parameters' to instruct :mod:`argparse` how to create the CLI"""
meta = {k: v.copy() for k, v in META.items()}
meta["enable"]["choices"] = {p.tool for p in plugins}
Expand All @@ -135,9 +135,9 @@ def __meta__(plugins: Sequence[PluginWrapper]) -> Dict[str, dict]:
@critical_logging()
def parse_args(
args: Sequence[str],
plugins: Sequence[PluginWrapper],
plugins: Sequence[PluginProtocol],
description: str = "Validate a given TOML file",
get_parser_spec: Callable[[Sequence[PluginWrapper]], Dict[str, dict]] = __meta__,
get_parser_spec: Callable[[Sequence[PluginProtocol]], Dict[str, dict]] = __meta__,
params_class: Type[T] = CliParams, # type: ignore[assignment]
) -> T:
"""Parse command line parameters
Expand Down Expand Up @@ -167,11 +167,14 @@ def parse_args(
return params_class(**params) # type: ignore[call-overload, no-any-return]


Plugins = TypeVar("Plugins", bound=PluginProtocol)


def select_plugins(
plugins: Sequence[PluginWrapper],
plugins: Sequence[Plugins],
enabled: Sequence[str] = (),
disabled: Sequence[str] = (),
) -> List[PluginWrapper]:
) -> List[Plugins]:
available = list(plugins)
if enabled:
available = [p for p in available if p.tool in enabled]
Expand Down Expand Up @@ -219,7 +222,7 @@ def run(args: Sequence[str] = ()) -> int:
(for example ``["--verbose", "setup.cfg"]``).
"""
args = args or sys.argv[1:]
plugins: List[PluginWrapper] = list_plugins_from_entry_points()
plugins = list_plugins_from_entry_points()
params: CliParams = parse_args(args, plugins)
setup_logging(params.loglevel)
tool_plugins = [RemotePlugin.from_str(t) for t in params.tool]
Expand Down Expand Up @@ -263,7 +266,7 @@ def _split_lines(self, text: str, width: int) -> List[str]:
return list(chain.from_iterable(wrap(x, width) for x in text.splitlines()))


def plugins_help(plugins: Sequence[PluginWrapper]) -> str:
def plugins_help(plugins: Sequence[PluginProtocol]) -> str:
return "\n".join(_format_plugin_help(p) for p in plugins)


Expand All @@ -273,7 +276,7 @@ def _flatten_str(text: str) -> str:
return (text[0].lower() + text[1:]).strip()


def _format_plugin_help(plugin: PluginWrapper) -> str:
def _format_plugin_help(plugin: PluginProtocol) -> str:
help_text = plugin.help_text
help_text = f": {_flatten_str(help_text)}" if help_text else ""
return f"* {plugin.tool!r}{help_text}"
Expand Down
77 changes: 62 additions & 15 deletions src/validate_pyproject/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,11 @@
from importlib.metadata import EntryPoint, entry_points
from string import Template
from textwrap import dedent
from typing import Any, Callable, Iterable, List, Optional, Protocol
from typing import Any, Callable, Generator, Iterable, List, Optional, Protocol, Union

from .. import __version__
from ..types import Plugin, Schema

ENTRYPOINT_GROUP = "validate_pyproject.tool_schema"


class PluginProtocol(Protocol):
@property
Expand Down Expand Up @@ -66,16 +64,49 @@ def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.tool!r}, {self.id})"


class StoredPlugin:
def __init__(self, tool: str, schema: Schema):
self._tool, _, self._fragment = tool.partition("#")
self._schema = schema

@property
def id(self) -> str:
return self.schema.get("id", "MISSING ID")

@property
def tool(self) -> str:
return self._tool

@property
def schema(self) -> Schema:
return self._schema

@property
def fragment(self) -> str:
return self._fragment

@property
def help_text(self) -> str:
return self.schema.get("description", "")

def __repr__(self) -> str:
args = [repr(self.tool), self.id]
if self.fragment:
args.append(f"fragment={self.fragment!r}")
return f"{self.__class__.__name__}({', '.join(args)}, <schema: {self.id}>)"


if typing.TYPE_CHECKING:
_: PluginProtocol = typing.cast(PluginWrapper, None)


def iterate_entry_points(group: str = ENTRYPOINT_GROUP) -> Iterable[EntryPoint]:
def iterate_entry_points(group: str) -> Iterable[EntryPoint]:
"""Produces a generator yielding an EntryPoint object for each plugin registered
via ``setuptools`` `entry point`_ mechanism.

This method can be used in conjunction with :obj:`load_from_entry_point` to filter
the plugins before actually loading them.
the plugins before actually loading them. The entry points are not
deduplicated, but they are sorted.
"""
entries = entry_points()
if hasattr(entries, "select"): # pragma: no cover
Expand All @@ -90,10 +121,7 @@ def iterate_entry_points(group: str = ENTRYPOINT_GROUP) -> Iterable[EntryPoint]:
# TODO: Once Python 3.10 becomes the oldest version supported, this fallback and
# conditional statement can be removed.
entries_ = (plugin for plugin in entries.get(group, []))
deduplicated = {
e.name: e for e in sorted(entries_, key=lambda e: (e.name, e.value))
}
return list(deduplicated.values())
return sorted(entries_, key=lambda e: e.name)


def load_from_entry_point(entry_point: EntryPoint) -> PluginWrapper:
Expand All @@ -105,23 +133,42 @@ def load_from_entry_point(entry_point: EntryPoint) -> PluginWrapper:
raise ErrorLoadingPlugin(entry_point=entry_point) from ex


def load_from_multi_entry_point(
entry_point: EntryPoint,
) -> Generator[StoredPlugin, None, None]:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find it easier to read Iterator[T] instead of Generator[T, None, None] and understand what the function does (we don't have to remember which argument order for the return, yield and send).

Since it does not affect too much the typechecking, I tend to stick with the simpler approach. Is there a benefit in using Generator?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a habit I have due to a problem in typing: things like contextlib.contextmanager takes an Iterator because people usually type these as iterators. This is incorrect, as contextmanager calls .close, which is part of the Generator protocol, not the Iterator protocol. So mypy can't find a mistake like passing an Iterator instead of a Generator to things like contextmanager. The reason they gave for not fixing it was that everyone types these as Iterator. So I never type it as Iterator. :) Python 3.13 also added single argument Generator[T], largely to help with this I believe.

"""Carefully load the plugin, raising a meaningful message in case of errors"""
try:
fn = entry_point.load()
output = fn()
except Exception as ex:
raise ErrorLoadingPlugin(entry_point=entry_point) from ex

for tool, schema in output.get("tools", {}).items():
yield StoredPlugin(tool, schema)
for schema in output.get("schemas", []):
yield StoredPlugin("", schema)


def list_from_entry_points(
group: str = ENTRYPOINT_GROUP,
filtering: Callable[[EntryPoint], bool] = lambda _: True,
) -> List[PluginWrapper]:
) -> List[Union[PluginWrapper, StoredPlugin]]:
"""Produces a list of plugin objects for each plugin registered
via ``setuptools`` `entry point`_ mechanism.

Args:
group: name of the setuptools' entry point group where plugins is being
registered
filtering: function returning a boolean deciding if the entry point should be
loaded and included (or not) in the final list. A ``True`` return means the
plugin should be included.
"""
return [
load_from_entry_point(e) for e in iterate_entry_points(group) if filtering(e)
eps: List[Union[PluginWrapper, StoredPlugin]] = [
load_from_entry_point(e)
for e in iterate_entry_points("validate_pyproject.tool_schema")
if filtering(e)
]
for e in iterate_entry_points("validate_pyproject.multi_schema"):
eps.extend(load_from_multi_entry_point(e))
dedup = {(e.tool if e.tool else e.id): e for e in sorted(eps, key=lambda e: e.tool)}
return list(dedup.values())


class ErrorLoadingPlugin(RuntimeError):
Expand Down
4 changes: 3 additions & 1 deletion src/validate_pyproject/pre_compile/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ class CliParams(NamedTuple):
store: str = ""


def parser_spec(plugins: Sequence[PluginWrapper]) -> Dict[str, dict]:
def parser_spec(
plugins: Sequence[PluginProtocol],
) -> Dict[str, dict]:
common = ("version", "enable", "disable", "verbose", "very_verbose")
cli_spec = cli.__meta__(plugins)
meta = {k: v.copy() for k, v in META.items()}
Expand Down
8 changes: 4 additions & 4 deletions src/validate_pyproject/repo_review.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ def repo_review_checks() -> Dict[str, VPP001]:

def repo_review_families(pyproject: Dict[str, Any]) -> Dict[str, Dict[str, str]]:
has_distutils = "distutils" in pyproject.get("tool", {})
plugin_names = (ep.name for ep in plugins.iterate_entry_points())
plugin_list = (
f"`[tool.{n}]`" for n in plugin_names if n != "distutils" or has_distutils
plugin_list = plugins.list_from_entry_points(
lambda e: e.name != "distutils" or has_distutils
)
descr = f"Checks `[build-system]`, `[project]`, {', '.join(plugin_list)}"
plugin_names = (f"`[tool.{n.tool}]`" for n in plugin_list if n.tool)
descr = f"Checks `[build-system]`, `[project]`, {', '.join(plugin_names)}"
return {"validate-pyproject": {"name": "Validate-PyProject", "description": descr}}
75 changes: 71 additions & 4 deletions tests/test_plugins.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
# The code in this module is mostly borrowed/adapted from PyScaffold and was originally
# published under the MIT license
# The original PyScaffold license can be found in 'NOTICE.txt'
from importlib.metadata import EntryPoint # pragma: no cover
import importlib.metadata
import sys
from types import ModuleType
from typing import List

import pytest

from validate_pyproject import plugins
from validate_pyproject.plugins import ENTRYPOINT_GROUP, ErrorLoadingPlugin
from validate_pyproject.plugins import ErrorLoadingPlugin

EXISTING = (
"setuptools",
Expand All @@ -18,7 +21,9 @@ def test_load_from_entry_point__error():
# This module does not exist, so Python will have some trouble loading it
# EntryPoint(name, value, group)
entry = "mypkg.SOOOOO___fake___:activate"
fake = EntryPoint("fake", entry, ENTRYPOINT_GROUP)
fake = importlib.metadata.EntryPoint(
"fake", entry, "validate_pyproject.tool_schema"
)
with pytest.raises(ErrorLoadingPlugin):
plugins.load_from_entry_point(fake)

Expand All @@ -28,7 +33,7 @@ def is_entry_point(ep):


def test_iterate_entry_points():
plugin_iter = plugins.iterate_entry_points()
plugin_iter = plugins.iterate_entry_points("validate_pyproject.tool_schema")
assert hasattr(plugin_iter, "__iter__")
pluging_list = list(plugin_iter)
assert all(is_entry_point(e) for e in pluging_list)
Expand Down Expand Up @@ -68,3 +73,65 @@ def _fn2(_):

pw = plugins.PluginWrapper("name", _fn2)
assert pw.help_text == "Help for `name`"


class TestStoredPlugin:
def test_empty_help_text(self):
def _fn1(_):
return {}

pw = plugins.StoredPlugin("name", {})
assert pw.help_text == ""

def _fn2(_):
"""Help for `${tool}`"""
return {}

pw = plugins.StoredPlugin("name", {"description": "Help for me"})
assert pw.help_text == "Help for me"


def fake_multi_iterate_entry_points(name: str) -> List[importlib.metadata.EntryPoint]:
if name == "validate_pyproject.multi_schema":
return [
importlib.metadata.EntryPoint(
name="_", value="test_module:f", group="validate_pyproject.multi_schema"
)
]
return []


def test_multi_plugins(monkeypatch):
s1 = {"id": "example1"}
s2 = {"id": "example2"}
s3 = {"id": "example3"}
sys.modules["test_module"] = ModuleType("test_module")
sys.modules["test_module"].f = lambda: {
"tools": {"example#frag": s1},
"schemas": [s2, s3],
} # type: ignore[attr-defined]
monkeypatch.setattr(
plugins, "iterate_entry_points", fake_multi_iterate_entry_points
)

lst = plugins.list_from_entry_points()
assert len(lst) == 3
assert all(e.id.startswith("example") for e in lst)

(fragmented,) = (e for e in lst if e.tool)
assert fragmented.tool == "example"
assert fragmented.fragment == "frag"
assert fragmented.schema == s1


def test_broken_multi_plugin(monkeypatch):
def broken_ep():
raise RuntimeError("Broken")

sys.modules["test_module"] = ModuleType("test_module")
sys.modules["test_module"].f = broken_ep
monkeypatch.setattr(
plugins, "iterate_entry_points", fake_multi_iterate_entry_points
)
with pytest.raises(ErrorLoadingPlugin):
plugins.list_from_entry_points()