diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/index/ComparisonIndexer.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/index/ComparisonIndexer.java index 19b92501a37..ea88f52aba5 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/index/ComparisonIndexer.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/index/ComparisonIndexer.java @@ -1,8 +1,10 @@ package ai.timefold.solver.core.impl.bavet.common.index; +import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; import java.util.Iterator; +import java.util.List; import java.util.Map; import java.util.NavigableMap; import java.util.NoSuchElementException; @@ -194,15 +196,13 @@ private Iterator createRandomIterator(Object queryCompositeKey, RandomGenerat } } } - default -> { - if (filter == null) { - yield new RandomIterator(queryCompositeKey, - indexer -> indexer.randomIterator(queryCompositeKey, workingRandom)); - } else { - yield new RandomIterator(queryCompositeKey, - indexer -> indexer.randomIterator(queryCompositeKey, workingRandom, filter)); - } - } + default -> + // Always draw from the unfiltered leaf iterators, weighting each bucket by its full size, + // and apply the filter (if any) during selection. This keeps the bucket weights exact even + // for the filtered path: rejected tuples are removed as they are drawn, so every surviving + // element stays equally likely to be visited next. + new RandomIterator(queryCompositeKey, workingRandom, + indexer -> indexer.randomIterator(queryCompositeKey, workingRandom), filter); }; } @@ -274,18 +274,122 @@ public T next() { } } - private final class RandomIterator extends DefaultIterator { + /** + * Iterates the in-range leaf indexers so that the selection is fair across all elements: + * each leaf indexer is drawn with a probability proportional to the number of elements it + * contributes to the query range, so every element has the same chance of being visited next. + * Without this weighting, an element in a small leaf indexer would be over-represented + * relative to an element in a large one. + */ + private final class RandomIterator implements Iterator { + + private final RandomGenerator workingRandom; + private final @Nullable Predicate filter; + private final List buckets = new ArrayList<>(); + private int remainingTotal = 0; - public RandomIterator(Object queryCompositeKey, Function, Iterator> downstreamIteratorFunction) { - super(queryCompositeKey, downstreamIteratorFunction); + private boolean hasNextComputed = false; + private @Nullable T next = null; + private @Nullable Bucket nextBucket = null; + private @Nullable Bucket lastReturnedBucket = null; + + private RandomIterator(Object queryCompositeKey, RandomGenerator workingRandom, + Function, Iterator> downstreamIteratorFunction, @Nullable Predicate filter) { + this.workingRandom = workingRandom; + this.filter = filter; + var indexKey = keyUnpacker.apply(queryCompositeKey); + for (var entry : comparisonMap.entrySet()) { + if (boundaryReached(entry.getKey(), indexKey)) { + // Boundary reached; the remaining leaf indexers are out of range. + break; + } + var downstreamIndexer = entry.getValue(); + var size = downstreamIndexer.size(queryCompositeKey); + if (size <= 0) { + continue; + } + buckets.add(new Bucket(downstreamIteratorFunction.apply(downstreamIndexer), size)); + remainingTotal += size; + } + } + + @Override + public boolean hasNext() { + if (hasNextComputed) { + return true; + } + while (remainingTotal > 0) { + var bucket = pickBucket(); + if (!bucket.iterator.hasNext()) { + // The leaf indexer has no more elements; drop it from the draw. + remainingTotal -= bucket.remaining; + bucket.remaining = 0; + continue; + } + var candidate = bucket.iterator.next(); + if (filter == null || filter.test(candidate)) { + next = candidate; + nextBucket = bucket; + hasNextComputed = true; + return true; + } + // Rejected by the filter; remove it so it is never drawn again and the weights stay exact. + bucket.iterator.remove(); + bucket.remaining--; + remainingTotal--; + } + next = null; + nextBucket = null; + return false; + } + + private Bucket pickBucket() { + var threshold = workingRandom.nextInt(remainingTotal); + var cumulative = 0; + for (var bucket : buckets) { + cumulative += bucket.remaining; + if (threshold < cumulative) { + return bucket; + } + } + throw new IllegalStateException( + "Impossible state: no leaf indexer selected for threshold (%d).".formatted(threshold)); + } + + @Override + public T next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + var result = next; + lastReturnedBucket = nextBucket; + hasNextComputed = false; + next = null; + nextBucket = null; + return result; } @Override public void remove() { - if (downstreamIterator == null) { + if (lastReturnedBucket == null) { throw new IllegalStateException("next() must be called before remove()."); } - downstreamIterator.remove(); + lastReturnedBucket.iterator.remove(); + lastReturnedBucket.remaining--; + remainingTotal--; + lastReturnedBucket = null; + } + + } + + private final class Bucket { + + private final Iterator iterator; + private int remaining; + + private Bucket(Iterator iterator, int remaining) { + this.iterator = iterator; + this.remaining = remaining; } } diff --git a/core/src/test/java/ai/timefold/solver/core/impl/bavet/common/index/ComparisonIndexerTest.java b/core/src/test/java/ai/timefold/solver/core/impl/bavet/common/index/ComparisonIndexerTest.java index a6cb163d412..40daa539a30 100644 --- a/core/src/test/java/ai/timefold/solver/core/impl/bavet/common/index/ComparisonIndexerTest.java +++ b/core/src/test/java/ai/timefold/solver/core/impl/bavet/common/index/ComparisonIndexerTest.java @@ -1,9 +1,17 @@ package ai.timefold.solver.core.impl.bavet.common.index; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.NoSuchElementException; +import java.util.Random; +import java.util.Set; import ai.timefold.solver.core.api.score.stream.Joiners; import ai.timefold.solver.core.impl.bavet.bi.joiner.DefaultBiJoiner; +import ai.timefold.solver.core.impl.bavet.common.joiner.JoinerType; import ai.timefold.solver.core.impl.bavet.common.tuple.UniTuple; import org.junit.jupiter.api.Test; @@ -89,6 +97,155 @@ void putRemoveSize() { assertThat(forEachToTuples(indexer, 60)).isEmpty(); } + @Test + void randomIteratorIsFairAcrossLeafIndexersOfDifferentSizes() { + var indexer = newRandomAccessLessThanIndexer(); + for (var i = 0; i < 100; i++) { + indexer.put(10, newTuple("age10-" + i)); + } + var age20 = newTuple("age20"); + indexer.put(20, age20); + var age30 = newTuple("age30"); + indexer.put(30, age30); + var age40 = newTuple("age40"); + indexer.put(40, age40); + + var random = new Random(0); + var iterations = 200_000; + var selectionCountMap = new HashMap, Integer>(); + for (var i = 0; i < iterations; i++) { + var iterator = indexer.randomIterator(100, random); + assertThat(iterator).hasNext(); + selectionCountMap.merge(iterator.next(), 1, Integer::sum); + } + + var expectedTinyCount = iterations / 103; + assertThat(selectionCountMap.getOrDefault(age20, 0)).isBetween(expectedTinyCount / 2, expectedTinyCount * 2); + assertThat(selectionCountMap.getOrDefault(age30, 0)).isBetween(expectedTinyCount / 2, expectedTinyCount * 2); + assertThat(selectionCountMap.getOrDefault(age40, 0)).isBetween(expectedTinyCount / 2, expectedTinyCount * 2); + var tinyBucketSelectionCount = selectionCountMap.getOrDefault(age20, 0) + + selectionCountMap.getOrDefault(age30, 0) + + selectionCountMap.getOrDefault(age40, 0); + var bigBucketShare = (iterations - tinyBucketSelectionCount) / (double) iterations; + assertThat(bigBucketShare).isBetween(0.93, 0.99); + } + + @Test + void randomIteratorSingleBucketDelegatesToLeaf() { + var indexer = newRandomAccessLessThanIndexer(); + var age10 = newTuple("age10"); + indexer.put(10, age10); + var age10b = newTuple("age10b"); + indexer.put(10, age10b); + var age10c = newTuple("age10c"); + indexer.put(10, age10c); + + var iterator = indexer.randomIterator(100, new Random(0)); + var resultList = new ArrayList>(); + while (iterator.hasNext()) { + resultList.add(iterator.next()); + iterator.remove(); + } + + assertThat(resultList).containsExactlyInAnyOrder(age10, age10b, age10c); + assertThat(iterator.hasNext()).isFalse(); + } + + @Test + void randomIteratorSkipsOutOfRangeBuckets() { + var indexer = newRandomAccessLessThanIndexer(); + indexer.put(10, newTuple("age10")); + indexer.put(20, newTuple("age20")); + + var iterator = indexer.randomIterator(5, new Random(0)); + + assertThat(iterator.hasNext()).isFalse(); + assertThatThrownBy(iterator::next) + .isInstanceOf(NoSuchElementException.class); + } + + @Test + void randomIteratorRemoveBeforeNextThrows() { + var indexer = newRandomAccessLessThanIndexer(); + indexer.put(10, newTuple("age10")); + indexer.put(20, newTuple("age20")); + + var iterator = indexer.randomIterator(100, new Random(0)); + + assertThatThrownBy(iterator::remove) + .isInstanceOf(IllegalStateException.class); + } + + @Test + void randomIteratorWithFilterRespectsPredicateAndIsComplete() { + var indexer = newRandomAccessLessThanIndexer(); + var age10 = newTuple("age10"); + indexer.put(10, age10); + var age10b = newTuple("age10b"); + indexer.put(10, age10b); + var age10c = newTuple("age10c"); + indexer.put(10, age10c); + var age10d = newTuple("age10d"); + indexer.put(10, age10d); + var age10e = newTuple("age10e"); + indexer.put(10, age10e); + var age20 = newTuple("age20"); + indexer.put(20, age20); + var age20b = newTuple("age20b"); + indexer.put(20, age20b); + var age20c = newTuple("age20c"); + indexer.put(20, age20c); + var age20d = newTuple("age20d"); + indexer.put(20, age20d); + var age20e = newTuple("age20e"); + indexer.put(20, age20e); + + var allowedSet = Set.of(age10, age10d, age20b, age20e); + var iterator = indexer.randomIterator(100, new Random(0), allowedSet::contains); + var resultList = new ArrayList>(); + while (iterator.hasNext()) { + resultList.add(iterator.next()); + iterator.remove(); + } + + assertThat(resultList).containsExactlyInAnyOrder(age10, age10d, age20b, age20e); + } + + @Test + void randomIteratorWithFilterIsFairOverSurvivingElements() { + var indexer = newRandomAccessLessThanIndexer(); + // Big bucket: 100 tuples, but only one survives the filter. + var survivor10 = newTuple("survivor10"); + indexer.put(10, survivor10); + for (var i = 0; i < 99; i++) { + indexer.put(10, newTuple("reject10-" + i)); + } + // Tiny bucket: a single surviving tuple. + var survivor20 = newTuple("survivor20"); + indexer.put(20, survivor20); + + var allowedSet = Set.of(survivor10, survivor20); + var random = new Random(0); + var iterations = 100_000; + var survivor10Count = 0; + for (var i = 0; i < iterations; i++) { + var iterator = indexer.randomIterator(100, random, allowedSet::contains); + assertThat(iterator).hasNext(); + if (iterator.next().equals(survivor10)) { + survivor10Count++; + } + } + // Despite the big bucket holding 100x more (unfiltered) tuples, both survivors are about equally likely. + // Under the old per-bucket weighting survivor10 would be picked roughly 100/101 of the time. + var survivor10Share = survivor10Count / (double) iterations; + assertThat(survivor10Share).isBetween(0.45, 0.55); + } + + private static ComparisonIndexer, Integer> newRandomAccessLessThanIndexer() { + return new ComparisonIndexer<>(JoinerType.LESS_THAN, KeyUnpacker. single(), + RandomAccessLeafIndexer::new); + } + private static UniTuple newTuple(String factA) { return UniTuple.of(factA, 0); }