@@ -3863,15 +3863,26 @@ def load_text(
38633863 return tc .tree_sequence ()
38643864
38653865
3866- class TreeIterator :
3867- """
3868- Simple class providing forward and backward iteration over a tree sequence.
3869- """
3870-
3871- def __init__ (self , tree ):
3872- self .tree = tree
3873- self .more_trees = True
3866+ class ObjectIterator :
3867+ # Simple class providing forward and backward iteration over a
3868+ # mutable object with ``next()`` and ``prev()`` methods, e.g.
3869+ # a Tree or a Variant. ``interval`` allows the bounds of the
3870+ # iterator to be specified, and should already have
3871+ # been checked using _check_genomic_range(left, right)
3872+ # If ``return_copies`` is True, the iterator will return
3873+ # immutable copies of each object (this is likely to be significantly
3874+ # less efficient).
3875+ # It can be useful to define __len__ on one of these iterators,
3876+ # which e.g. allows progress bars to provide useful feedback.
3877+
3878+ def __init__ (self , obj , interval , return_copies = False ):
3879+ self ._obj = obj
3880+ self .min_pos = interval [0 ]
3881+ self .max_pos = interval [1 ]
3882+ self .return_copies = return_copies
38743883 self .forward = True
3884+ self .started = False
3885+ self .finished = False
38753886
38763887 def __iter__ (self ):
38773888 return self
@@ -3880,17 +3891,114 @@ def __reversed__(self):
38803891 self .forward = False
38813892 return self
38823893
3894+ def obj_left (self ):
3895+ # Used to work out where to stop iterating when going backwards.
3896+ # Override with code to return the left coordinate of self.obj
3897+ raise NotImplementedError ()
3898+
3899+ def obj_right (self ):
3900+ # Used to work out when to stop iterating when going forwards.
3901+ # Override with code to return the right coordinate of self.obj
3902+ raise NotImplementedError ()
3903+
3904+ def seek_to_start (self ):
3905+ # Override to set the object position to self.min_pos
3906+ raise NotImplementedError ()
3907+
3908+ def seek_to_end (self ):
3909+ # Override to set the object position just before self.max_pos
3910+ raise NotImplementedError ()
3911+
38833912 def __next__ (self ):
3884- if self .forward :
3885- self .more_trees = self .more_trees and self .tree .next ()
3886- else :
3887- self .more_trees = self .more_trees and self .tree .prev ()
3888- if not self .more_trees :
3913+ if not self .finished :
3914+ if not self .started :
3915+ if self .forward :
3916+ self .seek_to_start ()
3917+ else :
3918+ self .seek_to_end ()
3919+ self .started = True
3920+ else :
3921+ if self .forward :
3922+ if not self ._obj .next () or self .obj_left () >= self .max_pos :
3923+ print ("fwd" , self .obj_left (), self .min_pos )
3924+ self .finished = True
3925+ else :
3926+ if not self ._obj .prev () or self .obj_right () < self .min_pos :
3927+ self .finished = True
3928+ if self .finished :
38893929 raise StopIteration ()
3890- return self .tree
3930+ return self ._obj .copy () if self .return_copies else self ._obj
3931+
3932+
3933+ class TreeIterator (ObjectIterator ):
3934+ """
3935+ An iterator over some or all of the :class:`trees<Tree>`
3936+ in a :class:`TreeSequence`.
3937+ """
3938+
3939+ def obj_left (self ):
3940+ return self ._obj .interval .left
3941+
3942+ def obj_right (self ):
3943+ return self ._obj .interval .right
3944+
3945+ def seek_to_start (self ):
3946+ self ._obj .seek (self .min_pos )
3947+
3948+ def seek_to_end (self ):
3949+ self ._obj .seek (np .nextafter (self .max_pos , - np .inf ))
38913950
38923951 def __len__ (self ):
3893- return self .tree .tree_sequence .num_trees
3952+ """
3953+ The number of trees over which a newly created iterator will iterate.
3954+ """
3955+ ts = self ._obj .tree_sequence
3956+ if self .min_pos == 0 and self .max_pos == ts .sequence_length :
3957+ # a common case: don't incur the cost of searching through the breakpoints
3958+ return ts .num_trees
3959+ breaks = ts .breakpoints (as_array = True )
3960+ left_index = breaks .searchsorted (self .min_pos , side = "right" )
3961+ right_index = breaks .searchsorted (self .max_pos , side = "left" )
3962+ return right_index - left_index + 1
3963+
3964+
3965+ class VariantIterator (ObjectIterator ):
3966+ """
3967+ An iterator over some or all of the :class:`variants<Variant>`
3968+ in a :class:`TreeSequence`.
3969+ """
3970+
3971+ def __init__ (self , variant , interval , copy ):
3972+ super ().__init__ (variant , interval , copy )
3973+ if interval [0 ] == 0 and interval [1 ] == variant .tree_sequence .sequence_length :
3974+ # a common case: don't incur the cost of searching through the positions
3975+ self .min_max_sites = [0 , variant .tree_sequence .num_sites ]
3976+ else :
3977+ self .min_max_sites = variant .tree_sequence .sites_position .searchsorted (
3978+ interval
3979+ )
3980+ if self .min_max_sites [0 ] >= self .min_max_sites [1 ]:
3981+ # upper bound is exclusive: we don't include the site at self.bound[1]
3982+ self .finished = True
3983+
3984+ def obj_left (self ):
3985+ return self ._obj .site .position
3986+
3987+ def obj_right (self ):
3988+ return self ._obj .site .position
3989+
3990+ def seek_to_start (self ):
3991+ self ._obj .decode (self .min_max_sites [0 ])
3992+
3993+ def seek_to_end (self ):
3994+ self ._obj .decode (self .min_max_sites [1 ] - 1 )
3995+
3996+ def __len__ (self ):
3997+ """
3998+ The number of variants (i.e. sites) over which a newly created iterator will
3999+ iterate.
4000+ """
4001+ return self .min_max_sites [1 ] - self .min_max_sites [0 ]
38944002
38954003
38964004class SimpleContainerSequence :
@@ -4077,7 +4185,7 @@ def aslist(self, **kwargs):
40774185 :return: A list of the trees in this tree sequence.
40784186 :rtype: list
40794187 """
4080- return [tree . copy () for tree in self .trees (** kwargs )]
4188+ return [tree for tree in self .trees (copy = True , ** kwargs )]
40814189
40824190 @classmethod
40834191 def load (cls , file_or_path , * , skip_tables = False , skip_reference_sequence = False ):
@@ -4970,6 +5078,9 @@ def trees(
49705078 sample_lists = False ,
49715079 root_threshold = 1 ,
49725080 sample_counts = None ,
5081+ left = None ,
5082+ right = None ,
5083+ copy = None ,
49735084 tracked_leaves = None ,
49745085 leaf_counts = None ,
49755086 leaf_lists = None ,
@@ -5001,28 +5112,39 @@ def trees(
50015112 are roots. To efficiently restrict the roots of the tree to
50025113 those subtending meaningful topology, set this to 2. This value
50035114 is only relevant when trees have multiple roots.
5115+ :param float left: The left-most coordinate of the region over which
5116+ to iterate. Default: ``None`` treated as 0.
5117+ :param float right: The right-most coordinate of the region over which
5118+ to iterate. Default: ``None`` treated as ``.sequence_length``. This
5119+ value is exclusive, so that a tree whose ``interval.left`` is exactly
5120+ equivalent to ``right`` will not be included in the iteration.
5121+ :param bool copy: Return a immutable copy of each tree. This will be
5122+ inefficient. Default: ``None`` treated as False.
50045123 :param bool sample_counts: Deprecated since 0.2.4.
50055124 :return: An iterator over the Trees in this tree sequence.
5006- :rtype: collections.abc.Iterable, :class:`Tree`
5125+ :rtype: TreeIterator
50075126 """
50085127 # tracked_leaves, leaf_counts and leaf_lists are deprecated aliases
50095128 # for tracked_samples, sample_counts and sample_lists respectively.
50105129 # These are left over from an older version of the API when leaves
50115130 # and samples were synonymous.
5131+ interval = self ._check_genomic_range (left , right )
50125132 if tracked_leaves is not None :
50135133 tracked_samples = tracked_leaves
50145134 if leaf_counts is not None :
50155135 sample_counts = leaf_counts
50165136 if leaf_lists is not None :
50175137 sample_lists = leaf_lists
5138+ if copy is None :
5139+ copy = False
50185140 tree = Tree (
50195141 self ,
50205142 tracked_samples = tracked_samples ,
50215143 sample_lists = sample_lists ,
50225144 root_threshold = root_threshold ,
50235145 sample_counts = sample_counts ,
50245146 )
5025- return TreeIterator (tree )
5147+ return TreeIterator (tree , interval = interval , return_copies = copy )
50265148
50275149 def coiterate (self , other , ** kwargs ):
50285150 """
@@ -5309,8 +5431,8 @@ def variants(
53095431 :param int right: End with the last site before this position. If ``None``
53105432 (default) assume ``right`` is the sequence length, so that the last
53115433 variant corresponds to the last site in the tree sequence.
5312- :return: An iterator over all variants in this tree sequence.
5313- :rtype: iter(:class:`Variant`)
5434+ :return: An iterator over the specified variants in this tree sequence.
5435+ :rtype: VariantIterator
53145436 """
53155437 interval = self ._check_genomic_range (left , right )
53165438 if impute_missing_data is not None :
@@ -5327,26 +5449,13 @@ def variants(
53275449 copy = True
53285450 # See comments for the Variant type for discussion on why the
53295451 # present form was chosen.
5330- variant = tskit .Variant (
5452+ variant_object = tskit .Variant (
53315453 self ,
53325454 samples = samples ,
53335455 isolated_as_missing = isolated_as_missing ,
53345456 alleles = alleles ,
53355457 )
5336- if left == 0 and right == self .sequence_length :
5337- start = 0
5338- stop = self .num_sites
5339- else :
5340- start , stop = np .searchsorted (self .sites_position , interval )
5341-
5342- if copy :
5343- for site_id in range (start , stop ):
5344- variant .decode (site_id )
5345- yield variant .copy ()
5346- else :
5347- for site_id in range (start , stop ):
5348- variant .decode (site_id )
5349- yield variant
5458+ return VariantIterator (variant_object , interval = interval , copy = copy )
53505459
53515460 def genotype_matrix (
53525461 self ,
0 commit comments