Skip to content

Commit 4a123d8

Browse files
committed
Record more tsinfer parameters in provenance
These are useful when coming to inspect how the tree sequence was inferred
1 parent 0a83c7b commit 4a123d8

File tree

3 files changed

+75
-6
lines changed

3 files changed

+75
-6
lines changed

tests/test_provenance.py

+53-6
Original file line numberDiff line numberDiff line change
@@ -103,15 +103,27 @@ def test_no_provenance_match_samples(self, small_sd_fixture):
103103
assert ts.num_provenances == small_sd_fixture.num_provenances
104104

105105
@pytest.mark.parametrize("mmr", [None, 0.1])
106-
def test_provenance_infer(self, small_sd_fixture, mmr):
106+
@pytest.mark.parametrize("pc", [True, False])
107+
@pytest.mark.parametrize("post", [True, False])
108+
@pytest.mark.parametrize("precision", [4, 5])
109+
def test_provenance_infer(self, small_sd_fixture, mmr, pc, post, precision):
107110
ts = tsinfer.infer(
108-
small_sd_fixture, mismatch_ratio=mmr, recombination_rate=1e-8
111+
small_sd_fixture,
112+
path_compression=pc,
113+
post_process=post,
114+
precision=precision,
115+
mismatch_ratio=mmr,
116+
recombination_rate=1e-8,
109117
)
110118
assert ts.num_provenances == small_sd_fixture.num_provenances + 1
111119
record = json.loads(ts.provenance(-1).record)
112120
params = record["parameters"]
113121
assert params["command"] == "infer"
122+
assert params["post_process"] == post
123+
assert params["precision"] == precision
114124
assert params["mismatch_ratio"] == mmr
125+
assert params["path_compression"] == pc
126+
assert "simplify" not in params
115127

116128
def test_provenance_generate_ancestors(self, small_sd_fixture):
117129
ancestors = tsinfer.generate_ancestors(small_sd_fixture)
@@ -122,24 +134,42 @@ def test_provenance_generate_ancestors(self, small_sd_fixture):
122134
assert params["command"] == "generate_ancestors"
123135

124136
@pytest.mark.parametrize("mmr", [None, 0.1])
125-
def test_provenance_match_ancestors(self, small_sd_fixture, mmr):
137+
@pytest.mark.parametrize("pc", [True, False])
138+
@pytest.mark.parametrize("precision", [4, 5])
139+
def test_provenance_match_ancestors(self, small_sd_fixture, mmr, pc, precision):
126140
ancestors = tsinfer.generate_ancestors(small_sd_fixture)
127141
anc_ts = tsinfer.match_ancestors(
128-
small_sd_fixture, ancestors, mismatch_ratio=mmr, recombination_rate=1e-8
142+
small_sd_fixture,
143+
ancestors,
144+
mismatch_ratio=mmr,
145+
recombination_rate=1e-8,
146+
path_compression=pc,
147+
precision=precision,
129148
)
130149
assert anc_ts.num_provenances == small_sd_fixture.num_provenances + 2
131150
params = json.loads(anc_ts.provenance(-2).record)["parameters"]
132151
assert params["command"] == "generate_ancestors"
133152
params = json.loads(anc_ts.provenance(-1).record)["parameters"]
134153
assert params["command"] == "match_ancestors"
135154
assert params["mismatch_ratio"] == mmr
155+
assert params["path_compression"] == pc
156+
assert params["precision"] == precision
136157

