Skip to content

Commit

Permalink
HSEARCH-4950 Update knn restrictions for Elasticsearch
Browse files Browse the repository at this point in the history
  • Loading branch information
marko-bekhta committed Jan 16, 2024
1 parent a159601 commit aeeb486
Show file tree
Hide file tree
Showing 12 changed files with 128 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public abstract class AbstractElasticsearchNestablePredicate extends AbstractEla


@Override
public void doCheckNestableWithin(PredicateNestingContext context) {
public void checkNestableWithin(PredicateNestingContext context) {
List<String> nestedPathHierarchy = getNestedPathHierarchy();
String expectedParentNestedPath = context.getNestedPath();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import com.google.gson.JsonObject;

public abstract class AbstractElasticsearchPredicate extends ElasticsearchSearchPredicate {
public abstract class AbstractElasticsearchPredicate implements ElasticsearchSearchPredicate {

private static final JsonAccessor<Float> BOOST_ACCESSOR = JsonAccessor.root().property( "boost" ).asFloat();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ private ElasticsearchBooleanPredicate(Builder builder) {
}

@Override
public void doCheckNestableWithin(PredicateNestingContext context) {
public void checkNestableWithin(PredicateNestingContext context) {
checkAcceptableWithin( context, mustClauses );
checkAcceptableWithin( context, shouldClauses );
checkAcceptableWithin( context, filterClauses );
Expand Down Expand Up @@ -217,6 +217,8 @@ public void must(SearchPredicate clause) {
ElasticsearchSearchPredicate elasticsearchClause = ElasticsearchSearchPredicate.from( scope, clause );
elasticsearchClause.checkNestableWithin( PredicateNestingContext.doesNotAcceptKnn() );
mustClauses.add( elasticsearchClause );

checkShouldClauses();
}

@Override
Expand All @@ -227,18 +229,32 @@ public void mustNot(SearchPredicate clause) {
ElasticsearchSearchPredicate elasticsearchClause = ElasticsearchSearchPredicate.from( scope, clause );
elasticsearchClause.checkNestableWithin( PredicateNestingContext.doesNotAcceptKnn() );
mustNotClauses.add( elasticsearchClause );

checkShouldClauses();
}

@Override
public void should(SearchPredicate clause) {
if ( shouldClauses == null ) {
shouldClauses = new ArrayList<>();
}

ElasticsearchSearchPredicate elasticsearchClause = ElasticsearchSearchPredicate.from( scope, clause );
elasticsearchClause.checkNestableWithin( PredicateNestingContext.acceptsKnn() );
elasticsearchClause.checkNestableWithin(
!hasNonShouldClause() && maybeKnnClause( clause )
? PredicateNestingContext.acceptsKnn()
: PredicateNestingContext.doesNotAcceptKnn() );
shouldClauses.add( elasticsearchClause );
}

private boolean maybeKnnClause(SearchPredicate clause) {
return clause instanceof ElasticsearchKnnPredicate
|| ( clause instanceof ElasticsearchNamedPredicate
// TODO:
// && clause.providedPredicate instanceof ElasticsearchKnnPredicate
);
}

@Override
public void filter(SearchPredicate clause) {
if ( filterClauses == null ) {
Expand All @@ -247,6 +263,8 @@ public void filter(SearchPredicate clause) {
ElasticsearchSearchPredicate elasticsearchClause = ElasticsearchSearchPredicate.from( scope, clause );
elasticsearchClause.checkNestableWithin( PredicateNestingContext.doesNotAcceptKnn() );
filterClauses.add( elasticsearchClause );

checkShouldClauses();
}

@Override
Expand Down Expand Up @@ -335,6 +353,20 @@ private void optimizeClauseCollection(List<ElasticsearchSearchPredicate> collect
}
}

/*
* For Elasticsearch backend only:
* It may be that we've added knn should clauses and all was fine.
* But now we are adding to such bool predicate a non-should clause.
* We want to make sure that in such case there are no knn clauses in should or fail.
* OpenSearch backend will not be affected by these checks.
*/
private void checkShouldClauses() {
if ( shouldClauses != null ) {
shouldClauses.forEach(
should -> should.checkNestableWithin( PredicateNestingContext.doesNotAcceptKnn() ) );
}
}

private void checkAndClearClauseCollections() {
if ( mustClauses != null && mustClauses.isEmpty() ) {
mustClauses = null;
Expand All @@ -348,6 +380,10 @@ private boolean hasAtLeastOneMustOrFilterPredicate() {
return mustClauses != null || filterClauses != null;
}

private boolean hasNonShouldClause() {
return mustClauses != null || filterClauses != null || mustNotClauses != null;
}

private boolean hasOnlyOneMustClause() {
return mustClauses != null
&& mustClauses.size() == 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ protected JsonObject doToJsonQuery(PredicateRequestContext context, JsonObject o
}

@Override
public void doCheckNestableWithin(PredicateNestingContext context) {
public void checkNestableWithin(PredicateNestingContext context) {
if ( context.getNestedPath() != null || !context.acceptsKnnClause() ) {
throw log.cannotAddKnnClauseAtThisStep();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ private ElasticsearchMatchAllPredicate(Builder builder) {
}

@Override
public void doCheckNestableWithin(PredicateNestingContext context) {
public void checkNestableWithin(PredicateNestingContext context) {
// Nothing to do
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ private ElasticsearchMatchIdPredicate(Builder builder) {
}

@Override
public void doCheckNestableWithin(PredicateNestingContext context) {
public void checkNestableWithin(PredicateNestingContext context) {
// Nothing to do
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class ElasticsearchMatchNonePredicate extends AbstractElasticsearchPredicate {
}

@Override
public void doCheckNestableWithin(PredicateNestingContext context) {
public void checkNestableWithin(PredicateNestingContext context) {
// Nothing to do
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ private ElasticsearchNamedPredicate(Builder builder, ElasticsearchSearchPredicat
}

@Override
public void doCheckNestableWithin(PredicateNestingContext context) {
public void checkNestableWithin(PredicateNestingContext context) {
providedPredicate.checkNestableWithin( context );
super.checkNestableWithin( context );
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,17 @@

import com.google.gson.JsonObject;

public abstract class ElasticsearchSearchPredicate implements SearchPredicate {
public interface ElasticsearchSearchPredicate extends SearchPredicate {

private static final Log log = LoggerFactory.make( Log.class, MethodHandles.lookup() );
Log log = LoggerFactory.make( Log.class, MethodHandles.lookup() );

public abstract Set<String> indexNames();
Set<String> indexNames();

public final void checkNestableWithin(PredicateNestingContext context) {
doCheckNestableWithin( context.wrap( this ) );
}

protected abstract void doCheckNestableWithin(PredicateNestingContext context);
void checkNestableWithin(PredicateNestingContext context);

public abstract JsonObject toJsonQuery(PredicateRequestContext context);
JsonObject toJsonQuery(PredicateRequestContext context);

public static ElasticsearchSearchPredicate from(ElasticsearchSearchIndexScope<?> scope, SearchPredicate predicate) {
static ElasticsearchSearchPredicate from(ElasticsearchSearchIndexScope<?> scope, SearchPredicate predicate) {
if ( !( predicate instanceof ElasticsearchSearchPredicate ) ) {
throw log.cannotMixElasticsearchSearchQueryWithOtherPredicates( predicate );
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import com.google.gson.JsonObject;

class ElasticsearchUserProvidedJsonPredicate extends ElasticsearchSearchPredicate {
class ElasticsearchUserProvidedJsonPredicate implements ElasticsearchSearchPredicate {

private final Set<String> indexNames;
private final JsonObject json;
Expand All @@ -29,7 +29,7 @@ public Set<String> indexNames() {
}

@Override
public void doCheckNestableWithin(PredicateNestingContext context) {
public void checkNestableWithin(PredicateNestingContext context) {
// Nothing to do: we'll assume the user knows what they are doing.
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ public class PredicateNestingContext {
private final String nestedPath;
private final boolean acceptsKnnClause;

private final Class<? extends ElasticsearchSearchPredicate> predicateType;

public static PredicateNestingContext acceptsKnn() {
return ACCEPTS_KNN;
}
Expand All @@ -27,14 +25,8 @@ public static PredicateNestingContext nested(String nestedPath) {
}

private PredicateNestingContext(String nestedPath, boolean acceptsKnnClause) {
this( nestedPath, acceptsKnnClause, null );
}

private PredicateNestingContext(String nestedPath, boolean acceptsKnnClause,
Class<? extends ElasticsearchSearchPredicate> predicateType) {
this.nestedPath = nestedPath;
this.acceptsKnnClause = acceptsKnnClause;
this.predicateType = predicateType;
}

private PredicateNestingContext(String nestedPath) {
Expand All @@ -53,11 +45,4 @@ public boolean acceptsKnnClause() {
return acceptsKnnClause;
}

public PredicateNestingContext wrap(ElasticsearchSearchPredicate elasticsearchSearchPredicate) {
return new PredicateNestingContext(
nestedPath,
acceptsKnnClause && this.predicateType == null,
elasticsearchSearchPredicate.getClass()
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import static org.junit.jupiter.api.Assumptions.assumeFalse;
import static org.junit.jupiter.api.Assumptions.assumeTrue;

import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
Expand Down Expand Up @@ -183,29 +184,64 @@ void knnPredicateInWrongPlace_addingPrebuiltKnn() {
StubMappingScope scope = index.createScope();

// all good:
SearchPredicate knn = scope.predicate().knn( 15 ).field( "location" ).matching( 50.0f, 50.0f ).toPredicate();
// we add multiple clauses to prevent "optimizations" leading to bool predicate being replaced by a simple knn predicate
SearchPredicate boolKnnInShould = scope.predicate().bool().should( knn )
.should( scope.predicate().match().field( "parking" ).matching( Boolean.TRUE ) ).toPredicate();

// adding knn to non-should bool clauses is not ok,
// we only allow knn as a should clause!
knnPredicateInWrongPlace( () -> scope.predicate().bool().must( knn ) );
knnPredicateInWrongPlace( () -> scope.predicate().bool().mustNot( knn ) );
knnPredicateInWrongPlace( () -> scope.predicate().bool().filter( knn ) );

// adding boolean predicate with a should knn clause as any boolean clause (nesting a correct bool into another one)
// is not ok !
knnPredicateInWrongPlace( () -> scope.predicate().bool().should( boolKnnInShould ) );
knnPredicateInWrongPlace( () -> scope.predicate().bool().must( boolKnnInShould ) );
knnPredicateInWrongPlace( () -> scope.predicate().bool().mustNot( boolKnnInShould ) );
knnPredicateInWrongPlace( () -> scope.predicate().bool().filter( boolKnnInShould ) );

// adding as a knn filter:
knnPredicateInWrongPlace(
() -> scope.predicate().knn( 10 ).field( "location" ).matching( 50.0f, 50.0f ).filter( knn ) );
knnPredicateInWrongPlace(
() -> scope.predicate().knn( 10 ).field( "location" ).matching( 50.0f, 50.0f ).filter( boolKnnInShould ) );
SearchPredicate inlineKnn = scope.predicate().knn( 15 ).field( "location" ).matching( 50.0f, 50.0f ).toPredicate();
SearchPredicate namedKnn = scope.predicate().named( "knn-named" )
.param( "k", 25 )
.param( "vector", new float[] { 50.0f, 50.0f } ).toPredicate();

for ( SearchPredicate knn : Arrays.asList( inlineKnn, namedKnn ) ) {
// adding knn to non-should bool clauses is not ok,
// we only allow knn as a should clause!
knnPredicateInWrongPlace( () -> scope.predicate().bool().must( knn ) );
knnPredicateInWrongPlace( () -> scope.predicate().bool().mustNot( knn ) );
knnPredicateInWrongPlace( () -> scope.predicate().bool().filter( knn ) );

knnPredicateInWrongPlace( () -> scope.predicate().bool()
// should is ok
.should( knn )
// not ok since we already have a knn in should we can only add more should clauses
.must( scope.predicate().match().field( "parking" ).matching( Boolean.TRUE ) ) );
knnPredicateInWrongPlace( () -> scope.predicate().bool()
// so far so good:
.must( scope.predicate().match().field( "parking" ).matching( Boolean.TRUE ) )
// cannot add a knn through should as we already have a non-should clause
.should( knn ) );
knnPredicateInWrongPlace( () -> scope.predicate().bool()
// so far so good:
.mustNot( scope.predicate().match().field( "parking" ).matching( Boolean.TRUE ) )
// cannot add a knn through should as we already have a non-should clause
.should( knn ) );
knnPredicateInWrongPlace( () -> scope.predicate().bool()
// so far so good:
.filter( scope.predicate().match().field( "parking" ).matching( Boolean.TRUE ) )
// cannot add a knn through should as we already have a non-should clause
.should( knn ) );


// we add multiple clauses to prevent "optimizations" leading to bool predicate being replaced by a simple knn predicate
SearchPredicate inlineBoolKnnInShould = scope.predicate().bool().should( knn )
.should( scope.predicate().match().field( "parking" ).matching( Boolean.TRUE ) ).toPredicate();
SearchPredicate namedBoolKnnInShould = scope.predicate().named( "bool-knn-in-should-named" )
.param( "knn", knn )
.toPredicate();

for ( SearchPredicate boolKnnInShould : Arrays.asList( inlineBoolKnnInShould, namedBoolKnnInShould ) ) {
// adding boolean predicate with a should knn clause as any boolean clause (nesting a correct bool into another one)
// is not ok !
knnPredicateInWrongPlace( () -> scope.predicate().bool().should( boolKnnInShould ) );
knnPredicateInWrongPlace( () -> scope.predicate().bool().must( boolKnnInShould ) );
knnPredicateInWrongPlace( () -> scope.predicate().bool().mustNot( boolKnnInShould ) );
knnPredicateInWrongPlace( () -> scope.predicate().bool().filter( boolKnnInShould ) );

// adding as a knn filter:
knnPredicateInWrongPlace(
() -> scope.predicate().knn( 10 ).field( "location" ).matching( 50.0f, 50.0f ).filter( knn ) );
knnPredicateInWrongPlace(
() -> scope.predicate().knn( 10 ).field( "location" ).matching( 50.0f, 50.0f )
.filter( boolKnnInShould ) );
}

}
}

@Test
Expand Down Expand Up @@ -267,6 +303,22 @@ private static class PredicateIndexBinding {
nestedRating =
nested.field( "nestedRating", f -> f.asInteger().projectable( Projectable.YES ).sortable( Sortable.YES ) )
.toReference();

root.namedPredicate( "knn-named", context -> {
int k = context.param( "k", Integer.class );
float[] vector = context.param( "vector", float[].class );

return context.predicate().knn( k ).field( "location" )
.matching( vector )
.toPredicate();
} );

root.namedPredicate( "bool-knn-in-should-named", context -> {
SearchPredicate knn = context.param( "knn", SearchPredicate.class );

return context.predicate().bool().should( knn )
.should( context.predicate().match().field( "parking" ).matching( Boolean.TRUE ) ).toPredicate();
} );
}

}
Expand Down

0 comments on commit aeeb486

Please sign in to comment.