diff --git a/CHANGELOG.md b/CHANGELOG.md index 7839d9825..b2ab6b446 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Fix use-after-free case on nmslib search path [#1305](https://github.com/opensearch-project/k-NN/pull/1305) * Allow nested knn field mapping when train model [#1318](https://github.com/opensearch-project/k-NN/pull/1318) * Properly designate model state for actively training models when nodes crash or leave cluster [#1317](https://github.com/opensearch-project/k-NN/pull/1317) +* Fix script score queries not getting cached [#1367](https://github.com/opensearch-project/k-NN/pull/1367) ### Infrastructure * Upgrade gradle to 8.4 [1289](https://github.com/opensearch-project/k-NN/pull/1289) * Refactor security testing to install from individual components [#1307](https://github.com/opensearch-project/k-NN/pull/1307) diff --git a/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/ScriptScoringIT.java b/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/ScriptScoringIT.java index 2ee1a32f9..c4fd05c03 100644 --- a/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/ScriptScoringIT.java +++ b/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/ScriptScoringIT.java @@ -15,6 +15,7 @@ import org.opensearch.knn.index.SpaceType; import org.opensearch.core.rest.RestStatus; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -107,7 +108,7 @@ private void validateKNNInnerProductScriptScoreSearch(String testIndex, String t params.put(QUERY_VALUE, queryVector); params.put(METHOD_PARAMETER_SPACE_TYPE, SpaceType.INNER_PRODUCT.getValue()); - Request request = constructKNNScriptQueryRequest(testIndex, qb, params, k); + Request request = constructKNNScriptQueryRequest(testIndex, qb, params, k, Collections.emptyMap()); Response response = client().performRequest(request); assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); diff --git a/src/main/java/org/opensearch/knn/plugin/script/KNNScoreScriptFactory.java b/src/main/java/org/opensearch/knn/plugin/script/KNNScoreScriptFactory.java index 63b367b2d..4f6a1a6c4 100644 --- a/src/main/java/org/opensearch/knn/plugin/script/KNNScoreScriptFactory.java +++ b/src/main/java/org/opensearch/knn/plugin/script/KNNScoreScriptFactory.java @@ -6,64 +6,21 @@ package org.opensearch.knn.plugin.script; import org.apache.lucene.search.IndexSearcher; -import org.opensearch.knn.plugin.stats.KNNCounter; -import org.apache.lucene.index.LeafReaderContext; import org.opensearch.script.ScoreScript; +import org.opensearch.script.ScriptFactory; import org.opensearch.search.lookup.SearchLookup; -import java.io.IOException; import java.util.Map; -public class KNNScoreScriptFactory implements ScoreScript.LeafFactory { - private final Map params; - private final SearchLookup lookup; - private String similaritySpace; - private String field; - private Object query; - private KNNScoringSpace knnScoringSpace; - - private IndexSearcher searcher; - - public KNNScoreScriptFactory(Map params, SearchLookup lookup, IndexSearcher searcher) { - KNNCounter.SCRIPT_QUERY_REQUESTS.increment(); - this.params = params; - this.lookup = lookup; - this.field = getValue(params, "field").toString(); - this.similaritySpace = getValue(params, "space_type").toString(); - this.query = getValue(params, "query_value"); - this.searcher = searcher; - - this.knnScoringSpace = KNNScoringSpaceFactory.create( - this.similaritySpace, - this.query, - lookup.doc().mapperService().fieldType(this.field) - ); - } - - private Object getValue(Map params, String fieldName) { - final Object value = params.get(fieldName); - if (value != null) return value; - - KNNCounter.SCRIPT_QUERY_ERRORS.increment(); - throw new IllegalArgumentException("Missing parameter [" + fieldName + "]"); - } - +public class KNNScoreScriptFactory implements ScoreScript.Factory, ScriptFactory { @Override - public boolean needs_score() { - return false; + public boolean isResultDeterministic() { + // This implies the results are cacheable + return true; } - /** - * For each segment, supply the KNNScoreScript that should be used to re-score the documents returned from the - * query. Because the method to score the documents was set during factory construction, the scripts are agnostic of - * the similarity space. The KNNScoringSpace will return the correct script, given the query, the field type, and - * the similarity space. - * - * @param ctx LeafReaderContext for the segment - * @return ScoreScript to be executed - */ @Override - public ScoreScript newInstance(LeafReaderContext ctx) throws IOException { - return knnScoringSpace.getScoreScript(params, field, lookup, ctx, this.searcher); + public ScoreScript.LeafFactory newFactory(Map params, SearchLookup lookup, IndexSearcher indexSearcher) { + return new KNNScoreScriptLeafFactory(params, lookup, indexSearcher); } } diff --git a/src/main/java/org/opensearch/knn/plugin/script/KNNScoreScriptLeafFactory.java b/src/main/java/org/opensearch/knn/plugin/script/KNNScoreScriptLeafFactory.java new file mode 100644 index 000000000..1caca0d4b --- /dev/null +++ b/src/main/java/org/opensearch/knn/plugin/script/KNNScoreScriptLeafFactory.java @@ -0,0 +1,73 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.plugin.script; + +import java.io.IOException; +import java.util.Locale; +import java.util.Map; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.IndexSearcher; +import org.opensearch.knn.plugin.stats.KNNCounter; +import org.opensearch.script.ScoreScript; +import org.opensearch.search.lookup.SearchLookup; + +/* + * A factory that creates KNNScoreScriptLeafFactory objects. The factory is responsible for parsing the parameters + * passed in the query and creating the KNNScoreScriptLeafFactory object. + */ +public class KNNScoreScriptLeafFactory implements ScoreScript.LeafFactory { + private final Map params; + private final SearchLookup lookup; + private final String similaritySpace; + private final String field; + private final Object query; + private final KNNScoringSpace knnScoringSpace; + private final IndexSearcher searcher; + + public KNNScoreScriptLeafFactory(Map params, SearchLookup lookup, IndexSearcher searcher) { + KNNCounter.SCRIPT_QUERY_REQUESTS.increment(); + this.params = params; + this.lookup = lookup; + this.field = getValue(params, "field").toString(); + this.similaritySpace = getValue(params, "space_type").toString(); + this.query = getValue(params, "query_value"); + this.searcher = searcher; + + this.knnScoringSpace = KNNScoringSpaceFactory.create( + this.similaritySpace, + this.query, + lookup.doc().mapperService().fieldType(this.field) + ); + } + + private Object getValue(Map params, String fieldName) { + final Object value = params.get(fieldName); + if (value != null) return value; + + KNNCounter.SCRIPT_QUERY_ERRORS.increment(); + throw new IllegalArgumentException(String.format(Locale.ROOT, "Missing parameter [%s]", fieldName)); + } + + @Override + public boolean needs_score() { + return false; + } + + /** + * For each segment, supply the KNNScoreScript that should be used to re-score the documents returned from the + * query. Because the method to score the documents was set during factory construction, the scripts are agnostic of + * the similarity space. The KNNScoringSpace will return the correct script, given the query, the field type, and + * the similarity space. + * + * @param ctx LeafReaderContext for the segment + * @return ScoreScript to be executed + */ + @Override + public ScoreScript newInstance(LeafReaderContext ctx) throws IOException { + return knnScoringSpace.getScoreScript(params, field, lookup, ctx, this.searcher); + } +} diff --git a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringScriptEngine.java b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringScriptEngine.java index 42e0e90ec..61b26c760 100644 --- a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringScriptEngine.java +++ b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringScriptEngine.java @@ -39,7 +39,7 @@ public FactoryType compile(String name, String code, ScriptContext KNNCounter.SCRIPT_COMPILATION_ERRORS.increment(); throw new IllegalArgumentException("Unknown script name " + code); } - ScoreScript.Factory factory = KNNScoreScriptFactory::new; + ScoreScript.Factory factory = new KNNScoreScriptFactory(); return context.factoryClazz.cast(factory); } diff --git a/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java b/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java index 2c4f150c4..2d3f53580 100644 --- a/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java +++ b/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java @@ -350,7 +350,8 @@ public void testL2PainlessScriptingWithByteVectorDataType() throws Exception { Collections.emptyMap(), Script.DEFAULT_SCRIPT_LANG, source, - 4 + 4, + Collections.emptyMap() ); Response response = client().performRequest(request); @@ -370,7 +371,8 @@ public void testL2PainlessScriptingWithFloatVectorDataType() throws Exception { Collections.emptyMap(), Script.DEFAULT_SCRIPT_LANG, source, - 4 + 4, + Collections.emptyMap() ); Response response = client().performRequest(request); diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScriptScoringIT.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScriptScoringIT.java index 1ee7875c6..59c4f8c0e 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScriptScoringIT.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScriptScoringIT.java @@ -25,9 +25,11 @@ import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.stream.Collectors; import static org.hamcrest.Matchers.containsString; @@ -451,7 +453,7 @@ public void testHammingScriptScore_Long() throws Exception { params1.put("field", FIELD_NAME); params1.put("query_value", queryValue1); params1.put("space_type", SpaceType.HAMMING_BIT.getValue()); - Request request1 = constructKNNScriptQueryRequest(INDEX_NAME, qb1, params1, 4); + Request request1 = constructKNNScriptQueryRequest(INDEX_NAME, qb1, params1, 4, Collections.emptyMap()); Response response1 = client().performRequest(request1); assertEquals(request1.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response1.getStatusLine().getStatusCode())); @@ -491,7 +493,7 @@ public void testHammingScriptScore_Long() throws Exception { params2.put("field", FIELD_NAME); params2.put("query_value", queryValue2); params2.put("space_type", SpaceType.HAMMING_BIT.getValue()); - Request request2 = constructKNNScriptQueryRequest(INDEX_NAME, qb2, params2, 4); + Request request2 = constructKNNScriptQueryRequest(INDEX_NAME, qb2, params2, 4, Collections.emptyMap()); Response response2 = client().performRequest(request2); assertEquals(request2.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response2.getStatusLine().getStatusCode())); @@ -561,7 +563,7 @@ public void testHammingScriptScore_Base64() throws Exception { params1.put("field", FIELD_NAME); params1.put("query_value", queryValue1); params1.put("space_type", SpaceType.HAMMING_BIT.getValue()); - Request request1 = constructKNNScriptQueryRequest(INDEX_NAME, qb1, params1, 4); + Request request1 = constructKNNScriptQueryRequest(INDEX_NAME, qb1, params1, 4, Collections.emptyMap()); Response response1 = client().performRequest(request1); assertEquals(request1.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response1.getStatusLine().getStatusCode())); @@ -601,7 +603,7 @@ public void testHammingScriptScore_Base64() throws Exception { params2.put("field", FIELD_NAME); params2.put("query_value", queryValue2); params2.put("space_type", SpaceType.HAMMING_BIT.getValue()); - Request request2 = constructKNNScriptQueryRequest(INDEX_NAME, qb2, params2, 4); + Request request2 = constructKNNScriptQueryRequest(INDEX_NAME, qb2, params2, 4, Collections.emptyMap()); Response response2 = client().performRequest(request2); assertEquals(request2.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response2.getStatusLine().getStatusCode())); @@ -683,4 +685,110 @@ public void testKNNInnerProdScriptScore() throws Exception { assertEquals("4", results.get(2).getDocId()); assertEquals("1", results.get(3).getDocId()); } + + public void testKNNScriptScoreWithRequestCacheEnabled() throws Exception { + /* + * Create knn index and populate data + */ + createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); + Float[] f1 = { 6.0f, 6.0f }; + addKnnDoc(INDEX_NAME, "1", FIELD_NAME, f1); + + Float[] f2 = { 2.0f, 2.0f }; + addKnnDoc(INDEX_NAME, "2", FIELD_NAME, f2); + + Float[] f3 = { 4.0f, 4.0f }; + addKnnDoc(INDEX_NAME, "3", FIELD_NAME, f3); + + Float[] f4 = { 3.0f, 3.0f }; + addKnnDoc(INDEX_NAME, "4", FIELD_NAME, f4); + + /** + * Construct Search Request + */ + QueryBuilder qb = new MatchAllQueryBuilder(); + Map scriptParams = new HashMap<>(); + /* + * params": { + * "field": "my_dense_vector", + * "vector": [2.0, 2.0] + * } + */ + float[] queryVector = { 1.0f, 1.0f }; + scriptParams.put("field", FIELD_NAME); + scriptParams.put("query_value", queryVector); + scriptParams.put("space_type", SpaceType.L2.getValue()); + Map searchParams = new HashMap<>(); + searchParams.put("request_cache", true); + + // first request with request cache enabled + Request firstScriptQueryRequest = constructKNNScriptQueryRequest(INDEX_NAME, qb, scriptParams, 4, searchParams); + Response firstScriptQueryResponse = client().performRequest(firstScriptQueryRequest); + assertEquals( + firstScriptQueryRequest.getEndpoint() + ": failed", + RestStatus.OK, + RestStatus.fromCode(firstScriptQueryResponse.getStatusLine().getStatusCode()) + ); + + List results = parseSearchResponse(EntityUtils.toString(firstScriptQueryResponse.getEntity()), FIELD_NAME); + List expectedDocids = Arrays.asList("2", "4", "3", "1"); + + List actualDocids = new ArrayList<>(); + for (KNNResult result : results) { + actualDocids.add(result.getDocId()); + } + + assertEquals(4, results.size()); + assertEquals(expectedDocids, actualDocids); + + // assert that the request cache was hit missed at first request + Request firstStatsRequest = new Request("GET", "/" + INDEX_NAME + "/_stats"); + Response firstStatsResponse = client().performRequest(firstStatsRequest); + assertEquals( + firstStatsRequest.getEndpoint() + ": failed", + RestStatus.OK, + RestStatus.fromCode(firstStatsResponse.getStatusLine().getStatusCode()) + ); + String firstStatsResponseBody = EntityUtils.toString(firstStatsResponse.getEntity()); + Map firstQueryCacheMap = Optional.ofNullable( + createParser(MediaTypeRegistry.getDefaultMediaType().xContent(), firstStatsResponseBody).map() + ) + .map(r -> (Map) r.get("indices")) + .map(i -> (Map) i.get(INDEX_NAME)) + .map(ind -> (Map) ind.get("total")) + .map(t -> (Map) t.get("request_cache")) + .orElseThrow(() -> new IllegalStateException("Query Cache Map not found")); + // assert that the request cache was hit missed at first request + assertEquals(1, firstQueryCacheMap.get("miss_count")); + assertEquals(0, firstQueryCacheMap.get("hit_count")); + + // second request with request cache enabled + Request secondScriptQueryRequest = constructKNNScriptQueryRequest(INDEX_NAME, qb, scriptParams, 4, searchParams); + Response secondScriptQueryResponse = client().performRequest(secondScriptQueryRequest); + assertEquals( + firstScriptQueryRequest.getEndpoint() + ": failed", + RestStatus.OK, + RestStatus.fromCode(secondScriptQueryResponse.getStatusLine().getStatusCode()) + ); + + Request secondStatsRequest = new Request("GET", "/" + INDEX_NAME + "/_stats"); + Response secondStatsResponse = client().performRequest(secondStatsRequest); + assertEquals( + secondStatsRequest.getEndpoint() + ": failed", + RestStatus.OK, + RestStatus.fromCode(secondStatsResponse.getStatusLine().getStatusCode()) + ); + String secondStatsResponseBody = EntityUtils.toString(secondStatsResponse.getEntity()); + Map secondQueryCacheMap = Optional.ofNullable( + createParser(MediaTypeRegistry.getDefaultMediaType().xContent(), secondStatsResponseBody).map() + ) + .map(r -> (Map) r.get("indices")) + .map(i -> (Map) i.get(INDEX_NAME)) + .map(ind -> (Map) ind.get("total")) + .map(t -> (Map) t.get("request_cache")) + .orElseThrow(() -> new IllegalStateException("Query Cache Map not found")); + assertEquals(1, secondQueryCacheMap.get("miss_count")); + // assert that the request cache was hit at second request + assertEquals(1, secondQueryCacheMap.get("hit_count")); + } } diff --git a/src/test/java/org/opensearch/knn/plugin/script/PainlessScriptIT.java b/src/test/java/org/opensearch/knn/plugin/script/PainlessScriptIT.java index e2fd536f1..5fa88b0a5 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/PainlessScriptIT.java +++ b/src/test/java/org/opensearch/knn/plugin/script/PainlessScriptIT.java @@ -155,7 +155,15 @@ public void testL2ScriptScoreFails() throws Exception { private Request buildPainlessScoreScriptRequest(String source, int size, Map documents) throws Exception { buildTestIndex(documents); QueryBuilder qb = new MatchAllQueryBuilder(); - return constructScriptScoreContextSearchRequest(INDEX_NAME, qb, Collections.emptyMap(), Script.DEFAULT_SCRIPT_LANG, source, size); + return constructScriptScoreContextSearchRequest( + INDEX_NAME, + qb, + Collections.emptyMap(), + Script.DEFAULT_SCRIPT_LANG, + source, + size, + Collections.emptyMap() + ); } private Request buildPainlessScoreScriptRequest( @@ -166,7 +174,15 @@ private Request buildPainlessScoreScriptRequest( ) throws Exception { buildTestIndex(documents, properties); QueryBuilder qb = new MatchAllQueryBuilder(); - return constructScriptScoreContextSearchRequest(INDEX_NAME, qb, Collections.emptyMap(), Script.DEFAULT_SCRIPT_LANG, source, size); + return constructScriptScoreContextSearchRequest( + INDEX_NAME, + qb, + Collections.emptyMap(), + Script.DEFAULT_SCRIPT_LANG, + source, + size, + Collections.emptyMap() + ); } private Request buildPainlessScriptedMetricRequest( diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index 56e63d53c..78cd84f05 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -11,6 +11,7 @@ import lombok.extern.log4j.Log4j2; import org.apache.commons.lang.StringUtils; import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.apache.hc.core5.net.URIBuilder; import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.common.xcontent.XContentHelper; import org.opensearch.core.xcontent.DeprecationHandler; @@ -60,6 +61,7 @@ import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.Objects; import java.util.PriorityQueue; import java.util.Set; import java.util.concurrent.TimeUnit; @@ -837,12 +839,13 @@ protected Request constructScriptedMetricAggregationSearchRequest( protected Request constructScriptScoreContextSearchRequest( String indexName, QueryBuilder qb, - Map params, + Map scriptParams, String language, String source, - int size + int size, + Map searchParams ) throws Exception { - Script script = buildScript(source, language, params); + Script script = buildScript(source, language, scriptParams); ScriptScoreQueryBuilder sc = new ScriptScoreQueryBuilder(qb, script); XContentBuilder builder = XContentFactory.jsonBuilder().startObject().field("size", size).startObject("query"); builder.startObject("script_score"); @@ -852,7 +855,13 @@ protected Request constructScriptScoreContextSearchRequest( builder.endObject(); builder.endObject(); builder.endObject(); - Request request = new Request("POST", "/" + indexName + "/_search"); + URIBuilder uriBuilder = new URIBuilder("/" + indexName + "/_search"); + if (Objects.nonNull(searchParams)) { + for (Map.Entry entry : searchParams.entrySet()) { + uriBuilder.addParameter(entry.getKey(), entry.getValue().toString()); + } + } + Request request = new Request("POST", uriBuilder.toString()); request.setJsonEntity(builder.toString()); return request; } @@ -873,15 +882,21 @@ protected Request constructKNNScriptQueryRequest(String indexName, QueryBuilder return request; } - protected Request constructKNNScriptQueryRequest(String indexName, QueryBuilder qb, Map params, int size) - throws Exception { + protected Request constructKNNScriptQueryRequest( + String indexName, + QueryBuilder qb, + Map scriptParams, + int size, + Map searchParams + ) throws Exception { return constructScriptScoreContextSearchRequest( indexName, qb, - params, + scriptParams, KNNScoringScriptEngine.NAME, KNNScoringScriptEngine.SCRIPT_SOURCE, - size + size, + searchParams ); } @@ -1132,7 +1147,7 @@ protected void validateKNNScriptScoreSearch(String testIndex, String testField, params.put(QUERY_VALUE, queryVector); params.put(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()); - Request request = constructKNNScriptQueryRequest(testIndex, qb, params, k); + Request request = constructKNNScriptQueryRequest(testIndex, qb, params, k, Collections.emptyMap()); Response response = client().performRequest(request); assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); @@ -1158,7 +1173,8 @@ protected void validateKNNPainlessScriptScoreSearch(String testIndex, String tes Collections.emptyMap(), Script.DEFAULT_SCRIPT_LANG, source, - k + k, + Collections.emptyMap() ); Response response = client().performRequest(request); assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));