Skip to content

Commit 6b9099a

Browse files
hyanwongmergify[bot]
authored andcommitted
Better formatted error message for ancestral alleles
1 parent 7db1d38 commit 6b9099a

File tree

2 files changed

+22
-15
lines changed

2 files changed

+22
-15
lines changed

tests/test_sgkit.py

+9-11
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,7 @@ def test_ancestral_missingness(tmp_path):
527527
ancestral_allele = ds.variant_ancestral_allele.values
528528
ancestral_allele[0] = "N"
529529
ancestral_allele[11] = "-"
530+
ancestral_allele[12] = "💩"
530531
ancestral_allele[15] = "💩"
531532
ds = ds.drop_vars(["variant_ancestral_allele"])
532533
sgkit.save_dataset(ds, str(zarr_path) + ".tmp")
@@ -538,19 +539,16 @@ def test_ancestral_missingness(tmp_path):
538539
)
539540
ds = sgkit.load_dataset(str(zarr_path) + ".tmp")
540541
sd = tsinfer.SgkitSampleData(str(zarr_path) + ".tmp")
541-
with pytest.warns(UserWarning, match="The following alleles were not found"):
542+
with pytest.warns(
543+
UserWarning,
544+
match=r"not found in the variant_allele array for the 4 [\s\S]*'💩': 2",
545+
):
542546
inf_ts = tsinfer.infer(sd)
543-
for i, (
544-
inf_var,
545-
var,
546-
) in enumerate(zip(inf_ts.variants(), ts.variants())):
547-
assert inf_var.site.ancestral_state == var.site.ancestral_state or i in [
548-
0,
549-
11,
550-
15,
551-
]
552-
if i in [0, 11, 15]:
547+
for i, (inf_var, var) in enumerate(zip(inf_ts.variants(), ts.variants())):
548+
if i in [0, 11, 12, 15]:
553549
assert inf_var.site.metadata == {"inference_type": "parsimony"}
550+
else:
551+
assert inf_var.site.ancestral_state == var.site.ancestral_state
554552

555553

556554
@pytest.mark.skipif(sys.platform == "win32", reason="File permission errors on Windows")

tsinfer/formats.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -2405,11 +2405,20 @@ def sites_ancestral_allele(self):
24052405
except IndexError:
24062406
unknown_alleles[allele] += 1
24072407
ret[i] = allele_index
2408-
if sum(unknown_alleles.values()) > 0:
2408+
tot = sum(unknown_alleles.values())
2409+
if tot > 0:
2410+
num_sites = len(string_allele)
2411+
frac_bad = tot / num_sites
2412+
frac_bad_per_type = [v / num_sites for v in unknown_alleles.values()]
2413+
summarise_unknown = [
2414+
f"'{k}': {v} ({frac * 100:.2f}% of sites)" # Summarise per allele type
2415+
for (k, v), frac in zip(unknown_alleles.items(), frac_bad_per_type)
2416+
]
24092417
warnings.warn(
2410-
"The following alleles were not found in the variant_allele array "
2411-
"and will be treated as unknown:\n"
2412-
f"{unknown_alleles}"
2418+
"An ancestral allele was not found in the variant_allele array for "
2419+
+ f"the {tot} sites ({frac_bad * 100 :.2f}%) listed below. "
2420+
+ "They will be treated as of unknown ancestral state:\n "
2421+
+ "\n ".join(summarise_unknown)
24132422
)
24142423
return ret
24152424

0 commit comments

Comments
 (0)