Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package ai.timefold.solver.core.impl.bavet.common.index;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.NavigableMap;
import java.util.NoSuchElementException;
Expand Down Expand Up @@ -194,15 +196,13 @@ private Iterator<T> createRandomIterator(Object queryCompositeKey, RandomGenerat
}
}
}
default -> {
if (filter == null) {
yield new RandomIterator(queryCompositeKey,
indexer -> indexer.randomIterator(queryCompositeKey, workingRandom));
} else {
yield new RandomIterator(queryCompositeKey,
indexer -> indexer.randomIterator(queryCompositeKey, workingRandom, filter));
}
}
default ->
// Always draw from the unfiltered leaf iterators, weighting each bucket by its full size,
// and apply the filter (if any) during selection. This keeps the bucket weights exact even
// for the filtered path: rejected tuples are removed as they are drawn, so every surviving
// element stays equally likely to be visited next.
new RandomIterator(queryCompositeKey, workingRandom,
indexer -> indexer.randomIterator(queryCompositeKey, workingRandom), filter);
};
}

Expand Down Expand Up @@ -274,18 +274,122 @@ public T next() {
}
}

private final class RandomIterator extends DefaultIterator {
/**
* Iterates the in-range leaf indexers so that the selection is fair across all elements:
* each leaf indexer is drawn with a probability proportional to the number of elements it
* contributes to the query range, so every element has the same chance of being visited next.
* Without this weighting, an element in a small leaf indexer would be over-represented
* relative to an element in a large one.
*/
private final class RandomIterator implements Iterator<T> {

private final RandomGenerator workingRandom;
private final @Nullable Predicate<T> filter;
private final List<Bucket> buckets = new ArrayList<>();
private int remainingTotal = 0;

public RandomIterator(Object queryCompositeKey, Function<Indexer<T>, Iterator<T>> downstreamIteratorFunction) {
super(queryCompositeKey, downstreamIteratorFunction);
private boolean hasNextComputed = false;
private @Nullable T next = null;
private @Nullable Bucket nextBucket = null;
private @Nullable Bucket lastReturnedBucket = null;

private RandomIterator(Object queryCompositeKey, RandomGenerator workingRandom,
Function<Indexer<T>, Iterator<T>> downstreamIteratorFunction, @Nullable Predicate<T> filter) {
this.workingRandom = workingRandom;
this.filter = filter;
var indexKey = keyUnpacker.apply(queryCompositeKey);
for (var entry : comparisonMap.entrySet()) {
if (boundaryReached(entry.getKey(), indexKey)) {
// Boundary reached; the remaining leaf indexers are out of range.
break;
}
var downstreamIndexer = entry.getValue();
var size = downstreamIndexer.size(queryCompositeKey);
if (size <= 0) {
continue;
}
buckets.add(new Bucket(downstreamIteratorFunction.apply(downstreamIndexer), size));
remainingTotal += size;
}
}

@Override
public boolean hasNext() {
if (hasNextComputed) {
return true;
}
while (remainingTotal > 0) {
var bucket = pickBucket();
if (!bucket.iterator.hasNext()) {
// The leaf indexer has no more elements; drop it from the draw.
remainingTotal -= bucket.remaining;
bucket.remaining = 0;
continue;
}
var candidate = bucket.iterator.next();
if (filter == null || filter.test(candidate)) {
next = candidate;
nextBucket = bucket;
hasNextComputed = true;
return true;
}
// Rejected by the filter; remove it so it is never drawn again and the weights stay exact.
bucket.iterator.remove();
bucket.remaining--;
remainingTotal--;
}
next = null;
nextBucket = null;
return false;
}

private Bucket pickBucket() {
var threshold = workingRandom.nextInt(remainingTotal);
var cumulative = 0;
for (var bucket : buckets) {
cumulative += bucket.remaining;
if (threshold < cumulative) {
return bucket;
}
}
throw new IllegalStateException(
"Impossible state: no leaf indexer selected for threshold (%d).".formatted(threshold));
}

@Override
public T next() {
if (!hasNext()) {
throw new NoSuchElementException();
}
var result = next;
lastReturnedBucket = nextBucket;
hasNextComputed = false;
next = null;
nextBucket = null;
return result;
}

@Override
public void remove() {
if (downstreamIterator == null) {
if (lastReturnedBucket == null) {
throw new IllegalStateException("next() must be called before remove().");
}
downstreamIterator.remove();
lastReturnedBucket.iterator.remove();
lastReturnedBucket.remaining--;
remainingTotal--;
lastReturnedBucket = null;
}

}

private final class Bucket {

private final Iterator<T> iterator;
private int remaining;

private Bucket(Iterator<T> iterator, int remaining) {
this.iterator = iterator;
this.remaining = remaining;
}

}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
package ai.timefold.solver.core.impl.bavet.common.index;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.NoSuchElementException;
import java.util.Random;
import java.util.Set;

import ai.timefold.solver.core.api.score.stream.Joiners;
import ai.timefold.solver.core.impl.bavet.bi.joiner.DefaultBiJoiner;
import ai.timefold.solver.core.impl.bavet.common.joiner.JoinerType;
import ai.timefold.solver.core.impl.bavet.common.tuple.UniTuple;

import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -89,6 +97,155 @@ void putRemoveSize() {
assertThat(forEachToTuples(indexer, 60)).isEmpty();
}

@Test
void randomIteratorIsFairAcrossLeafIndexersOfDifferentSizes() {
var indexer = newRandomAccessLessThanIndexer();
for (var i = 0; i < 100; i++) {
indexer.put(10, newTuple("age10-" + i));
}
var age20 = newTuple("age20");
indexer.put(20, age20);
var age30 = newTuple("age30");
indexer.put(30, age30);
var age40 = newTuple("age40");
indexer.put(40, age40);

var random = new Random(0);
var iterations = 200_000;
var selectionCountMap = new HashMap<UniTuple<String>, Integer>();
for (var i = 0; i < iterations; i++) {
var iterator = indexer.randomIterator(100, random);
assertThat(iterator).hasNext();
selectionCountMap.merge(iterator.next(), 1, Integer::sum);
}

var expectedTinyCount = iterations / 103;
assertThat(selectionCountMap.getOrDefault(age20, 0)).isBetween(expectedTinyCount / 2, expectedTinyCount * 2);
assertThat(selectionCountMap.getOrDefault(age30, 0)).isBetween(expectedTinyCount / 2, expectedTinyCount * 2);
assertThat(selectionCountMap.getOrDefault(age40, 0)).isBetween(expectedTinyCount / 2, expectedTinyCount * 2);
var tinyBucketSelectionCount = selectionCountMap.getOrDefault(age20, 0)
+ selectionCountMap.getOrDefault(age30, 0)
+ selectionCountMap.getOrDefault(age40, 0);
var bigBucketShare = (iterations - tinyBucketSelectionCount) / (double) iterations;
assertThat(bigBucketShare).isBetween(0.93, 0.99);
}

@Test
void randomIteratorSingleBucketDelegatesToLeaf() {
var indexer = newRandomAccessLessThanIndexer();
var age10 = newTuple("age10");
indexer.put(10, age10);
var age10b = newTuple("age10b");
indexer.put(10, age10b);
var age10c = newTuple("age10c");
indexer.put(10, age10c);

var iterator = indexer.randomIterator(100, new Random(0));
var resultList = new ArrayList<UniTuple<String>>();
while (iterator.hasNext()) {
resultList.add(iterator.next());
iterator.remove();
}

assertThat(resultList).containsExactlyInAnyOrder(age10, age10b, age10c);
assertThat(iterator.hasNext()).isFalse();
}

@Test
void randomIteratorSkipsOutOfRangeBuckets() {
var indexer = newRandomAccessLessThanIndexer();
indexer.put(10, newTuple("age10"));
indexer.put(20, newTuple("age20"));

var iterator = indexer.randomIterator(5, new Random(0));

assertThat(iterator.hasNext()).isFalse();
assertThatThrownBy(iterator::next)
.isInstanceOf(NoSuchElementException.class);
}

@Test
void randomIteratorRemoveBeforeNextThrows() {
var indexer = newRandomAccessLessThanIndexer();
indexer.put(10, newTuple("age10"));
indexer.put(20, newTuple("age20"));

var iterator = indexer.randomIterator(100, new Random(0));

assertThatThrownBy(iterator::remove)
.isInstanceOf(IllegalStateException.class);
}

@Test
void randomIteratorWithFilterRespectsPredicateAndIsComplete() {
var indexer = newRandomAccessLessThanIndexer();
var age10 = newTuple("age10");
indexer.put(10, age10);
var age10b = newTuple("age10b");
indexer.put(10, age10b);
var age10c = newTuple("age10c");
indexer.put(10, age10c);
var age10d = newTuple("age10d");
indexer.put(10, age10d);
var age10e = newTuple("age10e");
indexer.put(10, age10e);
var age20 = newTuple("age20");
indexer.put(20, age20);
var age20b = newTuple("age20b");
indexer.put(20, age20b);
var age20c = newTuple("age20c");
indexer.put(20, age20c);
var age20d = newTuple("age20d");
indexer.put(20, age20d);
var age20e = newTuple("age20e");
indexer.put(20, age20e);

var allowedSet = Set.of(age10, age10d, age20b, age20e);
var iterator = indexer.randomIterator(100, new Random(0), allowedSet::contains);
var resultList = new ArrayList<UniTuple<String>>();
while (iterator.hasNext()) {
resultList.add(iterator.next());
iterator.remove();
}

assertThat(resultList).containsExactlyInAnyOrder(age10, age10d, age20b, age20e);
}

@Test
void randomIteratorWithFilterIsFairOverSurvivingElements() {
var indexer = newRandomAccessLessThanIndexer();
// Big bucket: 100 tuples, but only one survives the filter.
var survivor10 = newTuple("survivor10");
indexer.put(10, survivor10);
for (var i = 0; i < 99; i++) {
indexer.put(10, newTuple("reject10-" + i));
}
// Tiny bucket: a single surviving tuple.
var survivor20 = newTuple("survivor20");
indexer.put(20, survivor20);

var allowedSet = Set.of(survivor10, survivor20);
var random = new Random(0);
var iterations = 100_000;
var survivor10Count = 0;
for (var i = 0; i < iterations; i++) {
var iterator = indexer.randomIterator(100, random, allowedSet::contains);
assertThat(iterator).hasNext();
if (iterator.next().equals(survivor10)) {
survivor10Count++;
}
}
// Despite the big bucket holding 100x more (unfiltered) tuples, both survivors are about equally likely.
// Under the old per-bucket weighting survivor10 would be picked roughly 100/101 of the time.
var survivor10Share = survivor10Count / (double) iterations;
assertThat(survivor10Share).isBetween(0.45, 0.55);
}

private static ComparisonIndexer<UniTuple<String>, Integer> newRandomAccessLessThanIndexer() {
return new ComparisonIndexer<>(JoinerType.LESS_THAN, KeyUnpacker.<Integer> single(),
RandomAccessLeafIndexer::new);
}

private static UniTuple<String> newTuple(String factA) {
return UniTuple.of(factA, 0);
}
Expand Down