Skip to content

Commit a1fc9a1

Browse files
committed
Add kwargs support for workflow invocations
1 parent f90a338 commit a1fc9a1

File tree

5 files changed

+95
-19
lines changed

5 files changed

+95
-19
lines changed

src/ert/config/ert_script.py

+50-11
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import warnings
99
from abc import abstractmethod
1010
from collections.abc import Callable
11-
from types import MappingProxyType, ModuleType
11+
from types import ModuleType
1212
from typing import TYPE_CHECKING, Any, TypeAlias
1313

1414
from typing_extensions import deprecated
@@ -115,24 +115,48 @@ def initializeAndRun(
115115
argument_types: list[type[Any]],
116116
argument_values: list[str],
117117
fixtures: dict[str, Any] | None = None,
118+
**kwargs: dict[str, Any],
118119
) -> Any:
119120
fixtures = {} if fixtures is None else fixtures
120-
arguments = []
121+
workflow_args = []
121122
for index, arg_value in enumerate(argument_values):
122123
arg_type = argument_types[index] if index < len(argument_types) else str
123124

124125
if arg_value is not None:
125-
arguments.append(arg_type(arg_value))
126+
workflow_args.append(arg_type(arg_value))
126127
else:
127-
arguments.append(None)
128-
fixtures["workflow_args"] = arguments
128+
workflow_args.append(None)
129+
fixtures["workflow_args"] = workflow_args
130+
131+
fixture_args = []
132+
all_func_args = inspect.signature(self.run).parameters
133+
is_using_wf_args_fixture = "workflow_args" in all_func_args
134+
129135
try:
130-
func_args = inspect.signature(self.run).parameters
136+
if not is_using_wf_args_fixture:
137+
fixture_or_kw_arguments = list(all_func_args)[len(workflow_args) :]
138+
else:
139+
fixture_or_kw_arguments = list(all_func_args)
140+
141+
func_args = {k: all_func_args[k] for k in fixture_or_kw_arguments}
142+
143+
kwargs_defaults = {
144+
k: v.default
145+
for k, v in func_args.items()
146+
if k not in fixtures
147+
and v.kind != v.VAR_POSITIONAL
148+
and not str(v).startswith("*")
149+
and v.default != v.empty
150+
}
151+
use_kwargs = {
152+
k: (kwargs or {}).get(k, default_value)
153+
for k, default_value in ({**kwargs_defaults, **kwargs}).items()
154+
}
131155
# If the user has specified *args, we skip injecting fixtures, and just
132156
# pass the user configured arguments
133157
if not any(p.kind == p.VAR_POSITIONAL for p in func_args.values()):
134158
try:
135-
arguments = self.insert_fixtures(func_args, fixtures)
159+
fixture_args = self.insert_fixtures(func_args, fixtures, use_kwargs)
136160
except ValueError as e:
137161
# This is here for backwards compatibility, the user does not have *argv
138162
# but positional arguments. Can not be mixed with using fixtures.
@@ -143,7 +167,20 @@ def initializeAndRun(
143167
self._ert = fixtures.get("ert_config")
144168
self._ensemble = fixtures.get("ensemble")
145169
self._storage = fixtures.get("storage")
146-
return self.run(*arguments)
170+
171+
positional_args = (
172+
fixture_args
173+
if is_using_wf_args_fixture
174+
else [*workflow_args, *fixture_args]
175+
)
176+
if not positional_args and not use_kwargs:
177+
return self.run()
178+
elif positional_args and not use_kwargs:
179+
return self.run(*positional_args)
180+
elif not positional_args and use_kwargs:
181+
return self.run(**use_kwargs)
182+
else:
183+
return self.run(*positional_args, **use_kwargs)
147184
except AttributeError as e:
148185
error_msg = str(e)
149186
if not hasattr(self, "run"):
@@ -169,20 +206,22 @@ def initializeAndRun(
169206

170207
def insert_fixtures(
171208
self,
172-
func_args: MappingProxyType[str, inspect.Parameter],
209+
func_args: dict[str, inspect.Parameter],
173210
fixtures: dict[str, Fixtures],
211+
kwargs: dict[str, Any],
174212
) -> list[Any]:
175213
arguments = []
176214
errors = []
177215
for val in func_args:
178216
if val in fixtures:
179217
arguments.append(fixtures[val])
180-
else:
218+
elif val not in kwargs:
181219
errors.append(val)
182220
if errors:
221+
kwargs_str = ",".join(f"{k}='{v}'" for k, v in kwargs.items())
183222
raise ValueError(
184223
f"Plugin: {self.__class__.__name__} misconfigured, arguments: {errors} "
185-
f"not found in fixtures: {list(fixtures)}"
224+
f"not found in fixtures: {list(fixtures)} or kwargs {kwargs_str}"
186225
)
187226
return arguments
188227

src/ert/gui/tools/plugins/plugin.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def getArguments(self, fixtures: dict[str, Any]) -> list[Any]:
4242
script = self.__loadPlugin()
4343
fixtures["parent"] = self.__parent_window
4444
func_args = inspect.signature(script.getArguments).parameters
45-
arguments = script.insert_fixtures(func_args, fixtures)
45+
arguments = script.insert_fixtures(dict(func_args), fixtures, {})
4646

4747
# Part of deprecation
4848
script._ert = fixtures.get("ert_config")

src/ert/libres_facade.py

+2
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ def run_ertscript( # type: ignore
210210
storage: Storage,
211211
ensemble: Ensemble,
212212
*args: Any,
213+
**kwargs: dict[str, Any],
213214
) -> Any:
214215
warnings.warn(
215216
"run_ertscript is deprecated, use the workflow runner",
@@ -225,6 +226,7 @@ def run_ertscript( # type: ignore
225226
"storage": storage,
226227
"config_file": self.config.user_config_file,
227228
},
229+
**kwargs,
228230
)
229231

230232
@classmethod

src/ert/workflow_runner.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def run(
5757
else:
5858
raise UserWarning("Unknown script type!")
5959
result = self.__script.initializeAndRun( # type: ignore
60-
self.job.argument_types(), arguments, fixtures=fixtures
60+
self.job.argument_types(), arguments, fixtures
6161
)
6262
self.__running = False
6363

tests/ert/unit_tests/config/test_ert_plugin.py

+41-6
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,38 @@ def run(self, arg_1, ert_script, fixture_2, arg_2="something"):
8989
fixture2_mock = MagicMock()
9090
with caplog.at_level(logging.WARNING):
9191
plugin.initializeAndRun(
92-
[], [1, 2], {"ert_script": fixture_mock, "fixture_2": fixture2_mock}
92+
[], [], {"ert_script": fixture_mock, "fixture_2": fixture2_mock}
9393
)
9494

9595
assert plugin.hasFailed()
9696
log = "\n".join(caplog.messages)
9797
assert "FixturePlugin misconfigured" in log
98-
assert "['arg_1', 'arg_2'] not found in fixtures" in log
98+
assert ("arguments: ['arg_1'] not found in fixtures") in log
99+
100+
101+
def test_plugin_with_mixed_arguments(caplog):
102+
fixture_mock = MagicMock()
103+
fixture2_mock = MagicMock()
104+
105+
class FixturePlugin(ErtPlugin):
106+
def run(self, arg_0, arg_1, ert_script, fixture_2, arg_2="something"):
107+
nonlocal fixture_mock
108+
nonlocal fixture2_mock
109+
110+
assert arg_0 == "1"
111+
assert arg_1 == "2"
112+
assert ert_script == fixture_mock
113+
assert fixture_2 == fixture2_mock
114+
assert arg_2 == "something else"
115+
116+
plugin = FixturePlugin()
117+
118+
plugin.initializeAndRun(
119+
[],
120+
[1, 2],
121+
{"ert_script": fixture_mock, "fixture_2": fixture2_mock},
122+
arg_2="something_else",
123+
)
99124

100125

101126
def test_plugin_with_fixtures_and_enough_arguments():
@@ -111,17 +136,27 @@ def run(self, workflow_args, ert_script):
111136
)
112137

113138

139+
def test_plugin_with_fixtures_and_enough_arguments_positional():
140+
class FixturePlugin(ErtPlugin):
141+
def run(self, a, b, c, ert_script):
142+
return ([a, b, c], ert_script)
143+
144+
plugin = FixturePlugin()
145+
fixture_mock = MagicMock()
146+
assert plugin.initializeAndRun([], [1, 2, 3], {"ert_script": fixture_mock}) == (
147+
["1", "2", "3"],
148+
fixture_mock,
149+
)
150+
151+
114152
def test_plugin_with_default_arguments(capsys):
115153
class FixturePlugin(ErtPlugin):
116154
def run(self, ert_script=None):
117155
return ert_script
118156

119157
plugin = FixturePlugin()
120158
fixture_mock = MagicMock()
121-
assert (
122-
plugin.initializeAndRun([], [1, 2], {"ert_script": fixture_mock})
123-
== fixture_mock
124-
)
159+
assert plugin.initializeAndRun([], [], {"ert_script": fixture_mock}) == fixture_mock
125160

126161

127162
def test_plugin_with_args():

0 commit comments

Comments
 (0)