From 0df5f62dda15c8c57b66ab6515cf191a62b28ab0 Mon Sep 17 00:00:00 2001 From: John Mazanec Date: Tue, 18 Feb 2025 08:49:26 -0800 Subject: [PATCH] Fix derived source for binary and byte vectors For binary and byte vectors, for derived source, we were not formatting them before adding them back to the source. Thus, they were binary strings in the source. This change fixes this formatting to format them as ints before adding back. Signed-off-by: John Mazanec --- CHANGELOG.md | 1 + ...AbstractPerFieldDerivedVectorInjector.java | 39 ++ .../NestedPerFieldDerivedVectorInjector.java | 6 +- .../RootPerFieldDerivedVectorInjector.java | 4 +- .../mapper/KNNVectorFieldMapperUtil.java | 6 +- .../mapper/KNNVectorFieldMapperUtilTests.java | 11 + .../opensearch/knn/integ/DerivedSourceIT.java | 349 +++++++++++++++--- .../org/opensearch/knn/KNNRestTestCase.java | 49 ++- .../java/org/opensearch/knn/TestUtils.java | 18 + 9 files changed, 410 insertions(+), 73 deletions(-) create mode 100644 src/main/java/org/opensearch/knn/index/codec/derivedsource/AbstractPerFieldDerivedVectorInjector.java diff --git a/CHANGELOG.md b/CHANGELOG.md index 97e8142671..224876bcab 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Features ### Enhancements ### Bug Fixes +* Fix derived source for binary and byte vectors [#2533](https://github.com/opensearch-project/k-NN/pull/2533/) ### Infrastructure ### Documentation ### Maintenance diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/AbstractPerFieldDerivedVectorInjector.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/AbstractPerFieldDerivedVectorInjector.java new file mode 100644 index 0000000000..bba5e14c19 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/AbstractPerFieldDerivedVectorInjector.java @@ -0,0 +1,39 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.derivedsource; + +import lombok.extern.log4j.Log4j2; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.util.BytesRef; +import org.opensearch.knn.common.FieldInfoExtractor; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; + +import java.io.IOException; + +@Log4j2 +abstract class AbstractPerFieldDerivedVectorInjector implements PerFieldDerivedVectorInjector { + /** + * Utility method for formatting the vector values based on the vector data type. KNNVectorValues must be advanced + * to the correct position. + * + * @param fieldInfo fieldinfo for the vector field + * @param vectorValues vector values of the field. getVector or getConditionalVector should return expected vector. + * @return vector formatted based on the vector data type + * @throws IOException if unable to deserialize stored vector + */ + protected Object formatVector(FieldInfo fieldInfo, KNNVectorValues vectorValues) throws IOException { + Object vectorValue = vectorValues.getVector(); + // If the vector value is a byte[], we must deserialize + if (vectorValue instanceof byte[]) { + BytesRef vectorBytesRef = new BytesRef((byte[]) vectorValue); + VectorDataType vectorDataType = FieldInfoExtractor.extractVectorDataType(fieldInfo); + return KNNVectorFieldMapperUtil.deserializeStoredVector(vectorBytesRef, vectorDataType); + } + return vectorValues.conditionalCloneVector(); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/NestedPerFieldDerivedVectorInjector.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/NestedPerFieldDerivedVectorInjector.java index 7e28156703..dc378ec378 100644 --- a/src/main/java/org/opensearch/knn/index/codec/derivedsource/NestedPerFieldDerivedVectorInjector.java +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/NestedPerFieldDerivedVectorInjector.java @@ -29,7 +29,7 @@ @Log4j2 @AllArgsConstructor -public class NestedPerFieldDerivedVectorInjector implements PerFieldDerivedVectorInjector { +public class NestedPerFieldDerivedVectorInjector extends AbstractPerFieldDerivedVectorInjector { private final FieldInfo childFieldInfo; private final DerivedSourceReaders derivedSourceReaders; @@ -116,7 +116,7 @@ public void inject(int parentDocId, Map sourceAsMap) throws IOEx reconstructedSource.add(position, new HashMap<>()); positions.add(position, docId); } - reconstructedSource.get(position).put(childFieldName, vectorValues.conditionalCloneVector()); + reconstructedSource.get(position).put(childFieldName, formatVector(childFieldInfo, vectorValues)); offsetPositionsIndex = position + 1; } sourceAsMap.put(parentFieldName, reconstructedSource); @@ -137,7 +137,7 @@ private void injectObject(int docId, Map sourceAsMap) throws IOE String field = fields[i]; currentMap = (Map) currentMap.computeIfAbsent(field, k -> new HashMap<>()); } - currentMap.put(fields[fields.length - 1], vectorValues.getVector()); + currentMap.put(fields[fields.length - 1], formatVector(childFieldInfo, vectorValues)); } /** diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/RootPerFieldDerivedVectorInjector.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/RootPerFieldDerivedVectorInjector.java index 430fd24ae1..e9c4d21a68 100644 --- a/src/main/java/org/opensearch/knn/index/codec/derivedsource/RootPerFieldDerivedVectorInjector.java +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/RootPerFieldDerivedVectorInjector.java @@ -16,7 +16,7 @@ /** * {@link PerFieldDerivedVectorInjector} for root fields (i.e. non nested fields). */ -class RootPerFieldDerivedVectorInjector implements PerFieldDerivedVectorInjector { +class RootPerFieldDerivedVectorInjector extends AbstractPerFieldDerivedVectorInjector { private final FieldInfo fieldInfo; private final CheckedSupplier, IOException> vectorValuesSupplier; @@ -40,7 +40,7 @@ public RootPerFieldDerivedVectorInjector(FieldInfo fieldInfo, DerivedSourceReade public void inject(int docId, Map sourceAsMap) throws IOException { KNNVectorValues vectorValues = vectorValuesSupplier.get(); if (vectorValues.docId() == docId || vectorValues.advance(docId) == docId) { - sourceAsMap.put(fieldInfo.name, vectorValues.conditionalCloneVector()); + sourceAsMap.put(fieldInfo.name, formatVector(fieldInfo, vectorValues)); } } } diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java index 1240098191..d2f9c21b7e 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java @@ -86,10 +86,10 @@ public static StoredField createStoredFieldForFloatVector(String name, float[] v * @return either int[] or float[] of corresponding vector */ public static Object deserializeStoredVector(BytesRef storedVector, VectorDataType vectorDataType) { - if (VectorDataType.BYTE == vectorDataType) { + if (VectorDataType.BYTE == vectorDataType || VectorDataType.BINARY == vectorDataType) { byte[] bytes = storedVector.bytes; - int[] byteAsIntArray = new int[bytes.length]; - Arrays.setAll(byteAsIntArray, i -> bytes[i]); + int[] byteAsIntArray = new int[storedVector.length]; + Arrays.setAll(byteAsIntArray, i -> bytes[i + storedVector.offset]); return byteAsIntArray; } diff --git a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java index 8bcf9fdbeb..b41b8349cb 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java @@ -41,6 +41,17 @@ public void testStoredFields_whenVectorIsByteType_thenSucceed() { assertArrayEquals(byteAsIntArray, (int[]) vector); } + public void testStoredFields_whenVectorIsBinaryType_thenSucceed() { + StoredField storedField = KNNVectorFieldMapperUtil.createStoredFieldForByteVector(TEST_FIELD_NAME, TEST_BYTE_VECTOR); + assertEquals(TEST_FIELD_NAME, storedField.name()); + assertEquals(TEST_BYTE_VECTOR, storedField.binaryValue().bytes); + Object vector = KNNVectorFieldMapperUtil.deserializeStoredVector(storedField.binaryValue(), VectorDataType.BINARY); + assertTrue(vector instanceof int[]); + int[] byteAsIntArray = new int[TEST_BYTE_VECTOR.length]; + Arrays.setAll(byteAsIntArray, i -> TEST_BYTE_VECTOR[i]); + assertArrayEquals(byteAsIntArray, (int[]) vector); + } + public void testStoredFields_whenVectorIsFloatType_thenSucceed() { StoredField storedField = KNNVectorFieldMapperUtil.createStoredFieldForFloatVector(TEST_FIELD_NAME, TEST_FLOAT_VECTOR); assertEquals(TEST_FIELD_NAME, storedField.name()); diff --git a/src/test/java/org/opensearch/knn/integ/DerivedSourceIT.java b/src/test/java/org/opensearch/knn/integ/DerivedSourceIT.java index ad5ef811b5..9255e9863e 100644 --- a/src/test/java/org/opensearch/knn/integ/DerivedSourceIT.java +++ b/src/test/java/org/opensearch/knn/integ/DerivedSourceIT.java @@ -5,7 +5,6 @@ package org.opensearch.knn.integ; -import com.google.common.primitives.Floats; import lombok.Builder; import lombok.Data; import lombok.SneakyThrows; @@ -26,10 +25,12 @@ import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.function.Function; import static org.opensearch.knn.common.KNNConstants.DIMENSION; import static org.opensearch.knn.common.KNNConstants.TYPE; import static org.opensearch.knn.common.KNNConstants.TYPE_KNN_VECTOR; +import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; /** * Integration tests for derived source feature for vector fields. Currently, with derived source, there are @@ -92,59 +93,64 @@ public class DerivedSourceIT extends KNNRestTestCase { */ @SneakyThrows public void testFlatBaseCase() { + String mapping = createVectorNonNestedMappings(TEST_DIMENSION, null); List indexConfigContexts = List.of( IndexConfigContext.builder() .indexName(("original-enable-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) .vectorFieldNames(List.of(FIELD_NAME)) .dimension(TEST_DIMENSION) .settings(DERIVED_ENABLED_SETTINGS) - .mapping(createVectorNonNestedMappings(TEST_DIMENSION)) + .mapping(mapping) .isNested(false) .docCount(DOCS) .indexIngestor(context -> { - bulkIngestRandomVectorsWithSkips(context.indexName, FIELD_NAME, context.docCount, context.dimension, 0.1f); + bulkIngestRandomVectorsWithSkips(context.indexName, FIELD_NAME, context.docCount, context.dimension, 32, 0.1f); refreshAllIndices(); }) + .updateVectorSupplier((c) -> randomFloatVector(c.dimension)) .build(), IndexConfigContext.builder() .indexName(("original-disable-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) .vectorFieldNames(List.of(FIELD_NAME)) .dimension(TEST_DIMENSION) .settings(DERIVED_DISABLED_SETTINGS) - .mapping(createVectorNonNestedMappings(TEST_DIMENSION)) + .mapping(mapping) .isNested(false) .docCount(DOCS) .indexIngestor(context -> { - bulkIngestRandomVectorsWithSkips(context.indexName, FIELD_NAME, context.docCount, context.dimension, 0.1f); + bulkIngestRandomVectorsWithSkips(context.indexName, FIELD_NAME, context.docCount, context.dimension, 32, 0.1f); refreshAllIndices(); }) + .updateVectorSupplier((c) -> randomFloatVector(c.dimension)) .build(), IndexConfigContext.builder() .indexName(("e2e-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) .vectorFieldNames(List.of(FIELD_NAME)) .dimension(TEST_DIMENSION) .settings(DERIVED_ENABLED_SETTINGS) - .mapping(createVectorNonNestedMappings(TEST_DIMENSION)) + .mapping(mapping) .isNested(false) .docCount(DOCS) .indexIngestor(context -> {}) // noop for reindex + .updateVectorSupplier((c) -> randomFloatVector(c.dimension)) .build(), IndexConfigContext.builder() .indexName(("e2d-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) .vectorFieldNames(List.of(FIELD_NAME)) .dimension(TEST_DIMENSION) .settings(DERIVED_DISABLED_SETTINGS) - .mapping(createVectorNonNestedMappings(TEST_DIMENSION)) + .mapping(mapping) .isNested(false) .docCount(DOCS) .indexIngestor(context -> {}) // noop for reindex + .updateVectorSupplier((c) -> randomFloatVector(c.dimension)) .build(), IndexConfigContext.builder() .indexName(("d2e-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) .vectorFieldNames(List.of(FIELD_NAME)) .dimension(TEST_DIMENSION) .settings(DERIVED_ENABLED_SETTINGS) - .mapping(createVectorNonNestedMappings(TEST_DIMENSION)) + .mapping(mapping) .isNested(false) .docCount(DOCS) .indexIngestor(context -> {}) // noop for reindex @@ -154,10 +160,247 @@ public void testFlatBaseCase() { .vectorFieldNames(List.of(FIELD_NAME)) .dimension(TEST_DIMENSION) .settings(DERIVED_DISABLED_SETTINGS) - .mapping(createVectorNonNestedMappings(TEST_DIMENSION)) + .mapping(mapping) .isNested(false) .docCount(DOCS) .indexIngestor(context -> {}) // noop for reindex + .updateVectorSupplier((c) -> randomFloatVector(c.dimension)) + .build() + + ); + testDerivedSourceE2E(indexConfigContexts); + } + + /** + * Testing flat, single field base case with index configuration. The test will automatically skip adding fields for + * random documents to ensure it works robustly. To ensure correctness, we repeat same operations against an + * index without derived source enabled (baseline). + * Test mapping: + * { + * "settings": { + * "index.knn" true, + * "index.knn.derived_source.enabled": true + * }, + * "mappings":{ + * "properties": { + * "test_vector": { + * "type": "knn_vector", + * "dimension": 128, + * "data_type": "byte" + * } + * } + * } + * } + * Baseline mapping: + * { + * "settings": { + * "index.knn" true, + * "index.knn.derived_source.enabled": false + * }, + * "mappings":{ + * "properties": { + * "test_vector": { + * "type": "knn_vector", + * "dimension": 128, + * "data_type": "byte" + * } + * } + * } + * } + */ + @SneakyThrows + public void testFlatByteBaseCase() { + String mapping = createVectorNonNestedMappings(TEST_DIMENSION, "byte"); + List indexConfigContexts = List.of( + IndexConfigContext.builder() + .indexName(("original-enable-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_ENABLED_SETTINGS) + .mapping(mapping) + .isNested(false) + .docCount(DOCS) + .indexIngestor(context -> { + bulkIngestRandomVectorsWithSkips(context.indexName, FIELD_NAME, context.docCount, context.dimension, 8, 0.1f); + refreshAllIndices(); + }) + .updateVectorSupplier((c) -> randomByteVector(c.dimension, 1)) + .build(), + IndexConfigContext.builder() + .indexName(("original-disable-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_DISABLED_SETTINGS) + .mapping(mapping) + .isNested(false) + .docCount(DOCS) + .indexIngestor(context -> { + bulkIngestRandomVectorsWithSkips(context.indexName, FIELD_NAME, context.docCount, context.dimension, 8, 0.1f); + refreshAllIndices(); + }) + .updateVectorSupplier((c) -> randomByteVector(c.dimension, 1)) + .build(), + IndexConfigContext.builder() + .indexName(("e2e-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_ENABLED_SETTINGS) + .mapping(mapping) + .isNested(false) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .updateVectorSupplier((c) -> randomByteVector(c.dimension, 1)) + .build(), + IndexConfigContext.builder() + .indexName(("e2d-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_DISABLED_SETTINGS) + .mapping(mapping) + .isNested(false) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .updateVectorSupplier((c) -> randomByteVector(c.dimension, 1)) + .build(), + IndexConfigContext.builder() + .indexName(("d2e-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_ENABLED_SETTINGS) + .mapping(mapping) + .isNested(false) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .updateVectorSupplier((c) -> randomByteVector(c.dimension, 1)) + .build(), + IndexConfigContext.builder() + .indexName(("d2d-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_DISABLED_SETTINGS) + .mapping(mapping) + .isNested(false) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .updateVectorSupplier((c) -> randomByteVector(c.dimension, 1)) + .build() + + ); + testDerivedSourceE2E(indexConfigContexts); + } + + /** + * Testing flat, single field base case with index configuration. The test will automatically skip adding fields for + * random documents to ensure it works robustly. To ensure correctness, we repeat same operations against an + * index without derived source enabled (baseline). + * Test mapping: + * { + * "settings": { + * "index.knn" true, + * "index.knn.derived_source.enabled": true + * }, + * "mappings":{ + * "properties": { + * "test_vector": { + * "type": "knn_vector", + * "dimension": 128, + * "data_type": "binary" + * } + * } + * } + * } + * Baseline mapping: + * { + * "settings": { + * "index.knn" true, + * "index.knn.derived_source.enabled": false + * }, + * "mappings":{ + * "properties": { + * "test_vector": { + * "type": "knn_vector", + * "dimension": 128, + * "data_type": "binary" + * } + * } + * } + * } + */ + @SneakyThrows + public void testFlatBinaryBaseCase() { + String mapping = createVectorNonNestedMappings(TEST_DIMENSION, "binary"); + List indexConfigContexts = List.of( + IndexConfigContext.builder() + .indexName(("original-enable-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_ENABLED_SETTINGS) + .mapping(mapping) + .isNested(false) + .docCount(DOCS) + .indexIngestor(context -> { + bulkIngestRandomVectorsWithSkips(context.indexName, FIELD_NAME, context.docCount, context.dimension, 1, 0.1f); + refreshAllIndices(); + }) + .updateVectorSupplier((c) -> randomByteVector(c.dimension, 8)) + .build(), + IndexConfigContext.builder() + .indexName(("original-disable-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_DISABLED_SETTINGS) + .mapping(mapping) + .isNested(false) + .docCount(DOCS) + .indexIngestor(context -> { + bulkIngestRandomVectorsWithSkips(context.indexName, FIELD_NAME, context.docCount, context.dimension, 1, 0.1f); + refreshAllIndices(); + }) + .updateVectorSupplier((c) -> randomByteVector(c.dimension, 8)) + .build(), + IndexConfigContext.builder() + .indexName(("e2e-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_ENABLED_SETTINGS) + .mapping(mapping) + .isNested(false) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .updateVectorSupplier((c) -> randomByteVector(c.dimension, 8)) + .build(), + IndexConfigContext.builder() + .indexName(("e2d-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_DISABLED_SETTINGS) + .mapping(mapping) + .isNested(false) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .updateVectorSupplier((c) -> randomByteVector(c.dimension, 8)) + .build(), + IndexConfigContext.builder() + .indexName(("d2e-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_ENABLED_SETTINGS) + .mapping(mapping) + .isNested(false) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .updateVectorSupplier((c) -> randomByteVector(c.dimension, 8)) + .build(), + IndexConfigContext.builder() + .indexName(("d2d-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) + .vectorFieldNames(List.of(FIELD_NAME)) + .dimension(TEST_DIMENSION) + .settings(DERIVED_DISABLED_SETTINGS) + .mapping(mapping) + .isNested(false) + .docCount(DOCS) + .indexIngestor(context -> {}) // noop for reindex + .updateVectorSupplier((c) -> randomByteVector(c.dimension, 8)) .build() ); @@ -252,6 +495,7 @@ public void testMultiFlatFields() { ); refreshAllIndices(); }) + .updateVectorSupplier((c) -> randomFloatVector(c.dimension)) .build(), IndexConfigContext.builder() .indexName(("original-disable-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) @@ -273,6 +517,7 @@ public void testMultiFlatFields() { ); refreshAllIndices(); }) + .updateVectorSupplier((c) -> randomFloatVector(c.dimension)) .build(), IndexConfigContext.builder() .indexName(("e2e-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) @@ -283,6 +528,7 @@ public void testMultiFlatFields() { .isNested(false) .docCount(DOCS) .indexIngestor(context -> {}) // noop for reindex + .updateVectorSupplier((c) -> randomFloatVector(c.dimension)) .build(), IndexConfigContext.builder() .indexName(("e2d-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) @@ -293,6 +539,7 @@ public void testMultiFlatFields() { .isNested(false) .docCount(DOCS) .indexIngestor(context -> {}) // noop for reindex + .updateVectorSupplier((c) -> randomFloatVector(c.dimension)) .build(), IndexConfigContext.builder() .indexName(("d2e-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) @@ -303,6 +550,7 @@ public void testMultiFlatFields() { .isNested(false) .docCount(DOCS) .indexIngestor(context -> {}) // noop for reindex + .updateVectorSupplier((c) -> randomFloatVector(c.dimension)) .build(), IndexConfigContext.builder() .indexName(("d2d-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) @@ -313,6 +561,7 @@ public void testMultiFlatFields() { .isNested(false) .docCount(DOCS) .indexIngestor(context -> {}) // noop for reindex + .updateVectorSupplier((c) -> randomFloatVector(c.dimension)) .build() ); @@ -369,7 +618,7 @@ public void testMultiFlatFields() { * } */ public void testNestedSingleDocBasic() { - String nestedMapping = createVectorNestedMappings(TEST_DIMENSION); + String nestedMapping = createVectorNestedMappings(TEST_DIMENSION, null); List indexConfigContexts = List.of( IndexConfigContext.builder() .indexName(("original-enable-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) @@ -390,6 +639,7 @@ public void testNestedSingleDocBasic() { ); refreshAllIndices(); }) + .updateVectorSupplier((c) -> randomFloatVector(c.dimension)) .build(), IndexConfigContext.builder() .indexName(("original-disable-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) @@ -410,6 +660,7 @@ public void testNestedSingleDocBasic() { ); refreshAllIndices(); }) + .updateVectorSupplier((c) -> randomFloatVector(c.dimension)) .build(), IndexConfigContext.builder() .indexName(("e2e-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) @@ -420,6 +671,7 @@ public void testNestedSingleDocBasic() { .isNested(true) .docCount(DOCS) .indexIngestor(context -> {}) // noop for reindex + .updateVectorSupplier((c) -> randomFloatVector(c.dimension)) .build(), IndexConfigContext.builder() .indexName(("e2d-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) @@ -430,6 +682,7 @@ public void testNestedSingleDocBasic() { .isNested(true) .docCount(DOCS) .indexIngestor(context -> {}) // noop for reindex + .updateVectorSupplier((c) -> randomFloatVector(c.dimension)) .build(), IndexConfigContext.builder() .indexName(("d2e-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) @@ -440,6 +693,7 @@ public void testNestedSingleDocBasic() { .isNested(true) .docCount(DOCS) .indexIngestor(context -> {}) // noop for reindex + .updateVectorSupplier((c) -> randomFloatVector(c.dimension)) .build(), IndexConfigContext.builder() .indexName(("d2d-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) @@ -450,6 +704,7 @@ public void testNestedSingleDocBasic() { .isNested(true) .docCount(DOCS) .indexIngestor(context -> {}) // noop for reindex + .updateVectorSupplier((c) -> randomFloatVector(c.dimension)) .build() ); @@ -507,7 +762,7 @@ public void testNestedSingleDocBasic() { */ @SneakyThrows public void testNestedMultiDocBasic() { - String nestedMapping = createVectorNestedMappings(TEST_DIMENSION); + String nestedMapping = createVectorNestedMappings(TEST_DIMENSION, null); List indexConfigContexts = List.of( IndexConfigContext.builder() .indexName(("original-enable-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) @@ -529,6 +784,7 @@ public void testNestedMultiDocBasic() { ); refreshAllIndices(); }) + .updateVectorSupplier((c) -> randomFloatVector(c.dimension)) .build(), IndexConfigContext.builder() .indexName(("original-disable-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) @@ -550,6 +806,7 @@ public void testNestedMultiDocBasic() { ); refreshAllIndices(); }) + .updateVectorSupplier((c) -> randomFloatVector(c.dimension)) .build(), IndexConfigContext.builder() .indexName(("e2e-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) @@ -560,6 +817,7 @@ public void testNestedMultiDocBasic() { .isNested(true) .docCount(DOCS) .indexIngestor(context -> {}) // noop for reindex + .updateVectorSupplier((c) -> randomFloatVector(c.dimension)) .build(), IndexConfigContext.builder() .indexName(("e2d-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) @@ -570,6 +828,7 @@ public void testNestedMultiDocBasic() { .isNested(true) .docCount(DOCS) .indexIngestor(context -> {}) // noop for reindex + .updateVectorSupplier((c) -> randomFloatVector(c.dimension)) .build(), IndexConfigContext.builder() .indexName(("d2e-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) @@ -580,6 +839,7 @@ public void testNestedMultiDocBasic() { .isNested(true) .docCount(DOCS) .indexIngestor(context -> {}) // noop for reindex + .updateVectorSupplier((c) -> randomFloatVector(c.dimension)) .build(), IndexConfigContext.builder() .indexName(("d2d-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) @@ -590,6 +850,7 @@ public void testNestedMultiDocBasic() { .isNested(true) .docCount(DOCS) .indexIngestor(context -> {}) // noop for reindex + .updateVectorSupplier((c) -> randomFloatVector(c.dimension)) .build() ); @@ -725,6 +986,7 @@ public void testObjectFieldTypes() { ); refreshAllIndices(); }) + .updateVectorSupplier((c) -> randomFloatVector(c.dimension)) .build(), IndexConfigContext.builder() .indexName(("original-disable-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) @@ -751,6 +1013,7 @@ public void testObjectFieldTypes() { ); refreshAllIndices(); }) + .updateVectorSupplier((c) -> randomFloatVector(c.dimension)) .build(), IndexConfigContext.builder() .indexName(("e2e-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) @@ -767,6 +1030,7 @@ public void testObjectFieldTypes() { .isNested(true) .docCount(DOCS) .indexIngestor(context -> {}) // noop for reindex + .updateVectorSupplier((c) -> randomFloatVector(c.dimension)) .build(), IndexConfigContext.builder() .indexName(("e2d-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) @@ -783,6 +1047,7 @@ public void testObjectFieldTypes() { .isNested(true) .docCount(DOCS) .indexIngestor(context -> {}) // noop for reindex + .updateVectorSupplier((c) -> randomFloatVector(c.dimension)) .build(), IndexConfigContext.builder() .indexName(("d2e-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) @@ -799,6 +1064,7 @@ public void testObjectFieldTypes() { .isNested(true) .docCount(DOCS) .indexIngestor(context -> {}) // noop for reindex + .updateVectorSupplier((c) -> randomFloatVector(c.dimension)) .build(), IndexConfigContext.builder() .indexName(("d2d-" + getTestName() + randomAlphaOfLength(6)).toLowerCase(Locale.ROOT)) @@ -815,6 +1081,7 @@ public void testObjectFieldTypes() { .isNested(true) .docCount(DOCS) .indexIngestor(context -> {}) // noop for reindex + .updateVectorSupplier((c) -> randomFloatVector(c.dimension)) .build() ); @@ -909,25 +1176,15 @@ private void testUpdate(List indexConfigContexts) { IndexConfigContext derivedSourceDisabledContext = indexConfigContexts.get(1); String originalIndexNameDerivedSourceEnabled = derivedSourceEnabledContext.indexName; String originalIndexNameDerivedSourceDisabled = derivedSourceDisabledContext.indexName; - float[] updateVector = randomFloatVector(derivedSourceDisabledContext.dimension); + Object updateVector = derivedSourceDisabledContext.updateVectorSupplier.apply(derivedSourceDisabledContext); // Update via POST //_doc/ for (String fieldName : derivedSourceEnabledContext.vectorFieldNames) { - updateKnnDoc( - originalIndexNameDerivedSourceEnabled, - String.valueOf(docWithVectorUpdate), - fieldName, - Floats.asList(updateVector).toArray() - ); + updateKnnDoc(originalIndexNameDerivedSourceEnabled, String.valueOf(docWithVectorUpdate), fieldName, updateVector); } for (String fieldName : derivedSourceDisabledContext.vectorFieldNames) { - updateKnnDoc( - originalIndexNameDerivedSourceDisabled, - String.valueOf(docWithVectorUpdate), - fieldName, - Floats.asList(updateVector).toArray() - ); + updateKnnDoc(originalIndexNameDerivedSourceDisabled, String.valueOf(docWithVectorUpdate), fieldName, updateVector); } refreshAllIndices(); assertDocsMatch( @@ -952,7 +1209,7 @@ private void testUpdate(List indexConfigContexts) { originalIndexNameDerivedSourceEnabled, String.valueOf(docWithVectorUpdateFromAPI), fieldName, - Floats.asList(updateVector).toArray() + updateVector ); } for (String fieldName : derivedSourceDisabledContext.vectorFieldNames) { @@ -960,7 +1217,7 @@ private void testUpdate(List indexConfigContexts) { originalIndexNameDerivedSourceDisabled, String.valueOf(docWithVectorUpdateFromAPI), fieldName, - Floats.asList(updateVector).toArray() + updateVector ); } refreshAllIndices(); @@ -972,20 +1229,10 @@ private void testUpdate(List indexConfigContexts) { // Update by query for (String fieldName : derivedSourceEnabledContext.vectorFieldNames) { - updateKnnDocByQuery( - originalIndexNameDerivedSourceEnabled, - String.valueOf(docWithUpdateByQuery), - fieldName, - Floats.asList(updateVector).toArray() - ); + updateKnnDocByQuery(originalIndexNameDerivedSourceEnabled, String.valueOf(docWithUpdateByQuery), fieldName, updateVector); } for (String fieldName : derivedSourceDisabledContext.vectorFieldNames) { - updateKnnDocByQuery( - originalIndexNameDerivedSourceDisabled, - String.valueOf(docWithUpdateByQuery), - fieldName, - Floats.asList(updateVector).toArray() - ); + updateKnnDocByQuery(originalIndexNameDerivedSourceDisabled, String.valueOf(docWithUpdateByQuery), fieldName, updateVector); } refreshAllIndices(); assertDocsMatch( @@ -1185,6 +1432,7 @@ private static class IndexConfigContext { boolean isNested; int docCount; CheckedConsumer indexIngestor; + Function updateVectorSupplier; } @SneakyThrows @@ -1216,22 +1464,23 @@ private void assertDocMatches(int docId, String index1, String index2) { } @SneakyThrows - private String createVectorNonNestedMappings(final int dimension) { + private String createVectorNonNestedMappings(final int dimension, String dataType) { XContentBuilder builder = XContentFactory.jsonBuilder() .startObject() .startObject(PROPERTIES_FIELD) .startObject(FIELD_NAME) .field(TYPE, TYPE_KNN_VECTOR) - .field(DIMENSION, dimension) - .endObject() - .endObject() - .endObject(); + .field(DIMENSION, dimension); + if (dataType != null) { + builder.field(VECTOR_DATA_TYPE_FIELD, dataType); + } + builder.endObject().endObject().endObject(); return builder.toString(); } @SneakyThrows - private String createVectorNestedMappings(final int dimension) { + private String createVectorNestedMappings(final int dimension, String dataType) { XContentBuilder builder = XContentFactory.jsonBuilder() .startObject() .startObject(PROPERTIES_FIELD) @@ -1240,13 +1489,11 @@ private String createVectorNestedMappings(final int dimension) { .startObject(PROPERTIES_FIELD) .startObject(FIELD_NAME) .field(TYPE, TYPE_KNN_VECTOR) - .field(DIMENSION, dimension) - .endObject() - .endObject() - .endObject() - .endObject() - .endObject(); - + .field(DIMENSION, dimension); + if (dataType != null) { + builder.field(VECTOR_DATA_TYPE_FIELD, dataType); + } + builder.endObject().endObject().endObject().endObject().endObject(); return builder.toString(); } } diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index ed2ffc54a6..44b72599bb 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -791,7 +791,7 @@ protected void addDocWithBinaryField(String index, String docId, String fieldNam /** * Update a KNN Doc with a new vector for the given fieldName */ - protected void updateKnnDoc(String index, String docId, String fieldName, Object[] vector) throws IOException { + protected void updateKnnDoc(String index, String docId, String fieldName, T vector) throws IOException { Request request = new Request("POST", "/" + index + "/_doc/" + docId + "?refresh=true"); XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); String parent = ParentChildHelper.getParentField(fieldName); @@ -811,7 +811,7 @@ protected void updateKnnDoc(String index, String docId, String fieldName, Object /** * Update a KNN Doc using the POST /\/_update/\. Only the vector field will be updated. */ - protected void updateKnnDocWithUpdateAPI(String index, String docId, String fieldName, Object[] vector) throws IOException { + protected void updateKnnDocWithUpdateAPI(String index, String docId, String fieldName, T vector) throws IOException { Request request = new Request("POST", "/" + index + "/_update/" + docId + "?refresh=true"); XContentBuilder builder = XContentFactory.jsonBuilder().startObject().startObject("doc"); String parent = ParentChildHelper.getParentField(fieldName); @@ -826,7 +826,7 @@ protected void updateKnnDocWithUpdateAPI(String index, String docId, String fiel assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); } - protected void updateKnnDocByQuery(String index, String docId, String fieldName, Object[] vector) throws IOException { + protected void updateKnnDocByQuery(String index, String docId, String fieldName, T vector) throws IOException { Request request = new Request("POST", "/" + index + "/_update_by_query?refresh=true"); XContentBuilder builder = XContentFactory.jsonBuilder() .startObject() @@ -1382,23 +1382,33 @@ public Map xContentBuilderToMap(XContentBuilder xContentBuilder) } public void bulkIngestRandomVectors(String indexName, String fieldName, int numVectors, int dimension) throws IOException { - // TODO: Do better on this one - float[][] vectors = TestUtils.randomlyGenerateStandardVectors(numVectors, dimension, 1); - for (int i = 0; i < numVectors; i++) { - float[] vector = vectors[i]; - addKnnDoc(indexName, String.valueOf(i + 1), fieldName, Floats.asList(vector).toArray()); - } + bulkIngestRandomVectorsWithSkips(indexName, fieldName, numVectors, dimension, 32, 0.0f); } - public void bulkIngestRandomVectorsWithSkips(String indexName, String fieldName, int numVectors, int dimension, float skipProb) - throws IOException { - float[][] vectors = TestUtils.randomlyGenerateStandardVectors(numVectors, dimension, 1); + public void bulkIngestRandomVectorsWithSkips( + String indexName, + String fieldName, + int numVectors, + int dimension, + int bitsPerDimension, + float skipProb + ) throws IOException { + float[][] floatVectors = null; + int[][] intVectors = null; + if (bitsPerDimension == 32) { + floatVectors = TestUtils.randomlyGenerateStandardVectors(numVectors, dimension, 1); + } else if (bitsPerDimension == 8) { + intVectors = TestUtils.randomlyGenerateStandardVectors(numVectors, dimension, 1, 1); + } else { + intVectors = TestUtils.randomlyGenerateStandardVectors(numVectors, dimension, 8, 1); + } + Random random = new Random(); random.setSeed(2); for (int i = 0; i < numVectors; i++) { - float[] vector = vectors[i]; + Object vector = floatVectors == null ? intVectors[i] : floatVectors[i]; if (random.nextFloat() > skipProb) { - addKnnDoc(indexName, String.valueOf(i + 1), fieldName, Floats.asList(vector).toArray()); + addKnnDoc(indexName, String.valueOf(i + 1), fieldName, vector); } else { addDocWithNumericField(indexName, String.valueOf(i + 1), "numeric-field", 1); } @@ -1635,6 +1645,17 @@ public void bulkIngestRandomVectorsWithNestedField(String indexName, String nest } } + public int[] randomByteVector(int dimension, int dimPerByte) { + int numDims = dimension / dimPerByte; + byte[] byteVector = new byte[numDims]; + random().nextBytes(byteVector); + int[] vector = new int[numDims]; + for (int j = 0; j < numDims; j++) { + vector[j] = byteVector[j]; + } + return vector; + } + // Method that adds multiple documents into the index using Bulk API public void bulkAddKnnDocs(String index, String fieldName, float[][] indexVectors, int docCount) throws IOException { Request request = new Request("POST", "/_bulk"); diff --git a/src/testFixtures/java/org/opensearch/knn/TestUtils.java b/src/testFixtures/java/org/opensearch/knn/TestUtils.java index 6fd584aeee..840335fc8f 100644 --- a/src/testFixtures/java/org/opensearch/knn/TestUtils.java +++ b/src/testFixtures/java/org/opensearch/knn/TestUtils.java @@ -128,6 +128,24 @@ public static float[][] randomlyGenerateStandardVectors(int numVectors, int dime return standardVectors; } + // Generating vectors using random function with a seed which makes these vectors standard and generate same vectors for each run. + public static int[][] randomlyGenerateStandardVectors(int numVectors, int dimensions, int dimPerByte, int seed) { + int numDims = dimensions / dimPerByte; + int[][] standardVectors = new int[numVectors][numDims]; + Random rand = new Random(seed); + + for (int i = 0; i < numVectors; i++) { + byte[] byteVector = new byte[numDims]; + rand.nextBytes(byteVector); + int[] vector = new int[numDims]; + for (int j = 0; j < numDims; j++) { + vector[j] = byteVector[j]; + } + standardVectors[i] = vector; + } + return standardVectors; + } + public static float[][] generateRandomVectors(int numVectors, int dimensions) { float[][] randomVectors = new float[numVectors][dimensions];