Skip to content

Commit 86c2c48

Browse files
authored
Preserve inputs passed to depends_on and produces decorator. (#42)
1 parent 58358ed commit 86c2c48

File tree

4 files changed

+124
-23
lines changed

4 files changed

+124
-23
lines changed

docs/changes.rst

+3
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ all releases are available on `Anaconda.org <https://anaconda.org/pytask/pytask>
2121
- :gh:`39` releases v0.0.9.
2222
- :gh:`40` cleans up the capture manager and other parts of pytask.
2323
- :gh:`41` shortens the task ids in the error reports for better readability.
24+
- :gh:`42` ensures that lists with one element and dictionaries with only a zero key as
25+
input for ``@pytask.mark.depends_on`` and ``@pytask.mark.produces`` are preserved as a
26+
dictionary inside the function.
2427

2528

2629
0.0.8 - 2020-10-04

src/_pytask/nodes.py

+52-15
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77
from abc import abstractmethod
88
from pathlib import Path
99
from typing import Any
10+
from typing import Dict
1011
from typing import Iterable
1112
from typing import List
13+
from typing import Tuple
1214
from typing import Union
1315

1416
import attr
@@ -82,17 +84,22 @@ class PythonFunctionTask(MetaTask):
8284
"""List[MetaNode]: A list of products of task."""
8385
markers = attr.ib(factory=list)
8486
"""Optional[List[Mark]]: A list of markers attached to the task function."""
87+
keep_dict = attr.ib(factory=dict)
8588
_report_sections = attr.ib(factory=list)
8689

8790
@classmethod
8891
def from_path_name_function_session(cls, path, name, function, session):
8992
"""Create a task from a path, name, function, and session."""
93+
keep_dictionary = {}
94+
9095
objects = _extract_nodes_from_function_markers(function, depends_on)
91-
nodes = _convert_objects_to_node_dictionary(objects, "depends_on")
96+
nodes, keep_dict = _convert_objects_to_node_dictionary(objects, "depends_on")
97+
keep_dictionary["depends_on"] = keep_dict
9298
dependencies = _collect_nodes(session, path, name, nodes)
9399

94100
objects = _extract_nodes_from_function_markers(function, produces)
95-
nodes = _convert_objects_to_node_dictionary(objects, "produces")
101+
nodes, keep_dict = _convert_objects_to_node_dictionary(objects, "produces")
102+
keep_dictionary["produces"] = keep_dict
96103
products = _collect_nodes(session, path, name, nodes)
97104

98105
markers = [
@@ -109,6 +116,7 @@ def from_path_name_function_session(cls, path, name, function, session):
109116
depends_on=dependencies,
110117
produces=products,
111118
markers=markers,
119+
keep_dict=keep_dictionary,
112120
)
113121

114122
def execute(self):
@@ -124,15 +132,15 @@ def _get_kwargs_from_task_for_function(self):
124132
"""Process dependencies and products to pass them as kwargs to the function."""
125133
func_arg_names = set(inspect.signature(self.function).parameters)
126134
kwargs = {}
127-
for name in ["depends_on", "produces"]:
128-
if name in func_arg_names:
129-
attribute = getattr(self, name)
130-
kwargs[name] = (
135+
for arg_name in ["depends_on", "produces"]:
136+
if arg_name in func_arg_names:
137+
attribute = getattr(self, arg_name)
138+
kwargs[arg_name] = (
131139
attribute[0].value
132-
if len(attribute) == 1 and 0 in attribute
133-
else {
134-
node_name: node.value for node_name, node in attribute.items()
135-
}
140+
if len(attribute) == 1
141+
and 0 in attribute
142+
and not self.keep_dict[arg_name]
143+
else {name: node.value for name, node in attribute.items()}
136144
)
137145

138146
return kwargs
@@ -208,32 +216,49 @@ def _extract_nodes_from_function_markers(function, parser):
208216

209217

210218
def _convert_objects_to_node_dictionary(objects, when):
211-
list_of_tuples = _convert_objects_to_list_of_tuples(objects)
219+
"""Convert objects to node dictionary."""
220+
list_of_tuples, keep_dict = _convert_objects_to_list_of_tuples(objects)
212221
_check_that_names_are_not_used_multiple_times(list_of_tuples, when)
213222
nodes = _convert_nodes_to_dictionary(list_of_tuples)
214-
return nodes
223+
return nodes, keep_dict
215224

216225

217226
def _convert_objects_to_list_of_tuples(objects):
227+
"""Convert objects to list of tuples.
228+
229+
Examples
230+
--------
231+
_convert_objects_to_list_of_tuples([{0: 0}, [4, (3, 2)], ((1, 4),))
232+
[(0, 0), (4,), (3, 2), (1, 4)], False
233+
234+
"""
235+
keep_dict = False
236+
218237
out = []
219238
for obj in objects:
220239
if isinstance(obj, dict):
221240
obj = obj.items()
222241

223242
if isinstance(obj, Iterable) and not isinstance(obj, str):
243+
keep_dict = True
224244
for x in obj:
225245
if isinstance(x, Iterable) and not isinstance(x, str):
226246
tuple_x = tuple(x)
227247
if len(tuple_x) in [1, 2]:
228248
out.append(tuple_x)
229249
else:
230-
raise ValueError("ERROR")
250+
raise ValueError(
251+
f"Element {x} can only have two elements at most."
252+
)
231253
else:
232254
out.append((x,))
233255
else:
234256
out.append((obj,))
235257

236-
return out
258+
if len(out) > 1:
259+
keep_dict = False
260+
261+
return out, keep_dict
237262

238263

239264
def _check_that_names_are_not_used_multiple_times(list_of_tuples, when):
@@ -263,7 +288,19 @@ def _check_that_names_are_not_used_multiple_times(list_of_tuples, when):
263288
)
264289

265290

266-
def _convert_nodes_to_dictionary(list_of_tuples):
291+
def _convert_nodes_to_dictionary(
292+
list_of_tuples: List[Tuple[str]],
293+
) -> Dict[str, Union[str, Path]]:
294+
"""Convert nodes to dictionaries.
295+
296+
Examples
297+
--------
298+
>>> _convert_nodes_to_dictionary([(0,), (1,)])
299+
{0: 0, 1: 1}
300+
>>> _convert_nodes_to_dictionary([(1, 0), (1,)])
301+
{1: 0, 0: 1}
302+
303+
"""
267304
nodes = {}
268305
counter = itertools.count()
269306
names = [x[0] for x in list_of_tuples if len(x) == 2]

tests/test_execute.py

+29
Original file line numberDiff line numberDiff line change
@@ -169,3 +169,32 @@ def task_dummy(depends_on, produces):
169169
result = runner.invoke(cli, [tmp_path.as_posix()])
170170

171171
assert result.exit_code == 0
172+
173+
174+
@pytest.mark.parametrize("input_type", ["list", "dict"])
175+
def test_preserve_input_for_dependencies_and_products(tmp_path, input_type):
176+
"""Input type for dependencies and products is preserved."""
177+
path = tmp_path.joinpath("in.txt")
178+
input_ = {0: path.as_posix()} if input_type == "dict" else [path.as_posix()]
179+
path.touch()
180+
181+
path = tmp_path.joinpath("out.txt")
182+
output = {0: path.as_posix()} if input_type == "dict" else [path.as_posix()]
183+
184+
source = f"""
185+
import pytask
186+
from pathlib import Path
187+
188+
@pytask.mark.depends_on({input_})
189+
@pytask.mark.produces({output})
190+
def task_dummy(depends_on, produces):
191+
for nodes in [depends_on, produces]:
192+
assert isinstance(nodes, dict)
193+
assert len(nodes) == 1
194+
assert 0 in nodes
195+
produces[0].touch()
196+
"""
197+
tmp_path.joinpath("task_dummy.py").write_text(textwrap.dedent(source))
198+
199+
session = main({"paths": tmp_path})
200+
assert session.exit_code == 0

tests/test_nodes.py

+40-8
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from _pytask.nodes import _check_that_names_are_not_used_multiple_times
88
from _pytask.nodes import _convert_nodes_to_dictionary
99
from _pytask.nodes import _convert_objects_to_list_of_tuples
10+
from _pytask.nodes import _convert_objects_to_node_dictionary
1011
from _pytask.nodes import _create_task_name
1112
from _pytask.nodes import _extract_nodes_from_function_markers
1213
from _pytask.nodes import _find_closest_ancestor
@@ -108,21 +109,25 @@ def state(self):
108109

109110
@pytest.mark.unit
110111
@pytest.mark.parametrize(
111-
("x", "expected"),
112+
("x", "expected_lot", "expected_kd"),
112113
[
113-
(["string"], [("string",)]),
114-
(("string",), [("string",)]),
115-
(range(2), [(0,), (1,)]),
116-
([{"a": 0, "b": 1}], [("a", 0), ("b", 1)]),
114+
(["string"], [("string",)], False),
115+
(("string",), [("string",)], False),
116+
(range(2), [(0,), (1,)], False),
117+
([{"a": 0, "b": 1}], [("a", 0), ("b", 1)], False),
117118
(
118119
["a", ("b", "c"), {"d": 1, "e": 1}],
119120
[("a",), ("b",), ("c",), ("d", 1), ("e", 1)],
121+
False,
120122
),
123+
([["string"]], [("string",)], True),
124+
([{0: "string"}], [(0, "string")], True),
121125
],
122126
)
123-
def test_convert_objects_to_list_of_tuples(x, expected):
124-
result = _convert_objects_to_list_of_tuples(x)
125-
assert result == expected
127+
def test_convert_objects_to_list_of_tuples(x, expected_lot, expected_kd):
128+
list_of_tuples, keep_dict = _convert_objects_to_list_of_tuples(x)
129+
assert list_of_tuples == expected_lot
130+
assert keep_dict is expected_kd
126131

127132

128133
ERROR = "'@pytask.mark.depends_on' has nodes with the same name:"
@@ -253,3 +258,30 @@ def test_shorten_node_name(node, paths, expectation, expected):
253258
with expectation:
254259
result = shorten_node_name(node, paths)
255260
assert result == expected
261+
262+
263+
@pytest.mark.integration
264+
@pytest.mark.parametrize("when", ["depends_on", "produces"])
265+
@pytest.mark.parametrize(
266+
"objects, expectation, expected_dict, expected_kd",
267+
[
268+
([0, 1], does_not_raise, {0: 0, 1: 1}, False),
269+
([{0: 0}, {1: 1}], does_not_raise, {0: 0, 1: 1}, False),
270+
([{0: 0}], does_not_raise, {0: 0}, True),
271+
([[0]], does_not_raise, {0: 0}, True),
272+
([((0, 0),), ((0, 1),)], ValueError, None, None),
273+
([{0: 0}, {0: 1}], ValueError, None, None),
274+
],
275+
)
276+
def test_convert_objects_to_node_dictionary(
277+
objects, when, expectation, expected_dict, expected_kd
278+
):
279+
expectation = (
280+
pytest.raises(expectation, match=f"'@pytask.mark.{when}' has nodes")
281+
if expectation == ValueError
282+
else expectation()
283+
)
284+
with expectation:
285+
node_dict, keep_dict = _convert_objects_to_node_dictionary(objects, when)
286+
assert node_dict == expected_dict
287+
assert keep_dict is expected_kd

0 commit comments

Comments
 (0)