Skip to content

Commit

Permalink
Store resources in provenance
Browse files Browse the repository at this point in the history
  • Loading branch information
benjeffery authored and mergify[bot] committed Feb 11, 2025
1 parent e3b2155 commit a826069
Show file tree
Hide file tree
Showing 8 changed files with 468 additions and 217 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ dependencies = [
"attrs>=19.2.0",
"dask[array]",
"numba",
"psutil>=5.9.0",
]

[project.urls]
Expand All @@ -78,4 +79,4 @@ packages = ["tsinfer"]
include-package-data = true

[tool.pytest.ini_options]
testpaths = ["tests"]
testpaths = ["tests"]
2 changes: 1 addition & 1 deletion requirements/development.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,4 @@ sgkit[vcf]
sphinx-book-theme
jupyter-book
sphinx-issues
ipywidgets
ipywidgets
30 changes: 25 additions & 5 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import _tsinfer
import tsinfer
import tsinfer.eval_util as eval_util
import tsinfer.provenance as provenance

IS_WINDOWS = sys.platform == "win32"

Expand Down Expand Up @@ -1440,7 +1441,14 @@ def test_equivalance_with_partitions(self, tmp_path, tmpdir):
ts = tsinfer.match_ancestors_batch_group_finalise(
tmpdir / "work", group_index
)
ts = tsinfer.match_ancestors_batch_finalise(tmpdir / "work")
with provenance.TimingAndMemory() as final_timing:
ts = tsinfer.match_ancestors_batch_finalise(tmpdir / "work")

prov = json.loads(ts.provenances()[-1].record)
assert "resources" in prov
# Check that the time taken was longer than finalise took
assert prov["resources"]["elapsed_time"] > final_timing.metrics.elapsed_time

ts2 = tsinfer.match_ancestors(samples, ancestors)
ts.tables.assert_equals(ts2.tables, ignore_provenance=True)

Expand Down Expand Up @@ -1542,7 +1550,13 @@ def test_match_samples_batch(self, tmp_path, tmpdir):
work_dir=tmpdir / "working_mat",
partition_index=i,
)
mat_ts_batch = tsinfer.match_samples_batch_finalise(tmpdir / "working_mat")
with provenance.TimingAndMemory() as final_timing:
mat_ts_batch = tsinfer.match_samples_batch_finalise(tmpdir / "working_mat")

prov = json.loads(mat_ts_batch.provenances()[-1].record)
assert "resources" in prov
# Check that the time taken was longer than finalise took
assert prov["resources"]["elapsed_time"] > final_timing.metrics.elapsed_time

