From aeeb486f8a902583e2b122f12b4d896b0227f28a Mon Sep 17 00:00:00 2001 From: marko-bekhta Date: Tue, 16 Jan 2024 17:04:40 +0100 Subject: [PATCH] HSEARCH-4950 Update knn restrictions for Elasticsearch --- ...bstractElasticsearchNestablePredicate.java | 2 +- .../impl/AbstractElasticsearchPredicate.java | 2 +- .../impl/ElasticsearchBooleanPredicate.java | 40 +++++++- .../impl/ElasticsearchKnnPredicate.java | 2 +- .../impl/ElasticsearchMatchAllPredicate.java | 2 +- .../impl/ElasticsearchMatchIdPredicate.java | 2 +- .../impl/ElasticsearchMatchNonePredicate.java | 2 +- .../impl/ElasticsearchNamedPredicate.java | 2 +- .../impl/ElasticsearchSearchPredicate.java | 16 ++- ...lasticsearchUserProvidedJsonPredicate.java | 4 +- .../impl/PredicateNestingContext.java | 15 --- .../ElasticsearchKnnPredicateSpecificsIT.java | 98 ++++++++++++++----- 12 files changed, 128 insertions(+), 59 deletions(-) diff --git a/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/predicate/impl/AbstractElasticsearchNestablePredicate.java b/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/predicate/impl/AbstractElasticsearchNestablePredicate.java index 9ff4274d05b..8cf2d914fe4 100644 --- a/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/predicate/impl/AbstractElasticsearchNestablePredicate.java +++ b/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/predicate/impl/AbstractElasticsearchNestablePredicate.java @@ -25,7 +25,7 @@ public abstract class AbstractElasticsearchNestablePredicate extends AbstractEla @Override - public void doCheckNestableWithin(PredicateNestingContext context) { + public void checkNestableWithin(PredicateNestingContext context) { List nestedPathHierarchy = getNestedPathHierarchy(); String expectedParentNestedPath = context.getNestedPath(); diff --git a/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/predicate/impl/AbstractElasticsearchPredicate.java b/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/predicate/impl/AbstractElasticsearchPredicate.java index 1dc7d385136..4e0c7612355 100644 --- a/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/predicate/impl/AbstractElasticsearchPredicate.java +++ b/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/predicate/impl/AbstractElasticsearchPredicate.java @@ -14,7 +14,7 @@ import com.google.gson.JsonObject; -public abstract class AbstractElasticsearchPredicate extends ElasticsearchSearchPredicate { +public abstract class AbstractElasticsearchPredicate implements ElasticsearchSearchPredicate { private static final JsonAccessor BOOST_ACCESSOR = JsonAccessor.root().property( "boost" ).asFloat(); diff --git a/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/predicate/impl/ElasticsearchBooleanPredicate.java b/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/predicate/impl/ElasticsearchBooleanPredicate.java index 5201efbcc3c..bc4934333e5 100644 --- a/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/predicate/impl/ElasticsearchBooleanPredicate.java +++ b/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/predicate/impl/ElasticsearchBooleanPredicate.java @@ -64,7 +64,7 @@ private ElasticsearchBooleanPredicate(Builder builder) { } @Override - public void doCheckNestableWithin(PredicateNestingContext context) { + public void checkNestableWithin(PredicateNestingContext context) { checkAcceptableWithin( context, mustClauses ); checkAcceptableWithin( context, shouldClauses ); checkAcceptableWithin( context, filterClauses ); @@ -217,6 +217,8 @@ public void must(SearchPredicate clause) { ElasticsearchSearchPredicate elasticsearchClause = ElasticsearchSearchPredicate.from( scope, clause ); elasticsearchClause.checkNestableWithin( PredicateNestingContext.doesNotAcceptKnn() ); mustClauses.add( elasticsearchClause ); + + checkShouldClauses(); } @Override @@ -227,6 +229,8 @@ public void mustNot(SearchPredicate clause) { ElasticsearchSearchPredicate elasticsearchClause = ElasticsearchSearchPredicate.from( scope, clause ); elasticsearchClause.checkNestableWithin( PredicateNestingContext.doesNotAcceptKnn() ); mustNotClauses.add( elasticsearchClause ); + + checkShouldClauses(); } @Override @@ -234,11 +238,23 @@ public void should(SearchPredicate clause) { if ( shouldClauses == null ) { shouldClauses = new ArrayList<>(); } + ElasticsearchSearchPredicate elasticsearchClause = ElasticsearchSearchPredicate.from( scope, clause ); - elasticsearchClause.checkNestableWithin( PredicateNestingContext.acceptsKnn() ); + elasticsearchClause.checkNestableWithin( + !hasNonShouldClause() && maybeKnnClause( clause ) + ? PredicateNestingContext.acceptsKnn() + : PredicateNestingContext.doesNotAcceptKnn() ); shouldClauses.add( elasticsearchClause ); } + private boolean maybeKnnClause(SearchPredicate clause) { + return clause instanceof ElasticsearchKnnPredicate + || ( clause instanceof ElasticsearchNamedPredicate + // TODO: + // && clause.providedPredicate instanceof ElasticsearchKnnPredicate + ); + } + @Override public void filter(SearchPredicate clause) { if ( filterClauses == null ) { @@ -247,6 +263,8 @@ public void filter(SearchPredicate clause) { ElasticsearchSearchPredicate elasticsearchClause = ElasticsearchSearchPredicate.from( scope, clause ); elasticsearchClause.checkNestableWithin( PredicateNestingContext.doesNotAcceptKnn() ); filterClauses.add( elasticsearchClause ); + + checkShouldClauses(); } @Override @@ -335,6 +353,20 @@ private void optimizeClauseCollection(List collect } } + /* + * For Elasticsearch backend only: + * It may be that we've added knn should clauses and all was fine. + * But now we are adding to such bool predicate a non-should clause. + * We want to make sure that in such case there are no knn clauses in should or fail. + * OpenSearch backend will not be affected by these checks. + */ + private void checkShouldClauses() { + if ( shouldClauses != null ) { + shouldClauses.forEach( + should -> should.checkNestableWithin( PredicateNestingContext.doesNotAcceptKnn() ) ); + } + } + private void checkAndClearClauseCollections() { if ( mustClauses != null && mustClauses.isEmpty() ) { mustClauses = null; @@ -348,6 +380,10 @@ private boolean hasAtLeastOneMustOrFilterPredicate() { return mustClauses != null || filterClauses != null; } + private boolean hasNonShouldClause() { + return mustClauses != null || filterClauses != null || mustNotClauses != null; + } + private boolean hasOnlyOneMustClause() { return mustClauses != null && mustClauses.size() == 1 diff --git a/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/predicate/impl/ElasticsearchKnnPredicate.java b/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/predicate/impl/ElasticsearchKnnPredicate.java index 82bbd88bb8f..7176e8a33f7 100644 --- a/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/predicate/impl/ElasticsearchKnnPredicate.java +++ b/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/predicate/impl/ElasticsearchKnnPredicate.java @@ -183,7 +183,7 @@ protected JsonObject doToJsonQuery(PredicateRequestContext context, JsonObject o } @Override - public void doCheckNestableWithin(PredicateNestingContext context) { + public void checkNestableWithin(PredicateNestingContext context) { if ( context.getNestedPath() != null || !context.acceptsKnnClause() ) { throw log.cannotAddKnnClauseAtThisStep(); } diff --git a/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/predicate/impl/ElasticsearchMatchAllPredicate.java b/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/predicate/impl/ElasticsearchMatchAllPredicate.java index db410843ff0..1d70f0c6fb9 100644 --- a/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/predicate/impl/ElasticsearchMatchAllPredicate.java +++ b/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/predicate/impl/ElasticsearchMatchAllPredicate.java @@ -23,7 +23,7 @@ private ElasticsearchMatchAllPredicate(Builder builder) { } @Override - public void doCheckNestableWithin(PredicateNestingContext context) { + public void checkNestableWithin(PredicateNestingContext context) { // Nothing to do } diff --git a/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/predicate/impl/ElasticsearchMatchIdPredicate.java b/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/predicate/impl/ElasticsearchMatchIdPredicate.java index 62fe5c659fc..3b61b8e81af 100644 --- a/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/predicate/impl/ElasticsearchMatchIdPredicate.java +++ b/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/predicate/impl/ElasticsearchMatchIdPredicate.java @@ -59,7 +59,7 @@ private ElasticsearchMatchIdPredicate(Builder builder) { } @Override - public void doCheckNestableWithin(PredicateNestingContext context) { + public void checkNestableWithin(PredicateNestingContext context) { // Nothing to do } diff --git a/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/predicate/impl/ElasticsearchMatchNonePredicate.java b/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/predicate/impl/ElasticsearchMatchNonePredicate.java index fe45ed4a06a..5e3e37cdbb2 100644 --- a/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/predicate/impl/ElasticsearchMatchNonePredicate.java +++ b/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/predicate/impl/ElasticsearchMatchNonePredicate.java @@ -24,7 +24,7 @@ class ElasticsearchMatchNonePredicate extends AbstractElasticsearchPredicate { } @Override - public void doCheckNestableWithin(PredicateNestingContext context) { + public void checkNestableWithin(PredicateNestingContext context) { // Nothing to do } diff --git a/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/predicate/impl/ElasticsearchNamedPredicate.java b/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/predicate/impl/ElasticsearchNamedPredicate.java index 3ffb4fc1cfc..90ea72375a4 100644 --- a/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/predicate/impl/ElasticsearchNamedPredicate.java +++ b/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/predicate/impl/ElasticsearchNamedPredicate.java @@ -38,7 +38,7 @@ private ElasticsearchNamedPredicate(Builder builder, ElasticsearchSearchPredicat } @Override - public void doCheckNestableWithin(PredicateNestingContext context) { + public void checkNestableWithin(PredicateNestingContext context) { providedPredicate.checkNestableWithin( context ); super.checkNestableWithin( context ); } diff --git a/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/predicate/impl/ElasticsearchSearchPredicate.java b/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/predicate/impl/ElasticsearchSearchPredicate.java index 50003d363bf..af501e21566 100644 --- a/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/predicate/impl/ElasticsearchSearchPredicate.java +++ b/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/predicate/impl/ElasticsearchSearchPredicate.java @@ -16,21 +16,17 @@ import com.google.gson.JsonObject; -public abstract class ElasticsearchSearchPredicate implements SearchPredicate { +public interface ElasticsearchSearchPredicate extends SearchPredicate { - private static final Log log = LoggerFactory.make( Log.class, MethodHandles.lookup() ); + Log log = LoggerFactory.make( Log.class, MethodHandles.lookup() ); - public abstract Set indexNames(); + Set indexNames(); - public final void checkNestableWithin(PredicateNestingContext context) { - doCheckNestableWithin( context.wrap( this ) ); - } - - protected abstract void doCheckNestableWithin(PredicateNestingContext context); + void checkNestableWithin(PredicateNestingContext context); - public abstract JsonObject toJsonQuery(PredicateRequestContext context); + JsonObject toJsonQuery(PredicateRequestContext context); - public static ElasticsearchSearchPredicate from(ElasticsearchSearchIndexScope scope, SearchPredicate predicate) { + static ElasticsearchSearchPredicate from(ElasticsearchSearchIndexScope scope, SearchPredicate predicate) { if ( !( predicate instanceof ElasticsearchSearchPredicate ) ) { throw log.cannotMixElasticsearchSearchQueryWithOtherPredicates( predicate ); } diff --git a/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/predicate/impl/ElasticsearchUserProvidedJsonPredicate.java b/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/predicate/impl/ElasticsearchUserProvidedJsonPredicate.java index 74d338b3251..e30c6c57a8c 100644 --- a/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/predicate/impl/ElasticsearchUserProvidedJsonPredicate.java +++ b/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/predicate/impl/ElasticsearchUserProvidedJsonPredicate.java @@ -12,7 +12,7 @@ import com.google.gson.JsonObject; -class ElasticsearchUserProvidedJsonPredicate extends ElasticsearchSearchPredicate { +class ElasticsearchUserProvidedJsonPredicate implements ElasticsearchSearchPredicate { private final Set indexNames; private final JsonObject json; @@ -29,7 +29,7 @@ public Set indexNames() { } @Override - public void doCheckNestableWithin(PredicateNestingContext context) { + public void checkNestableWithin(PredicateNestingContext context) { // Nothing to do: we'll assume the user knows what they are doing. } diff --git a/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/predicate/impl/PredicateNestingContext.java b/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/predicate/impl/PredicateNestingContext.java index 5ba3497f3f0..f895ab3fd83 100644 --- a/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/predicate/impl/PredicateNestingContext.java +++ b/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/predicate/impl/PredicateNestingContext.java @@ -12,8 +12,6 @@ public class PredicateNestingContext { private final String nestedPath; private final boolean acceptsKnnClause; - private final Class predicateType; - public static PredicateNestingContext acceptsKnn() { return ACCEPTS_KNN; } @@ -27,14 +25,8 @@ public static PredicateNestingContext nested(String nestedPath) { } private PredicateNestingContext(String nestedPath, boolean acceptsKnnClause) { - this( nestedPath, acceptsKnnClause, null ); - } - - private PredicateNestingContext(String nestedPath, boolean acceptsKnnClause, - Class predicateType) { this.nestedPath = nestedPath; this.acceptsKnnClause = acceptsKnnClause; - this.predicateType = predicateType; } private PredicateNestingContext(String nestedPath) { @@ -53,11 +45,4 @@ public boolean acceptsKnnClause() { return acceptsKnnClause; } - public PredicateNestingContext wrap(ElasticsearchSearchPredicate elasticsearchSearchPredicate) { - return new PredicateNestingContext( - nestedPath, - acceptsKnnClause && this.predicateType == null, - elasticsearchSearchPredicate.getClass() - ); - } } diff --git a/integrationtest/backend/elasticsearch/src/test/java/org/hibernate/search/integrationtest/backend/elasticsearch/search/ElasticsearchKnnPredicateSpecificsIT.java b/integrationtest/backend/elasticsearch/src/test/java/org/hibernate/search/integrationtest/backend/elasticsearch/search/ElasticsearchKnnPredicateSpecificsIT.java index 2a5b2e65035..2986190921e 100644 --- a/integrationtest/backend/elasticsearch/src/test/java/org/hibernate/search/integrationtest/backend/elasticsearch/search/ElasticsearchKnnPredicateSpecificsIT.java +++ b/integrationtest/backend/elasticsearch/src/test/java/org/hibernate/search/integrationtest/backend/elasticsearch/search/ElasticsearchKnnPredicateSpecificsIT.java @@ -11,6 +11,7 @@ import static org.junit.jupiter.api.Assumptions.assumeFalse; import static org.junit.jupiter.api.Assumptions.assumeTrue; +import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.function.Consumer; @@ -183,29 +184,64 @@ void knnPredicateInWrongPlace_addingPrebuiltKnn() { StubMappingScope scope = index.createScope(); // all good: - SearchPredicate knn = scope.predicate().knn( 15 ).field( "location" ).matching( 50.0f, 50.0f ).toPredicate(); - // we add multiple clauses to prevent "optimizations" leading to bool predicate being replaced by a simple knn predicate - SearchPredicate boolKnnInShould = scope.predicate().bool().should( knn ) - .should( scope.predicate().match().field( "parking" ).matching( Boolean.TRUE ) ).toPredicate(); - - // adding knn to non-should bool clauses is not ok, - // we only allow knn as a should clause! - knnPredicateInWrongPlace( () -> scope.predicate().bool().must( knn ) ); - knnPredicateInWrongPlace( () -> scope.predicate().bool().mustNot( knn ) ); - knnPredicateInWrongPlace( () -> scope.predicate().bool().filter( knn ) ); - - // adding boolean predicate with a should knn clause as any boolean clause (nesting a correct bool into another one) - // is not ok ! - knnPredicateInWrongPlace( () -> scope.predicate().bool().should( boolKnnInShould ) ); - knnPredicateInWrongPlace( () -> scope.predicate().bool().must( boolKnnInShould ) ); - knnPredicateInWrongPlace( () -> scope.predicate().bool().mustNot( boolKnnInShould ) ); - knnPredicateInWrongPlace( () -> scope.predicate().bool().filter( boolKnnInShould ) ); - - // adding as a knn filter: - knnPredicateInWrongPlace( - () -> scope.predicate().knn( 10 ).field( "location" ).matching( 50.0f, 50.0f ).filter( knn ) ); - knnPredicateInWrongPlace( - () -> scope.predicate().knn( 10 ).field( "location" ).matching( 50.0f, 50.0f ).filter( boolKnnInShould ) ); + SearchPredicate inlineKnn = scope.predicate().knn( 15 ).field( "location" ).matching( 50.0f, 50.0f ).toPredicate(); + SearchPredicate namedKnn = scope.predicate().named( "knn-named" ) + .param( "k", 25 ) + .param( "vector", new float[] { 50.0f, 50.0f } ).toPredicate(); + + for ( SearchPredicate knn : Arrays.asList( inlineKnn, namedKnn ) ) { + // adding knn to non-should bool clauses is not ok, + // we only allow knn as a should clause! + knnPredicateInWrongPlace( () -> scope.predicate().bool().must( knn ) ); + knnPredicateInWrongPlace( () -> scope.predicate().bool().mustNot( knn ) ); + knnPredicateInWrongPlace( () -> scope.predicate().bool().filter( knn ) ); + + knnPredicateInWrongPlace( () -> scope.predicate().bool() + // should is ok + .should( knn ) + // not ok since we already have a knn in should we can only add more should clauses + .must( scope.predicate().match().field( "parking" ).matching( Boolean.TRUE ) ) ); + knnPredicateInWrongPlace( () -> scope.predicate().bool() + // so far so good: + .must( scope.predicate().match().field( "parking" ).matching( Boolean.TRUE ) ) + // cannot add a knn through should as we already have a non-should clause + .should( knn ) ); + knnPredicateInWrongPlace( () -> scope.predicate().bool() + // so far so good: + .mustNot( scope.predicate().match().field( "parking" ).matching( Boolean.TRUE ) ) + // cannot add a knn through should as we already have a non-should clause + .should( knn ) ); + knnPredicateInWrongPlace( () -> scope.predicate().bool() + // so far so good: + .filter( scope.predicate().match().field( "parking" ).matching( Boolean.TRUE ) ) + // cannot add a knn through should as we already have a non-should clause + .should( knn ) ); + + + // we add multiple clauses to prevent "optimizations" leading to bool predicate being replaced by a simple knn predicate + SearchPredicate inlineBoolKnnInShould = scope.predicate().bool().should( knn ) + .should( scope.predicate().match().field( "parking" ).matching( Boolean.TRUE ) ).toPredicate(); + SearchPredicate namedBoolKnnInShould = scope.predicate().named( "bool-knn-in-should-named" ) + .param( "knn", knn ) + .toPredicate(); + + for ( SearchPredicate boolKnnInShould : Arrays.asList( inlineBoolKnnInShould, namedBoolKnnInShould ) ) { + // adding boolean predicate with a should knn clause as any boolean clause (nesting a correct bool into another one) + // is not ok ! + knnPredicateInWrongPlace( () -> scope.predicate().bool().should( boolKnnInShould ) ); + knnPredicateInWrongPlace( () -> scope.predicate().bool().must( boolKnnInShould ) ); + knnPredicateInWrongPlace( () -> scope.predicate().bool().mustNot( boolKnnInShould ) ); + knnPredicateInWrongPlace( () -> scope.predicate().bool().filter( boolKnnInShould ) ); + + // adding as a knn filter: + knnPredicateInWrongPlace( + () -> scope.predicate().knn( 10 ).field( "location" ).matching( 50.0f, 50.0f ).filter( knn ) ); + knnPredicateInWrongPlace( + () -> scope.predicate().knn( 10 ).field( "location" ).matching( 50.0f, 50.0f ) + .filter( boolKnnInShould ) ); + } + + } } @Test @@ -267,6 +303,22 @@ private static class PredicateIndexBinding { nestedRating = nested.field( "nestedRating", f -> f.asInteger().projectable( Projectable.YES ).sortable( Sortable.YES ) ) .toReference(); + + root.namedPredicate( "knn-named", context -> { + int k = context.param( "k", Integer.class ); + float[] vector = context.param( "vector", float[].class ); + + return context.predicate().knn( k ).field( "location" ) + .matching( vector ) + .toPredicate(); + } ); + + root.namedPredicate( "bool-knn-in-should-named", context -> { + SearchPredicate knn = context.param( "knn", SearchPredicate.class ); + + return context.predicate().bool().should( knn ) + .should( context.predicate().match().field( "parking" ).matching( Boolean.TRUE ) ).toPredicate(); + } ); } }