Skip to content

Commit

Permalink
HSEARCH-4950 Hide knn clauses in a predicate request context
Browse files Browse the repository at this point in the history
  • Loading branch information
marko-bekhta committed Dec 19, 2023
1 parent 879ade6 commit ac135d1
Show file tree
Hide file tree
Showing 8 changed files with 80 additions and 122 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import org.hibernate.search.backend.elasticsearch.search.common.impl.ElasticsearchSearchIndexScope;
import org.hibernate.search.engine.search.predicate.spi.SearchPredicateBuilder;

import com.google.gson.JsonElement;
import com.google.gson.JsonObject;

public abstract class AbstractElasticsearchPredicate implements ElasticsearchSearchPredicate {
Expand All @@ -37,29 +36,6 @@ public Set<String> indexNames() {
return indexNames;
}

@Override
public JsonElement toJsonKnn(PredicateRequestContext context) {
JsonElement result = doToJsonKnn( context );

if ( result == null ) {
return null;
}

// in case of withConstantScore boots is set by constant_score clause
if ( boost != null ) {
if ( result.isJsonArray() ) {
for ( JsonElement element : result.getAsJsonArray() ) {
BOOST_ACCESSOR.set( element.getAsJsonObject(), boost );
}
}
else {
BOOST_ACCESSOR.set( result.getAsJsonObject(), boost );
}
}

return result;
}

@Override
public JsonObject toJsonQuery(PredicateRequestContext context) {
JsonObject outerObject = new JsonObject();
Expand All @@ -77,10 +53,6 @@ public JsonObject toJsonQuery(PredicateRequestContext context) {
protected abstract JsonObject doToJsonQuery(PredicateRequestContext context,
JsonObject outerObject, JsonObject innerObject);

protected JsonElement doToJsonKnn(PredicateRequestContext context) {
return null;
}

protected boolean hasNoModifiers() {
return !withConstantScore && boost == null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,16 @@
import org.hibernate.search.engine.search.predicate.spi.BooleanPredicateBuilder;
import org.hibernate.search.util.common.logging.impl.LoggerFactory;

import com.google.gson.JsonArray;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;

class ElasticsearchBooleanPredicate extends AbstractElasticsearchPredicate {

private static final Log log = LoggerFactory.make( Log.class, MethodHandles.lookup() );

static final String MUST_PROPERTY_NAME = "must";
static final String MUST_NOT_PROPERTY_NAME = "must_not";
static final String SHOULD_PROPERTY_NAME = "should";
static final String FILTER_PROPERTY_NAME = "filter";
private static final String MUST_PROPERTY_NAME = "must";
private static final String MUST_NOT_PROPERTY_NAME = "must_not";
private static final String SHOULD_PROPERTY_NAME = "should";
private static final String FILTER_PROPERTY_NAME = "filter";

private static final JsonAccessor<String> MINIMUM_SHOULD_MATCH_ACCESSOR =
JsonAccessor.root().property( "minimum_should_match" ).asString();
Expand Down Expand Up @@ -73,22 +71,6 @@ public void checkNestableWithin(String expectedParentNestedPath) {
checkNestableWithin( expectedParentNestedPath, mustNotClauses );
}

@Override
protected JsonArray doToJsonKnn(PredicateRequestContext context) {
JsonArray knns = new JsonArray();
if ( shouldClauses == null ) {
return null;
}

for ( ElasticsearchSearchPredicate clause : shouldClauses ) {
JsonElement knn = clause.toJsonKnn( context );
if ( knn != null ) {
knns.add( knn.getAsJsonObject() );
}
}
return knns.isEmpty() ? null : knns;
}

@Override
protected JsonObject doToJsonQuery(PredicateRequestContext context,
JsonObject outerObject, JsonObject innerObject) {
Expand Down Expand Up @@ -122,9 +104,18 @@ private void contributeClauses(PredicateRequestContext context, JsonObject inner
}

for ( ElasticsearchSearchPredicate clause : clauses ) {
JsonObject clauseQuery = clause.toJsonQuery( context );
if ( clauseQuery != null ) {
GsonUtils.setOrAppendToArray( innerObject, occurProperty, clauseQuery );
JsonObject jsonQuery = clause.toJsonQuery( context );
if ( jsonQuery == null ) {
if ( !SHOULD_PROPERTY_NAME.equals( occurProperty ) ) {
throw log.knnPredicateCanOnlyBeShouldClause();
}
// This is an exceptional case for a KNN search on Elasticsearch distribution.
// A Knn predicate would contribute to a knn clause inside the request context itself
// and we do not want to add this json as a clause to the bool predicate.
// So the predicate returns null as JSON query and we ignore it.
}
else {
GsonUtils.setOrAppendToArray( innerObject, occurProperty, jsonQuery );
}
}
}
Expand Down Expand Up @@ -207,35 +198,31 @@ public void must(SearchPredicate clause) {
if ( mustClauses == null ) {
mustClauses = new ArrayList<>();
}
mustClauses.add( ElasticsearchSearchPredicate.from( scope, clause )
.checkAcceptableAsBoolPredicateClause( MUST_PROPERTY_NAME ) );
mustClauses.add( ElasticsearchSearchPredicate.from( scope, clause ) );
}

@Override
public void mustNot(SearchPredicate clause) {
if ( mustNotClauses == null ) {
mustNotClauses = new ArrayList<>();
}
mustNotClauses.add( ElasticsearchSearchPredicate.from( scope, clause )
.checkAcceptableAsBoolPredicateClause( MUST_NOT_PROPERTY_NAME ) );
mustNotClauses.add( ElasticsearchSearchPredicate.from( scope, clause ) );
}

@Override
public void should(SearchPredicate clause) {
if ( shouldClauses == null ) {
shouldClauses = new ArrayList<>();
}
shouldClauses.add( ElasticsearchSearchPredicate.from( scope, clause )
.checkAcceptableAsBoolPredicateClause( SHOULD_PROPERTY_NAME ) );
shouldClauses.add( ElasticsearchSearchPredicate.from( scope, clause ) );
}

@Override
public void filter(SearchPredicate clause) {
if ( filterClauses == null ) {
filterClauses = new ArrayList<>();
}
filterClauses.add( ElasticsearchSearchPredicate.from( scope, clause )
.checkAcceptableAsBoolPredicateClause( FILTER_PROPERTY_NAME ) );
filterClauses.add( ElasticsearchSearchPredicate.from( scope, clause ) );
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,21 +157,18 @@ private ElasticsearchImpl(Builder<?> builder) {
}

@Override
protected JsonObject doToJsonKnn(PredicateRequestContext context) {
JsonObject object = new JsonObject();
FIELD_ACCESSOR.set( object, absoluteFieldPath );
K_ACCESSOR.set( object, k );
public JsonObject toJsonQuery(PredicateRequestContext context) {
JsonObject knn = new JsonObject();
FIELD_ACCESSOR.set( knn, absoluteFieldPath );
K_ACCESSOR.set( knn, k );
if ( filter != null ) {
FILTER_ACCESSOR.set( object, filter.toJsonQuery( context ) );
FILTER_ACCESSOR.set( knn, filter.toJsonQuery( context ) );
}
NUM_CANDIDATES_ACCESSOR.set( object, numberOfCandidates != null ? numberOfCandidates : k );
QUERY_VECTOR_ACCESSOR.set( object, vector );
NUM_CANDIDATES_ACCESSOR.set( knn, numberOfCandidates != null ? numberOfCandidates : k );
QUERY_VECTOR_ACCESSOR.set( knn, vector );

return object;
}
context.contributeKnnClause( knn );

@Override
public JsonObject toJsonQuery(PredicateRequestContext context) {
return null;
}

Expand All @@ -180,14 +177,6 @@ protected JsonObject doToJsonQuery(PredicateRequestContext context, JsonObject o
throw new AssertionFailure( "Shouldn't be reached since we've overridden the toJsonQuery" );
}

@Override
public ElasticsearchSearchPredicate checkAcceptableAsBoolPredicateClause(String clauseType) {
if ( !ElasticsearchBooleanPredicate.SHOULD_PROPERTY_NAME.equals( clauseType ) ) {
throw log.knnPredicateCanOnlyBeShouldClause();
}
return super.checkAcceptableAsBoolPredicateClause( clauseType );
}

@Override
public void checkNestableWithin(String expectedParentNestedPath) {
if ( expectedParentNestedPath != null ) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import org.hibernate.search.engine.search.predicate.SearchPredicate;
import org.hibernate.search.util.common.logging.impl.LoggerFactory;

import com.google.gson.JsonElement;
import com.google.gson.JsonObject;

public interface ElasticsearchSearchPredicate extends SearchPredicate {
Expand All @@ -25,14 +24,8 @@ public interface ElasticsearchSearchPredicate extends SearchPredicate {

void checkNestableWithin(String expectedParentNestedPath);

default ElasticsearchSearchPredicate checkAcceptableAsBoolPredicateClause(String clauseType) {
return this;
}

JsonObject toJsonQuery(PredicateRequestContext context);

JsonElement toJsonKnn(PredicateRequestContext rootPredicateContext);

static ElasticsearchSearchPredicate from(ElasticsearchSearchIndexScope<?> scope, SearchPredicate predicate) {
if ( !( predicate instanceof ElasticsearchSearchPredicate ) ) {
throw log.cannotMixElasticsearchSearchQueryWithOtherPredicates( predicate );
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import org.hibernate.search.backend.elasticsearch.search.common.impl.ElasticsearchSearchIndexScope;

import com.google.gson.JsonElement;
import com.google.gson.JsonObject;

class ElasticsearchUserProvidedJsonPredicate implements ElasticsearchSearchPredicate {
Expand Down Expand Up @@ -38,9 +37,4 @@ public void checkNestableWithin(String expectedParentNestedPath) {
public JsonObject toJsonQuery(PredicateRequestContext context) {
return json;
}

@Override
public JsonElement toJsonKnn(PredicateRequestContext rootPredicateContext) {
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,21 @@
*/
package org.hibernate.search.backend.elasticsearch.search.predicate.impl;

import org.hibernate.search.backend.elasticsearch.gson.impl.JsonAccessor;
import org.hibernate.search.backend.elasticsearch.gson.impl.JsonObjectAccessor;
import org.hibernate.search.backend.elasticsearch.lowlevel.query.impl.Queries;
import org.hibernate.search.engine.backend.session.spi.BackendSessionContext;

import com.google.gson.JsonArray;
import com.google.gson.JsonElement;
import com.google.gson.JsonNull;
import com.google.gson.JsonObject;

public class PredicateRequestContext {

private final BackendSessionContext sessionContext;
private final String nestedPath;
private JsonElement jsonKnn = JsonNull.INSTANCE;

public PredicateRequestContext(BackendSessionContext sessionContext) {
this.sessionContext = sessionContext;
Expand All @@ -34,4 +43,44 @@ public PredicateRequestContext withNestedPath(String path) {
public String getNestedPath() {
return nestedPath;
}

public void contributeKnnClause(JsonObject knn) {
if ( jsonKnn.isJsonNull() ) {
jsonKnn = knn;
}
else if ( jsonKnn.isJsonArray() ) {
jsonKnn.getAsJsonArray().add( knn );
}
else {
JsonArray array = new JsonArray();
array.add( jsonKnn );
array.add( knn );
jsonKnn = array;
}
}

public JsonElement knnSearch(JsonArray filters) {
if ( jsonKnn.isJsonNull() ) {
return jsonKnn;
}
return addFiltersToKnn( jsonKnn, filters );
}

private static JsonElement addFiltersToKnn(JsonElement jsonKnn, JsonArray filters) {
if ( filters == null || filters.isEmpty() ) {
return jsonKnn;
}

if ( jsonKnn.isJsonArray() ) {
for ( JsonElement jsonElement : jsonKnn.getAsJsonArray() ) {
addFiltersToKnn( jsonElement.getAsJsonObject(), filters );
}
}
return jsonKnn;
}

private static void addFiltersToKnn(JsonObject jsonKnn, JsonArray filters) {
JsonObjectAccessor filterAccessor = JsonAccessor.root().property( "filter" ).asObject();
filterAccessor.set( jsonKnn, Queries.boolFilter( filterAccessor.getOrCreate( jsonKnn, Queries::matchAll ), filters ) );
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import java.util.concurrent.TimeUnit;

import org.hibernate.search.backend.elasticsearch.gson.impl.JsonAccessor;
import org.hibernate.search.backend.elasticsearch.gson.impl.JsonObjectAccessor;
import org.hibernate.search.backend.elasticsearch.logging.impl.Log;
import org.hibernate.search.backend.elasticsearch.lowlevel.query.impl.Queries;
import org.hibernate.search.backend.elasticsearch.orchestration.impl.ElasticsearchParallelWorkOrchestrator;
Expand Down Expand Up @@ -74,7 +73,6 @@ public class ElasticsearchSearchQueryBuilder<H>

private final Set<String> routingKeys;
private JsonObject jsonPredicate;
private JsonElement jsonKnn;
private JsonArray jsonSort;
private Map<DistanceSortKey, Integer> distanceSorts;
private Map<AggregationKey<?>, ElasticsearchSearchAggregation<?>> aggregations;
Expand Down Expand Up @@ -113,7 +111,6 @@ public ElasticsearchSearchQueryBuilder(
public void predicate(SearchPredicate predicate) {
ElasticsearchSearchPredicate elasticsearchPredicate = ElasticsearchSearchPredicate.from( scope, predicate );
this.jsonPredicate = elasticsearchPredicate.toJsonQuery( rootPredicateContext );
this.jsonKnn = elasticsearchPredicate.toJsonKnn( rootPredicateContext );
}

@Override
Expand Down Expand Up @@ -240,8 +237,9 @@ public ElasticsearchSearchQuery<H> build() {
payload.add( "query", jsonQuery );
}

if ( jsonKnn != null ) {
payload.add( "knn", addFiltersToKnn( jsonKnn, filters ) );
JsonElement jsonKnn = rootPredicateContext.knnSearch( filters );
if ( !jsonKnn.isJsonNull() ) {
payload.add( "knn", jsonKnn );
}

if ( jsonSort != null ) {
Expand Down Expand Up @@ -294,23 +292,4 @@ public ElasticsearchSearchQuery<H> build() {
scrollTimeout, totalHitCountThreshold
);
}

private static JsonElement addFiltersToKnn(JsonElement jsonKnn, JsonArray filters) {
if ( filters == null || filters.isEmpty() ) {
return jsonKnn;
}

if ( jsonKnn.isJsonArray() ) {
for ( JsonElement jsonElement : jsonKnn.getAsJsonArray() ) {
addFiltersToKnn( jsonElement.getAsJsonObject(), filters );
}
}
return jsonKnn;
}

private static JsonElement addFiltersToKnn(JsonObject jsonKnn, JsonArray filters) {
JsonObjectAccessor filterAccessor = JsonAccessor.root().property( "filter" ).asObject();
filterAccessor.set( jsonKnn, Queries.boolFilter( filterAccessor.getOrCreate( jsonKnn, Queries::matchAll ), filters ) );
return jsonKnn;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,6 @@ public void filter(SearchPredicate filter) {
this.filter = LuceneSearchPredicate.from( scope, filter );
}

@Override
public void numberOfCandidates(int numberOfCandidates) {
throw log.knnNumberOfCandidatesUnsupportedOption();
}

@Override
public SearchPredicate build() {
return new LuceneKnnPredicate( this );
Expand Down

0 comments on commit ac135d1

Please sign in to comment.