mask_wd = tsinfer.match_samples_batch_init(
work_dir=tmpdir / "working_mask",
Expand All @@ -1564,9 +1578,15 @@ def test_match_samples_batch(self, tmp_path, tmpdir):
mask_ts = tsinfer.match_samples(mask_sd, mask_anc_ts)
mat_ts = tsinfer.match_samples(mat_sd, mat_anc_ts)

mat_ts.tables.assert_equals(mask_ts.tables, ignore_timestamps=True)
mask_ts.tables.assert_equals(mask_ts_batch.tables, ignore_timestamps=True)
mask_ts_batch.tables.assert_equals(mat_ts_batch.tables, ignore_timestamps=True)
mat_ts.tables.assert_equals(
mask_ts.tables, ignore_timestamps=True, ignore_provenance=True
)
mask_ts.tables.assert_equals(
mask_ts_batch.tables, ignore_timestamps=True, ignore_provenance=True
)
mask_ts_batch.tables.assert_equals(
mat_ts_batch.tables, ignore_timestamps=True, ignore_provenance=True
)


class TestAncestorGeneratorsEquivalant:
Expand Down
90 changes: 88 additions & 2 deletions tests/test_provenance.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
Tests for the provenance stored in the output tree sequences.
"""
import json
import math
import time

import pytest
import tskit
Expand Down Expand Up @@ -68,6 +70,51 @@ def test_ancestors_file(self, small_sd_fixture):
self.validate_file(ancestor_data)


class TestResourceMetrics:
"""
Tests for the ResourceMetrics dataclass.
"""

def test_create_and_asdict(self):
metrics = provenance.ResourceMetrics(
elapsed_time=1.5, user_time=1.0, sys_time=0.5, max_memory=1000
)
d = metrics.asdict()
assert d == {
"elapsed_time": 1.5,
"user_time": 1.0,
"sys_time": 0.5,
"max_memory": 1000,
}

def test_combine_metrics(self):
m1 = provenance.ResourceMetrics(
elapsed_time=1.0, user_time=0.5, sys_time=0.2, max_memory=1000
)
m2 = provenance.ResourceMetrics(
elapsed_time=2.0, user_time=1.5, sys_time=0.3, max_memory=2000
)
combined = provenance.ResourceMetrics.combine([m1, m2])
assert combined.elapsed_time == 3.0
assert combined.user_time == 2.0
assert combined.sys_time == 0.5
assert combined.max_memory == 2000

def test_combine_empty_list(self):
with pytest.raises(ValueError):
provenance.ResourceMetrics.combine([])

def test_combine_single_metric(self):
m = provenance.ResourceMetrics(
elapsed_time=1.0, user_time=0.5, sys_time=0.2, max_memory=1000
)
combined = provenance.ResourceMetrics.combine([m])
assert combined.elapsed_time == 1.0
assert combined.user_time == 0.5
assert combined.sys_time == 0.2
assert combined.max_memory == 1000


class TestIncludeProvenance:
"""
Test that we can include or exclude provenances
Expand Down Expand Up @@ -124,6 +171,7 @@ def test_provenance_infer(self, small_sd_fixture, mmr, pc, post, precision):
assert params["mismatch_ratio"] == mmr
assert params["path_compression"] == pc
assert "simplify" not in params
assert "resources" in record

def test_provenance_generate_ancestors(self, small_sd_fixture):
ancestors = tsinfer.generate_ancestors(small_sd_fixture)
Expand All @@ -132,6 +180,7 @@ def test_provenance_generate_ancestors(self, small_sd_fixture):
timestamp, record = p
params = record["parameters"]
assert params["command"] == "generate_ancestors"
assert "resources" in record

@pytest.mark.parametrize("mmr", [None, 0.1])
@pytest.mark.parametrize("pc", [True, False])
Expand All @@ -154,6 +203,9 @@ def test_provenance_match_ancestors(self, small_sd_fixture, mmr, pc, precision):
assert params["mismatch_ratio"] == mmr
assert params["path_compression"] == pc
assert params["precision"] == precision
for provenance_index in [-2, -1]:
record = json.loads(anc_ts.provenance(provenance_index).record)
assert "resources" in record

@pytest.mark.parametrize("mmr", [None, 0.1])
@pytest.mark.parametrize("pc", [True, False])
Expand Down Expand Up @@ -183,6 +235,9 @@ def test_provenance_match_samples(self, small_sd_fixture, mmr, pc, precision, po
assert params["precision"] == precision
assert params["post_process"] == post
assert "simplify" not in params # deprecated
for provenance_index in [-3, -2, -1]:
record = json.loads(ts.provenance(provenance_index).record)
assert "resources" in record

@pytest.mark.parametrize("simp", [True, False])
def test_deprecated_simplify(self, small_sd_fixture, simp):
Expand All @@ -207,15 +262,46 @@ def test_no_command(self):
with pytest.raises(ValueError):
provenance.get_provenance_dict()

def validate_encoding(self, params):
pdict = provenance.get_provenance_dict("test", **params)
def validate_encoding(self, params, resources=None):
pdict = provenance.get_provenance_dict("test", resources=resources, **params)
encoded = pdict["parameters"]
assert encoded["command"] == "test"
del encoded["command"]
assert encoded == params
if resources is not None:
assert "resources" in pdict
assert pdict["resources"] == resources
else:
assert "resources" not in pdict

def test_empty_params(self):
self.validate_encoding({})

def test_non_empty_params(self):
self.validate_encoding({"a": 1, "b": "b", "c": 12345})

def test_with_resources(self):
self.validate_encoding(
{}, resources={"elapsed_time": 1.23, "max_memory": 567.89}
)


def test_timing_and_memory_context_manager():
with provenance.TimingAndMemory() as timing:
# Do some work to ensure measurable changes
time.sleep(0.1)
for i in range(1000000):
math.sqrt(i)
_ = [0] * 1000000

assert timing.metrics is not None
assert timing.metrics.elapsed_time > 0.1
# Check we have highres timing
assert timing.metrics.elapsed_time < 1
assert timing.metrics.user_time > 0
assert timing.metrics.sys_time >= 0
assert timing.metrics.max_memory > 100_000_000

# Test metrics are not available during context
with provenance.TimingAndMemory() as timing2:
assert timing2.metrics is None
8 changes: 6 additions & 2 deletions tests/test_variantdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,11 +587,15 @@ def test_sgkit_subset_equivalence(self, tmp_path, tmpdir):

mat_anc_ts = tsinfer.match_ancestors(mat_sd, ancestors_subset)
mask_anc_ts = tsinfer.match_ancestors(mask_sd, ancestors)
mat_anc_ts.tables.assert_equals(mask_anc_ts.tables, ignore_timestamps=True)
mat_anc_ts.tables.assert_equals(
mask_anc_ts.tables, ignore_timestamps=True, ignore_provenance=True
)

mat_ts = tsinfer.match_samples(mat_sd, mat_anc_ts)
mask_ts = tsinfer.match_samples(mask_sd, mask_anc_ts)
mat_ts.tables.assert_equals(mask_ts.tables, ignore_timestamps=True)
mat_ts.tables.assert_equals(
mask_ts.tables, ignore_timestamps=True, ignore_provenance=True
)


@pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on Windows")
Expand Down
6 changes: 4 additions & 2 deletions tsinfer/formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,13 +691,15 @@ def add_provenance(self, timestamp, record):
self.provenances_timestamp[n] = timestamp
self.provenances_record[n] = record

def record_provenance(self, command=None, **kwargs):
def record_provenance(self, command=None, resources=None, **kwargs):
"""
Records the provenance information for this file using the
tskit provenances schema.
"""
timestamp = datetime.datetime.now().isoformat()
record = provenance.get_provenance_dict(command=command, **kwargs)
record = provenance.get_provenance_dict(
command=command, resources=resources, **kwargs
)
self.add_provenance(timestamp, record)

def clear_provenances(self):
Expand Down
Loading

0 comments on commit a826069

Please sign in to comment.