Skip to content

Commit a826069

Browse files
benjefferymergify[bot]
authored andcommitted
Store resources in provenance
1 parent e3b2155 commit a826069

File tree

8 files changed

+468
-217
lines changed

8 files changed

+468
-217
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ dependencies = [
6161
"attrs>=19.2.0",
6262
"dask[array]",
6363
"numba",
64+
"psutil>=5.9.0",
6465
]
6566

6667
[project.urls]
@@ -78,4 +79,4 @@ packages = ["tsinfer"]
7879
include-package-data = true
7980

8081
[tool.pytest.ini_options]
81-
testpaths = ["tests"]
82+
testpaths = ["tests"]

requirements/development.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,4 @@ sgkit[vcf]
3434
sphinx-book-theme
3535
jupyter-book
3636
sphinx-issues
37-
ipywidgets
37+
ipywidgets

tests/test_inference.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
import _tsinfer
4343
import tsinfer
4444
import tsinfer.eval_util as eval_util
45+
import tsinfer.provenance as provenance
4546

4647
IS_WINDOWS = sys.platform == "win32"
4748

@@ -1440,7 +1441,14 @@ def test_equivalance_with_partitions(self, tmp_path, tmpdir):
14401441
ts = tsinfer.match_ancestors_batch_group_finalise(
14411442
tmpdir / "work", group_index
14421443
)
1443-
ts = tsinfer.match_ancestors_batch_finalise(tmpdir / "work")
1444+
with provenance.TimingAndMemory() as final_timing:
1445+
ts = tsinfer.match_ancestors_batch_finalise(tmpdir / "work")
1446+
1447+
prov = json.loads(ts.provenances()[-1].record)
1448+
assert "resources" in prov
1449+
# Check that the time taken was longer than finalise took
1450+
assert prov["resources"]["elapsed_time"] > final_timing.metrics.elapsed_time
1451+
14441452
ts2 = tsinfer.match_ancestors(samples, ancestors)
14451453
ts.tables.assert_equals(ts2.tables, ignore_provenance=True)
14461454

@@ -1542,7 +1550,13 @@ def test_match_samples_batch(self, tmp_path, tmpdir):
15421550
work_dir=tmpdir / "working_mat",
15431551
partition_index=i,
15441552
)
1545-
mat_ts_batch = tsinfer.match_samples_batch_finalise(tmpdir / "working_mat")
1553+
with provenance.TimingAndMemory() as final_timing:
1554+
mat_ts_batch = tsinfer.match_samples_batch_finalise(tmpdir / "working_mat")
1555+
1556+
prov = json.loads(mat_ts_batch.provenances()[-1].record)
1557+
assert "resources" in prov
1558+
# Check that the time taken was longer than finalise took
1559+
assert prov["resources"]["elapsed_time"] > final_timing.metrics.elapsed_time
15461560

