Skip to content

Add pipeline aggregations to NativeSearchQuery. #1809

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@
import org.elasticsearch.index.reindex.UpdateByQueryRequest;
import org.elasticsearch.index.reindex.UpdateByQueryRequestBuilder;
import org.elasticsearch.script.Script;
import org.elasticsearch.search.aggregations.AbstractAggregationBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.fetch.subphase.FetchSourceContext;
import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder;
Expand Down Expand Up @@ -1119,9 +1118,11 @@ private void prepareNativeSearch(NativeSearchQuery query, SearchSourceBuilder so
}

if (!isEmpty(query.getAggregations())) {
for (AbstractAggregationBuilder<?> aggregationBuilder : query.getAggregations()) {
sourceBuilder.aggregation(aggregationBuilder);
}
query.getAggregations().forEach(sourceBuilder::aggregation);
}

if (!isEmpty(query.getPipelineAggregations())) {
query.getPipelineAggregations().forEach(sourceBuilder::aggregation);
}

}
Expand All @@ -1144,9 +1145,11 @@ private void prepareNativeSearch(SearchRequestBuilder searchRequestBuilder, Nati
}

if (!isEmpty(nativeSearchQuery.getAggregations())) {
for (AbstractAggregationBuilder<?> aggregationBuilder : nativeSearchQuery.getAggregations()) {
searchRequestBuilder.addAggregation(aggregationBuilder);
}
nativeSearchQuery.getAggregations().forEach(searchRequestBuilder::addAggregation);
}

