@@ -3863,15 +3863,26 @@ def load_text(
3863
3863
return tc .tree_sequence ()
3864
3864
3865
3865
3866
- class TreeIterator :
3867
- """
3868
- Simple class providing forward and backward iteration over a tree sequence.
3869
- """
3870
-
3871
- def __init__ (self , tree ):
3872
- self .tree = tree
3873
- self .more_trees = True
3866
+ class ObjectIterator :
3867
+ # Simple class providing forward and backward iteration over a
3868
+ # mutable object with ``next()`` and ``prev()`` methods, e.g.
3869
+ # a Tree or a Variant. ``interval`` allows the bounds of the
3870
+ # iterator to be specified, and should already have
3871
+ # been checked using _check_genomic_range(left, right)
3872
+ # If ``return_copies`` is True, the iterator will return
3873
+ # immutable copies of each object (this is likely to be significantly
3874
+ # less efficient).
3875
+ # It can be useful to define __len__ on one of these iterators,
3876
+ # which e.g. allows progress bars to provide useful feedback.
3877
+
3878
+ def __init__ (self , obj , interval , return_copies = False ):
3879
+ self ._obj = obj
3880
+ self .min_pos = interval [0 ]
3881
+ self .max_pos = interval [1 ]
3882
+ self .return_copies = return_copies
3874
3883
self .forward = True
3884
+ self .started = False
3885
+ self .finished = False
3875
3886
3876
3887
def __iter__ (self ):
3877
3888
return self
@@ -3880,17 +3891,113 @@ def __reversed__(self):
3880
3891
self .forward = False
3881
3892
return self
3882
3893
3894
+ def obj_left (self ):
3895
+ # Used to work out where to stop iterating when going backwards.
3896
+ # Override with code to return the left coordinate of self.obj
3897
+ raise NotImplementedError ()
3898
+
3899
+ def obj_right (self ):
3900
+ # Used to work out when to stop iterating when going forwards.
3901
+ # Override with code to return the right coordinate of self.obj
3902
+ raise NotImplementedError ()
3903
+
3904
+ def seek_to_start (self ):
3905
+ # Override to set the object position to self.min_pos
3906
+ raise NotImplementedError ()
3907
+
3908
+ def seek_to_end (self ):
3909
+ # Override to set the object position just before self.max_pos
3910
+ raise NotImplementedError ()
3911
+
3883
3912
def __next__ (self ):
3884
- if self .forward :
3885
- self .more_trees = self .more_trees and self .tree .next ()
3886
- else :
3887
- self .more_trees = self .more_trees and self .tree .prev ()
3888
- if not self .more_trees :
3913
+ if not self .finished :
3914
+ if not self .started :
3915
+ if self .forward :
3916
+ self .seek_to_start ()
3917
+ else :
3918
+ self .seek_to_end ()
3919
+ self .started = True
3920
+ else :
3921
+ if self .forward :
3922
+ if not self ._obj .next () or self .obj_left () >= self .max_pos :
3923
+ self .finished = True
3924
+ else :
3925
+ if not self ._obj .prev () or self .obj_right () < self .min_pos :
3926
+ self .finished = True
3927
+ if self .finished :
3889
3928
raise StopIteration ()
3890
- return self .tree
3929
+ return self ._obj .copy () if self .return_copies else self ._obj
3930
+
3931
+
3932
+ class TreeIterator (ObjectIterator ):
3933
+ """
3934
+ An iterator over some or all of the :class:`trees<Tree>`
3935
+ in a :class:`TreeSequence`.
3936
+ """
3937
+
3938
+ def obj_left (self ):
3939
+ return self ._obj .interval .left
3940
+
3941
+ def obj_right (self ):
3942
+ return self ._obj .interval .right
3943
+
3944
+ def seek_to_start (self ):
3945
+ self ._obj .seek (self .min_pos )
3946
+
3947
+ def seek_to_end (self ):
3948
+ self ._obj .seek (np .nextafter (self .max_pos , - np .inf ))
3891
3949
3892
3950
def __len__ (self ):
3893
- return self .tree .tree_sequence .num_trees
3951
+ """
3952
+ The number of trees over which a newly created iterator will iterate.
3953
+ """
3954
+ ts = self ._obj .tree_sequence
3955
+ if self .min_pos == 0 and self .max_pos == ts .sequence_length :
3956
+ # a common case: don't incur the cost of searching through the breakpoints
3957
+ return ts .num_trees
3958
+ breaks = ts .breakpoints (as_array = True )
3959
+ left_index = breaks .searchsorted (self .min_pos , side = "right" )
3960
+ right_index = breaks .searchsorted (self .max_pos , side = "left" )
3961
+ return right_index - left_index + 1
3962
+
3963
+
3964
+ class VariantIterator (ObjectIterator ):
3965
+ """
3966
+ An iterator over some or all of the :class:`variants<Variant>`
3967
+ in a :class:`TreeSequence`.
3968
+ """
3969
+
3970
+ def __init__ (self , variant , interval , copy ):
3971
+ super ().__init__ (variant , interval , copy )
3972
+ if interval [0 ] == 0 and interval [1 ] == variant .tree_sequence .sequence_length :
3973
+ # a common case: don't incur the cost of searching through the positions
3974
+ self .min_max_sites = [0 , variant .tree_sequence .num_sites ]
3975
+ else :
3976
+ self .min_max_sites = variant .tree_sequence .sites_position .searchsorted (
3977
+ interval
3978
+ )
3979
+ if self .min_max_sites [0 ] >= self .min_max_sites [1 ]:
3980
+ # upper bound is exclusive: we don't include the site at self.bound[1]
3981
+ self .finished = True
3982
+
3983
+ def obj_left (self ):
3984
+ return self ._obj .site .position
3985
+
3986
+ def obj_right (self ):
3987
+ return self ._obj .site .position
3988
+
3989
+ def seek_to_start (self ):
3990
+ self ._obj .decode (self .min_max_sites [0 ])
3991
+
3992
+ def seek_to_end (self ):
3993
+ self ._obj .decode (self .min_max_sites [1 ] - 1 )
3994
+
3995
+ def __len__ (self ):
3996
+ """
3997
+ The number of variants (i.e. sites) over which a newly created iterator will
3998
+ iterate.
3999
+ """
4000
+ return self .min_max_sites [1 ] - self .min_max_sites [0 ]
3894
4001
3895
4002
3896
4003
class SimpleContainerSequence :
@@ -4077,7 +4184,7 @@ def aslist(self, **kwargs):
4077
4184
:return: A list of the trees in this tree sequence.
4078
4185
:rtype: list
4079
4186
"""
4080
- return [tree . copy () for tree in self .trees (** kwargs )]
4187
+ return [tree for tree in self .trees (copy = True , ** kwargs )]
4081
4188
4082
4189
@classmethod
4083
4190
def load (cls , file_or_path , * , skip_tables = False , skip_reference_sequence = False ):
@@ -4970,6 +5077,9 @@ def trees(
4970
5077
sample_lists = False ,
4971
5078
root_threshold = 1 ,
4972
5079
sample_counts = None ,
5080
+ left = None ,
5081
+ right = None ,
5082
+ copy = None ,
4973
5083
tracked_leaves = None ,
4974
5084
leaf_counts = None ,
4975
5085
leaf_lists = None ,
@@ -5001,28 +5111,39 @@ def trees(
5001
5111
are roots. To efficiently restrict the roots of the tree to
5002
5112
those subtending meaningful topology, set this to 2. This value
5003
5113
is only relevant when trees have multiple roots.
5114
+ :param float left: The left-most coordinate of the region over which
5115
+ to iterate. Default: ``None`` treated as 0.
5116
+ :param float right: The right-most coordinate of the region over which
5117
+ to iterate. Default: ``None`` treated as ``.sequence_length``. This
5118
+ value is exclusive, so that a tree whose ``interval.left`` is exactly
5119
+ equivalent to ``right`` will not be included in the iteration.
5120
+ :param bool copy: Return a immutable copy of each tree. This will be
5121
+ inefficient. Default: ``None`` treated as False.
5004
5122
:param bool sample_counts: Deprecated since 0.2.4.
5005
5123
:return: An iterator over the Trees in this tree sequence.
5006
- :rtype: collections.abc.Iterable, :class:`Tree`
5124
+ :rtype: TreeIterator
5007
5125
"""
5008
5126
# tracked_leaves, leaf_counts and leaf_lists are deprecated aliases
5009
5127
# for tracked_samples, sample_counts and sample_lists respectively.
5010
5128
# These are left over from an older version of the API when leaves
5011
5129
# and samples were synonymous.
5130
+ interval = self ._check_genomic_range (left , right )
5012
5131
if tracked_leaves is not None :
5013
5132
tracked_samples = tracked_leaves
5014
5133
if leaf_counts is not None :
5015
5134
sample_counts = leaf_counts
5016
5135
if leaf_lists is not None :
5017
5136
sample_lists = leaf_lists
5137
+ if copy is None :
5138
+ copy = False
5018
5139
tree = Tree (
5019
5140
self ,
5020
5141
tracked_samples = tracked_samples ,
5021
5142
sample_lists = sample_lists ,
5022
5143
root_threshold = root_threshold ,
5023
5144
sample_counts = sample_counts ,
5024
5145
)
5025
- return TreeIterator (tree )
5146
+ return TreeIterator (tree , interval = interval , return_copies = copy )
5026
5147
5027
5148
def coiterate (self , other , ** kwargs ):
5028
5149
"""
@@ -5309,8 +5430,8 @@ def variants(
5309
5430
:param int right: End with the last site before this position. If ``None``
5310
5431
(default) assume ``right`` is the sequence length, so that the last
5311
5432
variant corresponds to the last site in the tree sequence.
5312
- :return: An iterator over all variants in this tree sequence.
5313
- :rtype: iter(:class:`Variant`)
5433
+ :return: An iterator over the specified variants in this tree sequence.
5434
+ :rtype: VariantIterator
5314
5435
"""
5315
5436
interval = self ._check_genomic_range (left , right )
5316
5437
if impute_missing_data is not None :
@@ -5327,26 +5448,13 @@ def variants(
5327
5448
copy = True
5328
5449
# See comments for the Variant type for discussion on why the
5329
5450
# present form was chosen.
5330
- variant = tskit .Variant (
5451
+ variant_object = tskit .Variant (
5331
5452
self ,
5332
5453
samples = samples ,
5333
5454
isolated_as_missing = isolated_as_missing ,
5334
5455
alleles = alleles ,
5335
5456
)
5336
- if left == 0 and right == self .sequence_length :
5337
- start = 0
5338
- stop = self .num_sites
5339
- else :
5340
- start , stop = np .searchsorted (self .sites_position , interval )
5341
-
5342
- if copy :
5343
- for site_id in range (start , stop ):
5344
- variant .decode (site_id )
5345
- yield variant .copy ()
5346
- else :
5347
- for site_id in range (start , stop ):
5348
- variant .decode (site_id )
5349
- yield variant
5457
+ return VariantIterator (variant_object , interval = interval , copy = copy )
5350
5458
5351
5459
def genotype_matrix (
5352
5460
self ,
0 commit comments