Skip to content

Commit f09d07e

Browse files
committed
Generic iterator class for trees and variants
Also allows left and right to be passed to the trees iterator.
1 parent fd72573 commit f09d07e

File tree

4 files changed

+237
-40
lines changed

4 files changed

+237
-40
lines changed

python/tests/test_genotypes.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -661,7 +661,9 @@ def test_simple_case(self, ts_fixture):
661661
ts = ts_fixture
662662
test_variant = tskit.Variant(ts)
663663
test_variant.decode(1)
664-
for v in ts.variants(left=ts.site(1).position, right=ts.site(2).position):
664+
v_iter = ts.variants(left=ts.site(1).position, right=ts.site(2).position)
665+
assert len(v_iter) == 1
666+
for v in v_iter:
665667
# should only decode the first variant
666668
assert v.site.id == 1
667669
assert np.all(v.genotypes == test_variant.genotypes)
@@ -686,7 +688,9 @@ def test_left(self, left, expected):
686688
for x in range(int(tables.sequence_length)):
687689
tables.sites.add_row(position=x, ancestral_state="A")
688690
ts = tables.tree_sequence()
689-
positions = [var.site.position for var in ts.variants(left=left)]
691+
v_iter = ts.variants(left=left)
692+
assert len(v_iter) == len(expected)
693+
positions = [var.site.position for var in v_iter]
690694
assert positions == expected
691695

692696
@pytest.mark.parametrize(
@@ -706,7 +710,9 @@ def test_right(self, right, expected):
706710
for x in range(int(tables.sequence_length)):
707711
tables.sites.add_row(position=x, ancestral_state="A")
708712
ts = tables.tree_sequence()
709-
positions = [var.site.position for var in ts.variants(right=right)]
713+
v_iter = ts.variants(right=right)
714+
assert len(v_iter) == len(expected)
715+
positions = [var.site.position for var in v_iter]
710716
assert positions == expected
711717

712718
@pytest.mark.parametrize("bad_left", [-1, 10, 100, np.nan, np.inf, -np.inf])

python/tests/test_highlevel.py

+54-2
Original file line numberDiff line numberDiff line change
@@ -1724,6 +1724,48 @@ def test_trees_interface(self):
17241724
assert t.get_num_tracked_samples(0) == 0
17251725
assert list(t.samples(0)) == [0]
17261726

1727+
def test_trees_bad_left_right(self):
1728+
ts = tskit.Tree.generate_balanced(10, span=1).tree_sequence
1729+
with pytest.raises(ValueError):
1730+
ts.trees(left=0.5, right=0.5)
1731+
with pytest.raises(ValueError):
1732+
ts.trees(left=0.5, right=0.4)
1733+
with pytest.raises(ValueError):
1734+
ts.trees(left=0.5, right=1.1)
1735+
with pytest.raises(ValueError):
1736+
ts.trees(left=-0.1, right=0.1)
1737+
with pytest.raises(ValueError):
1738+
ts.trees(left=1, right=1.5)
1739+
1740+
def test_trees_left_right_one_tree(self):
1741+
ts = tskit.Tree.generate_balanced(10).tree_sequence
1742+
tree_iterator = ts.trees(left=0.5, right=0.6)
1743+
assert len(tree_iterator) == 1
1744+
trees = [tree.copy() for tree in tree_iterator]
1745+
assert len(trees) == 1
1746+
tree_iterator = reversed(ts.trees(left=0.5, right=0.6))
1747+
assert len(tree_iterator) == 1
1748+
assert trees[0] == ts.first()
1749+
trees = [tree.copy() for tree in tree_iterator]
1750+
assert len(trees) == 1
1751+
assert trees[0] == ts.first()
1752+
1753+
@pytest.mark.parametrize(
1754+
"interval", [(0, 0.5), (0.4, 0.6), (0.5, np.nextafter(0.5, 1)), (0.5, 1)]
1755+
)
1756+
def test_trees_left_right_many_trees(self, interval):
1757+
ts = msprime.simulate(5, recombination_rate=10, random_seed=1)
1758+
assert ts.num_trees > 10
1759+
tree_iter = ts.trees(left=interval[0], right=interval[1])
1760+
expected_length = len(tree_iter)
1761+
n_trees = 0
1762+
for tree in ts.trees():
1763+
# check if the tree is within the interval
1764+
if tree.interval[1] > interval[0] and tree.interval[0] < interval[1]:
1765+
n_trees += 1
1766+
assert tree.interval == next(tree_iter).interval
1767+
assert n_trees == expected_length
1768+
17271769
@pytest.mark.parametrize("ts", get_example_tree_sequences())
17281770
def test_get_pairwise_diversity(self, ts):
17291771
with pytest.raises(ValueError, match="at least one element"):
@@ -2994,8 +3036,18 @@ def test_trees_params(self):
29943036
)
29953037
# Skip the first param, which is `tree_sequence` and `self` respectively
29963038
tree_class_params = tree_class_params[1:]
2997-
# The trees iterator has some extra (deprecated) aliases
2998-
trees_iter_params = trees_iter_params[1:-3]
3039+
# The trees iterator has some extra (deprecated) aliases at the end
3040+
num_deprecated = 3
3041+
trees_iter_params = trees_iter_params[1:-num_deprecated]
3042+
3043+
# The trees iterator also has left/right/copy params which aren't in __init__()
3044+
assert trees_iter_params[-1][0] == "copy"
3045+
trees_iter_params = trees_iter_params[:-1]
3046+
assert trees_iter_params[-1][0] == "right"
3047+
trees_iter_params = trees_iter_params[:-1]
3048+
assert trees_iter_params[-1][0] == "left"
3049+
trees_iter_params = trees_iter_params[:-1]
3050+
29993051
assert trees_iter_params == tree_class_params
30003052

