Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix script score queries not getting cached #1367

Merged
merged 1 commit into from
Jan 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Fix use-after-free case on nmslib search path [#1305](https://github.com/opensearch-project/k-NN/pull/1305)
* Allow nested knn field mapping when train model [#1318](https://github.com/opensearch-project/k-NN/pull/1318)
* Properly designate model state for actively training models when nodes crash or leave cluster [#1317](https://github.com/opensearch-project/k-NN/pull/1317)
* Fix script score queries not getting cached [#1367](https://github.com/opensearch-project/k-NN/pull/1367)
### Infrastructure
* Upgrade gradle to 8.4 [1289](https://github.com/opensearch-project/k-NN/pull/1289)
* Refactor security testing to install from individual components [#1307](https://github.com/opensearch-project/k-NN/pull/1307)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.opensearch.knn.index.SpaceType;
import org.opensearch.core.rest.RestStatus;

import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -107,7 +108,7 @@ private void validateKNNInnerProductScriptScoreSearch(String testIndex, String t
params.put(QUERY_VALUE, queryVector);
params.put(METHOD_PARAMETER_SPACE_TYPE, SpaceType.INNER_PRODUCT.getValue());

Request request = constructKNNScriptQueryRequest(testIndex, qb, params, k);
Request request = constructKNNScriptQueryRequest(testIndex, qb, params, k, Collections.emptyMap());
Response response = client().performRequest(request);
assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,64 +6,21 @@
package org.opensearch.knn.plugin.script;

import org.apache.lucene.search.IndexSearcher;
import org.opensearch.knn.plugin.stats.KNNCounter;
import org.apache.lucene.index.LeafReaderContext;
import org.opensearch.script.ScoreScript;
import org.opensearch.script.ScriptFactory;
import org.opensearch.search.lookup.SearchLookup;

import java.io.IOException;
import java.util.Map;

public class KNNScoreScriptFactory implements ScoreScript.LeafFactory {
private final Map<String, Object> params;
private final SearchLookup lookup;
private String similaritySpace;
private String field;
private Object query;
private KNNScoringSpace knnScoringSpace;

private IndexSearcher searcher;

public KNNScoreScriptFactory(Map<String, Object> params, SearchLookup lookup, IndexSearcher searcher) {
KNNCounter.SCRIPT_QUERY_REQUESTS.increment();
this.params = params;
this.lookup = lookup;
this.field = getValue(params, "field").toString();
this.similaritySpace = getValue(params, "space_type").toString();
this.query = getValue(params, "query_value");
this.searcher = searcher;

this.knnScoringSpace = KNNScoringSpaceFactory.create(
this.similaritySpace,
this.query,
lookup.doc().mapperService().fieldType(this.field)
);
}

private Object getValue(Map<String, Object> params, String fieldName) {
final Object value = params.get(fieldName);
if (value != null) return value;

KNNCounter.SCRIPT_QUERY_ERRORS.increment();
throw new IllegalArgumentException("Missing parameter [" + fieldName + "]");
}

public class KNNScoreScriptFactory implements ScoreScript.Factory, ScriptFactory {
@Override
public boolean needs_score() {
return false;
public boolean isResultDeterministic() {
// This implies the results are cacheable
return true;
}

/**
* For each segment, supply the KNNScoreScript that should be used to re-score the documents returned from the
* query. Because the method to score the documents was set during factory construction, the scripts are agnostic of
* the similarity space. The KNNScoringSpace will return the correct script, given the query, the field type, and
* the similarity space.
*
* @param ctx LeafReaderContext for the segment
* @return ScoreScript to be executed
*/
@Override
public ScoreScript newInstance(LeafReaderContext ctx) throws IOException {
return knnScoringSpace.getScoreScript(params, field, lookup, ctx, this.searcher);
public ScoreScript.LeafFactory newFactory(Map<String, Object> params, SearchLookup lookup, IndexSearcher indexSearcher) {
return new KNNScoreScriptLeafFactory(params, lookup, indexSearcher);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.plugin.script;

import java.io.IOException;
import java.util.Locale;
import java.util.Map;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.IndexSearcher;
import org.opensearch.knn.plugin.stats.KNNCounter;
import org.opensearch.script.ScoreScript;
import org.opensearch.search.lookup.SearchLookup;

/*
* A factory that creates KNNScoreScriptLeafFactory objects. The factory is responsible for parsing the parameters
* passed in the query and creating the KNNScoreScriptLeafFactory object.
*/
public class KNNScoreScriptLeafFactory implements ScoreScript.LeafFactory {
private final Map<String, Object> params;
private final SearchLookup lookup;
private final String similaritySpace;
private final String field;
private final Object query;
private final KNNScoringSpace knnScoringSpace;
private final IndexSearcher searcher;

public KNNScoreScriptLeafFactory(Map<String, Object> params, SearchLookup lookup, IndexSearcher searcher) {
KNNCounter.SCRIPT_QUERY_REQUESTS.increment();
this.params = params;
this.lookup = lookup;
this.field = getValue(params, "field").toString();
this.similaritySpace = getValue(params, "space_type").toString();
this.query = getValue(params, "query_value");
this.searcher = searcher;

this.knnScoringSpace = KNNScoringSpaceFactory.create(
this.similaritySpace,
this.query,
lookup.doc().mapperService().fieldType(this.field)
);
}

private Object getValue(Map<String, Object> params, String fieldName) {
final Object value = params.get(fieldName);
if (value != null) return value;

KNNCounter.SCRIPT_QUERY_ERRORS.increment();
throw new IllegalArgumentException(String.format(Locale.ROOT, "Missing parameter [%s]", fieldName));
}

@Override
public boolean needs_score() {
return false;
}

/**
* For each segment, supply the KNNScoreScript that should be used to re-score the documents returned from the
* query. Because the method to score the documents was set during factory construction, the scripts are agnostic of
* the similarity space. The KNNScoringSpace will return the correct script, given the query, the field type, and
* the similarity space.
*
* @param ctx LeafReaderContext for the segment
* @return ScoreScript to be executed
*/
@Override
public ScoreScript newInstance(LeafReaderContext ctx) throws IOException {
return knnScoringSpace.getScoreScript(params, field, lookup, ctx, this.searcher);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public <FactoryType> FactoryType compile(String name, String code, ScriptContext
KNNCounter.SCRIPT_COMPILATION_ERRORS.increment();
throw new IllegalArgumentException("Unknown script name " + code);
}
ScoreScript.Factory factory = KNNScoreScriptFactory::new;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need this change?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes we need, since KNNScoreScriptFactory does not have a no-argument constructor that matches what's expected by the interface, we can't use the method reference KNNScoreScriptFactory::new here. And the IDE
will report error "Cannot resolve constructor 'KNNScoreScriptFactory'"

ScoreScript.Factory factory = new KNNScoreScriptFactory();
return context.factoryClazz.cast(factory);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,8 @@ public void testL2PainlessScriptingWithByteVectorDataType() throws Exception {
Collections.emptyMap(),
Script.DEFAULT_SCRIPT_LANG,
source,
4
4,
Collections.emptyMap()
);

Response response = client().performRequest(request);
Expand All @@ -370,7 +371,8 @@ public void testL2PainlessScriptingWithFloatVectorDataType() throws Exception {
Collections.emptyMap(),
Script.DEFAULT_SCRIPT_LANG,
source,
4
4,
Collections.emptyMap()
);

Response response = client().performRequest(request);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;

import static org.hamcrest.Matchers.containsString;
Expand Down Expand Up @@ -451,7 +453,7 @@ public void testHammingScriptScore_Long() throws Exception {
params1.put("field", FIELD_NAME);
params1.put("query_value", queryValue1);
params1.put("space_type", SpaceType.HAMMING_BIT.getValue());
Request request1 = constructKNNScriptQueryRequest(INDEX_NAME, qb1, params1, 4);
Request request1 = constructKNNScriptQueryRequest(INDEX_NAME, qb1, params1, 4, Collections.emptyMap());
Response response1 = client().performRequest(request1);
assertEquals(request1.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response1.getStatusLine().getStatusCode()));

Expand Down Expand Up @@ -491,7 +493,7 @@ public void testHammingScriptScore_Long() throws Exception {
params2.put("field", FIELD_NAME);
params2.put("query_value", queryValue2);
params2.put("space_type", SpaceType.HAMMING_BIT.getValue());
Request request2 = constructKNNScriptQueryRequest(INDEX_NAME, qb2, params2, 4);
Request request2 = constructKNNScriptQueryRequest(INDEX_NAME, qb2, params2, 4, Collections.emptyMap());
Response response2 = client().performRequest(request2);
assertEquals(request2.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response2.getStatusLine().getStatusCode()));

Expand Down Expand Up @@ -561,7 +563,7 @@ public void testHammingScriptScore_Base64() throws Exception {
params1.put("field", FIELD_NAME);
params1.put("query_value", queryValue1);
params1.put("space_type", SpaceType.HAMMING_BIT.getValue());
Request request1 = constructKNNScriptQueryRequest(INDEX_NAME, qb1, params1, 4);
Request request1 = constructKNNScriptQueryRequest(INDEX_NAME, qb1, params1, 4, Collections.emptyMap());
Response response1 = client().performRequest(request1);
assertEquals(request1.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response1.getStatusLine().getStatusCode()));

Expand Down Expand Up @@ -601,7 +603,7 @@ public void testHammingScriptScore_Base64() throws Exception {
params2.put("field", FIELD_NAME);
params2.put("query_value", queryValue2);
params2.put("space_type", SpaceType.HAMMING_BIT.getValue());
Request request2 = constructKNNScriptQueryRequest(INDEX_NAME, qb2, params2, 4);
Request request2 = constructKNNScriptQueryRequest(INDEX_NAME, qb2, params2, 4, Collections.emptyMap());
Response response2 = client().performRequest(request2);
assertEquals(request2.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response2.getStatusLine().getStatusCode()));

Expand Down Expand Up @@ -683,4 +685,110 @@ public void testKNNInnerProdScriptScore() throws Exception {
assertEquals("4", results.get(2).getDocId());
assertEquals("1", results.get(3).getDocId());
}

public void testKNNScriptScoreWithRequestCacheEnabled() throws Exception {
/*
* Create knn index and populate data
*/
createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2));
Float[] f1 = { 6.0f, 6.0f };
addKnnDoc(INDEX_NAME, "1", FIELD_NAME, f1);

Float[] f2 = { 2.0f, 2.0f };
addKnnDoc(INDEX_NAME, "2", FIELD_NAME, f2);

Float[] f3 = { 4.0f, 4.0f };
addKnnDoc(INDEX_NAME, "3", FIELD_NAME, f3);

Float[] f4 = { 3.0f, 3.0f };
addKnnDoc(INDEX_NAME, "4", FIELD_NAME, f4);

/**
* Construct Search Request
*/
QueryBuilder qb = new MatchAllQueryBuilder();
Map<String, Object> scriptParams = new HashMap<>();
/*
* params": {
* "field": "my_dense_vector",
* "vector": [2.0, 2.0]
* }
*/
float[] queryVector = { 1.0f, 1.0f };
scriptParams.put("field", FIELD_NAME);
scriptParams.put("query_value", queryVector);
scriptParams.put("space_type", SpaceType.L2.getValue());
Map<String, Object> searchParams = new HashMap<>();
searchParams.put("request_cache", true);

// first request with request cache enabled
Request firstScriptQueryRequest = constructKNNScriptQueryRequest(INDEX_NAME, qb, scriptParams, 4, searchParams);
Response firstScriptQueryResponse = client().performRequest(firstScriptQueryRequest);
assertEquals(
firstScriptQueryRequest.getEndpoint() + ": failed",
RestStatus.OK,
RestStatus.fromCode(firstScriptQueryResponse.getStatusLine().getStatusCode())
);

List<KNNResult> results = parseSearchResponse(EntityUtils.toString(firstScriptQueryResponse.getEntity()), FIELD_NAME);
List<String> expectedDocids = Arrays.asList("2", "4", "3", "1");

List<String> actualDocids = new ArrayList<>();
for (KNNResult result : results) {
actualDocids.add(result.getDocId());
}

assertEquals(4, results.size());
assertEquals(expectedDocids, actualDocids);

// assert that the request cache was hit missed at first request
Request firstStatsRequest = new Request("GET", "/" + INDEX_NAME + "/_stats");
Response firstStatsResponse = client().performRequest(firstStatsRequest);
assertEquals(
firstStatsRequest.getEndpoint() + ": failed",
RestStatus.OK,
RestStatus.fromCode(firstStatsResponse.getStatusLine().getStatusCode())
);
String firstStatsResponseBody = EntityUtils.toString(firstStatsResponse.getEntity());
Map<String, Object> firstQueryCacheMap = Optional.ofNullable(
createParser(MediaTypeRegistry.getDefaultMediaType().xContent(), firstStatsResponseBody).map()
)
.map(r -> (Map<String, Object>) r.get("indices"))
.map(i -> (Map<String, Object>) i.get(INDEX_NAME))
.map(ind -> (Map<String, Object>) ind.get("total"))
.map(t -> (Map<String, Object>) t.get("request_cache"))
.orElseThrow(() -> new IllegalStateException("Query Cache Map not found"));
// assert that the request cache was hit missed at first request
assertEquals(1, firstQueryCacheMap.get("miss_count"));
assertEquals(0, firstQueryCacheMap.get("hit_count"));

// second request with request cache enabled
Request secondScriptQueryRequest = constructKNNScriptQueryRequest(INDEX_NAME, qb, scriptParams, 4, searchParams);
Response secondScriptQueryResponse = client().performRequest(secondScriptQueryRequest);
assertEquals(
firstScriptQueryRequest.getEndpoint() + ": failed",
RestStatus.OK,
RestStatus.fromCode(secondScriptQueryResponse.getStatusLine().getStatusCode())
);

Request secondStatsRequest = new Request("GET", "/" + INDEX_NAME + "/_stats");
Response secondStatsResponse = client().performRequest(secondStatsRequest);
assertEquals(
secondStatsRequest.getEndpoint() + ": failed",
RestStatus.OK,
RestStatus.fromCode(secondStatsResponse.getStatusLine().getStatusCode())
);
String secondStatsResponseBody = EntityUtils.toString(secondStatsResponse.getEntity());
Map<String, Object> secondQueryCacheMap = Optional.ofNullable(
createParser(MediaTypeRegistry.getDefaultMediaType().xContent(), secondStatsResponseBody).map()
)
.map(r -> (Map<String, Object>) r.get("indices"))
.map(i -> (Map<String, Object>) i.get(INDEX_NAME))
.map(ind -> (Map<String, Object>) ind.get("total"))
.map(t -> (Map<String, Object>) t.get("request_cache"))
.orElseThrow(() -> new IllegalStateException("Query Cache Map not found"));
assertEquals(1, secondQueryCacheMap.get("miss_count"));
// assert that the request cache was hit at second request
assertEquals(1, secondQueryCacheMap.get("hit_count"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,15 @@ public void testL2ScriptScoreFails() throws Exception {
private Request buildPainlessScoreScriptRequest(String source, int size, Map<String, Float[]> documents) throws Exception {
buildTestIndex(documents);
QueryBuilder qb = new MatchAllQueryBuilder();
return constructScriptScoreContextSearchRequest(INDEX_NAME, qb, Collections.emptyMap(), Script.DEFAULT_SCRIPT_LANG, source, size);
return constructScriptScoreContextSearchRequest(
INDEX_NAME,
qb,
Collections.emptyMap(),
Script.DEFAULT_SCRIPT_LANG,
source,
size,
Collections.emptyMap()
);
}

private Request buildPainlessScoreScriptRequest(
Expand All @@ -166,7 +174,15 @@ private Request buildPainlessScoreScriptRequest(
) throws Exception {
buildTestIndex(documents, properties);
QueryBuilder qb = new MatchAllQueryBuilder();
return constructScriptScoreContextSearchRequest(INDEX_NAME, qb, Collections.emptyMap(), Script.DEFAULT_SCRIPT_LANG, source, size);
return constructScriptScoreContextSearchRequest(
INDEX_NAME,
qb,
Collections.emptyMap(),
Script.DEFAULT_SCRIPT_LANG,
source,
size,
Collections.emptyMap()
);
}

private Request buildPainlessScriptedMetricRequest(
Expand Down
Loading
Loading