137158
@pytest.mark.parametrize("mmr", [None, 0.1])
138-
def test_provenance_match_samples(self, small_sd_fixture, mmr):
159+
@pytest.mark.parametrize("pc", [True, False])
160+
@pytest.mark.parametrize("post", [True, False])
161+
@pytest.mark.parametrize("precision", [4, 5])
162+
def test_provenance_match_samples(self, small_sd_fixture, mmr, pc, precision, post):
139163
ancestors = tsinfer.generate_ancestors(small_sd_fixture)
140164
anc_ts = tsinfer.match_ancestors(small_sd_fixture, ancestors)
141165
ts = tsinfer.match_samples(
142-
small_sd_fixture, anc_ts, mismatch_ratio=mmr, recombination_rate=1e-8
166+
small_sd_fixture,
167+
anc_ts,
168+
mismatch_ratio=mmr,
169+
path_compression=pc,
170+
precision=precision,
171+
post_process=post,
172+
recombination_rate=1e-8,
143173
)
144174
assert ts.num_provenances == small_sd_fixture.num_provenances + 3
145175
params = json.loads(ts.provenance(-3).record)["parameters"]
@@ -149,6 +179,23 @@ def test_provenance_match_samples(self, small_sd_fixture, mmr):
149179
params = json.loads(ts.provenance(-1).record)["parameters"]
150180
assert params["command"] == "match_samples"
151181
assert params["mismatch_ratio"] == mmr
182+
assert params["path_compression"] == pc
183+
assert params["precision"] == precision
184+
assert params["post_process"] == post
185+
assert "simplify" not in params # deprecated
186+
187+
@pytest.mark.parametrize("simp", [True, False])
188+
def test_deprecated_simplify(self, small_sd_fixture, simp):
189+
# Included for completeness, but this is deprecated
190+
ancestors = tsinfer.generate_ancestors(small_sd_fixture)
191+
anc_ts = tsinfer.match_ancestors(small_sd_fixture, ancestors)
192+
ts1 = tsinfer.match_samples(small_sd_fixture, anc_ts, simplify=simp)
193+
ts2 = tsinfer.infer(small_sd_fixture, simplify=simp)
194+
for ts in [ts1, ts2]:
195+
record = json.loads(ts.provenance(-1).record)
196+
params = record["parameters"]
197+
assert params["simplify"] == simp
198+
assert "post_process" not in params
152199

153200

154201
class TestGetProvenance:

tsinfer/inference.py

+17
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,11 @@ def infer(
380380
record = provenance.get_provenance_dict(
381381
command="infer",
382382
mismatch_ratio=mismatch_ratio,
383+
path_compression=path_compression,
384+
precision=precision,
385+
simplify=simplify,
386+
post_process=post_process,
387+
# TODO: maybe record recombination rate (which could be a RateMap)
383388
)
384389
tables.provenances.add_row(record=json.dumps(record))
385390
inferred_ts = tables.tree_sequence()
@@ -577,6 +582,9 @@ def match_ancestors(
577582
record = provenance.get_provenance_dict(
578583
command="match_ancestors",
579584
mismatch_ratio=mismatch_ratio,
585+
path_compression=path_compression,
586+
precision=precision,
587+
# TODO: maybe record recombination rate (which could be a RateMap)
580588
)
581589
tables.provenances.add_row(record=json.dumps(record))
582590
ts = tables.tree_sequence()
@@ -810,6 +818,8 @@ def match_ancestors_batch_finalise(work_dir):
810818
record = provenance.get_provenance_dict(
811819
command="match_ancestors",
812820
mismatch_ratio=metadata["mismatch_ratio"],
821+
path_compression=metadata["path_compression"],
822+
precision=metadata["precision"],
813823
)
814824
tables.provenances.add_row(record=json.dumps(record))
815825
ts = tables.tree_sequence()
@@ -901,6 +911,8 @@ def augment_ancestors(
901911
record = provenance.get_provenance_dict(
902912
command="augment_ancestors",
903913
mismatch_ratio=mismatch_ratio,
914+
path_compression=path_compression,
915+
precision=precision,
904916
)
905917
tables.provenances.add_row(record=json.dumps(record))
906918
ts = tables.tree_sequence()
@@ -1257,6 +1269,11 @@ def match_samples(
12571269
record = provenance.get_provenance_dict(
12581270
command="match_samples",
12591271
mismatch_ratio=mismatch_ratio,
1272+
path_compression=path_compression,
1273+
precision=precision,
1274+
simplify=simplify,
1275+
post_process=post_process,
1276+
# TODO: maybe record recombination rate (which could be a RateMap)
12601277
)
12611278
tables.provenances.add_row(record=json.dumps(record))
12621279
ts = tables.tree_sequence()

tsinfer/provenance.py

+5
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,11 @@ def get_provenance_dict(command=None, **kwargs):
7575
raise ValueError("Command must be provided")
7676
parameters = dict(kwargs)
7777
parameters["command"] = command
78+
if "simplify" in parameters:
79+
if parameters["simplify"] is None:
80+
del parameters["simplify"] # simplify is deprecated version of post_process
81+
else:
82+
del parameters["post_process"]
7883
document = {
7984
"schema_version": "1.0.0",
8085
"software": {"name": "tsinfer", "version": __version__},

0 commit comments

Comments
 (0)