diff --git a/tsinfer/inference.py b/tsinfer/inference.py index d053053c..3ba67b94 100644 --- a/tsinfer/inference.py +++ b/tsinfer/inference.py @@ -1644,6 +1644,19 @@ def group_by_linesweep(self): epoch_end = np.hstack([breaks + 1, [self.num_ancestors]]) time_slices = np.vstack([epoch_start, epoch_end]).T epoch_sizes = time_slices[:, 1] - time_slices[:, 0] + # Find the epoch where the sum of ancestors has reached 1M as a cutoff + if np.sum(epoch_sizes) <= 1e6: + over_1M_epoch = len(time_slices) + over_1M_epoch_first_ancestor = self.num_ancestors + else: + over_1M_epoch = np.where(np.cumsum(epoch_sizes) > 1e6)[0][0] + over_1M_epoch_first_ancestor = time_slices[over_1M_epoch, 0] + logger.info( + f"1M ancestors reached at {over_1M_epoch} epoch and ancestor " + f"{over_1M_epoch_first_ancestor}" + ) + + # Find the first epoch with more than a 500 times the median epoch size median_size = np.median(epoch_sizes) cutoff = 500 * median_size # Zero out the first half so that an initial large epoch doesn't @@ -1653,13 +1666,17 @@ def group_by_linesweep(self): # the median epoch size. For a large set of human genomes the median epoch # size is around 10, so we'll stop grouping by linesweep at 5000. if np.max(epoch_sizes) <= cutoff: - large_epoch = len(time_slices) - large_epoch_first_ancestor = self.num_ancestors + large_epoch = over_1M_epoch + large_epoch_first_ancestor = over_1M_epoch_first_ancestor + logger.info("No large epochs found, using count cutoff") else: large_epoch = np.where(epoch_sizes > cutoff)[0][0] large_epoch_first_ancestor = time_slices[large_epoch, 0] + logger.info( + f"Large epoch found at {large_epoch} with {epoch_sizes[large_epoch]} " + f"ancestors and ancestor {large_epoch_first_ancestor}" + ) logger.info(f"{len(time_slices)} epochs with {median_size} median size.") - logger.info(f"First large (>{cutoff}) epoch is {large_epoch}") logger.info(f"Grouping {large_epoch_first_ancestor} ancestors by linesweep") ancestor_grouping = ancestors.group_ancestors_by_linesweep( start[:large_epoch_first_ancestor], @@ -1672,6 +1689,11 @@ def group_by_linesweep(self): ancestor_grouping[next_epoch] = np.arange(*time_slices[epoch]) next_epoch += 1 + # Assert that every ancestor appears once in ancestor grouping + assert ( + len(set(np.hstack(list(ancestor_grouping.values())))) == self.num_ancestors + ) + # Remove the "virtual root" ancestor try: assert 0 in ancestor_grouping[0]