Skip to content

Commit

Permalink
Add script_fields context support
Browse files Browse the repository at this point in the history
Include script_fields context to existing
supported context for knn methods.
Added test cases for method and doc values.

Signed-off-by: Vijayan Balasubramanian <[email protected]>
  • Loading branch information
VijayanB committed Aug 1, 2024
1 parent 27f3168 commit c9f1f11
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import org.opensearch.painless.spi.PainlessExtension;
import org.opensearch.painless.spi.Allowlist;
import org.opensearch.painless.spi.AllowlistLoader;
import org.opensearch.script.FieldScript;
import org.opensearch.script.ScoreScript;
import org.opensearch.script.ScriptContext;
import org.opensearch.script.ScriptedMetricAggContexts;
Expand All @@ -33,6 +34,8 @@ public Map<ScriptContext<?>, List<Allowlist>> getContextAllowlists() {
ScriptedMetricAggContexts.CombineScript.CONTEXT,
allowLists,
ScriptedMetricAggContexts.ReduceScript.CONTEXT,
allowLists,
FieldScript.CONTEXT,
allowLists
);
}
Expand Down
69 changes: 60 additions & 9 deletions src/test/java/org/opensearch/knn/integ/PainlessScriptIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,17 @@
package org.opensearch.knn.integ;

import lombok.SneakyThrows;
import org.apache.hc.core5.http.io.entity.EntityUtils;
import org.opensearch.client.Request;
import org.opensearch.client.Response;
import org.opensearch.client.ResponseException;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.index.query.MatchAllQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.knn.KNNRestTestCase;
import org.opensearch.knn.KNNResult;
import org.opensearch.knn.common.KNNConstants;
Expand All @@ -16,16 +25,7 @@
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.apache.hc.core5.http.io.entity.EntityUtils;
import org.opensearch.client.Request;
import org.opensearch.client.Response;
import org.opensearch.client.ResponseException;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.index.query.MatchAllQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.script.Script;

import java.io.IOException;
Expand Down Expand Up @@ -161,6 +161,43 @@ public void testL2ScriptScoreFails() throws Exception {
deleteKNNIndex(INDEX_NAME);
}

public void testCosineSimilarityScriptFields() throws Exception {
String source = String.format("1 + cosineSimilarity([2.0f, -2.0f], doc['%s'])", FIELD_NAME);
String scriptFieldName = "similarity";
Request request = buildPainlessScriptFieldsRequest(source, 3, getCosineTestData(), scriptFieldName);
Response response = client().performRequest(request);
assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));

List<KNNResult> results = parseSearchResponseScriptFields(EntityUtils.toString(response.getEntity()), scriptFieldName);
assertEquals(3, results.size());

String[] expectedDocIDs = { "0", "1", "2" };
for (int i = 0; i < results.size(); i++) {
assertEquals(expectedDocIDs[i], results.get(i).getDocId());
}
deleteKNNIndex(INDEX_NAME);
}

public void testScriptFieldsGetValueReturnsDocValues() throws Exception {

String source = String.format("doc['%s'].value[0]", FIELD_NAME);
String scriptFieldName = "doc_value_field";
Map<String, Float[]> testData = getKnnVectorTestData();
Request request = buildPainlessScriptFieldsRequest(source, testData.size(), testData, scriptFieldName);

Response response = client().performRequest(request);
assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));

List<KNNResult> results = parseSearchResponseScriptFields(EntityUtils.toString(response.getEntity()), scriptFieldName);
assertEquals(testData.size(), results.size());

String[] expectedDocIDs = { "1", "2", "3", "4" };
for (int i = 0; i < results.size(); i++) {
assertEquals(expectedDocIDs[i], results.get(i).getDocId());
}
deleteKNNIndex(INDEX_NAME);
}

private Request buildPainlessScoreScriptRequest(String source, int size, Map<String, Float[]> documents) throws Exception {
buildTestIndex(documents);
QueryBuilder qb = new MatchAllQueryBuilder();
Expand All @@ -175,6 +212,20 @@ private Request buildPainlessScoreScriptRequest(String source, int size, Map<Str
);
}

private Request buildPainlessScriptFieldsRequest(String source, int size, Map<String, Float[]> documents, String scriptFieldName)
throws Exception {
buildTestIndex(documents);
return constructScriptFieldsContextSearchRequest(
INDEX_NAME,
scriptFieldName,
Collections.emptyMap(),
Script.DEFAULT_SCRIPT_LANG,
source,
size,
Collections.emptyMap()
);
}

private Request buildPainlessScoreScriptRequest(
String source,
int size,
Expand Down
56 changes: 56 additions & 0 deletions src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,31 @@ protected List<Float> parseSearchResponseScore(String responseBody, String field
return knnSearchResponses;
}

protected List<KNNResult> parseSearchResponseScriptFields(String responseBody, String scriptFieldName) throws IOException {
@SuppressWarnings("unchecked")
List<Object> hits = (List<Object>) ((Map<String, Object>) createParser(
MediaTypeRegistry.getDefaultMediaType().xContent(),
responseBody
).map().get("hits")).get("hits");

@SuppressWarnings("unchecked")
List<KNNResult> knnSearchResponses = hits.stream().map(hit -> {
@SuppressWarnings("unchecked")
final float[] vector = Floats.toArray(
Arrays.stream(
((ArrayList<Float>) ((Map<String, Object>) ((Map<String, Object>) hit).get("fields")).get(scriptFieldName)).toArray()
).map(Object::toString).map(Float::valueOf).collect(Collectors.toList())
);
return new KNNResult(
(String) ((Map<String, Object>) hit).get("_id"),
vector,
((Double) ((Map<String, Object>) hit).get("_score")).floatValue()
);
}).collect(Collectors.toList());

return knnSearchResponses;
}

/**
* Parse the response of Aggregation to extract the value
*/
Expand Down Expand Up @@ -1002,6 +1027,37 @@ protected Request constructScriptedMetricAggregationSearchRequest(
return request;
}

protected Request constructScriptFieldsContextSearchRequest(
String indexName,
String fieldName,
Map<String, Object> scriptParams,
String language,
String source,
int size,
Map<String, Object> searchParams
) throws Exception {
Script script = buildScript(source, language, scriptParams);
XContentBuilder builder = XContentFactory.jsonBuilder().startObject().field("size", size).startObject("query");
builder.startObject("match_all");
builder.endObject();
builder.endObject();
builder.startObject("script_fields");
builder.startObject(fieldName);
builder.field("script", script);
builder.endObject();
builder.endObject();
builder.endObject();
URIBuilder uriBuilder = new URIBuilder("/" + indexName + "/_search");
if (Objects.nonNull(searchParams)) {
for (Map.Entry<String, Object> entry : searchParams.entrySet()) {
uriBuilder.addParameter(entry.getKey(), entry.getValue().toString());
}
}
Request request = new Request("POST", uriBuilder.toString());
request.setJsonEntity(builder.toString());
return request;
}

protected Request constructScriptScoreContextSearchRequest(
String indexName,
QueryBuilder qb,
Expand Down

0 comments on commit c9f1f11

Please sign in to comment.