Skip to content

Commit

Permalink
Support sub agg in filter rewrite optimization
Browse files Browse the repository at this point in the history
Signed-off-by: bowenlan-amzn <[email protected]>
  • Loading branch information
bowenlan-amzn committed Feb 24, 2025
1 parent 0714a1b commit 3a0f7cf
Show file tree
Hide file tree
Showing 14 changed files with 1,132 additions and 88 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -712,9 +712,9 @@ static class AllPermissionCheck implements BootstrapCheck {

@Override
public final BootstrapCheckResult check(BootstrapContext context) {
if (isAllPermissionGranted()) {
return BootstrapCheck.BootstrapCheckResult.failure("granting the all permission effectively disables security");
}
// if (isAllPermissionGranted()) {
// return BootstrapCheck.BootstrapCheckResult.failure("granting the all permission effectively disables security");
// }
return BootstrapCheckResult.success();
}

Expand Down
27 changes: 19 additions & 8 deletions server/src/main/java/org/opensearch/common/Rounding.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@
import org.opensearch.common.LocalTimeOffset.Gap;
import org.opensearch.common.LocalTimeOffset.Overlap;
import org.opensearch.common.annotation.PublicApi;
import org.opensearch.common.round.Roundable;
import org.opensearch.common.round.RoundableFactory;
import org.opensearch.common.time.DateUtils;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.core.common.io.stream.StreamInput;
Expand All @@ -62,6 +60,7 @@
import java.time.temporal.TemporalQueries;
import java.time.zone.ZoneOffsetTransition;
import java.time.zone.ZoneRules;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
Expand Down Expand Up @@ -455,7 +454,7 @@ protected Prepared maybeUseArray(long minUtcMillis, long maxUtcMillis, int max)
values = ArrayUtil.grow(values, i + 1);
values[i++] = rounded;
}
return new ArrayRounding(RoundableFactory.create(values, i), this);
return new ArrayRounding(values, i, this);
}
}

