Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support sub agg in filter rewrite optimization #17447

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@
import java.util.stream.Collectors;

import static org.opensearch.search.aggregations.MultiBucketConsumerService.MAX_BUCKET_SETTING;
import static org.opensearch.search.aggregations.bucket.filterrewrite.DateHistogramAggregatorBridge.segmentMatchAll;
import static org.opensearch.search.aggregations.bucket.filterrewrite.AggregatorBridge.segmentMatchAll;

/**
* Main aggregator that aggregates docs from multiple aggregations
Expand Down Expand Up @@ -563,14 +563,23 @@ private void processLeafFromQuery(LeafReaderContext ctx, Sort indexSortPrefix) t
}
}

@Override
protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws IOException {
finishLeaf(); // May need to wrap up previous leaf if it could not be precomputed
return filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, segmentMatchAll(context, ctx));
}
// @Override
// protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws IOException {
// finishLeaf(); // May need to wrap up previous leaf if it could not be precomputed
// return filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, segmentMatchAll(context, ctx));
// }
Comment on lines +566 to +570
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume this will be removed when this is ready for review. :)


@Override
protected LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {
boolean optimized = filterRewriteOptimizationContext.tryOptimize(
ctx,
this::incrementBucketDocCount,
segmentMatchAll(context, ctx),
collectableSubAggregators,
sub
);
if (optimized) throw new CollectionTerminatedException();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why create a separate variable here instead of just calling if (filterRewriteOptimizationContext.tryOptimize(...)) { throw... } ?


finishLeaf();

boolean fillDocIdSet = deferredCollectors != NO_OP_COLLECTOR;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,23 @@

package org.opensearch.search.aggregations.bucket.filterrewrite;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PointValues;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.DocIdSetBuilder;
import org.opensearch.index.mapper.MappedFieldType;
import org.opensearch.search.internal.SearchContext;

import java.io.IOException;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;

import static org.opensearch.search.aggregations.bucket.filterrewrite.PointTreeTraversal.multiRangesTraverse;

/**
* This interface provides a bridge between an aggregator and the optimization context, allowing
Expand All @@ -35,6 +42,8 @@
*/
public abstract class AggregatorBridge {

static final Logger logger = LogManager.getLogger(Helper.loggerName);

/**
* The field type associated with this aggregator bridge.
*/
Expand Down Expand Up @@ -79,12 +88,46 @@ void setRangesConsumer(Consumer<Ranges> setRanges) {
* @param incrementDocCount a consumer to increment the document count for a range bucket. The First parameter is document count, the second is the key of the bucket
* @param ranges
*/
abstract FilterRewriteOptimizationContext.DebugInfo tryOptimize(
abstract FilterRewriteOptimizationContext.OptimizeResult tryOptimize(
PointValues values,
BiConsumer<Long, Long> incrementDocCount,
Ranges ranges
Ranges ranges,
Supplier<DocIdSetBuilder> disBuilderSupplier
) throws IOException;

static FilterRewriteOptimizationContext.OptimizeResult getResult(
PointValues values,
BiConsumer<Long, Long> incrementDocCount,
Ranges ranges,
Supplier<DocIdSetBuilder> disBuilderSupplier,
Function<Integer, Long> getBucketOrd,
int size
) throws IOException {
BiConsumer<Integer, Integer> incrementFunc = (activeIndex, docCount) -> {
long bucketOrd = getBucketOrd.apply(activeIndex);
incrementDocCount.accept(bucketOrd, (long) docCount);
};

PointValues.PointTree tree = values.getPointTree();
FilterRewriteOptimizationContext.OptimizeResult optimizeResult = new FilterRewriteOptimizationContext.OptimizeResult();
int activeIndex = ranges.firstRangeIndex(tree.getMinPackedValue(), tree.getMaxPackedValue());
if (activeIndex < 0) {
logger.debug("No ranges match the query, skip the fast filter optimization");
return optimizeResult;
}
PointTreeTraversal.RangeCollectorForPointTree collector = new PointTreeTraversal.RangeCollectorForPointTree(
ranges,
incrementFunc,
size,
activeIndex,
disBuilderSupplier,
getBucketOrd,
optimizeResult
);

return multiRangesTraverse(tree, collector);
}

/**
* Checks whether the top level query matches all documents on the segment
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.apache.lucene.document.LongPoint;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PointValues;
import org.apache.lucene.util.DocIdSetBuilder;
import org.opensearch.common.Rounding;
import org.opensearch.index.mapper.DateFieldMapper;
import org.opensearch.index.mapper.MappedFieldType;
Expand All @@ -22,8 +23,7 @@
import java.util.OptionalLong;
import java.util.function.BiConsumer;
import java.util.function.Function;

import static org.opensearch.search.aggregations.bucket.filterrewrite.PointTreeTraversal.multiRangesTraverse;
import java.util.function.Supplier;

/**
* For date histogram aggregation
Expand Down Expand Up @@ -127,27 +127,31 @@ private DateFieldMapper.DateFieldType getFieldType() {
return (DateFieldMapper.DateFieldType) fieldType;
}

/**
* Get the size of buckets to stop early
*/
protected int getSize() {
return Integer.MAX_VALUE;
}

@Override
final FilterRewriteOptimizationContext.DebugInfo tryOptimize(
final FilterRewriteOptimizationContext.OptimizeResult tryOptimize(
PointValues values,
BiConsumer<Long, Long> incrementDocCount,
Ranges ranges
Ranges ranges,
Supplier<DocIdSetBuilder> disBuilderSupplier
) throws IOException {
int size = getSize();

DateFieldMapper.DateFieldType fieldType = getFieldType();
BiConsumer<Integer, Integer> incrementFunc = (activeIndex, docCount) -> {

Function<Integer, Long> getBucketOrd = (activeIndex) -> {
long rangeStart = LongPoint.decodeDimension(ranges.lowers[activeIndex], 0);
rangeStart = fieldType.convertNanosToMillis(rangeStart);
long bucketOrd = getBucketOrd(bucketOrdProducer().apply(rangeStart));
incrementDocCount.accept(bucketOrd, (long) docCount);
return getBucketOrd(bucketOrdProducer().apply(rangeStart));
};

return multiRangesTraverse(values.getPointTree(), ranges, incrementFunc, size);
return getResult(values, incrementDocCount, ranges, disBuilderSupplier, getBucketOrd, size);
}

private static long getBucketOrd(long bucketOrd) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,17 @@
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.NumericDocValues;
import org.apache.lucene.index.PointValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.DocIdSetBuilder;
import org.opensearch.index.mapper.DocCountFieldMapper;
import org.opensearch.search.aggregations.BucketCollector;
import org.opensearch.search.aggregations.LeafBucketCollector;
import org.opensearch.search.internal.SearchContext;

import java.io.IOException;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiConsumer;
import java.util.function.Supplier;

import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;

Expand All @@ -42,6 +47,8 @@ public final class FilterRewriteOptimizationContext {

private Ranges ranges; // built at shard level

private int subAggLength;

// debug info related fields
private final AtomicInteger leafNodeVisited = new AtomicInteger();
private final AtomicInteger innerNodeVisited = new AtomicInteger();
Expand All @@ -65,7 +72,8 @@ public FilterRewriteOptimizationContext(
private boolean canOptimize(final Object parent, final int subAggLength, SearchContext context) throws IOException {
if (context.maxAggRewriteFilters() == 0) return false;

if (parent != null || subAggLength != 0) return false;
if (parent != null) return false;
this.subAggLength = subAggLength;

boolean canOptimize = aggregatorBridge.canOptimize();
if (canOptimize) {
Expand Down Expand Up @@ -96,8 +104,13 @@ void setRanges(Ranges ranges) {
* @param incrementDocCount consume the doc_count results for certain ordinal
* @param segmentMatchAll if your optimization can prepareFromSegment, you should pass in this flag to decide whether to prepareFromSegment
*/
public boolean tryOptimize(final LeafReaderContext leafCtx, final BiConsumer<Long, Long> incrementDocCount, boolean segmentMatchAll)
throws IOException {
public boolean tryOptimize(
final LeafReaderContext leafCtx,
final BiConsumer<Long, Long> incrementDocCount,
boolean segmentMatchAll,
BucketCollector collectableSubAggregators,
LeafBucketCollector sub
) throws IOException {
segments.incrementAndGet();
if (!canOptimize) {
return false;
Expand All @@ -123,12 +136,43 @@ public boolean tryOptimize(final LeafReaderContext leafCtx, final BiConsumer<Lon
Ranges ranges = getRanges(leafCtx, segmentMatchAll);
if (ranges == null) return false;

consumeDebugInfo(aggregatorBridge.tryOptimize(values, incrementDocCount, ranges));
Supplier<DocIdSetBuilder> disBuilderSupplier = null;
if (subAggLength != 0) {
disBuilderSupplier = () -> {
try {
return new DocIdSetBuilder(leafCtx.reader().maxDoc(), values, aggregatorBridge.fieldType.name());
} catch (IOException e) {
throw new RuntimeException(e);
}
Comment on lines +144 to +146
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we catching a checked exception and rethrowing an unchecked exception here when the method is documented to throw IOException?

};
}
OptimizeResult optimizeResult = aggregatorBridge.tryOptimize(values, incrementDocCount, ranges, disBuilderSupplier);
consumeDebugInfo(optimizeResult);

optimizedSegments.incrementAndGet();
logger.debug("Fast filter optimization applied to shard {} segment {}", shardId, leafCtx.ord);
logger.debug("Crossed leaf nodes: {}, inner nodes: {}", leafNodeVisited, innerNodeVisited);

if (subAggLength == 0) {
return true;
}

// Handle sub aggregation
for (int bucketOrd = 0; bucketOrd < optimizeResult.builders.length; bucketOrd++) {
logger.debug("Collecting bucket {} for sub aggregation", bucketOrd);
DocIdSetBuilder builder = optimizeResult.builders[bucketOrd];
if (builder == null) {
continue;
}
DocIdSetIterator iterator = optimizeResult.builders[bucketOrd].build().iterator();
while (iterator.nextDoc() != NO_MORE_DOCS) {
int currentDoc = iterator.docID();
sub.collect(currentDoc, bucketOrd);
}
// resetting the sub collector after processing each bucket
sub = collectableSubAggregators.getLeafCollector(leafCtx);
}

return true;
}

Expand Down Expand Up @@ -160,10 +204,12 @@ private Ranges getRangesFromSegment(LeafReaderContext leafCtx, boolean segmentMa
/**
* Contains debug info of BKD traversal to show in profile
*/
static class DebugInfo {
static class OptimizeResult {
private final AtomicInteger leafNodeVisited = new AtomicInteger(); // leaf node visited
private final AtomicInteger innerNodeVisited = new AtomicInteger(); // inner node visited

public DocIdSetBuilder[] builders;

void visitLeaf() {
leafNodeVisited.incrementAndGet();
}
Expand All @@ -173,7 +219,7 @@ void visitInner() {
}
}

void consumeDebugInfo(DebugInfo debug) {
void consumeDebugInfo(OptimizeResult debug) {
leafNodeVisited.addAndGet(debug.leafNodeVisited.get());
innerNodeVisited.addAndGet(debug.innerNodeVisited.get());
}
Expand Down
Loading
Loading