Skip to content

Commit aeeb486

Browse files
committed
HSEARCH-4950 Update knn restrictions for Elasticsearch
1 parent a159601 commit aeeb486

File tree

12 files changed

+128
-59
lines changed

12 files changed

+128
-59
lines changed

backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/predicate/impl/AbstractElasticsearchNestablePredicate.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ public abstract class AbstractElasticsearchNestablePredicate extends AbstractEla
2525

2626

2727
@Override
28-
public void doCheckNestableWithin(PredicateNestingContext context) {
28+
public void checkNestableWithin(PredicateNestingContext context) {
2929
List<String> nestedPathHierarchy = getNestedPathHierarchy();
3030
String expectedParentNestedPath = context.getNestedPath();
3131

backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/predicate/impl/AbstractElasticsearchPredicate.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import com.google.gson.JsonObject;
1616

17-
public abstract class AbstractElasticsearchPredicate extends ElasticsearchSearchPredicate {
17+
public abstract class AbstractElasticsearchPredicate implements ElasticsearchSearchPredicate {
1818

1919
private static final JsonAccessor<Float> BOOST_ACCESSOR = JsonAccessor.root().property( "boost" ).asFloat();
2020

backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/predicate/impl/ElasticsearchBooleanPredicate.java

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ private ElasticsearchBooleanPredicate(Builder builder) {
6464
}
6565

6666
@Override
67-
public void doCheckNestableWithin(PredicateNestingContext context) {
67+
public void checkNestableWithin(PredicateNestingContext context) {
6868
checkAcceptableWithin( context, mustClauses );
6969
checkAcceptableWithin( context, shouldClauses );
7070
checkAcceptableWithin( context, filterClauses );
@@ -217,6 +217,8 @@ public void must(SearchPredicate clause) {
217217
ElasticsearchSearchPredicate elasticsearchClause = ElasticsearchSearchPredicate.from( scope, clause );
218218
elasticsearchClause.checkNestableWithin( PredicateNestingContext.doesNotAcceptKnn() );
219219
mustClauses.add( elasticsearchClause );
220+
221+
checkShouldClauses();
220222
}
221223

222224
@Override
@@ -227,18 +229,32 @@ public void mustNot(SearchPredicate clause) {
227229
ElasticsearchSearchPredicate elasticsearchClause = ElasticsearchSearchPredicate.from( scope, clause );
228230
elasticsearchClause.checkNestableWithin( PredicateNestingContext.doesNotAcceptKnn() );
229231
mustNotClauses.add( elasticsearchClause );
232+
233+
checkShouldClauses();
230234
}
231235

232236
@Override
233237
public void should(SearchPredicate clause) {
234238
if ( shouldClauses == null ) {
235239
shouldClauses = new ArrayList<>();
236240
}
241+
237242
ElasticsearchSearchPredicate elasticsearchClause = ElasticsearchSearchPredicate.from( scope, clause );
238-
elasticsearchClause.checkNestableWithin( PredicateNestingContext.acceptsKnn() );
243+
elasticsearchClause.checkNestableWithin(
244+
!hasNonShouldClause() && maybeKnnClause( clause )
245+
? PredicateNestingContext.acceptsKnn()
246+
: PredicateNestingContext.doesNotAcceptKnn() );
239247
shouldClauses.add( elasticsearchClause );
240248
}
241249

250+
private boolean maybeKnnClause(SearchPredicate clause) {
251+
return clause instanceof ElasticsearchKnnPredicate
252+
|| ( clause instanceof ElasticsearchNamedPredicate
253+
// TODO:
254+
// && clause.providedPredicate instanceof ElasticsearchKnnPredicate
255+
);
256+
}
257+
242258
@Override
243259
public void filter(SearchPredicate clause) {
244260
if ( filterClauses == null ) {
@@ -247,6 +263,8 @@ public void filter(SearchPredicate clause) {
247263
ElasticsearchSearchPredicate elasticsearchClause = ElasticsearchSearchPredicate.from( scope, clause );
248264
elasticsearchClause.checkNestableWithin( PredicateNestingContext.doesNotAcceptKnn() );
249265
filterClauses.add( elasticsearchClause );
266+
267+
checkShouldClauses();
250268
}
251269

252270
@Override
@@ -335,6 +353,20 @@ private void optimizeClauseCollection(List<ElasticsearchSearchPredicate> collect
335353
}
336354
}
337355

356+
/*
357+
* For Elasticsearch backend only:
358+
* It may be that we've added knn should clauses and all was fine.
359+
* But now we are adding to such bool predicate a non-should clause.
360+
* We want to make sure that in such case there are no knn clauses in should or fail.
361+
* OpenSearch backend will not be affected by these checks.
362+
*/
363+
private void checkShouldClauses() {
364+
if ( shouldClauses != null ) {
365+
shouldClauses.forEach(
366+
should -> should.checkNestableWithin( PredicateNestingContext.doesNotAcceptKnn() ) );
367+
}
368+
}
369+
338370
private void checkAndClearClauseCollections() {
339371
if ( mustClauses != null && mustClauses.isEmpty() ) {
340372
mustClauses = null;
@@ -348,6 +380,10 @@ private boolean hasAtLeastOneMustOrFilterPredicate() {
348380
return mustClauses != null || filterClauses != null;
349381
}
350382

383+
private boolean hasNonShouldClause() {
384+
return mustClauses != null || filterClauses != null || mustNotClauses != null;
385+
}
386+
351387
private boolean hasOnlyOneMustClause() {
352388
return mustClauses != null
353389
&& mustClauses.size() == 1

backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/predicate/impl/ElasticsearchKnnPredicate.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ protected JsonObject doToJsonQuery(PredicateRequestContext context, JsonObject o
183183
}
184184

185185
@Override
186-
public void doCheckNestableWithin(PredicateNestingContext context) {
186+
public void checkNestableWithin(PredicateNestingContext context) {
187187
if ( context.getNestedPath() != null || !context.acceptsKnnClause() ) {
188188
throw log.cannotAddKnnClauseAtThisStep();
189189
}

backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/predicate/impl/ElasticsearchMatchAllPredicate.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ private ElasticsearchMatchAllPredicate(Builder builder) {
2323
}
2424

2525
@Override
26-
public void doCheckNestableWithin(PredicateNestingContext context) {
26+
public void checkNestableWithin(PredicateNestingContext context) {
2727
// Nothing to do
2828
}
2929

backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/predicate/impl/ElasticsearchMatchIdPredicate.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ private ElasticsearchMatchIdPredicate(Builder builder) {
5959
}
6060

6161
@Override
62-
public void doCheckNestableWithin(PredicateNestingContext context) {
62+
public void checkNestableWithin(PredicateNestingContext context) {
6363
// Nothing to do
6464
}
6565

backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/predicate/impl/ElasticsearchMatchNonePredicate.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class ElasticsearchMatchNonePredicate extends AbstractElasticsearchPredicate {
2424
}
2525

2626
@Override
27-
public void doCheckNestableWithin(PredicateNestingContext context) {
27+
public void checkNestableWithin(PredicateNestingContext context) {
2828
// Nothing to do
2929
}
3030

backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/predicate/impl/ElasticsearchNamedPredicate.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ private ElasticsearchNamedPredicate(Builder builder, ElasticsearchSearchPredicat
3838
}
3939

4040
@Override
41-
public void doCheckNestableWithin(PredicateNestingContext context) {
41+
public void checkNestableWithin(PredicateNestingContext context) {
4242
providedPredicate.checkNestableWithin( context );
4343
super.checkNestableWithin( context );
4444
}

backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/predicate/impl/ElasticsearchSearchPredicate.java

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,17 @@
1616

1717
import com.google.gson.JsonObject;
1818

19-
public abstract class ElasticsearchSearchPredicate implements SearchPredicate {
19+
public interface ElasticsearchSearchPredicate extends SearchPredicate {
2020

21-
private static final Log log = LoggerFactory.make( Log.class, MethodHandles.lookup() );
21+
Log log = LoggerFactory.make( Log.class, MethodHandles.lookup() );
2222

23-
public abstract Set<String> indexNames();
23+
Set<String> indexNames();
2424

25-
public final void checkNestableWithin(PredicateNestingContext context) {
26-
doCheckNestableWithin( context.wrap( this ) );
27-
}
28-
29-
protected abstract void doCheckNestableWithin(PredicateNestingContext context);
25+
void checkNestableWithin(PredicateNestingContext context);
3026

31-
public abstract JsonObject toJsonQuery(PredicateRequestContext context);
27+
JsonObject toJsonQuery(PredicateRequestContext context);
3228

33-
public static ElasticsearchSearchPredicate from(ElasticsearchSearchIndexScope<?> scope, SearchPredicate predicate) {
29+
static ElasticsearchSearchPredicate from(ElasticsearchSearchIndexScope<?> scope, SearchPredicate predicate) {
3430
if ( !( predicate instanceof ElasticsearchSearchPredicate ) ) {
3531
throw log.cannotMixElasticsearchSearchQueryWithOtherPredicates( predicate );
3632
}

backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/predicate/impl/ElasticsearchUserProvidedJsonPredicate.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import com.google.gson.JsonObject;
1414

15-
class ElasticsearchUserProvidedJsonPredicate extends ElasticsearchSearchPredicate {
15+
class ElasticsearchUserProvidedJsonPredicate implements ElasticsearchSearchPredicate {
1616

1717
private final Set<String> indexNames;
1818
private final JsonObject json;
@@ -29,7 +29,7 @@ public Set<String> indexNames() {
2929
}
3030

3131
@Override
32-
public void doCheckNestableWithin(PredicateNestingContext context) {
32+
public void checkNestableWithin(PredicateNestingContext context) {
3333
// Nothing to do: we'll assume the user knows what they are doing.
3434
}
3535

backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/predicate/impl/PredicateNestingContext.java

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@ public class PredicateNestingContext {
1212
private final String nestedPath;
1313
private final boolean acceptsKnnClause;
1414

15-
private final Class<? extends ElasticsearchSearchPredicate> predicateType;
16-
1715
public static PredicateNestingContext acceptsKnn() {
1816
return ACCEPTS_KNN;
1917
}
@@ -27,14 +25,8 @@ public static PredicateNestingContext nested(String nestedPath) {
2725
}
2826

2927
private PredicateNestingContext(String nestedPath, boolean acceptsKnnClause) {
30-
this( nestedPath, acceptsKnnClause, null );
31-
}
32-
33-
private PredicateNestingContext(String nestedPath, boolean acceptsKnnClause,
34-
Class<? extends ElasticsearchSearchPredicate> predicateType) {
3528
this.nestedPath = nestedPath;
3629
this.acceptsKnnClause = acceptsKnnClause;
37-
this.predicateType = predicateType;
3830
}
3931

4032
private PredicateNestingContext(String nestedPath) {
@@ -53,11 +45,4 @@ public boolean acceptsKnnClause() {
5345
return acceptsKnnClause;
5446
}
5547

56-
public PredicateNestingContext wrap(ElasticsearchSearchPredicate elasticsearchSearchPredicate) {
57-
return new PredicateNestingContext(
58-
nestedPath,
59-
acceptsKnnClause && this.predicateType == null,
60-
elasticsearchSearchPredicate.getClass()
61-
);
62-
}
6348
}

integrationtest/backend/elasticsearch/src/test/java/org/hibernate/search/integrationtest/backend/elasticsearch/search/ElasticsearchKnnPredicateSpecificsIT.java

Lines changed: 75 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import static org.junit.jupiter.api.Assumptions.assumeFalse;
1212
import static org.junit.jupiter.api.Assumptions.assumeTrue;
1313

14+
import java.util.Arrays;
1415
import java.util.List;
1516
import java.util.Map;
1617
import java.util.function.Consumer;
@@ -183,29 +184,64 @@ void knnPredicateInWrongPlace_addingPrebuiltKnn() {
183184
StubMappingScope scope = index.createScope();
184185

185186
// all good:
186-
SearchPredicate knn = scope.predicate().knn( 15 ).field( "location" ).matching( 50.0f, 50.0f ).toPredicate();
187-
// we add multiple clauses to prevent "optimizations" leading to bool predicate being replaced by a simple knn predicate
188-
SearchPredicate boolKnnInShould = scope.predicate().bool().should( knn )
189-
.should( scope.predicate().match().field( "parking" ).matching( Boolean.TRUE ) ).toPredicate();
190-
191-
// adding knn to non-should bool clauses is not ok,
192-
// we only allow knn as a should clause!
193-
knnPredicateInWrongPlace( () -> scope.predicate().bool().must( knn ) );
194-
knnPredicateInWrongPlace( () -> scope.predicate().bool().mustNot( knn ) );
195-
knnPredicateInWrongPlace( () -> scope.predicate().bool().filter( knn ) );
196-
197-
// adding boolean predicate with a should knn clause as any boolean clause (nesting a correct bool into another one)
198-
// is not ok !
199-
knnPredicateInWrongPlace( () -> scope.predicate().bool().should( boolKnnInShould ) );
200-
knnPredicateInWrongPlace( () -> scope.predicate().bool().must( boolKnnInShould ) );
201-
knnPredicateInWrongPlace( () -> scope.predicate().bool().mustNot( boolKnnInShould ) );
202-
knnPredicateInWrongPlace( () -> scope.predicate().bool().filter( boolKnnInShould ) );
203-
204-
// adding as a knn filter:
205-
knnPredicateInWrongPlace(
206-
() -> scope.predicate().knn( 10 ).field( "location" ).matching( 50.0f, 50.0f ).filter( knn ) );
207-
knnPredicateInWrongPlace(
208-
() -> scope.predicate().knn( 10 ).field( "location" ).matching( 50.0f, 50.0f ).filter( boolKnnInShould ) );
187+
SearchPredicate inlineKnn = scope.predicate().knn( 15 ).field( "location" ).matching( 50.0f, 50.0f ).toPredicate();
188+
SearchPredicate namedKnn = scope.predicate().named( "knn-named" )
189+
.param( "k", 25 )
190+
.param( "vector", new float[] { 50.0f, 50.0f } ).toPredicate();
191+
192+
for ( SearchPredicate knn : Arrays.asList( inlineKnn, namedKnn ) ) {
193+
// adding knn to non-should bool clauses is not ok,
194+
// we only allow knn as a should clause!
195+
knnPredicateInWrongPlace( () -> scope.predicate().bool().must( knn ) );
196+
knnPredicateInWrongPlace( () -> scope.predicate().bool().mustNot( knn ) );
197+
knnPredicateInWrongPlace( () -> scope.predicate().bool().filter( knn ) );
198+
199+
knnPredicateInWrongPlace( () -> scope.predicate().bool()
200+
// should is ok
201+
.should( knn )
202+
// not ok since we already have a knn in should we can only add more should clauses
203+
.must( scope.predicate().match().field( "parking" ).matching( Boolean.TRUE ) ) );
204+
knnPredicateInWrongPlace( () -> scope.predicate().bool()
205+
// so far so good:
206+
.must( scope.predicate().match().field( "parking" ).matching( Boolean.TRUE ) )
207+
// cannot add a knn through should as we already have a non-should clause
208+
.should( knn ) );
209+
knnPredicateInWrongPlace( () -> scope.predicate().bool()
210+
// so far so good:
211+
.mustNot( scope.predicate().match().field( "parking" ).matching( Boolean.TRUE ) )
212+
// cannot add a knn through should as we already have a non-should clause
213+
.should( knn ) );
214+
knnPredicateInWrongPlace( () -> scope.predicate().bool()
215+
// so far so good:
216+
.filter( scope.predicate().match().field( "parking" ).matching( Boolean.TRUE ) )
217+
// cannot add a knn through should as we already have a non-should clause
218+
.should( knn ) );
219+
220+
221+
// we add multiple clauses to prevent "optimizations" leading to bool predicate being replaced by a simple knn predicate
222+
SearchPredicate inlineBoolKnnInShould = scope.predicate().bool().should( knn )
223+
.should( scope.predicate().match().field( "parking" ).matching( Boolean.TRUE ) ).toPredicate();
224+
SearchPredicate namedBoolKnnInShould = scope.predicate().named( "bool-knn-in-should-named" )
225+
.param( "knn", knn )
226+
.toPredicate();
227+
228+
for ( SearchPredicate boolKnnInShould : Arrays.asList( inlineBoolKnnInShould, namedBoolKnnInShould ) ) {
229+
// adding boolean predicate with a should knn clause as any boolean clause (nesting a correct bool into another one)
230+
// is not ok !
231+
knnPredicateInWrongPlace( () -> scope.predicate().bool().should( boolKnnInShould ) );
232+
knnPredicateInWrongPlace( () -> scope.predicate().bool().must( boolKnnInShould ) );
233+
knnPredicateInWrongPlace( () -> scope.predicate().bool().mustNot( boolKnnInShould ) );
234+
knnPredicateInWrongPlace( () -> scope.predicate().bool().filter( boolKnnInShould ) );
235+
236+
// adding as a knn filter:
237+
knnPredicateInWrongPlace(
238+
() -> scope.predicate().knn( 10 ).field( "location" ).matching( 50.0f, 50.0f ).filter( knn ) );
239+
knnPredicateInWrongPlace(
240+
() -> scope.predicate().knn( 10 ).field( "location" ).matching( 50.0f, 50.0f )
241+
.filter( boolKnnInShould ) );
242+
}
243+
244+
}
209245
}
210246

211247
@Test
@@ -267,6 +303,22 @@ private static class PredicateIndexBinding {
267303
nestedRating =
268304
nested.field( "nestedRating", f -> f.asInteger().projectable( Projectable.YES ).sortable( Sortable.YES ) )
269305
.toReference();
306+
307+
root.namedPredicate( "knn-named", context -> {
308+
int k = context.param( "k", Integer.class );
309+
float[] vector = context.param( "vector", float[].class );
310+
311+
return context.predicate().knn( k ).field( "location" )
312+
.matching( vector )
313+
.toPredicate();
314+
} );
315+
316+
root.namedPredicate( "bool-knn-in-should-named", context -> {
317+
SearchPredicate knn = context.param( "knn", SearchPredicate.class );
318+
319+
return context.predicate().bool().should( knn )
320+
.should( context.predicate().match().field( "parking" ).matching( Boolean.TRUE ) ).toPredicate();
321+
} );
270322
}
271323

272324
}

0 commit comments

Comments
 (0)