30013053

python/tskit/genotypes.py

+30
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,36 @@ def decode(self, site_id) -> None:
233233
"""
234234
self._ll_variant.decode(site_id)
235235

236+
def next(self): # noqa A002
237+
"""
238+
Decode the variant at the next site, returning True if successful, False
239+
if the variant is already at the last site. If the variant has not yet been
240+
decoded, decode the variant at the first site.
241+
"""
242+
if self._ll_variant.site_id == self.tree_sequence.num_sites - 1:
243+
# TODO: should also set the variant to the null state
244+
return False
245+
if self._ll_variant.site_id == tskit.NULL:
246+
self.decode(0)
247+
else:
248+
self.decode(self._ll_variant.site_id + 1)
249+
return True
250+
251+
def prev(self):
252+
"""
253+
Decode the variant at the previous site, returning True if successful, False
254+
if the variant is already at the first site. If the variant has not yet been
255+
decoded at any site, decode the variant at the last site.
256+
"""
257+
if self._ll_variant.site_id == 0:
258+
# TODO: should also set the variant to the null state
259+
return False
260+
if self._ll_variant.site_id == tskit.NULL:
261+
self.decode(self.tree_sequence.num_sites - 1)
262+
else:
263+
self.decode(self._ll_variant.site_id - 1)
264+
return True
265+
236266
def copy(self) -> Variant:
237267
"""
238268
Create a copy of this Variant. Note that calling :meth:`decode` on the

python/tskit/trees.py

