Skip to content

Commit 1aa0233

Browse files
authored
Merge pull request #998 from benjeffery/batch-coverage
Improve batch-match coverage
2 parents d2048c3 + a063456 commit 1aa0233

File tree

3 files changed

+153
-48
lines changed

3 files changed

+153
-48
lines changed

docs/large_scale.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,8 +195,7 @@ of ancestors. There are three API methods that work together to enable distribut
195195
3. {meth}`match_samples_batch_finalise`
196196

197197
{meth}`match_samples_batch_init` should be called to set up the batch matching and to determine the
198-
groupings of samples. Similar to {meth}`match_ancestors_batch_init` is has a `min_work_per_job` and
199-
`max_num_partitions` arguments to control the level of parallelism. The method writes a file
198+
groupings of samples. Similar to {meth}`match_ancestors_batch_init` it has a `min_work_per_job` argument to control the level of parallelism. The method writes a file
200199
`metadata.json` to the directory `work_dir` that contains a JSON encoded dictionary with
201200
configuration for later steps. This is also returned by the call. The `num_partitions` key in
202201
this dictionary is the number of times {meth}`match_samples_batch_partition` will need

tests/test_inference.py

Lines changed: 118 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,10 @@
3535
import msprime
3636
import numpy as np
3737
import pytest
38+
import sgkit
3839
import tskit
3940
import tsutil
41+
import xarray as xr
4042
from tskit import MetadataSchema
4143

4244
import _tsinfer
@@ -1401,16 +1403,16 @@ def test_equivalance_many_at_once(self, tmp_path, tmpdir):
14011403
tmpdir / "ancestors.zarr",
14021404
1000,
14031405
)
1404-
tsinfer.match_ancestors_batch_groups(
1405-
tmpdir / "work", 0, len(metadata["ancestor_grouping"]) // 2, 2
1406-
)
1406+
num_groupings = len(metadata["ancestor_grouping"])
1407+
tsinfer.match_ancestors_batch_groups(tmpdir / "work", 0, num_groupings // 2, 2)
14071408
tsinfer.match_ancestors_batch_groups(
14081409
tmpdir / "work",
1409-
len(metadata["ancestor_grouping"]) // 2,
1410-
len(metadata["ancestor_grouping"]),
1410+
num_groupings // 2,
1411+
num_groupings,
14111412
2,
14121413
)
1413-
# TODO Check which ones written to disk
1414+
assert (tmpdir / "work" / f"ancestors_{(num_groupings//2)-1}.trees").exists()
1415+
assert (tmpdir / "work" / f"ancestors_{num_groupings-1}.trees").exists()
14141416
ts = tsinfer.match_ancestors_batch_finalise(tmpdir / "work")
14151417
ts2 = tsinfer.match_ancestors(samples, ancestors)
14161418
ts.tables.assert_equals(ts2.tables, ignore_provenance=True)
@@ -1438,6 +1440,11 @@ def test_equivalance_with_partitions(self, tmp_path, tmpdir):
14381440
tsinfer.match_ancestors_batch_group_partition(
14391441
tmpdir / "work", group_index, p_index
14401442
)
1443+
with pytest.raises(ValueError, match="out of range"):
1444+
tsinfer.match_ancestors_batch_group_partition(
1445+
tmpdir / "work", group_index, p_index + 1000
1446+
)
1447+
14411448
ts = tsinfer.match_ancestors_batch_group_finalise(
14421449
tmpdir / "work", group_index
14431450
)
@@ -1523,6 +1530,34 @@ def test_errors(self, tmp_path, tmpdir):
15231530
with pytest.raises(ValueError, match="sequence length is different"):
15241531
tsinfer.match_ancestors_batch_groups(tmpdir / "work", 2, 3)
15251532

1533+
def test_low_min_work_per_job(self, tmp_path, tmpdir):
1534+
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
1535+
samples = tsinfer.VariantData(zarr_path, "variant_ancestral_allele")
1536+
_ = tsinfer.generate_ancestors(samples, path=str(tmpdir / "ancestors.zarr"))
1537+
metadata = tsinfer.match_ancestors_batch_init(
1538+
tmpdir / "work",
1539+
zarr_path,
1540+
"variant_ancestral_allele",
1541+
tmpdir / "ancestors.zarr",
1542+
min_work_per_job=1,
1543+
max_num_partitions=2,
1544+
)
1545+
for group in metadata["ancestor_grouping"]:
1546+
assert group["partitions"] is None or len(group["partitions"]) <= 2
1547+
1548+
metadata = tsinfer.match_ancestors_batch_init(
1549+
tmpdir / "work2",
1550+
zarr_path,
1551+
"variant_ancestral_allele",
1552+
tmpdir / "ancestors.zarr",
1553+
min_work_per_job=1,
1554+
max_num_partitions=20000,
1555+
)
1556+
for group in metadata["ancestor_grouping"]:
1557+
if group["partitions"] is not None:
1558+
for partition in group["partitions"]:
1559+
assert len(partition) == 1
1560+
15261561

15271562
@pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on windows")
15281563
class TestBatchSampleMatching:
@@ -1543,8 +1578,8 @@ def test_match_samples_batch(self, tmp_path, tmpdir):
15431578
ancestral_state="variant_ancestral_allele",
15441579
ancestor_ts_path=tmpdir / "mat_anc.trees",
15451580
min_work_per_job=1,
1546-
max_num_partitions=10,
15471581
)
1582+
assert mat_wd.num_partitions == mat_sd.num_samples
15481583
for i in range(mat_wd.num_partitions):
15491584
tsinfer.match_samples_batch_partition(
15501585
work_dir=tmpdir / "working_mat",
@@ -1564,7 +1599,6 @@ def test_match_samples_batch(self, tmp_path, tmpdir):
15641599
ancestral_state="variant_ancestral_allele",
15651600
ancestor_ts_path=tmpdir / "mask_anc.trees",
15661601
min_work_per_job=1,
1567-
max_num_partitions=10,
15681602
site_mask="variant_mask_foobar",
15691603
sample_mask="samples_mask_foobar",
15701604
)
@@ -1588,6 +1622,82 @@ def test_match_samples_batch(self, tmp_path, tmpdir):
15881622
mat_ts_batch.tables, ignore_timestamps=True, ignore_provenance=True
15891623
)
15901624