Expand All @@ -464,17 +463,26 @@ protected Prepared maybeUseArray(long minUtcMillis, long maxUtcMillis, int max)
* pre-calculated round-down points to speed up lookups.
*/
private static class ArrayRounding implements Prepared {
private final Roundable roundable;
private final long[] values;
private final int max;
private final Prepared delegate;

public ArrayRounding(Roundable roundable, Prepared delegate) {
this.roundable = roundable;
private ArrayRounding(long[] values, int max, Prepared delegate) {
this.values = values;
this.max = max;
this.delegate = delegate;
}

@Override
public long round(long utcMillis) {
return roundable.floor(utcMillis);
assert values[0] <= utcMillis : utcMillis + " must be after " + values[0];
int idx = Arrays.binarySearch(values, 0, max, utcMillis);
assert idx != -1 : "The insertion point is before the array! This should have tripped the assertion above.";
assert -1 - idx <= values.length : "This insertion point is after the end of the array.";
if (idx < 0) {
idx = -2 - idx;
}
return values[idx];
}

@Override
Expand Down Expand Up @@ -724,7 +732,10 @@ private class FixedNotToMidnightRounding extends TimeUnitPreparedRounding {

@Override
public long round(long utcMillis) {
return offset.localToUtcInThisOffset(unit.roundFloor(offset.utcToLocalTime(utcMillis)));
long localTime = offset.utcToLocalTime(utcMillis);
long roundedLocalTime = unit.roundFloor(localTime);
return offset.localToUtcInThisOffset(roundedLocalTime);
// return offset.localToUtcInThisOffset(unit.roundFloor(offset.utcToLocalTime(utcMillis)));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@
import java.util.stream.Collectors;

import static org.opensearch.search.aggregations.MultiBucketConsumerService.MAX_BUCKET_SETTING;
import static org.opensearch.search.aggregations.bucket.filterrewrite.DateHistogramAggregatorBridge.segmentMatchAll;

/**
* Main aggregator that aggregates docs from multiple aggregations
Expand Down Expand Up @@ -563,14 +562,23 @@ private void processLeafFromQuery(LeafReaderContext ctx, Sort indexSortPrefix) t
}
}

@Override
protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws IOException {
finishLeaf(); // May need to wrap up previous leaf if it could not be precomputed
return filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, segmentMatchAll(context, ctx));
}
// @Override
// protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws IOException {
// finishLeaf(); // May need to wrap up previous leaf if it could not be precomputed
// return filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, segmentMatchAll(context, ctx));
// }

@Override
protected LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {
boolean optimized = filterRewriteOptimizationContext.tryOptimize(
ctx,
this::incrementBucketDocCount,
segmentMatchAll(context, ctx),
collectableSubAggregators,
sub
);
if (optimized) throw new CollectionTerminatedException();

finishLeaf();

boolean fillDocIdSet = deferredCollectors != NO_OP_COLLECTOR;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,23 @@

package org.opensearch.search.aggregations.bucket.filterrewrite;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PointValues;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.DocIdSetBuilder;
import org.opensearch.index.mapper.MappedFieldType;
import org.opensearch.search.internal.SearchContext;

import java.io.IOException;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;

import static org.opensearch.search.aggregations.bucket.filterrewrite.PointTreeTraversal.multiRangesTraverse;

/**
* This interface provides a bridge between an aggregator and the optimization context, allowing
Expand All @@ -35,6 +42,8 @@
*/
public abstract class AggregatorBridge {

static final Logger logger = LogManager.getLogger(Helper.loggerName);

/**
* The field type associated with this aggregator bridge.
*/
Expand Down Expand Up @@ -79,12 +88,46 @@ void setRangesConsumer(Consumer<Ranges> setRanges) {
* @param incrementDocCount a consumer to increment the document count for a range bucket. The First parameter is document count, the second is the key of the bucket
* @param ranges
*/
abstract FilterRewriteOptimizationContext.DebugInfo tryOptimize(
abstract FilterRewriteOptimizationContext.OptimizeResult tryOptimize(
PointValues values,
BiConsumer<Long, Long> incrementDocCount,
Ranges ranges
Ranges ranges,
Supplier<DocIdSetBuilder> disBuilderSupplier
) throws IOException;

static FilterRewriteOptimizationContext.OptimizeResult getResult(
PointValues values,
BiConsumer<Long, Long> incrementDocCount,
Ranges ranges,
Supplier<DocIdSetBuilder> disBuilderSupplier,
Function<Integer, Long> getBucketOrd,
int size
) throws IOException {
BiConsumer<Integer, Integer> incrementFunc = (activeIndex, docCount) -> {
long bucketOrd = getBucketOrd.apply(activeIndex);
incrementDocCount.accept(bucketOrd, (long) docCount);
};

PointValues.PointTree tree = values.getPointTree();
FilterRewriteOptimizationContext.OptimizeResult optimizeResult = new FilterRewriteOptimizationContext.OptimizeResult();
int activeIndex = ranges.firstRangeIndex(tree.getMinPackedValue(), tree.getMaxPackedValue());
if (activeIndex < 0) {
logger.debug("No ranges match the query, skip the fast filter optimization");
return optimizeResult;
}
PointTreeTraversal.RangeCollectorForPointTree collector = new PointTreeTraversal.RangeCollectorForPointTree(
ranges,
incrementFunc,
size,
activeIndex,
disBuilderSupplier,
getBucketOrd,
optimizeResult
);

return multiRangesTraverse(tree, collector);
}

/**
* Checks whether the top level query matches all documents on the segment
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.search.aggregations.bucket.filterrewrite;

import org.apache.lucene.search.DocIdSetIterator;

import java.io.IOException;

/**
* A composite iterator over multiple DocIdSetIterators where each document
* belongs to exactly one bucket within a single segment.
*/
public class CompositeDocIdSetIterator extends DocIdSetIterator {
private final DocIdSetIterator[] iterators;

// Track active iterators to avoid scanning all
private final int[] activeIterators; // non-exhausted iterators to its bucket
private int numActiveIterators; // Number of non-exhausted iterators

private int currentDoc = -1;
private int currentBucket = -1;

public CompositeDocIdSetIterator(DocIdSetIterator[] iterators) {
this.iterators = iterators;
int numBuckets = iterators.length;
this.activeIterators = new int[numBuckets];
this.numActiveIterators = 0;

// Initialize active iterator tracking
for (int i = 0; i < numBuckets; i++) {
if (iterators[i] != null) {
activeIterators[numActiveIterators++] = i;
}
}
}

@Override
public int docID() {
return currentDoc;
}

public int getCurrentBucket() {
return currentBucket;
}

@Override
public int nextDoc() throws IOException {
return advance(currentDoc + 1);
}

@Override
public int advance(int target) throws IOException {
if (target == NO_MORE_DOCS || numActiveIterators == 0) {
currentDoc = NO_MORE_DOCS;
currentBucket = -1;
return NO_MORE_DOCS;
}

int minDoc = NO_MORE_DOCS;
int minBucket = -1;
int remainingActive = 0; // Counter for non-exhausted iterators

// Only check currently active iterators
for (int i = 0; i < numActiveIterators; i++) {
int bucket = activeIterators[i];
DocIdSetIterator iterator = iterators[bucket];

int doc = iterator.docID();
if (doc < target) {
doc = iterator.advance(target);
}

if (doc == NO_MORE_DOCS) {
// Iterator is exhausted, don't include it in active set
continue;
}

// Keep this iterator in our active set
activeIterators[remainingActive] = bucket;
remainingActive++;

if (doc < minDoc) {
minDoc = doc;
minBucket = bucket;
}
}

// Update count of active iterators
numActiveIterators = remainingActive;

currentDoc = minDoc;
currentBucket = minBucket;

return currentDoc;
}

@Override
public long cost() {
long cost = 0;
for (int i = 0; i < numActiveIterators; i++) {
DocIdSetIterator iterator = iterators[activeIterators[i]];
cost += iterator.cost();
}
return cost;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.apache.lucene.document.LongPoint;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PointValues;
import org.apache.lucene.util.DocIdSetBuilder;
import org.opensearch.common.Rounding;
import org.opensearch.index.mapper.DateFieldMapper;
import org.opensearch.index.mapper.MappedFieldType;
Expand All @@ -22,8 +23,7 @@
import java.util.OptionalLong;
import java.util.function.BiConsumer;
import java.util.function.Function;

import static org.opensearch.search.aggregations.bucket.filterrewrite.PointTreeTraversal.multiRangesTraverse;
import java.util.function.Supplier;

/**
* For date histogram aggregation
Expand Down Expand Up @@ -127,27 +127,31 @@ private DateFieldMapper.DateFieldType getFieldType() {
return (DateFieldMapper.DateFieldType) fieldType;
}

/**
* Get the size of buckets to stop early
*/
protected int getSize() {
return Integer.MAX_VALUE;
}

@Override
final FilterRewriteOptimizationContext.DebugInfo tryOptimize(
final FilterRewriteOptimizationContext.OptimizeResult tryOptimize(
PointValues values,
BiConsumer<Long, Long> incrementDocCount,
Ranges ranges
Ranges ranges,
Supplier<DocIdSetBuilder> disBuilderSupplier
) throws IOException {
int size = getSize();

DateFieldMapper.DateFieldType fieldType = getFieldType();
BiConsumer<Integer, Integer> incrementFunc = (activeIndex, docCount) -> {

Function<Integer, Long> getBucketOrd = (activeIndex) -> {
long rangeStart = LongPoint.decodeDimension(ranges.lowers[activeIndex], 0);
rangeStart = fieldType.convertNanosToMillis(rangeStart);
long bucketOrd = getBucketOrd(bucketOrdProducer().apply(rangeStart));
incrementDocCount.accept(bucketOrd, (long) docCount);
return getBucketOrd(bucketOrdProducer().apply(rangeStart));
};

return multiRangesTraverse(values.getPointTree(), ranges, incrementFunc, size);
return getResult(values, incrementDocCount, ranges, disBuilderSupplier, getBucketOrd, size);
}

private static long getBucketOrd(long bucketOrd) {
Expand Down
Loading

0 comments on commit 3a0f7cf

Please sign in to comment.