7
7
from typing import Any , Callable , Optional
8
8
9
9
import click
10
+ import yaml
10
11
11
12
from warnet .constants import (
12
13
HOOK_NAME_KEY ,
16
17
WARNET_USER_DIR_ENV_VAR ,
17
18
)
18
19
20
+
21
+ class PluginError (Exception ):
22
+ pass
23
+
24
+
19
25
hook_registry : set [Callable [..., Any ]] = set ()
20
26
imported_modules = {}
21
27
22
28
23
29
@click .group (name = "plugin" )
24
30
def plugin ():
31
+ """Control plugins"""
25
32
pass
26
33
27
34
@@ -33,6 +40,29 @@ def ls():
33
40
click .secho ("Could not determine the plugin directory location." )
34
41
click .secho ("Consider setting environment variable containing your project directory:" )
35
42
click .secho (f"export { WARNET_USER_DIR_ENV_VAR } =/home/user/path/to/project/" , fg = "yellow" )
43
+ sys .exit (1 )
44
+
45
+ for plugin , status in get_plugins_with_status (plugin_dir ):
46
+ if status :
47
+ click .secho (f"{ plugin .stem :<20} enabled" , fg = "green" )
48
+ else :
49
+ click .secho (f"{ plugin .stem :<20} disabled" , fg = "yellow" )
50
+
51
+
52
+ @plugin .command ()
53
+ @click .argument ("plugin" , type = str )
54
+ @click .argument ("function" , type = str )
55
+ def run (plugin : str , function : str ):
56
+ module = imported_modules .get (f"plugins.{ plugin } " )
57
+ if hasattr (module , function ):
58
+ func = getattr (module , function )
59
+ if callable (func ):
60
+ result = func ()
61
+ print (result )
62
+ else :
63
+ click .secho (f"{ function } in { module } is not callable." )
64
+ else :
65
+ click .secho (f"Could not find { function } in { module } " )
36
66
37
67
38
68
def api (func : Callable [..., Any ]) -> Callable [..., Any ]:
@@ -134,6 +164,11 @@ def load_user_modules() -> bool:
134
164
if not plugin_dir or not plugin_dir .is_dir ():
135
165
return was_successful_load
136
166
167
+ enabled_plugins = [plugin for plugin , enabled in get_plugins_with_status (plugin_dir ) if enabled ]
168
+
169
+ if not enabled_plugins :
170
+ return was_successful_load
171
+
137
172
# Temporarily add the directory to sys.path for imports
138
173
sys .path .insert (0 , str (plugin_dir ))
139
174
@@ -146,15 +181,16 @@ def load_user_modules() -> bool:
146
181
sys .modules [HOOKS_API_STEM ] = hooks_module
147
182
hooks_spec .loader .exec_module (hooks_module )
148
183
149
- for file in plugin_dir .glob ("*.py" ):
150
- if file .stem not in ("__init__" , HOOKS_API_STEM ):
151
- module_name = f"{ PLUGINS_LABEL } .{ file .stem } "
152
- spec = importlib .util .spec_from_file_location (module_name , file )
153
- module = importlib .util .module_from_spec (spec )
154
- imported_modules [module_name ] = module
155
- sys .modules [module_name ] = module
156
- spec .loader .exec_module (module )
157
- was_successful_load = True
184
+ for plugin_path in enabled_plugins :
185
+ for file in plugin_path .glob ("*.py" ):
186
+ if file .stem not in ("__init__" , HOOKS_API_STEM ):
187
+ module_name = f"{ PLUGINS_LABEL } .{ file .stem } "
188
+ spec = importlib .util .spec_from_file_location (module_name , file )
189
+ module = importlib .util .module_from_spec (spec )
190
+ imported_modules [module_name ] = module
191
+ sys .modules [module_name ] = module
192
+ spec .loader .exec_module (module )
193
+ was_successful_load = True
158
194
159
195
# Remove the added path from sys.path
160
196
sys .path .pop (0 )
@@ -190,3 +226,34 @@ def get_version(package_name: str) -> str:
190
226
except PackageNotFoundError :
191
227
print (f"Package not found: { package_name } " )
192
228
sys .exit (1 )
229
+
230
+
231
+ def open_yaml (path : Path ) -> dict :
232
+ try :
233
+ with open (path ) as file :
234
+ return yaml .safe_load (file )
235
+ except FileNotFoundError as e :
236
+ raise PluginError (f"YAML file { path } not found." ) from e
237
+ except yaml .YAMLError as e :
238
+ raise PluginError (f"Error parsing yaml: { e } " ) from e
239
+
240
+
241
+ def check_if_plugin_enabled (path : Path ) -> bool :
242
+ enabled = None
243
+ try :
244
+ plugin_dict = open_yaml (path / Path ("plugin.yaml" ))
245
+ enabled = plugin_dict .get ("enabled" )
246
+ except PluginError as e :
247
+ click .secho (e )
248
+
249
+ return bool (enabled )
250
+
251
+
252
+ def get_plugins_with_status (plugin_dir : Path ) -> list [tuple [Path , bool ]]:
253
+ candidates = [
254
+ Path (os .path .join (plugin_dir , name ))
255
+ for name in os .listdir (plugin_dir )
256
+ if os .path .isdir (os .path .join (plugin_dir , name ))
257
+ ]
258
+ plugins = [plugin_dir for plugin_dir in candidates if any (plugin_dir .glob ("plugin.yaml" ))]
259
+ return [(plugin , check_if_plugin_enabled (plugin )) for plugin in plugins ]
0 commit comments