1625+
def test_force_sample_times(self, tmp_path, tmpdir):
1626+
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
1627+
ds = sgkit.load_dataset(zarr_path)
1628+
array = [0.0001] * ts.num_individuals
1629+
ds.update(
1630+
{
1631+
"individuals_time": xr.DataArray(
1632+
data=array, dims=["sample"], name="individuals_time"
1633+
)
1634+
}
1635+
)
1636+
sgkit.save_dataset(
1637+
ds.drop_vars(set(ds.data_vars) - {"individuals_time"}), zarr_path, mode="a"
1638+
)
1639+
samples = tsinfer.VariantData(zarr_path, "variant_ancestral_allele")
1640+
anc = tsinfer.generate_ancestors(samples, path=str(tmpdir / "ancestors.zarr"))
1641+
anc_ts = tsinfer.match_ancestors(samples, anc)
1642+
anc_ts.dump(tmpdir / "anc.trees")
1643+
1644+
wd = tsinfer.match_samples_batch_init(
1645+
work_dir=tmpdir / "working",
1646+
sample_data_path=samples.path,
1647+
ancestral_state="variant_ancestral_allele",
1648+
ancestor_ts_path=tmpdir / "anc.trees",
1649+
min_work_per_job=1e6,
1650+
force_sample_times=True,
1651+
)
1652+
for i in range(wd.num_partitions):
1653+
tsinfer.match_samples_batch_partition(
1654+
work_dir=tmpdir / "working",
1655+
partition_index=i,
1656+
)
1657+
ts_batch = tsinfer.match_samples_batch_finalise(tmpdir / "working")
1658+
ts = tsinfer.match_samples(samples, anc_ts, force_sample_times=True)
1659+
ts.tables.assert_equals(ts_batch.tables, ignore_provenance=True)
1660+
1661+
def test_array_args(self, tmp_path, tmpdir):
1662+
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
1663+
sample_mask = np.zeros(ts.num_individuals, dtype=bool)
1664+
sample_mask[42] = True
1665+
site_mask = np.zeros(ts.num_sites, dtype=bool)
1666+
site_mask[42] = True
1667+
rng = np.random.RandomState(42)
1668+
sites_time = rng.uniform(0, 1, ts.num_sites - 1)
1669+
samples = tsinfer.VariantData(
1670+
zarr_path,
1671+
"variant_ancestral_allele",
1672+
sample_mask=sample_mask,
1673+
site_mask=site_mask,
1674+
sites_time=sites_time,
1675+
)
1676+
anc = tsinfer.generate_ancestors(samples, path=str(tmpdir / "ancestors.zarr"))
1677+
anc_ts = tsinfer.match_ancestors(samples, anc)
1678+
anc_ts.dump(tmpdir / "anc.trees")
1679+
1680+
wd = tsinfer.match_samples_batch_init(
1681+
work_dir=tmpdir / "working",
1682+
sample_data_path=samples.path,
1683+
sample_mask=sample_mask,
1684+
site_mask=site_mask,
1685+
ancestral_state="variant_ancestral_allele",
1686+
ancestor_ts_path=tmpdir / "anc.trees",
1687+
min_work_per_job=1e6,
1688+
)
1689+
for i in range(wd.num_partitions):
1690+
tsinfer.match_samples_batch_partition(
1691+
work_dir=tmpdir / "working",
1692+
partition_index=i,
1693+
)
1694+
ts_batch = tsinfer.match_samples_batch_finalise(tmpdir / "working")
1695+
ts = tsinfer.match_samples(
1696+
samples,
1697+
anc_ts,
1698+
)
1699+
ts.tables.assert_equals(ts_batch.tables, ignore_provenance=True)
1700+
15911701

