Skip to content

Commit 413dd2e

Browse files
authored
Add de-/serializer func to PickleNode attributes. (#673)
1 parent 1dd46aa commit 413dd2e

File tree

9 files changed

+85
-15
lines changed

9 files changed

+85
-15
lines changed

docs/source/changes.md

+2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ releases are available on [PyPI](https://pypi.org/project/pytask) and
1313
nodes in v0.6.0.
1414
- {pull}`662` adds the `.pixi` folder to be ignored by default during the collection.
1515
- {pull}`671` enhances the documentation on complex repetitions. Closes {issue}`670`.
16+
- {pull}`673` adds de-/serializer function attributes to the `PickleNode`. Closes
17+
{issue}`669`.
1618

1719
## 0.5.2 - 2024-12-19
1820

docs/source/how_to_guides/writing_custom_nodes.md

+8-1
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ Here are some explanations.
111111
signature is a hash and a unique identifier for the node. For most nodes it will be a
112112
hash of the path or the name.
113113

114-
- The {func}`classmethod` {meth}`~pytask.PickleNode.from_path` is a convenient method to
114+
- The classmethod {meth}`~pytask.PickleNode.from_path` is a convenient method to
115115
instantiate the class.
116116

117117
- The method {meth}`~pytask.PickleNode.state` yields a value that signals the node's
@@ -129,6 +129,13 @@ Here are some explanations.
129129
- {meth}`~pytask.PickleNode.save` is called when a task function returns and allows to
130130
save the return values.
131131

132+
## Improvements
133+
134+
Usually, you would like your custom node to work with {class}`pathlib.Path` objects and
135+
{class}`upath.UPath` objects allowing to work with remote filesystems. To simplify
136+
getting the state of the node, you can use the {class}`pytask.get_state_of_path`
137+
function.
138+
132139
## Conclusion
133140

134141
Nodes are an important in concept pytask. They allow to pytask to build a DAG and

docs_src/how_to_guides/writing_custom_nodes_example_3_py310.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def __init__(
2828
) -> None:
2929
self.name = name
3030
self.path = path
31-
self.attributes = attributes or {}
31+
self.attributes = attributes if attributes is not None else {}
3232

3333
@property
3434
def signature(self) -> str:

docs_src/how_to_guides/writing_custom_nodes_example_3_py38.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def __init__(
2929
) -> None:
3030
self.name = name
3131
self.path = path
32-
self.attributes = attributes or {}
32+
self.attributes = attributes if attributes is not None else {}
3333

3434
@property
3535
def signature(self) -> str:

pyproject.toml

+2-3
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ test = [
7171
"syrupy",
7272
"aiohttp", # For HTTPPath tests.
7373
"coiled",
74+
"cloudpickle",
7475
]
7576
typing = ["mypy>=1.9.0,<1.11", "nbqa>=1.8.5"]
7677

@@ -85,9 +86,7 @@ Tracker = "https://github.com/pytask-dev/pytask/issues"
8586
pytask = "pytask:cli"
8687

8788
[tool.uv]
88-
dev-dependencies = [
89-
"tox-uv>=1.7.0", "pygraphviz;platform_system=='Linux'",
90-
]
89+
dev-dependencies = ["tox-uv>=1.7.0", "pygraphviz;platform_system=='Linux'"]
9190

9291
[build-system]
9392
requires = ["hatchling", "hatch_vcs"]

src/_pytask/nodes.py

+27-6
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from attrs import define
1616
from attrs import field
17+
from typing_extensions import deprecated
1718
from upath import UPath
1819
from upath._stat import UPathStatResult
1920

@@ -28,6 +29,9 @@
2829
from _pytask.typing import no_default
2930

3031
if TYPE_CHECKING:
32+
from io import BufferedReader
33+
from io import BufferedWriter
34+
3135
from _pytask.mark import Mark
3236
from _pytask.models import NodeInfo
3337
from _pytask.tree_util import PyTree
@@ -40,6 +44,7 @@
4044
"PythonNode",
4145
"Task",
4246
"TaskWithoutPath",
47+
"get_state_of_path",
4348
]
4449

4550

@@ -145,7 +150,7 @@ def signature(self) -> str:
145150

146151
def state(self) -> str | None:
147152
"""Return the state of the node."""
148-
return _get_state(self.path)
153+
return get_state_of_path(self.path)
149154

150155
def execute(self, **kwargs: Any) -> Any:
151156
"""Execute the task."""
@@ -188,7 +193,7 @@ def state(self) -> str | None:
188193
The state is given by the modification timestamp.
189194
190195
"""
191-
return _get_state(self.path)
196+
return get_state_of_path(self.path)
192197

193198
def load(self, is_product: bool = False) -> Path: # noqa: ARG002
194199
"""Load the value."""
@@ -310,12 +315,18 @@ class PickleNode(PPathNode):
310315
The path to the file.
311316
attributes: dict[Any, Any]
312317
A dictionary to store additional information of the task.
318+
serializer
319+
A function to serialize the object. Defaults to :func:`pickle.dump`.
320+
deserializer
321+
A function to deserialize the object. Defaults to :func:`pickle.load`.
313322
314323
"""
315324

316325
path: Path
317326
name: str = ""
318327
attributes: dict[Any, Any] = field(factory=dict)
328+
serializer: Callable[[Any, BufferedWriter], None] = field(default=pickle.dump)
329+
deserializer: Callable[[BufferedReader], Any] = field(default=pickle.load)
319330

320331
@property
321332
def signature(self) -> str:
@@ -332,17 +343,17 @@ def from_path(cls, path: Path) -> PickleNode:
332343
return cls(name=path.as_posix(), path=path)
333344

334345
def state(self) -> str | None:
335-
return _get_state(self.path)
346+
return get_state_of_path(self.path)
336347

337348
def load(self, is_product: bool = False) -> Any:
338349
if is_product:
339350
return self
340351
with self.path.open("rb") as f:
341-
return pickle.load(f) # noqa: S301
352+
return self.deserializer(f)
342353

343354
def save(self, value: Any) -> None:
344355
with self.path.open("wb") as f:
345-
pickle.dump(value, f)
356+
self.serializer(value, f)
346357

347358

348359
@define(kw_only=True)
@@ -387,7 +398,7 @@ def collect(self) -> list[Path]:
387398
return list(self.root_dir.glob(self.pattern)) # type: ignore[union-attr]
388399

389400

390-
def _get_state(path: Path) -> str | None:
401+
def get_state_of_path(path: Path) -> str | None:
391402
"""Get state of a path.
392403
393404
A simple function to handle local and remote files.
@@ -411,3 +422,13 @@ def _get_state(path: Path) -> str | None:
411422
return stat.as_info().get("ETag", "0")
412423
msg = "Unknown stat object."
413424
raise NotImplementedError(msg)
425+
426+
427+
@deprecated("Use 'pytask.get_state_of_path' instead.")
428+
def _get_state(path: Path) -> str | None:
429+
"""Get state of a path.
430+
431+
A simple function to handle local and remote files.
432+
433+
"""
434+
return get_state_of_path(path)

src/pytask/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
from _pytask.nodes import PickleNode
5454
from _pytask.nodes import PythonNode
5555
from _pytask.nodes import Task
56-
from _pytask.nodes import TaskWithoutPath
56+
from _pytask.nodes import TaskWithoutPath, get_state_of_path
5757
from _pytask.outcomes import CollectionOutcome
5858
from _pytask.outcomes import Exit
5959
from _pytask.outcomes import ExitCode
@@ -146,6 +146,7 @@
146146
"get_all_marks",
147147
"get_marks",
148148
"get_plugin_manager",
149+
"get_state_of_path",
149150
"has_mark",
150151
"hash_value",
151152
"hookimpl",

tests/test_execute.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -985,8 +985,7 @@ def test_download_file(runner, tmp_path):
985985
from upath import UPath
986986
987987
url = UPath(
988-
"https://gist.githubusercontent.com/tobiasraabe/64c24426d5398cac4b9d37b85ebfaf"
989-
"7c/raw/50c61fa9a5aa0b7d3a7582c4c260b43dabfea720/gistfile1.txt"
988+
"https://gist.githubusercontent.com/tobiasraabe/64c24426d5398cac4b9d37b85ebfaf7c/raw/50c61fa9a5aa0b7d3a7582c4c260b43dabfea720/gistfile1.txt"
990989
)
991990
992991
def task_download_file(path: UPath = url) -> Annotated[str, Path("data.csv")]:

tests/test_nodes.py

+41
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import pickle
44
from pathlib import Path
55

6+
import cloudpickle
67
import pytest
78

89
from pytask import NodeInfo
@@ -126,3 +127,43 @@ def test_hash_of_pickle_node(tmp_path, value, exists, expected):
126127
)
127128
def test_comply_with_protocol(node, protocol, expected):
128129
assert isinstance(node, protocol) is expected
130+
131+
132+
@pytest.mark.unit
133+
def test_custom_serializer_deserializer_pickle_node(tmp_path):
134+
"""Test that PickleNode correctly uses cloudpickle for de-/serialization."""
135+
136+
# Define custom serializer and deserializer using cloudpickle
137+
def custom_serializer(obj, file):
138+
# Custom serialization logic that adds a wrapper around the data
139+
cloudpickle.dump({"custom_prefix": obj}, file)
140+
141+
def custom_deserializer(file):
142+
# Custom deserialization logic that unwraps the data
143+
data = cloudpickle.load(file)
144+
return data["custom_prefix"]
145+
146+
# Create test data and path
147+
test_data = {"key": "value"}
148+
path = tmp_path.joinpath("custom.pkl")
149+
150+
# Create PickleNode with custom serializer and deserializer
151+
node = PickleNode(
152+
name="test",
153+
path=path,
154+
serializer=custom_serializer,
155+
deserializer=custom_deserializer,
156+
)
157+
158+
# Test saving with custom serializer
159+
node.save(test_data)
160+
161+
# Verify custom serialization was used by directly reading the file
162+
with path.open("rb") as f:
163+
raw_data = cloudpickle.load(f)
164+
assert "custom_prefix" in raw_data
165+
assert raw_data["custom_prefix"] == test_data
166+
167+
# Test loading with custom deserializer
168+
loaded_data = node.load(is_product=False)
169+
assert loaded_data == test_data

0 commit comments

Comments
 (0)