+144-35
Original file line numberDiff line numberDiff line change
@@ -3863,15 +3863,26 @@ def load_text(
38633863
return tc.tree_sequence()
38643864

38653865

3866-
class TreeIterator:
3867-
"""
3868-
Simple class providing forward and backward iteration over a tree sequence.
3869-
"""
3870-
3871-
def __init__(self, tree):
3872-
self.tree = tree
3873-
self.more_trees = True
3866+
class ObjectIterator:
3867+
# Simple class providing forward and backward iteration over a
3868+
# mutable object with ``next()`` and ``prev()`` methods, e.g.
3869+
# a Tree or a Variant. ``interval`` allows the bounds of the
3870+
# iterator to be specified, and should already have
3871+
# been checked using _check_genomic_range(left, right)
3872+
# If ``return_copies`` is True, the iterator will return
3873+
# immutable copies of each object (this is likely to be significantly
3874+
# less efficient).
3875+
# It can be useful to define __len__ on one of these iterators,
3876+
# which e.g. allows progress bars to provide useful feedback.
3877+
3878+
def __init__(self, obj, interval, return_copies=False):
3879+
self._obj = obj
3880+
self.min_pos = interval[0]
3881+
self.max_pos = interval[1]
3882+
self.return_copies = return_copies
38743883
self.forward = True
3884+
self.started = False
3885+
self.finished = False
38753886

38763887
def __iter__(self):
38773888
return self
@@ -3880,17 +3891,114 @@ def __reversed__(self):
38803891
self.forward = False
38813892
return self
38823893

3894+
def obj_left(self):
3895+
# Used to work out where to stop iterating when going backwards.
3896+
# Override with code to return the left coordinate of self.obj
3897+
raise NotImplementedError()
3898+
3899+
def obj_right(self):
3900+
# Used to work out when to stop iterating when going forwards.
3901+
# Override with code to return the right coordinate of self.obj
3902+
raise NotImplementedError()
3903+
3904+
def seek_to_start(self):
3905+
# Override to set the object position to self.min_pos
3906+
raise NotImplementedError()
3907+
3908+
def seek_to_end(self):
3909+
# Override to set the object position just before self.max_pos
3910+
raise NotImplementedError()
3911+
38833912
def __next__(self):
3884-
if self.forward:
3885-
self.more_trees = self.more_trees and self.tree.next()
3886-
else:
3887-
self.more_trees = self.more_trees and self.tree.prev()
3888-
if not self.more_trees:
3913+
if not self.finished:
3914+
if not self.started:
3915+
if self.forward:
3916+
self.seek_to_start()
3917+
else:
3918+
self.seek_to_end()
3919+
self.started = True
3920+
else:
3921+
if self.forward:
3922+
if not self._obj.next() or self.obj_left() >= self.max_pos:
3923+
print("fwd", self.obj_left(), self.min_pos)
3924+
self.finished = True
3925+
else:
3926+
if not self._obj.prev() or self.obj_right() < self.min_pos:
3927+
self.finished = True
3928+
if self.finished:
38893929
raise StopIteration()
3890-
return self.tree
3930+
return self._obj.copy() if self.return_copies else self._obj
3931+
3932+
3933+
class TreeIterator(ObjectIterator):
3934+
"""
3935+
An iterator over some or all of the :class:`trees<Tree>`
3936+
in a :class:`TreeSequence`.
3937+
"""
3938+
3939+
def obj_left(self):
3940+
return self._obj.interval.left
3941+
3942+
def obj_right(self):
3943+
return self._obj.interval.right
3944+
3945+
def seek_to_start(self):
3946+
self._obj.seek(self.min_pos)
3947+
3948+
def seek_to_end(self):
3949+
self._obj.seek(np.nextafter(self.max_pos, -np.inf))
38913950

38923951
def __len__(self):
3893-
return self.tree.tree_sequence.num_trees
3952+
"""
3953+
The number of trees over which a newly created iterator will iterate.
3954+
"""
3955+
ts = self._obj.tree_sequence
3956+
if self.min_pos == 0 and self.max_pos == ts.sequence_length:
3957+
# a common case: don't incur the cost of searching through the breakpoints
3958+
return ts.num_trees
3959+
breaks = ts.breakpoints(as_array=True)
3960+
left_index = breaks.searchsorted(self.min_pos, side="right")
3961+
right_index = breaks.searchsorted(self.max_pos, side="left")
3962+
return right_index - left_index + 1
3963+
3964+
3965+
class VariantIterator(ObjectIterator):
3966+
"""
3967+
An iterator over some or all of the :class:`variants<Variant>`
3968+
in a :class:`TreeSequence`.
3969+
"""
3970+
3971+
def __init__(self, variant, interval, copy):
3972+
super().__init__(variant, interval, copy)
3973+
if interval[0] == 0 and interval[1] == variant.tree_sequence.sequence_length:
3974+
# a common case: don't incur the cost of searching through the positions
3975+
self.min_max_sites = [0, variant.tree_sequence.num_sites]
3976+
else:
3977+
self.min_max_sites = variant.tree_sequence.sites_position.searchsorted(
3978+
interval
3979+
)
3980+
if self.min_max_sites[0] >= self.min_max_sites[1]:
3981+
# upper bound is exclusive: we don't include the site at self.bound[1]
3982+
self.finished = True
3983+
3984+
def obj_left(self):
3985+
return self._obj.site.position
3986+
3987+
def obj_right(self):
3988+
return self._obj.site.position
3989+
3990+
def seek_to_start(self):
3991+
self._obj.decode(self.min_max_sites[0])
3992+
3993+
def seek_to_end(self):
3994+
self._obj.decode(self.min_max_sites[1] - 1)
3995+
3996+
def __len__(self):
3997+
"""
3998+
The number of variants (i.e. sites) over which a newly created iterator will
3999+
iterate.
4000+
"""
4001+
return self.min_max_sites[1] - self.min_max_sites[0]
38944002

38954003

38964004
class SimpleContainerSequence:
@@ -4077,7 +4185,7 @@ def aslist(self, **kwargs):
40774185
:return: A list of the trees in this tree sequence.
40784186
:rtype: list
40794187
"""
4080-
return [tree.copy() for tree in self.trees(**kwargs)]
4188+
return [tree for tree in self.trees(copy=True, **kwargs)]
40814189

40824190
@classmethod
40834191
def load(cls, file_or_path, *, skip_tables=False, skip_reference_sequence=False):
@@ -4970,6 +5078,9 @@ def trees(
49705078
sample_lists=False,
49715079
root_threshold=1,
49725080
sample_counts=None,
5081+
left=None,
5082+
right=None,
5083+
copy=None,
49735084
tracked_leaves=None,
49745085
leaf_counts=None,
49755086
leaf_lists=None,
@@ -5001,28 +5112,39 @@ def trees(
50015112
are roots. To efficiently restrict the roots of the tree to
50025113
those subtending meaningful topology, set this to 2. This value
50035114
is only relevant when trees have multiple roots.
5115+
:param float left: The left-most coordinate of the region over which
5116+
to iterate. Default: ``None`` treated as 0.
5117+
:param float right: The right-most coordinate of the region over which
5118+
to iterate. Default: ``None`` treated as ``.sequence_length``. This
5119+
value is exclusive, so that a tree whose ``interval.left`` is exactly
5120+
equivalent to ``right`` will not be included in the iteration.
5121+
:param bool copy: Return a immutable copy of each tree. This will be
5122+
inefficient. Default: ``None`` treated as False.
50045123
:param bool sample_counts: Deprecated since 0.2.4.
50055124
:return: An iterator over the Trees in this tree sequence.
5006-
:rtype: collections.abc.Iterable, :class:`Tree`
5125+
:rtype: TreeIterator
50075126
"""
50085127
# tracked_leaves, leaf_counts and leaf_lists are deprecated aliases
50095128
# for tracked_samples, sample_counts and sample_lists respectively.
50105129
# These are left over from an older version of the API when leaves
50115130
# and samples were synonymous.
5131+
interval = self._check_genomic_range(left, right)
50125132
if tracked_leaves is not None:
50135133
tracked_samples = tracked_leaves
50145134
if leaf_counts is not None:
50155135
sample_counts = leaf_counts
50165136
if leaf_lists is not None:
50175137
sample_lists = leaf_lists
5138+
if copy is None:
5139+
copy = False
50185140
tree = Tree(
50195141
self,
50205142
tracked_samples=tracked_samples,
50215143
sample_lists=sample_lists,
50225144
root_threshold=root_threshold,
50235145
sample_counts=sample_counts,
50245146
)
5025-
return TreeIterator(tree)
5147+
return TreeIterator(tree, interval=interval, return_copies=copy)
50265148

50275149
def coiterate(self, other, **kwargs):
50285150
"""
@@ -5309,8 +5431,8 @@ def variants(
53095431
:param int right: End with the last site before this position. If ``None``
53105432
(default) assume ``right`` is the sequence length, so that the last
53115433
variant corresponds to the last site in the tree sequence.
5312-
:return: An iterator over all variants in this tree sequence.
5313-
:rtype: iter(:class:`Variant`)
5434+
:return: An iterator over the specified variants in this tree sequence.
5435+
:rtype: VariantIterator
53145436
"""
53155437
interval = self._check_genomic_range(left, right)
53165438
if impute_missing_data is not None:
@@ -5327,26 +5449,13 @@ def variants(
53275449
copy = True
53285450
# See comments for the Variant type for discussion on why the
53295451
# present form was chosen.
5330-
variant = tskit.Variant(
5452+
variant_object = tskit.Variant(
53315453
self,
53325454
samples=samples,
53335455
isolated_as_missing=isolated_as_missing,
53345456
alleles=alleles,
53355457
)
5336-
if left == 0 and right == self.sequence_length:
5337-
start = 0
5338-
stop = self.num_sites
5339-
else:
5340-
start, stop = np.searchsorted(self.sites_position, interval)
5341-
5342-
if copy:
5343-
for site_id in range(start, stop):
5344-
variant.decode(site_id)
5345-
yield variant.copy()
5346-
else:
5347-
for site_id in range(start, stop):
5348-
variant.decode(site_id)
5349-
yield variant
5458+
return VariantIterator(variant_object, interval=interval, copy=copy)
53505459

53515460
def genotype_matrix(
53525461
self,

0 commit comments

Comments
 (0)