@@ -3863,15 +3863,26 @@ def load_text(
3863
3863
return tc .tree_sequence ()
3864
3864
3865
3865
3866
- class TreeIterator :
3867
- """
3868
- Simple class providing forward and backward iteration over a tree sequence.
3869
- """
3870
-
3871
- def __init__ (self , tree ):
3872
- self .tree = tree
3873
- self .more_trees = True
3866
+ class ObjectIterator :
3867
+ # Simple class providing forward and backward iteration over a
3868
+ # mutable object with ``next()`` and ``prev()`` methods, e.g.
3869
+ # a Tree or a Variant. ``interval`` allows the bounds of the
3870
+ # iterator to be specified, and should already have
3871
+ # been checked using _check_genomic_range(left, right)
3872
+ # If ``return_copies`` is True, the iterator will return
3873
+ # immutable copies of each object (this is likely to be significantly
3874
+ # less efficient).
3875
+ # It can be useful to define __len__ on one of these iterators,
3876
+ # which e.g. allows progress bars to provide useful feedback.
3877
+
3878
+ def __init__ (self , obj , interval , return_copies = False ):
3879
+ self ._obj = obj
3880
+ self .min_pos = interval [0 ]
3881
+ self .max_pos = interval [1 ]
3882
+ self .return_copies = return_copies
3874
3883
self .forward = True
3884
+ self .started = False
3885
+ self .finished = False
3875
3886
3876
3887
def __iter__ (self ):
3877
3888
return self
@@ -3880,17 +3891,114 @@ def __reversed__(self):
3880
3891
self .forward = False
3881
3892
return self
3882
3893
3894
+ def obj_left (self ):
3895
+ # Used to work out where to stop iterating when going backwards.
3896
+ # Override with code to return the left coordinate of self.obj
3897
+ raise NotImplementedError ()
3898
+
3899
+ def obj_right (self ):
3900
+ # Used to work out when to stop iterating when going forwards.
3901
+ # Override with code to return the right coordinate of self.obj
3902
+ raise NotImplementedError ()
3903
+
3904
+ def seek_to_start (self ):
3905
+ # Override to set the object position to self.min_pos
3906
+ raise NotImplementedError ()
3907
+
3908
+ def seek_to_end (self ):
3909
+ # Override to set the object position just before self.max_pos
3910
+ raise NotImplementedError ()
3911
+
3883
3912
def __next__ (self ):
3884
- if self .forward :
3885
- self .more_trees = self .more_trees and self .tree .next ()
3886
- else :
3887
- self .more_trees = self .more_trees and self .tree .prev ()
3888
- if not self .more_trees :
3913
+ if not self .finished :
3914
+ if not self .started :
3915
+ if self .forward :
3916
+ self .seek_to_start ()
3917
+ else :
3918
+ self .seek_to_end ()
3919
+ self .started = True
3920
+ else :
3921
+ if self .forward :
3922
+ if not self ._obj .next () or self .obj_left () >= self .max_pos :
3923
+ print ("fwd" , self .obj_left (), self .min_pos )
3924
+ self .finished = True
3925
+ else :
3926
+ if not self ._obj .prev () or self .obj_right () < self .min_pos :
3927
+ self .finished = True
3928
+ if self .finished :
3889
3929
raise StopIteration ()
3890
- return self .tree
3930
+ return self ._obj .copy () if self .return_copies else self ._obj
3931
+
3932
+
3933
+ class TreeIterator (ObjectIterator ):
3934
+ """
3935
+ An iterator over some or all of the :class:`trees<Tree>`
3936
+ in a :class:`TreeSequence`.
3937
+ """
3938
+
3939
+ def obj_left (self ):
3940
+ return self ._obj .interval .left
3941
+
3942
+ def obj_right (self ):
3943
+ return self ._obj .interval .right
3944
+
3945
+ def seek_to_start (self ):
3946
+ self ._obj .seek (self .min_pos )
3947
+
3948
+ def seek_to_end (self ):
3949
+ self ._obj .seek (np .nextafter (self .max_pos , - np .inf ))
3891
3950
3892
3951
def __len__ (self ):
3893
- return self .tree .tree_sequence .num_trees
3952
+ """
3953
+ The number of trees over which a newly created iterator will iterate.
3954
+ """
3955
+ ts = self ._obj .tree_sequence
3956
+ if self .min_pos == 0 and self .max_pos == ts .sequence_length :
3957
+ # a common case: don't incur the cost of searching through the breakpoints
3958
+ return ts .num_trees
3959
+ breaks = ts .breakpoints (as_array = True )
3960
+ left_index = breaks .searchsorted (self .min_pos , side = "right" )
3961
+ right_index = breaks .searchsorted (self .max_pos , side = "left" )
3962
+ return right_index - left_index + 1
3963
+
3964
+
3965
+ class VariantIterator (ObjectIterator ):
3966
+ """
3967
+ An iterator over some or all of the :class:`variants<Variant>`
3968
+ in a :class:`TreeSequence`.
3969
+ """
3970
+
3971
+ def __init__ (self , variant , interval , copy ):
3972
+ super ().__init__ (variant , interval , copy )
3973
+ if interval [0 ] == 0 and interval [1 ] == variant .tree_sequence .sequence_length :
3974
+ # a common case: don't incur the cost of searching through the positions
3975
+ self .min_max_sites = [0 , variant .tree_sequence .num_sites ]
3976
+ else :
3977
+ self .min_max_sites = variant .tree_sequence .sites_position .searchsorted (
3978
+ interval
3979
+ )
3980
+ if self .min_max_sites [0 ] >= self .min_max_sites [1 ]:
3981
+ # upper bound is exclusive: we don't include the site at self.bound[1]
3982
+ self .finished = True
3983
+
3984
+ def obj_left (self ):
3985
+ return self ._obj .site .position
3986
+
3987
+ def obj_right (self ):
3988
+ return self ._obj .site .position
3989
+
3990
+ def seek_to_start (self ):
3991
+ self ._obj .decode (self .min_max_sites [0 ])
3992
+
3993
+ def seek_to_end (self ):
3994
+ self ._obj .decode (self .min_max_sites [1 ] - 1 )
3995
+
3996
+ def __len__ (self ):
3997
+ """
3998
+ The number of variants (i.e. sites) over which a newly created iterator will
3999
+ iterate.
4000
+ """
4001
+ return self .min_max_sites [1 ] - self .min_max_sites [0 ]
3894
4002
3895
4003
3896
4004
class SimpleContainerSequence :
@@ -4077,7 +4185,7 @@ def aslist(self, **kwargs):
4077
4185
:return: A list of the trees in this tree sequence.
4078
4186
:rtype: list
4079
4187
"""
4080
- return [tree . copy () for tree in self .trees (** kwargs )]
4188
+ return [tree for tree in self .trees (copy = True , ** kwargs )]
4081
4189
4082
4190
@classmethod
4083
4191
def load (cls , file_or_path , * , skip_tables = False , skip_reference_sequence = False ):
@@ -4970,6 +5078,9 @@ def trees(
4970
5078
sample_lists = False ,
4971
5079
root_threshold = 1 ,
4972
5080
sample_counts = None ,
5081
+ left = None ,
5082
+ right = None ,
5083
+ copy = None ,
4973
5084
tracked_leaves = None ,
4974
5085
leaf_counts = None ,
4975
5086
leaf_lists = None ,
@@ -5001,28 +5112,39 @@ def trees(
5001
5112
are roots. To efficiently restrict the roots of the tree to
5002
5113
those subtending meaningful topology, set this to 2. This value
5003
5114
is only relevant when trees have multiple roots.
5115
+ :param float left: The left-most coordinate of the region over which
5116
+ to iterate. Default: ``None`` treated as 0.
5117
+ :param float right: The right-most coordinate of the region over which
5118
+ to iterate. Default: ``None`` treated as ``.sequence_length``. This
5119
+ value is exclusive, so that a tree whose ``interval.left`` is exactly
5120
+ equivalent to ``right`` will not be included in the iteration.
5121
+ :param bool copy: Return a immutable copy of each tree. This will be
5122
+ inefficient. Default: ``None`` treated as False.
5004
5123
:param bool sample_counts: Deprecated since 0.2.4.
5005
5124
:return: An iterator over the Trees in this tree sequence.
5006
- :rtype: collections.abc.Iterable, :class:`Tree`
5125
+ :rtype: TreeIterator
5007
5126
"""
5008
5127
# tracked_leaves, leaf_counts and leaf_lists are deprecated aliases
5009
5128
# for tracked_samples, sample_counts and sample_lists respectively.
5010
5129
# These are left over from an older version of the API when leaves
5011
5130
# and samples were synonymous.
5131
+ interval = self ._check_genomic_range (left , right )
5012
5132
if tracked_leaves is not None :
5013
5133
tracked_samples = tracked_leaves
5014
5134
if leaf_counts is not None :
5015
5135
sample_counts = leaf_counts
5016
5136
if leaf_lists is not None :
5017
5137
sample_lists = leaf_lists
5138
+ if copy is None :
5139
+ copy = False
5018
5140
tree = Tree (
5019
5141
self ,
5020
5142
tracked_samples = tracked_samples ,
5021
5143
sample_lists = sample_lists ,
5022
5144
root_threshold = root_threshold ,
5023
5145
sample_counts = sample_counts ,
5024
5146
)
5025
- return TreeIterator (tree )
5147
+ return TreeIterator (tree , interval = interval , return_copies = copy )
5026
5148
5027
5149
def coiterate (self , other , ** kwargs ):
5028
5150
"""
@@ -5309,8 +5431,8 @@ def variants(
5309
5431
:param int right: End with the last site before this position. If ``None``
5310
5432
(default) assume ``right`` is the sequence length, so that the last
5311
5433
variant corresponds to the last site in the tree sequence.
5312
- :return: An iterator over all variants in this tree sequence.
5313
- :rtype: iter(:class:`Variant`)
5434
+ :return: An iterator over the specified variants in this tree sequence.
5435
+ :rtype: VariantIterator
5314
5436
"""
5315
5437
interval = self ._check_genomic_range (left , right )
5316
5438
if impute_missing_data is not None :
@@ -5327,26 +5449,13 @@ def variants(
5327
5449
copy = True
5328
5450
# See comments for the Variant type for discussion on why the
5329
5451
# present form was chosen.
5330
- variant = tskit .Variant (
5452
+ variant_object = tskit .Variant (
5331
5453
self ,
5332
5454
samples = samples ,
5333
5455
isolated_as_missing = isolated_as_missing ,
5334
5456
alleles = alleles ,
5335
5457
)
5336
- if left == 0 and right == self .sequence_length :
5337
- start = 0
5338
- stop = self .num_sites
5339
- else :
5340
- start , stop = np .searchsorted (self .sites_position , interval )
5341
-
5342
- if copy :
5343
- for site_id in range (start , stop ):
5344
- variant .decode (site_id )
5345
- yield variant .copy ()
5346
- else :
5347
- for site_id in range (start , stop ):
5348
- variant .decode (site_id )
5349
- yield variant
5458
+ return VariantIterator (variant_object , interval = interval , copy = copy )
5350
5459
5351
5460
def genotype_matrix (
5352
5461
self ,
0 commit comments