Skip to content

Commit

Permalink
in middle
Browse files Browse the repository at this point in the history
Signed-off-by: Vijayan Balasubramanian <[email protected]>
  • Loading branch information
VijayanB committed Jan 6, 2025
1 parent a4a507e commit af0e58c
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ public void testKNNCosineScriptScore() throws Exception {
float[] indexVector2 = { 8.1f, 9.1f, 0.3f };
float[] queryVector = { 3.0f, 4.0f, 5.5f };
if (isRunningAgainstOldCluster()) {
createKnnIndex(testIndex, createKNNDefaultScriptScoreSettings(), createKnnIndexMapping(TEST_FIELD, DIMENSIONS));
createKnnIndex(testIndex, createKNNDefaultScriptScoreSettings(), createKnnIndexMapping(TEST_FIELD, 3));
addKnnDoc(testIndex, "1", TEST_FIELD, indexVector1);
validateScore(1, queryVector, new float[] { VectorUtil.cosine(queryVector, indexVector1) });
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,13 @@

package org.opensearch.knn.bwc;

import org.apache.hc.core5.http.io.entity.EntityUtils;
import org.apache.lucene.util.VectorUtil;
import org.opensearch.client.Response;
import org.opensearch.knn.index.SpaceType;

import java.util.List;

import static org.opensearch.knn.TestUtils.NODES_BWC_CLUSTER;

public class ScriptScoringIT extends AbstractRollingUpgradeTestCase {
Expand Down Expand Up @@ -54,4 +59,60 @@ public void validateKNNL2ScriptScoreOnUpgrade(int totalDocsCount, int docId) thr
validateKNNScriptScoreSearch(testIndex, TEST_FIELD, DIMENSIONS, totalDocsCount, K, SpaceType.L2);
}

// KNN script scoring for space_type "cosine"
public void testKNNCosineScriptScore() throws Exception {
float[] indexVector1 = { 1.1f, 2.1f, 3.3f };
float[] indexVector2 = { 8.1f, 9.1f, 10.3f };
float[] indexVector3 = { 9.1f, 10.1f, 11.3f };
float[] indexVector4 = { 10.1f, 11.1f, 12.3f };
float[] queryVector = { 3.0f, 4.0f, 13.5f };
waitForClusterHealthGreen(NODES_BWC_CLUSTER);
int k = 10;
switch (getClusterType()) {
case OLD:
createKnnIndex(testIndex, createKNNDefaultScriptScoreSettings(), createKnnIndexMapping(TEST_FIELD, 3));
addKnnDoc(testIndex, "1", TEST_FIELD, indexVector1);
validateScore(k, queryVector, List.of(VectorUtil.cosine(queryVector, indexVector1)));
break;
case MIXED:
if (isFirstMixedRound()) {
addKnnDoc(testIndex, "2", TEST_FIELD, indexVector2);
validateScore(
k,
queryVector,
List.of(VectorUtil.cosine(queryVector, indexVector1), VectorUtil.cosine(queryVector, indexVector2))
);
} else {
addKnnDoc(testIndex, "3", TEST_FIELD, indexVector3);
validateScore(
k,
queryVector,
List.of(
VectorUtil.cosine(queryVector, indexVector1),
VectorUtil.cosine(queryVector, indexVector2),
VectorUtil.cosine(queryVector, indexVector3)
)
);
}
break;
case UPGRADED:
addKnnDoc(testIndex, "3", TEST_FIELD, indexVector3);
validateScore(
k,
queryVector,
List.of(
VectorUtil.cosine(queryVector, indexVector1),
VectorUtil.cosine(queryVector, indexVector2),
VectorUtil.cosine(queryVector, indexVector3),
VectorUtil.cosine(queryVector, indexVector4)
)
);
}
}

private void validateScore(int k, float[] queryVector, List<Float> expectedScores) throws Exception {
final Response responseBody = executeKNNScriptScoreRequest(testIndex, TEST_FIELD, k, SpaceType.COSINESIMIL, queryVector);
List<Float> actualScores = parseSearchResponseScore(EntityUtils.toString(responseBody.getEntity()), TEST_FIELD);
assertEquals(expectedScores, actualScores);
}
}

0 comments on commit af0e58c

Please sign in to comment.