35
35
import msprime
36
36
import numpy as np
37
37
import pytest
38
+ import sgkit
38
39
import tskit
39
40
import tsutil
41
+ import xarray as xr
40
42
from tskit import MetadataSchema
41
43
42
44
import _tsinfer
@@ -1401,16 +1403,16 @@ def test_equivalance_many_at_once(self, tmp_path, tmpdir):
1401
1403
tmpdir / "ancestors.zarr" ,
1402
1404
1000 ,
1403
1405
)
1404
- tsinfer .match_ancestors_batch_groups (
1405
- tmpdir / "work" , 0 , len (metadata ["ancestor_grouping" ]) // 2 , 2
1406
- )
1406
+ num_groupings = len (metadata ["ancestor_grouping" ])
1407
+ tsinfer .match_ancestors_batch_groups (tmpdir / "work" , 0 , num_groupings // 2 , 2 )
1407
1408
tsinfer .match_ancestors_batch_groups (
1408
1409
tmpdir / "work" ,
1409
- len ( metadata [ "ancestor_grouping" ]) // 2 ,
1410
- len ( metadata [ "ancestor_grouping" ]) ,
1410
+ num_groupings // 2 ,
1411
+ num_groupings ,
1411
1412
2 ,
1412
1413
)
1413
- # TODO Check which ones written to disk
1414
+ assert (tmpdir / "work" / f"ancestors_{ (num_groupings // 2 )- 1 } .trees" ).exists ()
1415
+ assert (tmpdir / "work" / f"ancestors_{ num_groupings - 1 } .trees" ).exists ()
1414
1416
ts = tsinfer .match_ancestors_batch_finalise (tmpdir / "work" )
1415
1417
ts2 = tsinfer .match_ancestors (samples , ancestors )
1416
1418
ts .tables .assert_equals (ts2 .tables , ignore_provenance = True )
@@ -1438,6 +1440,11 @@ def test_equivalance_with_partitions(self, tmp_path, tmpdir):
1438
1440
tsinfer .match_ancestors_batch_group_partition (
1439
1441
tmpdir / "work" , group_index , p_index
1440
1442
)
1443
+ with pytest .raises (ValueError , match = "out of range" ):
1444
+ tsinfer .match_ancestors_batch_group_partition (
1445
+ tmpdir / "work" , group_index , p_index + 1000
1446
+ )
1447
+
1441
1448
ts = tsinfer .match_ancestors_batch_group_finalise (
1442
1449
tmpdir / "work" , group_index
1443
1450
)
@@ -1523,6 +1530,34 @@ def test_errors(self, tmp_path, tmpdir):
1523
1530
with pytest .raises (ValueError , match = "sequence length is different" ):
1524
1531
tsinfer .match_ancestors_batch_groups (tmpdir / "work" , 2 , 3 )
1525
1532
1533
+ def test_low_min_work_per_job (self , tmp_path , tmpdir ):
1534
+ ts , zarr_path = tsutil .make_ts_and_zarr (tmp_path )
1535
+ samples = tsinfer .VariantData (zarr_path , "variant_ancestral_allele" )
1536
+ _ = tsinfer .generate_ancestors (samples , path = str (tmpdir / "ancestors.zarr" ))
1537
+ metadata = tsinfer .match_ancestors_batch_init (
1538
+ tmpdir / "work" ,
1539
+ zarr_path ,
1540
+ "variant_ancestral_allele" ,
1541
+ tmpdir / "ancestors.zarr" ,
1542
+ min_work_per_job = 1 ,
1543
+ max_num_partitions = 2 ,
1544
+ )
1545
+ for group in metadata ["ancestor_grouping" ]:
1546
+ assert group ["partitions" ] is None or len (group ["partitions" ]) <= 2
1547
+
1548
+ metadata = tsinfer .match_ancestors_batch_init (
1549
+ tmpdir / "work2" ,
1550
+ zarr_path ,
1551
+ "variant_ancestral_allele" ,
1552
+ tmpdir / "ancestors.zarr" ,
1553
+ min_work_per_job = 1 ,
1554
+ max_num_partitions = 20000 ,
1555
+ )
1556
+ for group in metadata ["ancestor_grouping" ]:
1557
+ if group ["partitions" ] is not None :
1558
+ for partition in group ["partitions" ]:
1559
+ assert len (partition ) == 1
1560
+
1526
1561
1527
1562
@pytest .mark .skipif (sys .platform == "win32" , reason = "No cyvcf2 on windows" )
1528
1563
class TestBatchSampleMatching :
@@ -1543,8 +1578,8 @@ def test_match_samples_batch(self, tmp_path, tmpdir):
1543
1578
ancestral_state = "variant_ancestral_allele" ,
1544
1579
ancestor_ts_path = tmpdir / "mat_anc.trees" ,
1545
1580
min_work_per_job = 1 ,
1546
- max_num_partitions = 10 ,
1547
1581
)
1582
+ assert mat_wd .num_partitions == mat_sd .num_samples
1548
1583
for i in range (mat_wd .num_partitions ):
1549
1584
tsinfer .match_samples_batch_partition (
1550
1585
work_dir = tmpdir / "working_mat" ,
@@ -1564,7 +1599,6 @@ def test_match_samples_batch(self, tmp_path, tmpdir):
1564
1599
ancestral_state = "variant_ancestral_allele" ,
1565
1600
ancestor_ts_path = tmpdir / "mask_anc.trees" ,
1566
1601
min_work_per_job = 1 ,
1567
- max_num_partitions = 10 ,
1568
1602
site_mask = "variant_mask_foobar" ,
1569
1603
sample_mask = "samples_mask_foobar" ,
1570
1604
)
@@ -1588,6 +1622,82 @@ def test_match_samples_batch(self, tmp_path, tmpdir):
1588
1622
mat_ts_batch .tables , ignore_timestamps = True , ignore_provenance = True
1589
1623
)
1590
1624
1625
+ def test_force_sample_times (self , tmp_path , tmpdir ):
1626
+ ts , zarr_path = tsutil .make_ts_and_zarr (tmp_path )
1627
+ ds = sgkit .load_dataset (zarr_path )
1628
+ array = [0.0001 ] * ts .num_individuals
1629
+ ds .update (
1630
+ {
1631
+ "individuals_time" : xr .DataArray (
1632
+ data = array , dims = ["sample" ], name = "individuals_time"
1633
+ )
1634
+ }
1635
+ )
1636
+ sgkit .save_dataset (
1637
+ ds .drop_vars (set (ds .data_vars ) - {"individuals_time" }), zarr_path , mode = "a"
1638
+ )
1639
+ samples = tsinfer .VariantData (zarr_path , "variant_ancestral_allele" )
1640
+ anc = tsinfer .generate_ancestors (samples , path = str (tmpdir / "ancestors.zarr" ))
1641
+ anc_ts = tsinfer .match_ancestors (samples , anc )
1642
+ anc_ts .dump (tmpdir / "anc.trees" )
1643
+
1644
+ wd = tsinfer .match_samples_batch_init (
1645
+ work_dir = tmpdir / "working" ,
1646
+ sample_data_path = samples .path ,
1647
+ ancestral_state = "variant_ancestral_allele" ,
1648
+ ancestor_ts_path = tmpdir / "anc.trees" ,
1649
+ min_work_per_job = 1e6 ,
1650
+ force_sample_times = True ,
1651
+ )
1652
+ for i in range (wd .num_partitions ):
1653
+ tsinfer .match_samples_batch_partition (
1654
+ work_dir = tmpdir / "working" ,
1655
+ partition_index = i ,
1656
+ )
1657
+ ts_batch = tsinfer .match_samples_batch_finalise (tmpdir / "working" )
1658
+ ts = tsinfer .match_samples (samples , anc_ts , force_sample_times = True )
1659
+ ts .tables .assert_equals (ts_batch .tables , ignore_provenance = True )
1660
+
1661
+ def test_array_args (self , tmp_path , tmpdir ):
1662
+ ts , zarr_path = tsutil .make_ts_and_zarr (tmp_path )
1663
+ sample_mask = np .zeros (ts .num_individuals , dtype = bool )
1664
+ sample_mask [42 ] = True
1665
+ site_mask = np .zeros (ts .num_sites , dtype = bool )
1666
+ site_mask [42 ] = True
1667
+ rng = np .random .RandomState (42 )
1668
+ sites_time = rng .uniform (0 , 1 , ts .num_sites - 1 )
1669
+ samples = tsinfer .VariantData (
1670
+ zarr_path ,
1671
+ "variant_ancestral_allele" ,
1672
+ sample_mask = sample_mask ,
1673
+ site_mask = site_mask ,
1674
+ sites_time = sites_time ,
1675
+ )
1676
+ anc = tsinfer .generate_ancestors (samples , path = str (tmpdir / "ancestors.zarr" ))
1677
+ anc_ts = tsinfer .match_ancestors (samples , anc )
1678
+ anc_ts .dump (tmpdir / "anc.trees" )
1679
+
1680
+ wd = tsinfer .match_samples_batch_init (
1681
+ work_dir = tmpdir / "working" ,
1682
+ sample_data_path = samples .path ,
1683
+ sample_mask = sample_mask ,
1684
+ site_mask = site_mask ,
1685
+ ancestral_state = "variant_ancestral_allele" ,
1686
+ ancestor_ts_path = tmpdir / "anc.trees" ,
1687
+ min_work_per_job = 1e6 ,
1688
+ )
1689
+ for i in range (wd .num_partitions ):
1690
+ tsinfer .match_samples_batch_partition (
1691
+ work_dir = tmpdir / "working" ,
1692
+ partition_index = i ,
1693
+ )
1694
+ ts_batch = tsinfer .match_samples_batch_finalise (tmpdir / "working" )
1695
+ ts = tsinfer .match_samples (
1696
+ samples ,
1697
+ anc_ts ,
1698
+ )
1699
+ ts .tables .assert_equals (ts_batch .tables , ignore_provenance = True )
1700
+
1591
1701
1592
1702
class TestAncestorGeneratorsEquivalant :
1593
1703
"""
0 commit comments