Skip to content

Commit cb16c92

Browse files
authored
Merge pull request #36 from radix-ai/ls-highlevelgraph
Add support for `HighLevelGraph`s
2 parents 882d355 + 8a1b484 commit cb16c92

File tree

8 files changed

+132
-22
lines changed

8 files changed

+132
-22
lines changed

.circleci/config.yml

+4-4
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,14 @@ jobs:
2828
name: Run linters
2929
command: |
3030
source activate graphchain-circleci-env
31-
flake8 graphchain --max-complexity=10 --ignore=W504
32-
pydocstyle graphchain --convention=numpy
33-
mypy graphchain --ignore-missing-imports --strict
31+
flake8 graphchain
32+
pydocstyle graphchain
33+
mypy graphchain
3434
- run:
3535
name: Run tests
3636
command: |
3737
source activate graphchain-circleci-env
38-
pytest -vx --cov=graphchain graphchain
38+
pytest
3939
4040
workflows:
4141
version: 2

docs/conf.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
# The short X.Y version
2828
version = ''
2929
# The full version, including alpha/beta/rc tags
30-
release = '1.0.0'
30+
release = '1.1.0'
3131

3232

3333
# -- General configuration ---------------------------------------------------

environment.circleci.yml

+10-10
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,20 @@ channels:
33
- defaults
44
- conda-forge
55
dependencies:
6-
- cloudpickle=0.6
7-
- dask=1.0
6+
- cloudpickle=0.8
7+
- dask=1.2
88
- fs-s3fs=0.1
99
- joblib=0.13
10-
- mypy<0.700
10+
- mypy<0.800
1111
- pydocstyle=3.0
12-
- pytest=4.0
12+
- pytest=4.4
1313
- pytest-cov=2.6
14-
- pytest-xdist=1.25
14+
- pytest-xdist=1.28
1515
- pip:
16-
- flake8~=3.6.0
17-
- flake8-comprehensions~=1.4.1
18-
- flake8-bandit~=2.0.0
19-
- flake8-bugbear~=18.8.0
16+
- flake8~=3.7.7
17+
- flake8-comprehensions~=2.1.0
18+
- flake8-bandit~=2.1.0
19+
- flake8-bugbear~=19.3.0
2020
- flake8-mutable~=1.2.0
21-
- flake8-rst-docstrings~=0.0.8
21+
- flake8-rst-docstrings~=0.0.9
2222
- lz4~=2.1.6

graphchain/core.py

+21-5
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
import datetime as dt
33
import functools
44
import logging
5-
import pickle
65
import time
6+
from copy import deepcopy
7+
from pickle import HIGHEST_PROTOCOL # noqa: S403
78
from typing import (Any, Callable, Container, Dict, Hashable, Iterable,
89
Optional, Union)
910

@@ -12,9 +13,24 @@
1213
import fs
1314
import fs.base
1415
import joblib
16+
from dask.highlevelgraph import HighLevelGraph
1517

1618
from .utils import get_size, str_to_posix_fully_portable_filename
1719

20+
21+
def hlg_setitem(self: HighLevelGraph, key: Hashable, value: Any) -> None:
22+
"""Set a HighLevelGraph computation."""
23+
for d in self.layers.values():
24+
if key in d:
25+
d[key] = value
26+
break
27+
28+
29+
# Monkey patch HighLevelGraph to add a missing `__setitem__` method.
30+
if not hasattr(HighLevelGraph, '__setitem__'):
31+
HighLevelGraph.__setitem__ = hlg_setitem
32+
33+
1834
logger = logging.getLogger(__name__)
1935

2036

@@ -166,7 +182,7 @@ def time_to_result(self, memoize: bool = True) -> float:
166182
load_time = self.read_time('store') / 2
167183
self._time_to_result = load_time
168184
return load_time
169-
except Exception:
185+
except Exception: # noqa: S110
170186
pass
171187
compute_time = self.read_time('compute')
172188
dependency_time = 0
@@ -232,7 +248,7 @@ def store(self, result: Any) -> None:
232248
start_time = time.perf_counter()
233249
with self.cache_fs.open( # type: ignore
234250
self.cache_filename, 'wb') as fid:
235-
joblib.dump(result, fid, protocol=pickle.HIGHEST_PROTOCOL)
251+
joblib.dump(result, fid, protocol=HIGHEST_PROTOCOL)
236252
store_time = time.perf_counter() - start_time
237253
# Write store time and log operation
238254
self.write_time('store', store_time)
@@ -243,7 +259,7 @@ def store(self, result: Any) -> None:
243259
# Try to delete leftovers if they were created by accident.
244260
try:
245261
self.cache_fs.remove(self.cache_filename) # type: ignore
246-
except Exception:
262+
except Exception: # noqa: S110
247263
pass
248264

