Skip to content

Commit ae960a0

Browse files
authored
feat: add (de)serialization customization (#76)
1 parent 225303b commit ae960a0

10 files changed

+559
-225
lines changed

.flake8

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# https://black.readthedocs.io/en/stable/the_black_code_style/current_style.html#line-length
44
# TODO: https://github.com/PyCQA/flake8/issues/234
55
doctests = True
6-
ignore = B019,DAR103,E203,E501,FS003,S101,W503
6+
ignore = DAR103,E203,E501,FS003,S101,W503
77
max_line_length = 100
88
max_complexity = 10
99

README.md

+86-38
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
## What is graphchain?
66

7-
Graphchain is like [joblib.Memory](https://joblib.readthedocs.io/en/latest/memory.html#memory) for dask graphs. [Dask graph computations](http://dask.pydata.org/en/latest/spec.html) are cached to a local or remote location of your choice, specified by a [PyFilesystem FS URL](https://docs.pyfilesystem.org/en/latest/openers.html).
7+
Graphchain is like [joblib.Memory](https://joblib.readthedocs.io/en/latest/memory.html) for dask graphs. [Dask graph computations](https://docs.dask.org/en/latest/spec.html) are cached to a local or remote location of your choice, specified by a [PyFilesystem FS URL](https://docs.pyfilesystem.org/en/latest/openers.html).
88

99
When you change your dask graph (by changing a computation's implementation or its inputs), graphchain will take care to only recompute the minimum number of computations necessary to fetch the result. This allows you to iterate quickly over your graph without spending time on recomputing previously computed keys.
1010

@@ -23,7 +23,7 @@ Additionally, the result of a computation is only cached if it is estimated that
2323

2424
Install graphchain with pip to get started:
2525

26-
```bash
26+
```sh
2727
pip install graphchain
2828
```
2929

@@ -35,35 +35,35 @@ import graphchain
3535
import pandas as pd
3636

3737
def create_dataframe(num_rows, num_cols):
38-
print('Creating DataFrame...')
38+
print("Creating DataFrame...")
3939
return pd.DataFrame(data=[range(num_cols)]*num_rows)
4040

41-
def complicated_computation(df, num_quantiles):
42-
print('Running complicated computation on DataFrame...')
41+
def expensive_computation(df, num_quantiles):
42+
print("Running expensive computation on DataFrame...")
4343
return df.quantile(q=[i / num_quantiles for i in range(num_quantiles)])
4444

45-
def summarise_dataframes(*dfs):
46-
print('Summing DataFrames...')
45+
def summarize_dataframes(*dfs):
46+
print("Summing DataFrames...")
4747
return sum(df.sum().sum() for df in dfs)
4848

4949
dsk = {
50-
'df_a': (create_dataframe, 10_000, 1000),
51-
'df_b': (create_dataframe, 10_000, 1000),
52-
'df_c': (complicated_computation, 'df_a', 2048),
53-
'df_d': (complicated_computation, 'df_b', 2048),
54-
'result': (summarise_dataframes, 'df_c', 'df_d')
50+
"df_a": (create_dataframe, 10_000, 1000),
51+
"df_b": (create_dataframe, 10_000, 1000),
52+
"df_c": (expensive_computation, "df_a", 2048),
53+
"df_d": (expensive_computation, "df_b", 2048),
54+
"result": (summarize_dataframes, "df_c", "df_d")
5555
}
5656
```
5757

58-
Using `dask.get` to fetch the `'result'` key takes about 6 seconds:
58+
Using `dask.get` to fetch the `"result"` key takes about 6 seconds:
5959

6060
```python
61-
>>> %time dask.get(dsk, 'result')
61+
>>> %time dask.get(dsk, "result")
6262

6363
Creating DataFrame...
64-
Running complicated computation on DataFrame...
64+
Running expensive computation on DataFrame...
6565
Creating DataFrame...
66-
Running complicated computation on DataFrame...
66+
Running expensive computation on DataFrame...
6767
Summing DataFrames...
6868

6969
CPU times: user 7.39 s, sys: 686 ms, total: 8.08 s
@@ -73,10 +73,10 @@ Wall time: 6.19 s
7373
On the other hand, using `graphchain.get` for the first time to fetch `'result'` takes only 4 seconds:
7474

7575
```python
76-
>>> %time graphchain.get(dsk, 'result')
76+
>>> %time graphchain.get(dsk, "result")
7777

7878
Creating DataFrame...
79-
Running complicated computation on DataFrame...
79+
Running expensive computation on DataFrame...
8080
Summing DataFrames...
8181

8282
CPU times: user 4.7 s, sys: 519 ms, total: 5.22 s
@@ -85,10 +85,10 @@ Wall time: 4.04 s
8585

8686
The reason `graphchain.get` is faster than `dask.get` is because it can load `df_b` and `df_d` from cache after `df_a` and `df_c` have been computed and cached. Note that graphchain will only cache the result of a computation if loading that computation from cache is estimated to be faster than simply running the computation.
8787

88-
Running `graphchain.get` a second time to fetch `'result'` will be almost instant since this time the result itself is also available from cache:
88+
Running `graphchain.get` a second time to fetch `"result"` will be almost instant since this time the result itself is also available from cache:
8989

9090
```python
91-
>>> %time graphchain.get(dsk, 'result')
91+
>>> %time graphchain.get(dsk, "result")
9292

9393
CPU times: user 4.79 ms, sys: 1.79 ms, total: 6.58 ms
9494
Wall time: 5.34 ms
@@ -97,15 +97,15 @@ Wall time: 5.34 ms
9797
Now let's say we want to change how the result is summarised from a sum to an average:
9898

9999
```python
100-
def summarise_dataframes(*dfs):
101-
print('Averaging DataFrames...')
100+
def summarize_dataframes(*dfs):
101+
print("Averaging DataFrames...")
102102
return sum(df.mean().mean() for df in dfs) / len(dfs)
103103
```
104104

105-
If we then ask graphchain to fetch `'result'`, it will detect that only `summarise_dataframes` has changed and therefore only recompute this function with inputs loaded from cache:
105+
If we then ask graphchain to fetch `"result"`, it will detect that only `summarize_dataframes` has changed and therefore only recompute this function with inputs loaded from cache:
106106

107107
```python
108-
>>> %time graphchain.get(dsk, 'result')
108+
>>> %time graphchain.get(dsk, "result")
109109

110110
Averaging DataFrames...
111111

@@ -118,49 +118,97 @@ Wall time: 86.6 ms
118118
Graphchain's cache is by default `./__graphchain_cache__`, but you can ask graphchain to use a cache at any [PyFilesystem FS URL](https://docs.pyfilesystem.org/en/latest/openers.html) such as `s3://mybucket/__graphchain_cache__`:
119119

120120
```python
121-
graphchain.get(dsk, 'result', location='s3://mybucket/__graphchain_cache__')
121+
graphchain.get(dsk, "result", location="s3://mybucket/__graphchain_cache__")
122122
```
123123

124124
### Excluding keys from being cached
125125

126126
In some cases you may not want a key to be cached. To avoid writing certain keys to the graphchain cache, you can use the `skip_keys` argument:
127127

128128
```python
129-
graphchain.get(dsk, 'result', skip_keys=['result'])
129+
graphchain.get(dsk, "result", skip_keys=["result"])
130130
```
131131

132132
### Using graphchain with dask.delayed
133133

134134
Alternatively, you can use graphchain together with dask.delayed for easier dask graph creation:
135135

136136
```python
137+
import dask
138+
import pandas as pd
139+
137140
@dask.delayed
138141
def create_dataframe(num_rows, num_cols):
139-
print('Creating DataFrame...')
142+
print("Creating DataFrame...")
140143
return pd.DataFrame(data=[range(num_cols)]*num_rows)
141144

142145
@dask.delayed
143-
def complicated_computation(df, num_quantiles):
144-
print('Running complicated computation on DataFrame...')
146+
def expensive_computation(df, num_quantiles):
147+
print("Running expensive computation on DataFrame...")
145148
return df.quantile(q=[i / num_quantiles for i in range(num_quantiles)])
146149

147150
@dask.delayed
148-
def summarise_dataframes(*dfs):
149-
print('Summing DataFrames...')
151+
def summarize_dataframes(*dfs):
152+
print("Summing DataFrames...")
150153
return sum(df.sum().sum() for df in dfs)
151154

152-
df_a = create_dataframe(num_rows=50_000, num_cols=500, seed=42)
153-
df_b = create_dataframe(num_rows=50_000, num_cols=500, seed=42)
154-
df_c = complicated_computation(df_a, window=3)
155-
df_d = complicated_computation(df_b, window=3)
156-
result = summarise_dataframes(df_c, df_d)
155+
df_a = create_dataframe(num_rows=10_000, num_cols=1000)
156+
df_b = create_dataframe(num_rows=10_000, num_cols=1000)
157+
df_c = expensive_computation(df_a, num_quantiles=2048)
158+
df_d = expensive_computation(df_b, num_quantiles=2048)
159+
result = summarize_dataframes(df_c, df_d)
157160
```
158161

159162
After which you can compute `result` by setting the `delayed_optimize` method to `graphchain.optimize`:
160163

161164
```python
162-
with dask.config.set(scheduler='sync', delayed_optimize=graphchain.optimize):
163-
result.compute(location='s3://mybucket/__graphchain_cache__')
165+
import graphchain
166+
from functools import partial
167+
168+
optimize_s3 = partial(graphchain.optimize, location="s3://mybucket/__graphchain_cache__/")
169+
170+
with dask.config.set(scheduler="sync", delayed_optimize=optimize_s3):
171+
print(result.compute())
172+
```
173+
174+
### Using a custom a serializer/deserializer
175+
176+
By default graphchain will cache dask computations with [joblib.dump](https://joblib.readthedocs.io/en/latest/generated/joblib.dump.html) and LZ4 compression. However, you may also supply a custom `serialize` and `deserialize` function that writes and reads computations to and from a [PyFilesystem filesystem](https://docs.pyfilesystem.org/en/latest/introduction.html), respectively. For example, the following snippet shows how to serialize dask DataFrames with [dask.dataframe.to_parquet](https://docs.dask.org/en/stable/generated/dask.dataframe.to_parquet.html), while other objects are serialized with joblib:
177+
178+
```python
179+
import dask.dataframe
180+
import graphchain
181+
import fs.osfs
182+
import joblib
183+
import os
184+
from functools import partial
185+
from typing import Any
186+
187+
def custom_serialize(obj: Any, fs: fs.osfs.OSFS, key: str) -> None:
188+
"""Serialize dask DataFrames with to_parquet, and other objects with joblib.dump."""
189+
if isinstance(obj, dask.dataframe.DataFrame):
190+
obj.to_parquet(os.path.join(fs.root_path, "parquet", key))
191+
else:
192+
with fs.open(f"{key}.joblib", "wb") as fid:
193+
joblib.dump(obj, fid)
194+
195+
def custom_deserialize(fs: fs.osfs.OSFS, key: str) -> Any:
196+
"""Deserialize dask DataFrames with read_parquet, and other objects with joblib.load."""
197+
if fs.exists(f"{key}.joblib"):
198+
with fs.open(f"{key}.joblib", "rb") as fid:
199+
return joblib.load(fid)
200+
else:
201+
return dask.dataframe.read_parquet(os.path.join(fs.root_path, "parquet", key))
202+
203+
optimize_parquet = partial(
204+
graphchain.optimize,
205+
location="./__graphchain_cache__/custom/",
206+
serialize=custom_serialize,
207+
deserialize=custom_deserialize
208+
)
209+
210+
with dask.config.set(scheduler="sync", delayed_optimize=optimize_parquet):
211+
print(result.compute())
164212
```
165213

166214
## Contributing

0 commit comments

Comments
 (0)