15921702
class TestAncestorGeneratorsEquivalant:
15931703
"""

tsinfer/inference.py

Lines changed: 34 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import json
2929
import logging
3030
import math
31+
import operator
3132
import os
3233
import pathlib
3334
import pickle
@@ -714,34 +715,35 @@ def match_ancestors_batch_init(
714715
for group_index, group_ancestors in matcher.group_by_linesweep().items():
715716
# Make ancestor_ids JSON serialisable
716717
group_ancestors = list(map(int, group_ancestors))
717-
partitions = []
718-
current_partition = []
719-
current_partition_work = 0
720-
# TODO: Can do better here by packing ancestors
721-
# into as equal sized partitions as possible
718+
# The first group is trivial so never partition
722719
if group_index == 0:
723-
partitions.append(group_ancestors)
720+
partitions = [group_ancestors]
724721
else:
725722
total_work = sum(ancestor_lengths[ancestor] for ancestor in group_ancestors)
726-
min_work_per_job_group = min_work_per_job
727-
if total_work / max_num_partitions > min_work_per_job:
728-
min_work_per_job_group = total_work / max_num_partitions
729-
for ancestor in group_ancestors:
730-
if (
731-
current_partition_work + ancestor_lengths[ancestor]
732-
> min_work_per_job_group
733-
):
734-
partitions.append(current_partition)
735-
current_partition = [ancestor]
736-
current_partition_work = ancestor_lengths[ancestor]
737-
else:
738-
current_partition.append(ancestor)
739-
current_partition_work += ancestor_lengths[ancestor]
740-
partitions.append(current_partition)
723+
partition_count = math.ceil(total_work / min_work_per_job)
724+
if partition_count > max_num_partitions:
725+
partition_count = max_num_partitions
726+
727+
# Partition into roughly equal sized bins (by work)
728+
sorted_ancestors = sorted(
729+
group_ancestors, key=lambda x: ancestor_lengths[x], reverse=True
730+
)
731+
732+
# Use greedy bin packing - place each ancestor in the bin with
733+
# lowest total length
734+
heap = [(0, []) for _ in range(partition_count)]
735+
for ancestor in sorted_ancestors:
736+
sum_len, partition = heapq.heappop(heap)
737+
partition.append(ancestor)
738+
sum_len += ancestor_lengths[ancestor]
739+
heapq.heappush(heap, (sum_len, partition))
740+
partitions = [
741+
sorted(partition) for sum_len, partition in heap if sum_len > 0
742+
]
741743
if len(partitions) > 1:
742744
group_dir = work_dir / f"group_{group_index}"
743745
group_dir.mkdir()
744-
# TODO: Should be a dataclass
746+
745747
group = {
746748
"ancestors": group_ancestors,
747749
"partitions": partitions if len(partitions) > 1 else None,
@@ -902,7 +904,7 @@ def match_ancestors_batch_group_partition(work_dir, group_index, partition_index
902904
)
903905
logger.info(f"Dumping to {partition_path}")
904906
with open(partition_path, "wb") as f:
905-
pickle.dump((start_time, timing.metrics, results), f)
907+
pickle.dump((start_time, timing.metrics, ancestors_to_match, results), f)
906908

907909

908910
def match_ancestors_batch_group_finalise(work_dir, group_index):
@@ -935,17 +937,18 @@ def match_ancestors_batch_group_finalise(work_dir, group_index):
935937
)
936938
start_times = []
937939
timings = []
938-
results = []
940+
results = {}
939941
for partition_index in range(len(group["partitions"])):
940942
partition_path = os.path.join(
941943
work_dir, f"group_{group_index}", f"partition_{partition_index}.pkl"
942944
)
943945
with open(partition_path, "rb") as f:
944-
start_time, part_timing, result = pickle.load(f)
946+
start_time, part_timing, ancestors, result = pickle.load(f)
945947
start_times.append(start_time)
946-
results.extend(result)
948+
for ancestor, r in zip(ancestors, result):
949+
results[ancestor] = r
947950
timings.append(part_timing)
948-
951+
results = list(map(operator.itemgetter(1), sorted(results.items())))
949952
ts = matcher.finalise_group(group, results, group_index)
950953
path = os.path.join(work_dir, f"ancestors_{group_index}.trees")
951954
ts.dump(path)
@@ -1186,7 +1189,6 @@ def match_samples_batch_init(
11861189
ancestor_ts_path,
11871190
min_work_per_job,
11881191
*,
1189-
max_num_partitions=None,
11901192
sample_mask=None,
11911193
site_mask=None,
11921194
recombination_rate=None,
@@ -1206,7 +1208,7 @@ def match_samples_batch_init(
12061208
):
12071209
"""
12081210
match_samples_batch_init(work_dir, sample_data_path, ancestral_state,
1209-
ancestor_ts_path, min_work_per_job, \\*, max_num_partitions=None,
1211+
ancestor_ts_path, min_work_per_job, \\*,
12101212
sample_mask=None, site_mask=None, recombination_rate=None, mismatch_ratio=None,
12111213
path_compression=True, indexes=None, post_process=None, force_sample_times=False)
12121214
@@ -1237,9 +1239,6 @@ def match_samples_batch_init(
12371239
genotypes) to allocate to a single parallel job. If the amount of work in
12381240
a group of samples exceeds this level it will be broken up into parallel
12391241
partitions, subject to the constraint of `max_num_partitions`.
1240-
:param int max_num_partitions: The maximum number of partitions to split a
1241-
group of samples into. Useful for limiting the number of jobs in a
1242-
workflow to avoid job overhead. Defaults to 1000.
12431242
:param Union(array, str) sample_mask: A numpy array of booleans specifying
12441243
which samples to mask out (exclude) from the dataset. Alternatively, a
12451244
string can be provided, giving the name of an array in the input dataset
@@ -1277,9 +1276,6 @@ def match_samples_batch_init(
12771276
:return: A dictionary of the job metadata, as written to `metadata.json` in
12781277
`work_dir`.
12791278
"""
1280-
if max_num_partitions is None:
1281-
max_num_partitions = 1000
1282-
12831279
# Convert working_dir to pathlib.Path
12841280
work_dir = pathlib.Path(work_dir)
12851281

@@ -1329,9 +1325,9 @@ def match_samples_batch_init(
13291325
sample_times = sample_times.tolist()
13301326
wd.sample_indexes = sample_indexes
13311327
wd.sample_times = sample_times
1332-
num_samples_per_partition = int(min_work_per_job // variant_data.num_sites)
1333-
if num_samples_per_partition == 0:
1334-
num_samples_per_partition = 1
1328+
num_samples_per_partition = max(
1329+
1, math.ceil(min_work_per_job // variant_data.num_sites)
1330+
)
13351331
wd.num_samples_per_partition = num_samples_per_partition
13361332
wd.num_partitions = math.ceil(len(sample_indexes) / num_samples_per_partition)
13371333
wd_path = work_dir / "metadata.json"

0 commit comments

Comments
 (0)