249265
def patch_computation_in_graph(self) -> None:
@@ -343,7 +359,7 @@ def optimize(
343359
.. [2] http://dask.pydata.org/en/latest/optimize.html
344360
"""
345361
# Verify that the graph is a DAG.
346-
dsk = dsk.copy()
362+
dsk = deepcopy(dsk)
347363
assert dask.core.isdag(dsk, list(dsk.keys()))
348364
# Open or create the cache FS.
349365
# TODO(lsorber): lazily evaluate this for compatibility with `distributed`?

graphchain/tests/test_graphchain.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,8 @@ def test_first_run(
167167
storage.close()
168168

169169

170-
def NO_test_single_run_s3(
170+
@pytest.mark.skip(reason='Need AWS credentials to test') # type: ignore
171+
def test_single_run_s3(
171172
dask_graph: Dict[Hashable, Any],
172173
optimizer_s3: Tuple[
173174
str,
+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
"""Test module for the dask HighLevelGraphs."""
2+
import dask
3+
import pandas as pd
4+
import pytest
5+
from dask.highlevelgraph import HighLevelGraph
6+
7+
from ..core import optimize
8+
9+
10+
@pytest.fixture(scope="function") # type: ignore
11+
def dask_highlevelgraph() -> HighLevelGraph:
12+
"""Generate an example dask HighLevelGraph."""
13+
@dask.delayed(pure=True) # type: ignore
14+
def create_dataframe(num_rows: int, num_cols: int) -> pd.DataFrame:
15+
print('Creating DataFrame...')
16+
return pd.DataFrame(data=[range(num_cols)] * num_rows)
17+
18+
@dask.delayed(pure=True) # type: ignore
19+
def create_dataframe2(num_rows: int, num_cols: int) -> pd.DataFrame:
20+
print('Creating DataFrame...')
21+
return pd.DataFrame(data=[range(num_cols)] * num_rows)
22+
23+
@dask.delayed(pure=True) # type: ignore
24+
def complicated_computation(df: pd.DataFrame, num_quantiles: int) \
25+
-> pd.DataFrame:
26+
print('Running complicated computation on DataFrame...')
27+
return df.quantile(q=[i / num_quantiles for i in range(num_quantiles)])
28+
29+
@dask.delayed(pure=True) # type: ignore
30+
def summarise_dataframes(*dfs: pd.DataFrame) -> float:
31+
print('Summing DataFrames...')
32+
return sum(df.sum().sum() for df in dfs)
33+
34+
df_a = create_dataframe(1000, 1000)
35+
df_b = create_dataframe2(1000, 1000)
36+
df_c = complicated_computation(df_a, 2048)
37+
df_d = complicated_computation(df_b, 2048)
38+
result = summarise_dataframes(df_c, df_d)
39+
return result
40+
41+
42+
def test_highleveldag(dask_highlevelgraph: HighLevelGraph) -> None:
43+
"""Test that the graph can be traversed and its result is correct."""
44+
with dask.config.set(scheduler='sync'):
45+
result = dask_highlevelgraph.compute()
46+
assert result == 2045952000.0
47+
48+
49+
def test_highlevelgraph(dask_highlevelgraph: HighLevelGraph) -> None:
50+
"""Test that the graph can be traversed and its result is correct."""
51+
with dask.config.set(scheduler='sync', delayed_optimize=optimize):
52+
result = dask_highlevelgraph.compute()
53+
assert result == 2045952000.0

setup.cfg

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
; http://flake8.pycqa.org/en/latest/user/configuration.html#project-configuration
2+
[flake8]
3+
max_complexity = 10
4+
doctests = True
5+
ignore =
6+
# S101 Use of assert detected.
7+
# Motivation: asserts are useful to test invariants.
8+
S101,
9+
# Line breaks before (W503) or after (W504) binary operator.
10+
# Motivation: At least one must be ignored. This project enforces W503 [1].
11+
# [1] https://github.com/PyCQA/pycodestyle/issues/498.
12+
W504,
13+
# Failed to parse __all__ entry.
14+
# Motivation: flake8-rst-docstrings cannot parse dynamically generated
15+
# __all__ variables.
16+
RST902
17+
18+
; https://mypy.readthedocs.io/en/latest/config_file.html
19+
[mypy]
20+
ignore_missing_imports = True
21+
warn_unused_configs = True
22+
disallow_subclassing_any = True
23+
disallow_untyped_calls = True
24+
disallow_untyped_defs = True
25+
disallow_incomplete_defs = True
26+
check_untyped_defs = True
27+
disallow_untyped_decorators = True
28+
no_implicit_optional = True
29+
warn_redundant_casts = True
30+
warn_unused_ignores = True
31+
warn_return_any = True
32+
disallow_any_generics = True
33+
34+
; http://www.pydocstyle.org/en/latest/usage.html#configuration-files
35+
[pydocstyle]
36+
convention = numpy
37+
38+
; https://docs.pytest.org/en/latest/customize.html#adding-default-options
39+
[tool:pytest]
40+
addopts = --verbose --exitfirst --doctest-modules --log-level DEBUG --cov=graphchain

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
setup(
1212
name='graphchain',
13-
version='1.0.0',
13+
version='1.1.0',
1414
description='An efficient cache for the execution of dask graphs',
1515
long_description=long_description,
1616
long_description_content_type='text/markdown',

0 commit comments

Comments
 (0)