Skip to content

Commit

Permalink
HSEARCH-4950 Update knn restrictions for Elasticsearch
Browse files Browse the repository at this point in the history
  • Loading branch information
marko-bekhta committed Jan 17, 2024
1 parent 1ed4729 commit 6f407ad
Show file tree
Hide file tree
Showing 12 changed files with 142 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public abstract class AbstractElasticsearchNestablePredicate extends AbstractEla


@Override
public void doCheckNestableWithin(PredicateNestingContext context) {
public void checkNestableWithin(PredicateNestingContext context) {
List<String> nestedPathHierarchy = getNestedPathHierarchy();
String expectedParentNestedPath = context.getNestedPath();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import com.google.gson.JsonObject;

public abstract class AbstractElasticsearchPredicate extends ElasticsearchSearchPredicate {
public abstract class AbstractElasticsearchPredicate implements ElasticsearchSearchPredicate {

private static final JsonAccessor<Float> BOOST_ACCESSOR = JsonAccessor.root().property( "boost" ).asFloat();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,17 @@ private ElasticsearchBooleanPredicate(Builder builder) {
}

@Override
public void doCheckNestableWithin(PredicateNestingContext context) {
checkAcceptableWithin( context, mustClauses );
checkAcceptableWithin( context, shouldClauses );
checkAcceptableWithin( context, filterClauses );
checkAcceptableWithin( context, mustNotClauses );
public void checkNestableWithin(PredicateNestingContext context) {
// For Elasticsearch backend only:
// If this method is called that means we are trying to pass a bool predicate as a clause/filter/etc to some other predicate
// and that would mean that if our current bool predicate has a knn should clause -- it is not ok to continue,
// since it will place a knn clause deeper than a should clause of a top level bool predicate, which is not acceptable.
// Because of that we are making sure that the context we pass in to check the clauses is updated:
PredicateNestingContext updatedContext = context.rejectKnn();
checkAcceptableWithin( updatedContext, mustClauses );
checkAcceptableWithin( updatedContext, shouldClauses );
checkAcceptableWithin( updatedContext, filterClauses );
checkAcceptableWithin( updatedContext, mustNotClauses );
}

@Override
Expand Down Expand Up @@ -217,6 +223,8 @@ public void must(SearchPredicate clause) {
ElasticsearchSearchPredicate elasticsearchClause = ElasticsearchSearchPredicate.from( scope, clause );
elasticsearchClause.checkNestableWithin( PredicateNestingContext.doesNotAcceptKnn() );
mustClauses.add( elasticsearchClause );

checkShouldClausesForKnnAcceptability();
}

@Override
Expand All @@ -227,15 +235,21 @@ public void mustNot(SearchPredicate clause) {
ElasticsearchSearchPredicate elasticsearchClause = ElasticsearchSearchPredicate.from( scope, clause );
elasticsearchClause.checkNestableWithin( PredicateNestingContext.doesNotAcceptKnn() );
mustNotClauses.add( elasticsearchClause );

checkShouldClausesForKnnAcceptability();
}

@Override
public void should(SearchPredicate clause) {
if ( shouldClauses == null ) {
shouldClauses = new ArrayList<>();
}

ElasticsearchSearchPredicate elasticsearchClause = ElasticsearchSearchPredicate.from( scope, clause );
elasticsearchClause.checkNestableWithin( PredicateNestingContext.acceptsKnn() );
elasticsearchClause.checkNestableWithin(
!hasNonShouldClause()
? PredicateNestingContext.acceptsKnn()
: PredicateNestingContext.doesNotAcceptKnn() );
shouldClauses.add( elasticsearchClause );
}

Expand All @@ -247,6 +261,8 @@ public void filter(SearchPredicate clause) {
ElasticsearchSearchPredicate elasticsearchClause = ElasticsearchSearchPredicate.from( scope, clause );
elasticsearchClause.checkNestableWithin( PredicateNestingContext.doesNotAcceptKnn() );
filterClauses.add( elasticsearchClause );

checkShouldClausesForKnnAcceptability();
}

@Override
Expand Down Expand Up @@ -335,6 +351,20 @@ private void optimizeClauseCollection(List<ElasticsearchSearchPredicate> collect
}
}

/*
* For Elasticsearch backend only:
* It may be that we've added knn should clauses and all was fine.
* But now we are adding to such bool predicate a non-should clause.
* We want to make sure that in such case there are no knn clauses in should or fail.
* OpenSearch backend will not be affected by these checks.
*/
private void checkShouldClausesForKnnAcceptability() {
if ( shouldClauses != null ) {
shouldClauses.forEach(
should -> should.checkNestableWithin( PredicateNestingContext.doesNotAcceptKnn() ) );
}
}

private void checkAndClearClauseCollections() {
if ( mustClauses != null && mustClauses.isEmpty() ) {
mustClauses = null;
Expand All @@ -348,6 +378,10 @@ private boolean hasAtLeastOneMustOrFilterPredicate() {
return mustClauses != null || filterClauses != null;
}

private boolean hasNonShouldClause() {
return mustClauses != null || filterClauses != null || mustNotClauses != null;
}

private boolean hasOnlyOneMustClause() {
return mustClauses != null
&& mustClauses.size() == 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ protected JsonObject doToJsonQuery(PredicateRequestContext context, JsonObject o
}

@Override
public void doCheckNestableWithin(PredicateNestingContext context) {
public void checkNestableWithin(PredicateNestingContext context) {
if ( context.getNestedPath() != null || !context.acceptsKnnClause() ) {
throw log.cannotAddKnnClauseAtThisStep();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ private ElasticsearchMatchAllPredicate(Builder builder) {
}

@Override
public void doCheckNestableWithin(PredicateNestingContext context) {
public void checkNestableWithin(PredicateNestingContext context) {
// Nothing to do
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ private ElasticsearchMatchIdPredicate(Builder builder) {
}

@Override
public void doCheckNestableWithin(PredicateNestingContext context) {
public void checkNestableWithin(PredicateNestingContext context) {
// Nothing to do
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class ElasticsearchMatchNonePredicate extends AbstractElasticsearchPredicate {
}

@Override
public void doCheckNestableWithin(PredicateNestingContext context) {
public void checkNestableWithin(PredicateNestingContext context) {
// Nothing to do
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ private ElasticsearchNamedPredicate(Builder builder, ElasticsearchSearchPredicat
}

@Override
public void doCheckNestableWithin(PredicateNestingContext context) {
public void checkNestableWithin(PredicateNestingContext context) {
providedPredicate.checkNestableWithin( context );
super.checkNestableWithin( context );
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,17 @@

import com.google.gson.JsonObject;

public abstract class ElasticsearchSearchPredicate implements SearchPredicate {
public interface ElasticsearchSearchPredicate extends SearchPredicate {

private static final Log log = LoggerFactory.make( Log.class, MethodHandles.lookup() );
Log log = LoggerFactory.make( Log.class, MethodHandles.lookup() );

public abstract Set<String> indexNames();
Set<String> indexNames();

public final void checkNestableWithin(PredicateNestingContext context) {
doCheckNestableWithin( context.wrap( this ) );
}

protected abstract void doCheckNestableWithin(PredicateNestingContext context);
void checkNestableWithin(PredicateNestingContext context);

public abstract JsonObject toJsonQuery(PredicateRequestContext context);
JsonObject toJsonQuery(PredicateRequestContext context);

public static ElasticsearchSearchPredicate from(ElasticsearchSearchIndexScope<?> scope, SearchPredicate predicate) {
static ElasticsearchSearchPredicate from(ElasticsearchSearchIndexScope<?> scope, SearchPredicate predicate) {
if ( !( predicate instanceof ElasticsearchSearchPredicate ) ) {
throw log.cannotMixElasticsearchSearchQueryWithOtherPredicates( predicate );
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import com.google.gson.JsonObject;

class ElasticsearchUserProvidedJsonPredicate extends ElasticsearchSearchPredicate {
class ElasticsearchUserProvidedJsonPredicate implements ElasticsearchSearchPredicate {

private final Set<String> indexNames;
private final JsonObject json;
Expand All @@ -29,7 +29,7 @@ public Set<String> indexNames() {
}

@Override
public void doCheckNestableWithin(PredicateNestingContext context) {
public void checkNestableWithin(PredicateNestingContext context) {
// Nothing to do: we'll assume the user knows what they are doing.
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ public class PredicateNestingContext {
private final String nestedPath;
private final boolean acceptsKnnClause;

private final Class<? extends ElasticsearchSearchPredicate> predicateType;

public static PredicateNestingContext acceptsKnn() {
return ACCEPTS_KNN;
}
Expand All @@ -27,14 +25,8 @@ public static PredicateNestingContext nested(String nestedPath) {
}

private PredicateNestingContext(String nestedPath, boolean acceptsKnnClause) {
this( nestedPath, acceptsKnnClause, null );
}

private PredicateNestingContext(String nestedPath, boolean acceptsKnnClause,
Class<? extends ElasticsearchSearchPredicate> predicateType) {
this.nestedPath = nestedPath;
this.acceptsKnnClause = acceptsKnnClause;
this.predicateType = predicateType;
}

private PredicateNestingContext(String nestedPath) {
Expand All @@ -53,11 +45,13 @@ public boolean acceptsKnnClause() {
return acceptsKnnClause;
}

public PredicateNestingContext wrap(ElasticsearchSearchPredicate elasticsearchSearchPredicate) {
return new PredicateNestingContext(
nestedPath,
acceptsKnnClause && this.predicateType == null,
elasticsearchSearchPredicate.getClass()
);
public PredicateNestingContext rejectKnn() {
if ( !acceptsKnnClause ) {
return this;
}
if ( nestedPath == null ) {
return DOES_NOT_ACCEPT_KNN;
}
return new PredicateNestingContext( nestedPath, false );
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import static org.junit.jupiter.api.Assumptions.assumeFalse;
import static org.junit.jupiter.api.Assumptions.assumeTrue;

import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
Expand Down Expand Up @@ -183,29 +184,68 @@ void knnPredicateInWrongPlace_addingPrebuiltKnn() {
StubMappingScope scope = index.createScope();

// all good:
SearchPredicate knn = scope.predicate().knn( 15 ).field( "location" ).matching( 50.0f, 50.0f ).toPredicate();
// we add multiple clauses to prevent "optimizations" leading to bool predicate being replaced by a simple knn predicate
SearchPredicate boolKnnInShould = scope.predicate().bool().should( knn )
.should( scope.predicate().match().field( "parking" ).matching( Boolean.TRUE ) ).toPredicate();

// adding knn to non-should bool clauses is not ok,
// we only allow knn as a should clause!
knnPredicateInWrongPlace( () -> scope.predicate().bool().must( knn ) );
knnPredicateInWrongPlace( () -> scope.predicate().bool().mustNot( knn ) );
knnPredicateInWrongPlace( () -> scope.predicate().bool().filter( knn ) );

// adding boolean predicate with a should knn clause as any boolean clause (nesting a correct bool into another one)
// is not ok !
knnPredicateInWrongPlace( () -> scope.predicate().bool().should( boolKnnInShould ) );
knnPredicateInWrongPlace( () -> scope.predicate().bool().must( boolKnnInShould ) );
knnPredicateInWrongPlace( () -> scope.predicate().bool().mustNot( boolKnnInShould ) );
knnPredicateInWrongPlace( () -> scope.predicate().bool().filter( boolKnnInShould ) );

// adding as a knn filter:
knnPredicateInWrongPlace(
() -> scope.predicate().knn( 10 ).field( "location" ).matching( 50.0f, 50.0f ).filter( knn ) );
knnPredicateInWrongPlace(
() -> scope.predicate().knn( 10 ).field( "location" ).matching( 50.0f, 50.0f ).filter( boolKnnInShould ) );
SearchPredicate inlineKnn = scope.predicate().knn( 15 ).field( "location" ).matching( 50.0f, 50.0f ).toPredicate();
SearchPredicate namedKnn = scope.predicate().named( "knn-named" )
.param( "k", 25 )
.param( "vector", new float[] { 50.0f, 50.0f } ).toPredicate();

for ( SearchPredicate knn : Arrays.asList( inlineKnn, namedKnn ) ) {
// adding knn to non-should bool clauses is not ok,
// we only allow knn as a should clause!
knnPredicateInWrongPlace( () -> scope.predicate().bool().must( knn ) );
knnPredicateInWrongPlace( () -> scope.predicate().bool().mustNot( knn ) );
knnPredicateInWrongPlace( () -> scope.predicate().bool().filter( knn ) );

knnPredicateInWrongPlace( () -> scope.predicate().bool()
// should is ok
.should( knn )
// not ok since we already have a knn in should we can only add more should clauses
.must( scope.predicate().match().field( "parking" ).matching( Boolean.TRUE ) ) );
knnPredicateInWrongPlace( () -> scope.predicate().bool()
// so far so good:
.must( scope.predicate().match().field( "parking" ).matching( Boolean.TRUE ) )
// cannot add a knn through should as we already have a non-should clause
.should( knn ) );
knnPredicateInWrongPlace( () -> scope.predicate().bool()
// so far so good:
.mustNot( scope.predicate().match().field( "parking" ).matching( Boolean.TRUE ) )
// cannot add a knn through should as we already have a non-should clause
.should( knn ) );
knnPredicateInWrongPlace( () -> scope.predicate().bool()
// so far so good:
.filter( scope.predicate().match().field( "parking" ).matching( Boolean.TRUE ) )
// cannot add a knn through should as we already have a non-should clause
.should( knn ) );

// nested knn predicates are not allowed
knnPredicateInWrongPlace( () -> scope.predicate().nested( "object" ).add( knn ) );

// we add multiple clauses to prevent "optimizations" leading to bool predicate being replaced by a simple knn predicate
SearchPredicate inlineBoolKnnInShould = scope.predicate().bool().should( knn )
.should( scope.predicate().match().field( "parking" ).matching( Boolean.TRUE ) ).toPredicate();
SearchPredicate namedBoolKnnInShould = scope.predicate().named( "bool-knn-in-should-named" )
.param( "knn", knn )
.toPredicate();

for ( SearchPredicate boolKnnInShould : Arrays.asList( inlineBoolKnnInShould, namedBoolKnnInShould ) ) {
// adding boolean predicate with a should knn clause as any boolean clause (nesting a correct bool into another one)
// is not ok !
knnPredicateInWrongPlace( () -> scope.predicate().bool().should( boolKnnInShould ) );
knnPredicateInWrongPlace( () -> scope.predicate().bool().must( boolKnnInShould ) );
knnPredicateInWrongPlace( () -> scope.predicate().bool().mustNot( boolKnnInShould ) );
knnPredicateInWrongPlace( () -> scope.predicate().bool().filter( boolKnnInShould ) );

// adding as a knn filter:
knnPredicateInWrongPlace(
() -> scope.predicate().knn( 10 ).field( "location" ).matching( 50.0f, 50.0f ).filter( knn ) );
knnPredicateInWrongPlace(
() -> scope.predicate().knn( 10 ).field( "location" ).matching( 50.0f, 50.0f )
.filter( boolKnnInShould ) );

// nested knn predicates are not allowed
knnPredicateInWrongPlace( () -> scope.predicate().nested( "object" ).add( boolKnnInShould ) );
}
}
}

@Test
Expand Down Expand Up @@ -267,6 +307,22 @@ private static class PredicateIndexBinding {
nestedRating =
nested.field( "nestedRating", f -> f.asInteger().projectable( Projectable.YES ).sortable( Sortable.YES ) )
.toReference();

root.namedPredicate( "knn-named", context -> {
int k = context.param( "k", Integer.class );
float[] vector = context.param( "vector", float[].class );

return context.predicate().knn( k ).field( "location" )
.matching( vector )
.toPredicate();
} );

root.namedPredicate( "bool-knn-in-should-named", context -> {
SearchPredicate knn = context.param( "knn", SearchPredicate.class );

return context.predicate().bool().should( knn )
.should( context.predicate().match().field( "parking" ).matching( Boolean.TRUE ) ).toPredicate();
} );
}

}
Expand Down

0 comments on commit 6f407ad

Please sign in to comment.