15471561
mask_wd = tsinfer.match_samples_batch_init(
15481562
work_dir=tmpdir / "working_mask",
@@ -1564,9 +1578,15 @@ def test_match_samples_batch(self, tmp_path, tmpdir):
15641578
mask_ts = tsinfer.match_samples(mask_sd, mask_anc_ts)
15651579
mat_ts = tsinfer.match_samples(mat_sd, mat_anc_ts)
15661580

1567-
mat_ts.tables.assert_equals(mask_ts.tables, ignore_timestamps=True)
1568-
mask_ts.tables.assert_equals(mask_ts_batch.tables, ignore_timestamps=True)
1569-
mask_ts_batch.tables.assert_equals(mat_ts_batch.tables, ignore_timestamps=True)
1581+
mat_ts.tables.assert_equals(
1582+
mask_ts.tables, ignore_timestamps=True, ignore_provenance=True
1583+
)
1584+
mask_ts.tables.assert_equals(
1585+
mask_ts_batch.tables, ignore_timestamps=True, ignore_provenance=True
1586+
)
1587+
mask_ts_batch.tables.assert_equals(
1588+
mat_ts_batch.tables, ignore_timestamps=True, ignore_provenance=True
1589+
)
15701590

15711591

15721592
class TestAncestorGeneratorsEquivalant:

tests/test_provenance.py

Lines changed: 88 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
Tests for the provenance stored in the output tree sequences.
2121
"""
2222
import json
23+
import math
24+
import time
2325

2426
import pytest
2527
import tskit
@@ -68,6 +70,51 @@ def test_ancestors_file(self, small_sd_fixture):
6870
self.validate_file(ancestor_data)
6971

7072

73+
class TestResourceMetrics:
74+
"""
75+
Tests for the ResourceMetrics dataclass.
76+
"""
77+
78+
def test_create_and_asdict(self):
79+
metrics = provenance.ResourceMetrics(
80+
elapsed_time=1.5, user_time=1.0, sys_time=0.5, max_memory=1000
81+
)
82+
d = metrics.asdict()
83+
assert d == {
84+
"elapsed_time": 1.5,
85+
"user_time": 1.0,
86+
"sys_time": 0.5,
87+
"max_memory": 1000,
88+
}
89+
90+
def test_combine_metrics(self):
91+
m1 = provenance.ResourceMetrics(
92+
elapsed_time=1.0, user_time=0.5, sys_time=0.2, max_memory=1000
93+
)
94+
m2 = provenance.ResourceMetrics(
95+
elapsed_time=2.0, user_time=1.5, sys_time=0.3, max_memory=2000
96+
)
97+
combined = provenance.ResourceMetrics.combine([m1, m2])
98+
assert combined.elapsed_time == 3.0
99+
assert combined.user_time == 2.0
100+
assert combined.sys_time == 0.5
101+
assert combined.max_memory == 2000
102+
103+
def test_combine_empty_list(self):
104+
with pytest.raises(ValueError):
105+
provenance.ResourceMetrics.combine([])
106+
107+
def test_combine_single_metric(self):
108+
m = provenance.ResourceMetrics(
109+
elapsed_time=1.0, user_time=0.5, sys_time=0.2, max_memory=1000
110+
)
111+
combined = provenance.ResourceMetrics.combine([m])
112+
assert combined.elapsed_time == 1.0
113+
assert combined.user_time == 0.5
114+
assert combined.sys_time == 0.2
115+
assert combined.max_memory == 1000
116+
117+
71118
class TestIncludeProvenance:
72119
"""
73120
Test that we can include or exclude provenances
@@ -124,6 +171,7 @@ def test_provenance_infer(self, small_sd_fixture, mmr, pc, post, precision):
124171
assert params["mismatch_ratio"] == mmr
125172
assert params["path_compression"] == pc
126173
assert "simplify" not in params
174+
assert "resources" in record
127175

128176
def test_provenance_generate_ancestors(self, small_sd_fixture):
129177
ancestors = tsinfer.generate_ancestors(small_sd_fixture)
@@ -132,6 +180,7 @@ def test_provenance_generate_ancestors(self, small_sd_fixture):
132180
timestamp, record = p
133181
params = record["parameters"]
134182
assert params["command"] == "generate_ancestors"
183+
assert "resources" in record
135184

136185
@pytest.mark.parametrize("mmr", [None, 0.1])
137186
@pytest.mark.parametrize("pc", [True, False])
@@ -154,6 +203,9 @@ def test_provenance_match_ancestors(self, small_sd_fixture, mmr, pc, precision):
154203
assert params["mismatch_ratio"] == mmr
155204
assert params["path_compression"] == pc
156205
assert params["precision"] == precision
206+
for provenance_index in [-2, -1]:
207+
record = json.loads(anc_ts.provenance(provenance_index).record)
208+
assert "resources" in record
157209

158210
@pytest.mark.parametrize("mmr", [None, 0.1])
159211
@pytest.mark.parametrize("pc", [True, False])
@@ -183,6 +235,9 @@ def test_provenance_match_samples(self, small_sd_fixture, mmr, pc, precision, po
183235
assert params["precision"] == precision
184236
assert params["post_process"] == post
185237
assert "simplify" not in params # deprecated
238+
for provenance_index in [-3, -2, -1]:
239+
record = json.loads(ts.provenance(provenance_index).record)
240+
assert "resources" in record
186241

187242
@pytest.mark.parametrize("simp", [True, False])
188243
def test_deprecated_simplify(self, small_sd_fixture, simp):
@@ -207,15 +262,46 @@ def test_no_command(self):
207262
with pytest.raises(ValueError):
208263
provenance.get_provenance_dict()
209264

210-
def validate_encoding(self, params):
211-
pdict = provenance.get_provenance_dict("test", **params)
265+
def validate_encoding(self, params, resources=None):
266+
pdict = provenance.get_provenance_dict("test", resources=resources, **params)
212267
encoded = pdict["parameters"]
213268
assert encoded["command"] == "test"
214269
del encoded["command"]
215270
assert encoded == params
271+
if resources is not None:
272+
assert "resources" in pdict
273+
assert pdict["resources"] == resources
274+
else:
275+
assert "resources" not in pdict
216276

217277
def test_empty_params(self):
218278
self.validate_encoding({})
219279

220280
def test_non_empty_params(self):
221281
self.validate_encoding({"a": 1, "b": "b", "c": 12345})
282+
283+
def test_with_resources(self):
284+
self.validate_encoding(
285+
{}, resources={"elapsed_time": 1.23, "max_memory": 567.89}
286+
)
287+
288+
289+
def test_timing_and_memory_context_manager():
290+
with provenance.TimingAndMemory() as timing:
291+
# Do some work to ensure measurable changes
292+
time.sleep(0.1)
293+
for i in range(1000000):
294+
math.sqrt(i)
295+
_ = [0] * 1000000
296+
297+
assert timing.metrics is not None
298+
assert timing.metrics.elapsed_time > 0.1
299+
# Check we have highres timing
300+
assert timing.metrics.elapsed_time < 1
301+
assert timing.metrics.user_time > 0
302+
assert timing.metrics.sys_time >= 0
303+
assert timing.metrics.max_memory > 100_000_000
304+
305+
# Test metrics are not available during context
306+
with provenance.TimingAndMemory() as timing2:
307+
assert timing2.metrics is None

tests/test_variantdata.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -587,11 +587,15 @@ def test_sgkit_subset_equivalence(self, tmp_path, tmpdir):
587587

588588
mat_anc_ts = tsinfer.match_ancestors(mat_sd, ancestors_subset)
589589
mask_anc_ts = tsinfer.match_ancestors(mask_sd, ancestors)
590-
mat_anc_ts.tables.assert_equals(mask_anc_ts.tables, ignore_timestamps=True)
590+
mat_anc_ts.tables.assert_equals(
591+
mask_anc_ts.tables, ignore_timestamps=True, ignore_provenance=True
592+
)
591593

592594
mat_ts = tsinfer.match_samples(mat_sd, mat_anc_ts)
593595
mask_ts = tsinfer.match_samples(mask_sd, mask_anc_ts)
594-
mat_ts.tables.assert_equals(mask_ts.tables, ignore_timestamps=True)
596+
mat_ts.tables.assert_equals(
597+
mask_ts.tables, ignore_timestamps=True, ignore_provenance=True
598+
)
595599

596600

597601
@pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on Windows")

tsinfer/formats.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -691,13 +691,15 @@ def add_provenance(self, timestamp, record):
691691
self.provenances_timestamp[n] = timestamp
692692
self.provenances_record[n] = record
693693

694-
def record_provenance(self, command=None, **kwargs):
694+
def record_provenance(self, command=None, resources=None, **kwargs):
695695
"""
696696
Records the provenance information for this file using the
697697
tskit provenances schema.
698698
"""
699699
timestamp = datetime.datetime.now().isoformat()
700-
record = provenance.get_provenance_dict(command=command, **kwargs)
700+
record = provenance.get_provenance_dict(
701+
command=command, resources=resources, **kwargs
702+
)
701703
self.add_provenance(timestamp, record)
702704

703705
def clear_provenances(self):

0 commit comments

Comments
 (0)