if (!isEmpty(nativeSearchQuery.getPipelineAggregations())) {
nativeSearchQuery.getPipelineAggregations().forEach(searchRequestBuilder::addAggregation);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.script.mustache.SearchTemplateRequestBuilder;
import org.elasticsearch.search.aggregations.AbstractAggregationBuilder;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.search.collapse.CollapseBuilder;
import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder;
import org.elasticsearch.search.sort.SortBuilder;
Expand All @@ -48,6 +49,7 @@ public class NativeSearchQuery extends AbstractQuery {
private final List<ScriptField> scriptFields = new ArrayList<>();
@Nullable private CollapseBuilder collapseBuilder;
@Nullable private List<AbstractAggregationBuilder<?>> aggregations;
@Nullable private List<PipelineAggregationBuilder> pipelineAggregations;
@Nullable private HighlightBuilder highlightBuilder;
@Nullable private HighlightBuilder.Field[] highlightFields;
@Nullable private List<IndexBoost> indicesBoost;
Expand Down Expand Up @@ -143,6 +145,11 @@ public List<AbstractAggregationBuilder<?>> getAggregations() {
return aggregations;
}

@Nullable
public List<PipelineAggregationBuilder> getPipelineAggregations() {
return pipelineAggregations;
}

public void addAggregation(AbstractAggregationBuilder<?> aggregationBuilder) {

if (aggregations == null) {
Expand All @@ -156,6 +163,10 @@ public void setAggregations(List<AbstractAggregationBuilder<?>> aggregations) {
this.aggregations = aggregations;
}

public void setPipelineAggregations(List<PipelineAggregationBuilder> pipelineAggregationBuilders) {
this.pipelineAggregations = pipelineAggregationBuilders;
}

@Nullable
public List<IndexBoost> getIndicesBoost() {
return indicesBoost;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.script.mustache.SearchTemplateRequestBuilder;
import org.elasticsearch.search.aggregations.AbstractAggregationBuilder;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.search.collapse.CollapseBuilder;
import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder;
import org.elasticsearch.search.sort.SortBuilder;
Expand Down Expand Up @@ -55,6 +56,7 @@ public class NativeSearchQueryBuilder {
private final List<ScriptField> scriptFields = new ArrayList<>();
private final List<SortBuilder<?>> sortBuilders = new ArrayList<>();
private final List<AbstractAggregationBuilder<?>> aggregationBuilders = new ArrayList<>();
private final List<PipelineAggregationBuilder> pipelineAggregationBuilders = new ArrayList<>();
@Nullable private HighlightBuilder highlightBuilder;
@Nullable private HighlightBuilder.Field[] highlightFields;
private Pageable pageable = Pageable.unpaged();
Expand Down Expand Up @@ -105,6 +107,14 @@ public NativeSearchQueryBuilder addAggregation(AbstractAggregationBuilder<?> agg
return this;
}

/**
* @since 4.3
*/
public NativeSearchQueryBuilder addAggregation(PipelineAggregationBuilder pipelineAggregationBuilder) {
this.pipelineAggregationBuilders.add(pipelineAggregationBuilder);
return this;
}

public NativeSearchQueryBuilder withHighlightBuilder(HighlightBuilder highlightBuilder) {
this.highlightBuilder = highlightBuilder;
return this;
Expand Down Expand Up @@ -239,6 +249,10 @@ public NativeSearchQuery build() {
nativeSearchQuery.setAggregations(aggregationBuilders);
}

if (!isEmpty(pipelineAggregationBuilders)) {
nativeSearchQuery.setPipelineAggregations(pipelineAggregationBuilders);
}

if (minScore > 0) {
nativeSearchQuery.setMinScore(minScore);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import static org.assertj.core.api.Assertions.*;
import static org.elasticsearch.index.query.QueryBuilders.*;
import static org.elasticsearch.search.aggregations.AggregationBuilders.*;
import static org.elasticsearch.search.aggregations.PipelineAggregatorBuilders.*;
import static org.springframework.data.elasticsearch.annotations.FieldType.*;
import static org.springframework.data.elasticsearch.annotations.FieldType.Integer;

Expand All @@ -26,9 +27,14 @@
import java.util.List;

import org.elasticsearch.action.search.SearchType;
import org.elasticsearch.search.aggregations.Aggregation;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.pipeline.InternalStatsBucket;
import org.elasticsearch.search.aggregations.pipeline.ParsedStatsBucket;
import org.elasticsearch.search.aggregations.pipeline.StatsBucket;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Configuration;
Expand Down Expand Up @@ -109,7 +115,7 @@ public void after() {
indexOperations.delete();
}

@Test
@Test // DATAES-96
public void shouldReturnAggregatedResponseForGivenSearchQuery() {

// given
Expand All @@ -130,6 +136,56 @@ public void shouldReturnAggregatedResponseForGivenSearchQuery() {
assertThat(searchHits.hasSearchHits()).isFalse();
}

@Test // #1255
@DisplayName("should work with pipeline aggregations")
void shouldWorkWithPipelineAggregations() {

IndexInitializer.init(operations.indexOps(PipelineAggsEntity.class));
operations.save( //
new PipelineAggsEntity("1-1", "one"), //
new PipelineAggsEntity("2-1", "two"), //
new PipelineAggsEntity("2-2", "two"), //
new PipelineAggsEntity("3-1", "three"), //
new PipelineAggsEntity("3-2", "three"), //
new PipelineAggsEntity("3-3", "three") //
); //

NativeSearchQuery searchQuery = new NativeSearchQueryBuilder() //
.withQuery(matchAllQuery()) //
.withSearchType(SearchType.DEFAULT) //
.addAggregation(terms("keyword_aggs").field("keyword")) //
.addAggregation(statsBucket("keyword_bucket_stats", "keyword_aggs._count")) //
.withMaxResults(0) //
.build();

SearchHits<PipelineAggsEntity> searchHits = operations.search(searchQuery, PipelineAggsEntity.class);

Aggregations aggregations = searchHits.getAggregations();
assertThat(aggregations).isNotNull();
assertThat(aggregations.asMap().get("keyword_aggs")).isNotNull();
Aggregation keyword_bucket_stats = aggregations.asMap().get("keyword_bucket_stats");
assertThat(keyword_bucket_stats).isInstanceOf(StatsBucket.class);
if (keyword_bucket_stats instanceof ParsedStatsBucket) {
// Rest client
ParsedStatsBucket statsBucket = (ParsedStatsBucket) keyword_bucket_stats;
assertThat(statsBucket.getMin()).isEqualTo(1.0);
assertThat(statsBucket.getMax()).isEqualTo(3.0);
assertThat(statsBucket.getAvg()).isEqualTo(2.0);
assertThat(statsBucket.getSum()).isEqualTo(6.0);
assertThat(statsBucket.getCount()).isEqualTo(3L);
}
if (keyword_bucket_stats instanceof InternalStatsBucket) {
// transport client
InternalStatsBucket statsBucket = (InternalStatsBucket) keyword_bucket_stats;
assertThat(statsBucket.getMin()).isEqualTo(1.0);
assertThat(statsBucket.getMax()).isEqualTo(3.0);
assertThat(statsBucket.getAvg()).isEqualTo(2.0);
assertThat(statsBucket.getSum()).isEqualTo(6.0);
assertThat(statsBucket.getCount()).isEqualTo(3L);
}
}

// region entities
@Document(indexName = "test-index-articles-core-aggregation")
static class ArticleEntity {

Expand Down Expand Up @@ -256,4 +312,34 @@ public IndexQuery buildIndex() {
}
}

@Document(indexName = "pipeline-aggs")
static class PipelineAggsEntity {
@Id private String id;
@Field(type = Keyword) private String keyword;

public PipelineAggsEntity() {}

public PipelineAggsEntity(String id, String keyword) {
this.id = id;
this.keyword = keyword;
}

public String getId() {
return id;
}

public void setId(String id) {
this.id = id;
}

public String getKeyword() {
return keyword;
}

public void setKeyword(String keyword) {
this.keyword = keyword;
}
}
// endregion

}