Skip to content

Commit 28d389e

Browse files
committed
update hooks logic
1 parent fc253dc commit 28d389e

File tree

1 file changed

+76
-9
lines changed

1 file changed

+76
-9
lines changed

src/warnet/hooks.py

Lines changed: 76 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Any, Callable, Optional
88

99
import click
10+
import yaml
1011

1112
from warnet.constants import (
1213
HOOK_NAME_KEY,
@@ -16,12 +17,18 @@
1617
WARNET_USER_DIR_ENV_VAR,
1718
)
1819

20+
21+
class PluginError(Exception):
22+
pass
23+
24+
1925
hook_registry: set[Callable[..., Any]] = set()
2026
imported_modules = {}
2127

2228

2329
@click.group(name="plugin")
2430
def plugin():
31+
"""Control plugins"""
2532
pass
2633

2734

@@ -33,6 +40,29 @@ def ls():
3340
click.secho("Could not determine the plugin directory location.")
3441
click.secho("Consider setting environment variable containing your project directory:")
3542
click.secho(f"export {WARNET_USER_DIR_ENV_VAR}=/home/user/path/to/project/", fg="yellow")
43+
sys.exit(1)
44+
45+
for plugin, status in get_plugins_with_status(plugin_dir):
46+
if status:
47+
click.secho(f"{plugin.stem:<20} enabled", fg="green")
48+
else:
49+
click.secho(f"{plugin.stem:<20} disabled", fg="yellow")
50+
51+
52+
@plugin.command()
53+
@click.argument("plugin", type=str)
54+
@click.argument("function", type=str)
55+
def run(plugin: str, function: str):
56+
module = imported_modules.get(f"plugins.{plugin}")
57+
if hasattr(module, function):
58+
func = getattr(module, function)
59+
if callable(func):
60+
result = func()
61+
print(result)
62+
else:
63+
click.secho(f"{function} in {module} is not callable.")
64+
else:
65+
click.secho(f"Could not find {function} in {module}")
3666

3767

3868
def api(func: Callable[..., Any]) -> Callable[..., Any]:
@@ -134,6 +164,11 @@ def load_user_modules() -> bool:
134164
if not plugin_dir or not plugin_dir.is_dir():
135165
return was_successful_load
136166

167+
enabled_plugins = [plugin for plugin, enabled in get_plugins_with_status(plugin_dir) if enabled]
168+
169+
if not enabled_plugins:
170+
return was_successful_load
171+
137172
# Temporarily add the directory to sys.path for imports
138173
sys.path.insert(0, str(plugin_dir))
139174

@@ -146,15 +181,16 @@ def load_user_modules() -> bool:
146181
sys.modules[HOOKS_API_STEM] = hooks_module
147182
hooks_spec.loader.exec_module(hooks_module)
148183

149-
for file in plugin_dir.glob("*.py"):
150-
if file.stem not in ("__init__", HOOKS_API_STEM):
151-
module_name = f"{PLUGINS_LABEL}.{file.stem}"
152-
spec = importlib.util.spec_from_file_location(module_name, file)
153-
module = importlib.util.module_from_spec(spec)
154-
imported_modules[module_name] = module
155-
sys.modules[module_name] = module
156-
spec.loader.exec_module(module)
157-
was_successful_load = True
184+
for plugin_path in enabled_plugins:
185+
for file in plugin_path.glob("*.py"):
186+
if file.stem not in ("__init__", HOOKS_API_STEM):
187+
module_name = f"{PLUGINS_LABEL}.{file.stem}"
188+
spec = importlib.util.spec_from_file_location(module_name, file)
189+
module = importlib.util.module_from_spec(spec)
190+
imported_modules[module_name] = module
191+
sys.modules[module_name] = module
192+
spec.loader.exec_module(module)
193+
was_successful_load = True
158194

159195
# Remove the added path from sys.path
160196
sys.path.pop(0)
@@ -190,3 +226,34 @@ def get_version(package_name: str) -> str:
190226
except PackageNotFoundError:
191227
print(f"Package not found: {package_name}")
192228
sys.exit(1)
229+
230+
231+
def open_yaml(path: Path) -> dict:
232+
try:
233+
with open(path) as file:
234+
return yaml.safe_load(file)
235+
except FileNotFoundError as e:
236+
raise PluginError(f"YAML file {path} not found.") from e
237+
except yaml.YAMLError as e:
238+
raise PluginError(f"Error parsing yaml: {e}") from e
239+
240+
241+
def check_if_plugin_enabled(path: Path) -> bool:
242+
enabled = None
243+
try:
244+
plugin_dict = open_yaml(path / Path("plugin.yaml"))
245+
enabled = plugin_dict.get("enabled")
246+
except PluginError as e:
247+
click.secho(e)
248+
249+
return bool(enabled)
250+
251+
252+
def get_plugins_with_status(plugin_dir: Path) -> list[tuple[Path, bool]]:
253+
candidates = [
254+
Path(os.path.join(plugin_dir, name))
255+
for name in os.listdir(plugin_dir)
256+
if os.path.isdir(os.path.join(plugin_dir, name))
257+
]
258+
plugins = [plugin_dir for plugin_dir in candidates if any(plugin_dir.glob("plugin.yaml"))]
259+
return [(plugin, check_if_plugin_enabled(plugin)) for plugin in plugins]

0 commit comments

Comments
 (0)