diff --git a/CHANGELOG.md b/CHANGELOG.md index 555983d01..08d4b697e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Integrate Lucene Vector field with native engines to use KNNVectorFormat during segment creation [#1945](https://github.com/opensearch-project/k-NN/pull/1945) * Disallow a vector field to have an invalid character for a physical file name. [#1936] (https://github.com/opensearch-project/k-NN/pull/1936) * Fixed and abstracted functionality for allocating index memory [#1933](https://github.com/opensearch-project/k-NN/pull/1933) +* Fix graph merge stats size calculation [#1844](https://github.com/opensearch-project/k-NN/pull/1844) +* Disallow a vector field to have an invalid character for a physical file name. [#1936](https://github.com/opensearch-project/k-NN/pull/1936) ### Infrastructure ### Documentation ### Maintenance diff --git a/jni/src/faiss_index_service.cpp b/jni/src/faiss_index_service.cpp index 69866da76..b3d4d7e7a 100644 --- a/jni/src/faiss_index_service.cpp +++ b/jni/src/faiss_index_service.cpp @@ -83,7 +83,7 @@ jlong IndexService::initIndex( std::unordered_map parameters ) { // Create index using Faiss factory method - std::unique_ptr indexWriter(faissMethods->indexFactory(dim, indexDescription.c_str(), metric)); + std::unique_ptr index(faissMethods->indexFactory(dim, indexDescription.c_str(), metric)); // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread if(threadCount != 0) { @@ -91,17 +91,22 @@ jlong IndexService::initIndex( } // Add extra parameters that cant be configured with the index factory - SetExtraParameters(jniUtil, env, parameters, indexWriter.get()); + SetExtraParameters(jniUtil, env, parameters, index.get()); // Check that the index does not need to be trained - if(!indexWriter->is_trained) { + if(!index->is_trained) { throw std::runtime_error("Index is not trained"); } - std::unique_ptr idMap (faissMethods->indexIdMap(indexWriter.get())); + std::unique_ptr idMap (faissMethods->indexIdMap(index.get())); + //Makes sure the index is deleted when the destructor is called + idMap->own_fields = true; allocIndex(dynamic_cast(idMap->index), dim, numVectors); - indexWriter.release(); + + //Release the ownership so as to make sure not delete the underlying index that is created. The index is needed later + //in insert and write operations + index.release(); return reinterpret_cast(idMap.release()); } @@ -147,11 +152,8 @@ void IndexService::writeIndex( // Write the index to disk faissMethods->writeIndex(idMap.get(), indexPath.c_str()); } catch(std::exception &e) { - delete idMap->index; throw std::runtime_error("Failed to write index to disk"); } - // Free the memory used by the index - delete idMap->index; } BinaryIndexService::BinaryIndexService(std::unique_ptr faissMethods) : IndexService(std::move(faissMethods)) {} @@ -175,25 +177,29 @@ jlong BinaryIndexService::initIndex( std::unordered_map parameters ) { // Create index using Faiss factory method - std::unique_ptr indexWriter(faissMethods->indexBinaryFactory(dim, indexDescription.c_str())); - + std::unique_ptr index(faissMethods->indexBinaryFactory(dim, indexDescription.c_str())); // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread if(threadCount != 0) { omp_set_num_threads(threadCount); } // Add extra parameters that cant be configured with the index factory - SetExtraParameters(jniUtil, env, parameters, indexWriter.get()); + SetExtraParameters(jniUtil, env, parameters, index.get()); // Check that the index does not need to be trained - if(!indexWriter->is_trained) { + if(!index->is_trained) { throw std::runtime_error("Index is not trained"); } - std::unique_ptr idMap(faissMethods->indexBinaryIdMap(indexWriter.get())); + std::unique_ptr idMap(faissMethods->indexBinaryIdMap(index.get())); + //Makes sure the index is deleted when the destructor is called + idMap->own_fields = true; allocIndex(dynamic_cast(idMap->index), dim, numVectors); - indexWriter.release(); + + //Release the ownership so as to make sure not delete the underlying index that is created. The index is needed later + //in insert and write operations + index.release(); return reinterpret_cast(idMap.release()); } @@ -240,12 +246,8 @@ void BinaryIndexService::writeIndex( // Write the index to disk faissMethods->writeIndexBinary(idMap.get(), indexPath.c_str()); } catch(std::exception &e) { - delete idMap->index; throw std::runtime_error("Failed to write index to disk"); } - - // Free the memory used by the index - delete idMap->index; } } // namespace faiss_wrapper diff --git a/src/main/java/org/opensearch/knn/common/KNNVectorUtil.java b/src/main/java/org/opensearch/knn/common/KNNVectorUtil.java index 9381f73e8..93703f194 100644 --- a/src/main/java/org/opensearch/knn/common/KNNVectorUtil.java +++ b/src/main/java/org/opensearch/knn/common/KNNVectorUtil.java @@ -8,7 +8,7 @@ import lombok.AccessLevel; import lombok.NoArgsConstructor; -import java.util.ArrayList; +import java.util.List; import java.util.Objects; @NoArgsConstructor(access = AccessLevel.PRIVATE) @@ -45,17 +45,19 @@ public static boolean isZeroVector(float[] vector) { return true; } - /** - * Creates an int overflow safe arraylist. If there is an overflow it will create a list with default initial size - * @param batchSize size to allocate - * @return an arrayList + /* + * Converts an integer List to and array + * @param integerList + * @return null if list is null or empty, int[] otherwise */ - public static ArrayList createArrayList(long batchSize) { - try { - return new ArrayList<>(Math.toIntExact(batchSize)); - } catch (Exception exception) { - // No-op + public static int[] intListToArray(final List integerList) { + if (integerList == null || integerList.isEmpty()) { + return null; + } + int[] intArray = new int[integerList.size()]; + for (int i = 0; i < integerList.size(); i++) { + intArray[i] = integerList.get(i); } - return new ArrayList<>(); + return intArray; } } diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/VectorTransferIndexBuildStrategy.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategy.java similarity index 53% rename from src/main/java/org/opensearch/knn/index/codec/nativeindex/VectorTransferIndexBuildStrategy.java rename to src/main/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategy.java index 9ea61cffe..116de4460 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/VectorTransferIndexBuildStrategy.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategy.java @@ -6,50 +6,62 @@ package org.opensearch.knn.index.codec.nativeindex; import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; -import org.opensearch.knn.index.codec.transfer.OffHeapByteQuantizedVectorTransfer; +import org.opensearch.knn.index.codec.transfer.OffHeapByteVectorTransfer; import org.opensearch.knn.index.codec.transfer.OffHeapFloatVectorTransfer; -import org.opensearch.knn.index.codec.transfer.VectorTransfer; -import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues; +import org.opensearch.knn.index.codec.transfer.OffHeapVectorTransfer; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; import org.opensearch.knn.jni.JNIService; import java.io.IOException; import java.security.AccessController; import java.security.PrivilegedAction; +import java.util.ArrayList; +import java.util.List; import java.util.Map; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import static org.opensearch.knn.common.KNNConstants.MODEL_ID; +import static org.opensearch.knn.common.KNNVectorUtil.intListToArray; /** - * Transfers all vectors to offheap and then builds an index + * Transfers all vectors to off heap and then builds an index */ -final class VectorTransferIndexBuildStrategy implements NativeIndexBuildStrategy { +final class DefaultIndexBuildStrategy implements NativeIndexBuildStrategy { - private static VectorTransferIndexBuildStrategy INSTANCE = new VectorTransferIndexBuildStrategy(); + private static DefaultIndexBuildStrategy INSTANCE = new DefaultIndexBuildStrategy(); - public static VectorTransferIndexBuildStrategy getInstance() { + public static DefaultIndexBuildStrategy getInstance() { return INSTANCE; } - private VectorTransferIndexBuildStrategy() {} + private DefaultIndexBuildStrategy() {} public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVectorValues knnVectorValues) throws IOException { - // iterating it once to be safe - knnVectorValues.init(); - try (final VectorTransfer vectorTransfer = getVectorTransfer(indexInfo.getVectorDataType(), knnVectorValues)) { - vectorTransfer.transferBatch(); - assert !vectorTransfer.hasNext(); + knnVectorValues.init(); // to get bytesPerVector + int transferLimit = (int) Math.max(1, KNNSettings.getVectorStreamingMemoryLimit().getBytes() / knnVectorValues.bytesPerVector()); + try (final OffHeapVectorTransfer vectorTransfer = getVectorTransfer(indexInfo.getVectorDataType(), transferLimit)) { + + final List tranferredDocIds = new ArrayList<>(); + while (knnVectorValues.docId() != NO_MORE_DOCS) { + // append is true here so off heap memory buffer isn't overwritten + vectorTransfer.transfer(knnVectorValues.conditionalCloneVector(), true); + tranferredDocIds.add(knnVectorValues.docId()); + knnVectorValues.nextDoc(); + } + vectorTransfer.flush(true); final Map params = indexInfo.getParameters(); + long vectorAddress = vectorTransfer.getVectorAddress(); // Currently this is if else as there are only two cases, with more cases this will have to be made // more maintainable if (params.containsKey(MODEL_ID)) { AccessController.doPrivileged((PrivilegedAction) () -> { JNIService.createIndexFromTemplate( - vectorTransfer.getTransferredDocsIds(), - vectorTransfer.getVectorAddress(), + intListToArray(tranferredDocIds), + vectorAddress, knnVectorValues.dimension(), indexInfo.getIndexPath(), (byte[]) params.get(KNNConstants.MODEL_BLOB_PARAMETER), @@ -61,8 +73,8 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVector } else { AccessController.doPrivileged((PrivilegedAction) () -> { JNIService.createIndex( - vectorTransfer.getTransferredDocsIds(), - vectorTransfer.getVectorAddress(), + intListToArray(tranferredDocIds), + vectorAddress, knnVectorValues.dimension(), indexInfo.getIndexPath(), indexInfo.getParameters(), @@ -77,13 +89,13 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVector } } - private VectorTransfer getVectorTransfer(VectorDataType vectorDataType, KNNVectorValues knnVectorValues) throws IOException { + private OffHeapVectorTransfer getVectorTransfer(VectorDataType vectorDataType, int transferLimit) throws IOException { switch (vectorDataType) { case FLOAT: - return new OffHeapFloatVectorTransfer((KNNFloatVectorValues) knnVectorValues, knnVectorValues.totalLiveDocs()); + return (OffHeapVectorTransfer) new OffHeapFloatVectorTransfer(transferLimit); case BINARY: case BYTE: - return new OffHeapByteQuantizedVectorTransfer<>((KNNVectorValues) knnVectorValues, knnVectorValues.totalLiveDocs()); + return (OffHeapVectorTransfer) new OffHeapByteVectorTransfer(transferLimit); default: throw new IllegalArgumentException("Unsupported vector data type: " + vectorDataType); } diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategy.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategy.java index 4efb4c231..c8e119936 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategy.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategy.java @@ -5,21 +5,26 @@ package org.opensearch.knn.index.codec.nativeindex; +import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; -import org.opensearch.knn.index.codec.transfer.OffHeapByteQuantizedVectorTransfer; +import org.opensearch.knn.index.codec.transfer.OffHeapByteVectorTransfer; import org.opensearch.knn.index.codec.transfer.OffHeapFloatVectorTransfer; -import org.opensearch.knn.index.codec.transfer.VectorTransfer; +import org.opensearch.knn.index.codec.transfer.OffHeapVectorTransfer; import org.opensearch.knn.index.engine.KNNEngine; -import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; import org.opensearch.knn.jni.JNIService; import java.io.IOException; import java.security.AccessController; import java.security.PrivilegedAction; +import java.util.ArrayList; +import java.util.List; import java.util.Map; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.opensearch.knn.common.KNNVectorUtil.intListToArray; + /** * Iteratively builds the index. */ @@ -33,7 +38,8 @@ public static MemOptimizedNativeIndexBuildStrategy getInstance() { private MemOptimizedNativeIndexBuildStrategy() {} - public void buildAndWriteIndex(BuildIndexParams indexInfo, final KNNVectorValues knnVectorValues) throws IOException { + public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVectorValues knnVectorValues) throws IOException { + // Needed to make sure we dont get 0 dimensions while initializing index knnVectorValues.init(); KNNEngine engine = indexInfo.getKnnEngine(); Map indexParameters = indexInfo.getParameters(); @@ -48,18 +54,49 @@ public void buildAndWriteIndex(BuildIndexParams indexInfo, final KNNVectorValues ) ); - try (final VectorTransfer vectorTransfer = getVectorTransfer(indexInfo.getVectorDataType(), knnVectorValues)) { + int transferLimit = (int) Math.max(1, KNNSettings.getVectorStreamingMemoryLimit().getBytes() / knnVectorValues.bytesPerVector()); + try (final OffHeapVectorTransfer vectorTransfer = getVectorTransfer(indexInfo.getVectorDataType(), transferLimit)) { - while (vectorTransfer.hasNext()) { - vectorTransfer.transferBatch(); - long vectorAddress = vectorTransfer.getVectorAddress(); - int[] docs = vectorTransfer.getTransferredDocsIds(); + final List tranferredDocIds = new ArrayList<>(); + while (knnVectorValues.docId() != NO_MORE_DOCS) { + // append is false to be able to reuse the memory location + boolean transferred = vectorTransfer.transfer(knnVectorValues.conditionalCloneVector(), false); + tranferredDocIds.add(knnVectorValues.docId()); + if (transferred) { + // Insert vectors + long vectorAddress = vectorTransfer.getVectorAddress(); + AccessController.doPrivileged((PrivilegedAction) () -> { + JNIService.insertToIndex( + intListToArray(tranferredDocIds), + vectorAddress, + knnVectorValues.dimension(), + indexParameters, + indexMemoryAddress, + engine + ); + return null; + }); + tranferredDocIds.clear(); + } + knnVectorValues.nextDoc(); + } - // Insert vectors + boolean flush = vectorTransfer.flush(false); + // Need to make sure that the flushed vectors are indexed + if (flush) { + long vectorAddress = vectorTransfer.getVectorAddress(); AccessController.doPrivileged((PrivilegedAction) () -> { - JNIService.insertToIndex(docs, vectorAddress, knnVectorValues.dimension(), indexParameters, indexMemoryAddress, engine); + JNIService.insertToIndex( + intListToArray(tranferredDocIds), + vectorAddress, + knnVectorValues.dimension(), + indexParameters, + indexMemoryAddress, + engine + ); return null; }); + tranferredDocIds.clear(); } // Write vector @@ -73,14 +110,13 @@ public void buildAndWriteIndex(BuildIndexParams indexInfo, final KNNVectorValues } } - // TODO: Will probably need a factory once quantization is added - private VectorTransfer getVectorTransfer(VectorDataType vectorDataType, KNNVectorValues knnVectorValues) throws IOException { + private OffHeapVectorTransfer getVectorTransfer(final VectorDataType vectorDataType, final int transferLimit) { switch (vectorDataType) { case FLOAT: - return new OffHeapFloatVectorTransfer((KNNFloatVectorValues) knnVectorValues); + return (OffHeapVectorTransfer) new OffHeapFloatVectorTransfer(transferLimit); case BINARY: case BYTE: - return new OffHeapByteQuantizedVectorTransfer<>((KNNVectorValues) knnVectorValues); + return (OffHeapVectorTransfer) new OffHeapByteVectorTransfer(transferLimit); default: throw new IllegalArgumentException("Unsupported vector data type: " + vectorDataType); } diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java index 92d29b9ba..8a565785e 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java @@ -75,7 +75,7 @@ public static NativeIndexWriter getWriter(final FieldInfo fieldInfo, SegmentWrit if (iterative) { return new NativeIndexWriter(state, fieldInfo, MemOptimizedNativeIndexBuildStrategy.getInstance()); } - return new NativeIndexWriter(state, fieldInfo, VectorTransferIndexBuildStrategy.getInstance()); + return new NativeIndexWriter(state, fieldInfo, DefaultIndexBuildStrategy.getInstance()); } /** diff --git a/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapBinaryQuantizedVectorTransfer.java b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapBinaryQuantizedVectorTransfer.java deleted file mode 100644 index ff346a810..000000000 --- a/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapBinaryQuantizedVectorTransfer.java +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.knn.index.codec.transfer; - -import org.apache.commons.lang.StringUtils; -import org.opensearch.knn.index.vectorvalues.KNNVectorValues; - -import java.io.IOException; -import java.util.List; - -/** - * Transfer quantized binary vectors to off heap memory - * The reason this is different from {@link OffHeapByteQuantizedVectorTransfer} is because of allocation and deallocation - * of memory on JNI layer. Use this if unsigned int is needed on JNI layer - */ -public final class OffHeapBinaryQuantizedVectorTransfer extends OffHeapQuantizedVectorTransfer { - - public OffHeapBinaryQuantizedVectorTransfer(KNNVectorValues vectorValues, Long batchSize) { - super(vectorValues, batchSize, (vector, state) -> (byte[]) vector, StringUtils.EMPTY, DEFAULT_COMPRESSION_FACTOR); - } - - public OffHeapBinaryQuantizedVectorTransfer(KNNVectorValues vectorValues) { - this(vectorValues, null); - } - - @Override - public void close() { - super.close(); - // TODO: deallocate the memory location - } - - @Override - protected long transfer(List vectorsToTransfer, boolean append) throws IOException { - // TODO: call to JNIService to transfer vector - return 0L; - } -} diff --git a/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapBinaryVectorTransfer.java b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapBinaryVectorTransfer.java new file mode 100644 index 000000000..9271e8f9d --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapBinaryVectorTransfer.java @@ -0,0 +1,33 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.transfer; + +import java.io.IOException; +import java.util.List; + +/** + * Transfer quantized binary vectors to off heap memory + * The reason this is different from {@link OffHeapByteVectorTransfer} is because of allocation and deallocation + * of memory on JNI layer. Use this if unsigned int is needed on JNI layer + */ +public final class OffHeapBinaryVectorTransfer extends OffHeapVectorTransfer { + + public OffHeapBinaryVectorTransfer(int transferLimit) { + super(transferLimit); + } + + @Override + public void close() { + super.close(); + // TODO: deallocate the memory location + } + + @Override + protected long transfer(List vectorsToTransfer, boolean append) throws IOException { + // TODO: call to JNIService to transfer vector + return 0L; + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapByteQuantizedVectorTransfer.java b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapByteQuantizedVectorTransfer.java deleted file mode 100644 index e5c8d3e12..000000000 --- a/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapByteQuantizedVectorTransfer.java +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.knn.index.codec.transfer; - -import org.apache.commons.lang.StringUtils; -import org.opensearch.knn.index.vectorvalues.KNNVectorValues; -import org.opensearch.knn.jni.JNICommons; - -import java.io.IOException; -import java.util.List; - -/** - * Transfer quantized byte vectors to off heap memory. - * The reason this is different from {@link OffHeapBinaryQuantizedVectorTransfer} is because of allocation and deallocation - * of memory on JNI layer. Use this if signed int is needed on JNI layer - */ -public final class OffHeapByteQuantizedVectorTransfer extends OffHeapQuantizedVectorTransfer { - - public OffHeapByteQuantizedVectorTransfer(KNNVectorValues vectorValues, final Long batchSize) throws IOException { - super(vectorValues, batchSize, (vector, state) -> (byte[]) vector, StringUtils.EMPTY, DEFAULT_COMPRESSION_FACTOR); - } - - public OffHeapByteQuantizedVectorTransfer(KNNVectorValues vectorValues) throws IOException { - this(vectorValues, null); - } - - @Override - protected long transfer(List batch, boolean append) throws IOException { - return JNICommons.storeByteVectorData(getVectorAddress(), batch.toArray(new byte[][] {}), batchSize * batch.get(0).length, append); - } - - @Override - public void close() { - super.close(); - JNICommons.freeByteVectorData(getVectorAddress()); - } -} diff --git a/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapByteVectorTransfer.java b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapByteVectorTransfer.java new file mode 100644 index 000000000..c587dfd30 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapByteVectorTransfer.java @@ -0,0 +1,39 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.transfer; + +import org.opensearch.knn.jni.JNICommons; + +import java.io.IOException; +import java.util.List; + +/** + * Transfer quantized byte vectors to off heap memory. + * The reason this is different from {@link OffHeapBinaryVectorTransfer} is because of allocation and deallocation + * of memory on JNI layer. Use this if signed int is needed on JNI layer + */ +public final class OffHeapByteVectorTransfer extends OffHeapVectorTransfer { + + public OffHeapByteVectorTransfer(int transferLimit) { + super(transferLimit); + } + + @Override + protected long transfer(List batch, boolean append) throws IOException { + return JNICommons.storeByteVectorData( + getVectorAddress(), + batch.toArray(new byte[][] {}), + (long) batch.get(0).length * transferLimit, + append + ); + } + + @Override + public void close() { + super.close(); + JNICommons.freeByteVectorData(getVectorAddress()); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapFloatVectorTransfer.java b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapFloatVectorTransfer.java index 66246494e..af8fb4165 100644 --- a/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapFloatVectorTransfer.java +++ b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapFloatVectorTransfer.java @@ -5,8 +5,6 @@ package org.opensearch.knn.index.codec.transfer; -import org.apache.commons.lang.StringUtils; -import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues; import org.opensearch.knn.jni.JNICommons; import java.io.IOException; @@ -15,14 +13,10 @@ /** * Transfer float vectors to off heap memory. */ -public final class OffHeapFloatVectorTransfer extends OffHeapQuantizedVectorTransfer { +public final class OffHeapFloatVectorTransfer extends OffHeapVectorTransfer { - public OffHeapFloatVectorTransfer(KNNFloatVectorValues vectorValues, Long batchSize) throws IOException { - super(vectorValues, batchSize, (vector, state) -> vector, StringUtils.EMPTY, DEFAULT_COMPRESSION_FACTOR); - } - - public OffHeapFloatVectorTransfer(KNNFloatVectorValues vectorValues) throws IOException { - this(vectorValues, null); + public OffHeapFloatVectorTransfer(int transferLimit) { + super(transferLimit); } @Override @@ -30,7 +24,7 @@ protected long transfer(final List vectorsToTransfer, boolean append) t return JNICommons.storeVectorData( getVectorAddress(), vectorsToTransfer.toArray(new float[][] {}), - this.batchSize * vectorsToTransfer.get(0).length, + (long) vectorsToTransfer.get(0).length * this.transferLimit, append ); } diff --git a/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapQuantizedVectorTransfer.java b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapQuantizedVectorTransfer.java deleted file mode 100644 index 7ddd93620..000000000 --- a/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapQuantizedVectorTransfer.java +++ /dev/null @@ -1,124 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.knn.index.codec.transfer; - -import lombok.Getter; -import org.opensearch.knn.index.KNNSettings; -import org.opensearch.knn.index.vectorvalues.KNNVectorValues; - -import java.io.IOException; -import java.util.List; -import java.util.function.BiFunction; - -import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; -import static org.opensearch.knn.common.KNNVectorUtil.createArrayList; - -/** - * The class is intended to transfer {@link KNNVectorValues} to off heap memory. If also provides and ability to quantize the vector - * before it is transferred to offHeap memory. The class is not thread safe - * - * @param an array of primitive type - * @param an array of primitive type after being quantized - */ -abstract class OffHeapQuantizedVectorTransfer implements VectorTransfer { - - protected static final int DEFAULT_COMPRESSION_FACTOR = 1; - - @Getter - private long vectorAddress; - @Getter - private int[] transferredDocsIds; - private final int transferLimit; - // Keeping this as a member variable as this should not be changed considering the vector address is reused between batches - protected long batchSize; - - private final List vectorsToTransfer; - private final List transferredDocIdsList; - - private final KNNVectorValues vectorValues; - - // TODO: Replace with actual quantization parameters - private final BiFunction quantizer; - private final String quantizationState; - - public OffHeapQuantizedVectorTransfer( - final KNNVectorValues vectorValues, - final Long batchSize, - final BiFunction quantizer, - final String quantizationState, - final int compressionFactor - ) { - assert vectorValues.docId() != -1 : "vectorValues docId must be set, iterate it once for vector transfer to succeed"; - assert vectorValues.docId() != NO_MORE_DOCS : "vectorValues already iterated, Nothing to transfer"; - - this.quantizer = quantizer; - this.quantizationState = quantizationState; - this.transferLimit = (int) Math.max( - 1, - (int) KNNSettings.getVectorStreamingMemoryLimit().getBytes() / (vectorValues.bytesPerVector() / compressionFactor) - ); - this.batchSize = batchSize == null ? transferLimit : batchSize; - this.vectorsToTransfer = createArrayList(this.batchSize); - this.transferredDocIdsList = createArrayList(this.batchSize); - this.vectorValues = vectorValues; - this.vectorAddress = 0; // we can allocate initial memory here, currently storeVectorData takes care of it - } - - @Override - public void transferBatch() throws IOException { - if (vectorValues.docId() == NO_MORE_DOCS) { - // Throwing instead of returning so there is no way client can go into an infinite loop - throw new IllegalStateException("No more vectors available to transfer"); - } - - assert vectorsToTransfer.isEmpty() : "Last batch wasn't transferred"; - assert transferredDocIdsList.isEmpty() : "Last batch wasn't transferred"; - - int totalDocsTransferred = 0; - boolean freshBatch = true; - - // TODO: Create non-final QuantizationOutput once here and then reuse the output - while (vectorValues.docId() != NO_MORE_DOCS && totalDocsTransferred < batchSize) { - V quantizedVector = quantizer.apply(vectorValues.conditionalCloneVector(), quantizationState); - - transferredDocIdsList.add(vectorValues.docId()); - vectorsToTransfer.add(quantizedVector); - if (vectorsToTransfer.size() == transferLimit) { - vectorAddress = transfer(vectorsToTransfer, !freshBatch); - vectorsToTransfer.clear(); - freshBatch = false; - } - vectorValues.nextDoc(); - totalDocsTransferred++; - } - - // Handle batchSize < transferLimit - if (!vectorsToTransfer.isEmpty()) { - vectorAddress = transfer(vectorsToTransfer, !freshBatch); - vectorsToTransfer.clear(); - } - - this.transferredDocsIds = new int[transferredDocIdsList.size()]; - for (int i = 0; i < transferredDocIdsList.size(); i++) { - transferredDocsIds[i] = transferredDocIdsList.get(i); - } - transferredDocIdsList.clear(); - } - - @Override - public boolean hasNext() { - return vectorValues.docId() != NO_MORE_DOCS; - } - - @Override - public void close() { - transferredDocIdsList.clear(); - transferredDocsIds = null; - vectorAddress = 0; - } - - protected abstract long transfer(final List vectorsToTransfer, final boolean append) throws IOException; -} diff --git a/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransfer.java b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransfer.java new file mode 100644 index 000000000..0d357d909 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransfer.java @@ -0,0 +1,67 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.transfer; + +import lombok.Getter; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; + +import java.io.Closeable; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +/** + *

+ * The class is intended to transfer {@link KNNVectorValues} to off heap memory. + * It also provides and ability to quantize the vector before it is transferred to offHeap memory. + * The ability to quantize is added as to not iterate KNN {@link KNNVectorValues} multiple times. + *

+ * + *

+ * The class is not thread safe. + *

+ * + * @param byte[] or float[] + */ +public abstract class OffHeapVectorTransfer implements Closeable { + + @Getter + private long vectorAddress; + protected final int transferLimit; + + private final List vectorsToTransfer; + + public OffHeapVectorTransfer(final int transferLimit) { + this.transferLimit = transferLimit; + this.vectorsToTransfer = new ArrayList<>(transferLimit); + this.vectorAddress = 0; + } + + public boolean transfer(T vector, boolean append) throws IOException { + vectorsToTransfer.add(vector); + if (vectorsToTransfer.size() == this.transferLimit) { + vectorAddress = transfer(vectorsToTransfer, append); + vectorsToTransfer.clear(); + return true; + } + return false; + } + + public boolean flush(boolean append) throws IOException { + // flush before closing + if (!vectorsToTransfer.isEmpty()) { + vectorAddress = transfer(vectorsToTransfer, append); + return true; + } + return false; + } + + public void close() { + vectorAddress = 0; + } + + protected abstract long transfer(final List vectorsToTransfer, boolean append) throws IOException; +} diff --git a/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransfer.java b/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransfer.java deleted file mode 100644 index fd76a7861..000000000 --- a/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransfer.java +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.knn.index.codec.transfer; - -import java.io.Closeable; -import java.io.IOException; - -/** - * An interface to transfer vectors from one memory location to another - * Class is Closeable to be able to release memory once done - */ -public interface VectorTransfer extends Closeable { - - /** - * Transfer a batch of vectors from one location to another - * The batch size here is intended to be constant for multiple transfers so should be encapsulated in the - * implementation. A new batch size should require another instance - * @throws IOException - */ - void transferBatch() throws IOException; - - /** - * Indicates if there are more vectors to transfer - * @return - */ - boolean hasNext(); - - /** - * Gives the docIds for transfered vectors - * @return - */ - int[] getTransferredDocsIds(); - - /** - * @return the memory address of the vectors transferred - */ - long getVectorAddress(); -} diff --git a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNFloatVectorValues.java b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNFloatVectorValues.java index d11739ee6..9268aa4e5 100644 --- a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNFloatVectorValues.java +++ b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNFloatVectorValues.java @@ -35,7 +35,6 @@ public float[] conditionalCloneVector() throws IOException { if (vectorValuesIterator.getDocIdSetIterator() instanceof FloatVectorValues) { return Arrays.copyOf(vector, vector.length); } - return vector; } } diff --git a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValues.java b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValues.java index c2da8cde1..765c41c3f 100644 --- a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValues.java +++ b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValues.java @@ -59,9 +59,7 @@ public void init() throws IOException { * @return T an array of byte[], float[] Or a deep copy of it * @throws IOException */ - public T conditionalCloneVector() throws IOException { - return getVector(); - } + public abstract T conditionalCloneVector() throws IOException; /** * Dimension of vector is returned. Do call getVector function first before calling this function otherwise you will get 0 value. diff --git a/src/main/java/org/opensearch/knn/indices/ModelUtil.java b/src/main/java/org/opensearch/knn/indices/ModelUtil.java index 80fa7c835..9a7228add 100644 --- a/src/main/java/org/opensearch/knn/indices/ModelUtil.java +++ b/src/main/java/org/opensearch/knn/indices/ModelUtil.java @@ -16,6 +16,8 @@ import java.util.Locale; +import lombok.experimental.UtilityClass; + /** * A utility class for models. */ diff --git a/src/test/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategyTests.java b/src/test/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategyTests.java new file mode 100644 index 000000000..ae4b90022 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategyTests.java @@ -0,0 +1,133 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.nativeindex; + +import lombok.SneakyThrows; +import org.apache.lucene.index.DocsWithFieldSet; +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.MockedStatic; +import org.mockito.Mockito; +import org.opensearch.core.common.unit.ByteSizeValue; +import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; +import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; +import org.opensearch.knn.jni.JNIService; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.Map; + +import static org.mockito.ArgumentMatchers.eq; + +public class DefaultIndexBuildStrategyTests extends OpenSearchTestCase { + + private ArgumentCaptor vectorAddressCaptor = ArgumentCaptor.forClass(Long.class); + + @Before + public void init() { + vectorAddressCaptor = ArgumentCaptor.forClass(Long.class); + } + + @SneakyThrows + public void testBuildAndWrite() { + // Given + final Map docs = Map.of(0, new float[] { 1, 2 }, 1, new float[] { 2, 3 }, 2, new float[] { 3, 4 }); + DocsWithFieldSet docsWithFieldSet = new DocsWithFieldSet(); + docs.keySet().stream().sorted().forEach(docsWithFieldSet::add); + + KNNFloatVectorValues knnVectorValues = (KNNFloatVectorValues) KNNVectorValuesFactory.getVectorValues( + VectorDataType.FLOAT, + docsWithFieldSet, + docs + ); + try ( + MockedStatic mockedKNNSettings = Mockito.mockStatic(KNNSettings.class); + MockedStatic mockedJNIService = Mockito.mockStatic(JNIService.class) + ) { + + mockedKNNSettings.when(KNNSettings::getVectorStreamingMemoryLimit).thenReturn(new ByteSizeValue(16)); + mockedJNIService.when(() -> JNIService.initIndexFromScratch(3, 2, Map.of("index", "param"), KNNEngine.NMSLIB)).thenReturn(100L); + + BuildIndexParams buildIndexParams = BuildIndexParams.builder() + .indexPath("indexPath") + .knnEngine(KNNEngine.NMSLIB) + .vectorDataType(VectorDataType.FLOAT) + .parameters(Map.of("index", "param")) + .build(); + + // When + DefaultIndexBuildStrategy.getInstance().buildAndWriteIndex(buildIndexParams, knnVectorValues); + + // Then + mockedJNIService.verify( + () -> JNIService.createIndex( + eq(new int[] { 0, 1, 2 }), + vectorAddressCaptor.capture(), + eq(knnVectorValues.dimension()), + eq("indexPath"), + eq(Map.of("index", "param")), + eq(KNNEngine.NMSLIB) + ) + ); + mockedJNIService.verifyNoMoreInteractions(); + assertNotEquals(0L, vectorAddressCaptor.getValue().longValue()); + } + } + + @SneakyThrows + public void testBuildAndWriteWithModel() { + // Given + final Map docs = Map.of(0, new float[] { 1, 2 }, 1, new float[] { 2, 3 }, 2, new float[] { 3, 4 }); + DocsWithFieldSet docsWithFieldSet = new DocsWithFieldSet(); + docs.keySet().stream().sorted().forEach(docsWithFieldSet::add); + + byte[] modelBlob = new byte[] { 1 }; + + KNNFloatVectorValues knnVectorValues = (KNNFloatVectorValues) KNNVectorValuesFactory.getVectorValues( + VectorDataType.FLOAT, + docsWithFieldSet, + docs + ); + try ( + MockedStatic mockedKNNSettings = Mockito.mockStatic(KNNSettings.class); + MockedStatic mockedJNIService = Mockito.mockStatic(JNIService.class) + ) { + + mockedKNNSettings.when(KNNSettings::getVectorStreamingMemoryLimit).thenReturn(new ByteSizeValue(16)); + mockedJNIService.when( + () -> JNIService.initIndexFromScratch(3, 2, Map.of("model_id", "id", "model_blob", modelBlob), KNNEngine.NMSLIB) + ).thenReturn(100L); + + BuildIndexParams buildIndexParams = BuildIndexParams.builder() + .indexPath("indexPath") + .knnEngine(KNNEngine.NMSLIB) + .vectorDataType(VectorDataType.FLOAT) + .parameters(Map.of("model_id", "id", "model_blob", modelBlob)) + .build(); + + // When + DefaultIndexBuildStrategy.getInstance().buildAndWriteIndex(buildIndexParams, knnVectorValues); + + // Then + mockedJNIService.verify( + () -> JNIService.createIndexFromTemplate( + eq(new int[] { 0, 1, 2 }), + vectorAddressCaptor.capture(), + eq(2), + eq("indexPath"), + eq(modelBlob), + eq(Map.of("model_id", "id", "model_blob", modelBlob)), + eq(KNNEngine.NMSLIB) + ) + ); + mockedJNIService.verifyNoMoreInteractions(); + assertNotEquals(0L, vectorAddressCaptor.getValue().longValue()); + } + } +} diff --git a/src/test/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategyTests.java b/src/test/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategyTests.java new file mode 100644 index 000000000..001676f5b --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategyTests.java @@ -0,0 +1,103 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.nativeindex; + +import lombok.SneakyThrows; +import org.apache.lucene.index.DocsWithFieldSet; +import org.mockito.ArgumentCaptor; +import org.mockito.MockedStatic; +import org.mockito.Mockito; +import org.opensearch.core.common.unit.ByteSizeValue; +import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; +import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; +import org.opensearch.knn.jni.JNIService; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.Map; + +import static org.mockito.ArgumentMatchers.eq; + +public class MemOptimizedNativeIndexBuildStrategyTests extends OpenSearchTestCase { + + @SneakyThrows + public void testBuildAndWrite() { + // Given + ArgumentCaptor vectorAddressCaptor = ArgumentCaptor.forClass(Long.class); + + final Map docs = Map.of(0, new float[] { 1, 2 }, 1, new float[] { 2, 3 }, 2, new float[] { 3, 4 }); + DocsWithFieldSet docsWithFieldSet = new DocsWithFieldSet(); + docs.keySet().stream().sorted().forEach(docsWithFieldSet::add); + + KNNFloatVectorValues knnVectorValues = (KNNFloatVectorValues) KNNVectorValuesFactory.getVectorValues( + VectorDataType.FLOAT, + docsWithFieldSet, + docs + ); + try ( + MockedStatic mockedKNNSettings = Mockito.mockStatic(KNNSettings.class); + MockedStatic mockedJNIService = Mockito.mockStatic(JNIService.class) + ) { + + // Limits transfer to 2 vectors + mockedKNNSettings.when(KNNSettings::getVectorStreamingMemoryLimit).thenReturn(new ByteSizeValue(16)); + mockedJNIService.when(() -> JNIService.initIndexFromScratch(3, 2, Map.of("index", "param"), KNNEngine.FAISS)).thenReturn(100L); + + BuildIndexParams buildIndexParams = BuildIndexParams.builder() + .indexPath("indexPath") + .knnEngine(KNNEngine.FAISS) + .vectorDataType(VectorDataType.FLOAT) + .parameters(Map.of("index", "param")) + .build(); + + // When + MemOptimizedNativeIndexBuildStrategy.getInstance().buildAndWriteIndex(buildIndexParams, knnVectorValues); + + mockedJNIService.verify( + () -> JNIService.initIndexFromScratch( + knnVectorValues.totalLiveDocs(), + knnVectorValues.dimension(), + Map.of("index", "param"), + KNNEngine.FAISS + ) + ); + + // Then + mockedJNIService.verify( + () -> JNIService.insertToIndex( + eq(new int[] { 0, 1 }), + vectorAddressCaptor.capture(), + eq(knnVectorValues.dimension()), + eq(Map.of("index", "param")), + eq(100L), + eq(KNNEngine.FAISS) + ) + ); + + // For the flush + mockedJNIService.verify( + () -> JNIService.insertToIndex( + eq(new int[] { 2 }), + vectorAddressCaptor.capture(), + eq(knnVectorValues.dimension()), + eq(Map.of("index", "param")), + eq(100L), + eq(KNNEngine.FAISS) + ) + ); + + mockedJNIService.verify( + () -> JNIService.writeIndex(eq("indexPath"), eq(100L), eq(KNNEngine.FAISS), eq(Map.of("index", "param"))) + ); + + assertNotEquals(0L, vectorAddressCaptor.getValue().longValue()); + assertEquals(vectorAddressCaptor.getValue().longValue(), vectorAddressCaptor.getAllValues().get(0).longValue()); + } + } +} diff --git a/src/test/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferTests.java b/src/test/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferTests.java index f4aed1049..e811ca363 100644 --- a/src/test/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferTests.java @@ -6,129 +6,84 @@ package org.opensearch.knn.index.codec.transfer; import lombok.SneakyThrows; -import org.apache.lucene.index.DocsWithFieldSet; -import org.junit.Before; -import org.mockito.Mock; -import org.opensearch.common.settings.ClusterSettings; -import org.opensearch.core.common.unit.ByteSizeValue; import org.opensearch.knn.KNNTestCase; -import org.opensearch.knn.index.KNNSettings; -import org.opensearch.knn.index.VectorDataType; -import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues; -import org.opensearch.knn.index.vectorvalues.KNNVectorValues; -import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; -import java.util.Map; - -import static org.mockito.Mockito.when; -import static org.opensearch.knn.index.KNNSettings.KNN_VECTOR_STREAMING_MEMORY_LIMIT_PCT_SETTING; +import java.util.List; public class OffHeapVectorTransferTests extends KNNTestCase { - @Mock - ClusterSettings clusterSettings; - - @Before - @Override - public void setUp() throws Exception { - super.setUp(); - - when(clusterService.getClusterSettings()).thenReturn(clusterSettings); - - KNNSettings.state().setClusterService(clusterService); - } - @SneakyThrows public void testFloatTransfer() { - // Given - when(clusterSettings.get(KNN_VECTOR_STREAMING_MEMORY_LIMIT_PCT_SETTING)).thenReturn(new ByteSizeValue(16)); - final Map docs = Map.of(0, new float[] { 1, 2 }, 1, new float[] { 2, 3 }, 2, new float[] { 3, 4 }); - DocsWithFieldSet docsWithFieldSet = new DocsWithFieldSet(); - docs.keySet().stream().sorted().forEach(docsWithFieldSet::add); - - //Transfer 1 vector - KNNFloatVectorValues knnVectorValues = (KNNFloatVectorValues) KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, docs); - knnVectorValues.nextDoc(); knnVectorValues.getVector(); - VectorTransfer vectorTransfer; - - //Transfer batch, limit == batch size - knnVectorValues = (KNNFloatVectorValues) KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, docs); - knnVectorValues.nextDoc(); knnVectorValues.getVector(); - vectorTransfer = new OffHeapFloatVectorTransfer(knnVectorValues); - testTransferBatchVectors(vectorTransfer, new int[][] { { 0, 1 }, { 2 } }, 2); - - //Transfer batch, limit < batch size - knnVectorValues = (KNNFloatVectorValues) KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, docs); - knnVectorValues.nextDoc(); knnVectorValues.getVector(); - vectorTransfer = new OffHeapFloatVectorTransfer(knnVectorValues, 5L); - vectorTransfer.transferBatch(); - assertNotEquals(0, vectorTransfer.getVectorAddress()); - assertArrayEquals(new int[] {0, 1, 2}, vectorTransfer.getTransferredDocsIds()); - - //Transfer batch, limit > batch size - knnVectorValues = (KNNFloatVectorValues) KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, docs); - knnVectorValues.nextDoc(); knnVectorValues.getVector(); - vectorTransfer = new OffHeapFloatVectorTransfer(knnVectorValues, 1L); - testTransferBatchVectors(vectorTransfer, new int[][] { { 0 }, { 1 }, { 2 } }, 3); + List vectors = List.of( + new float[] { 0.1f, 0.2f }, + new float[] { 0.2f, 0.3f }, + new float[] { 0.3f, 0.4f }, + new float[] { 0.3f, 0.4f }, + new float[] { 0.3f, 0.4f } + ); + + OffHeapFloatVectorTransfer vectorTransfer = new OffHeapFloatVectorTransfer(2); + long vectorAddress = 0; + assertFalse(vectorTransfer.transfer(vectors.get(0), false)); + assertEquals(0, vectorTransfer.getVectorAddress()); + assertTrue(vectorTransfer.transfer(vectors.get(1), false)); + vectorAddress = vectorTransfer.getVectorAddress(); + assertFalse(vectorTransfer.transfer(vectors.get(2), false)); + assertEquals(vectorAddress, vectorTransfer.getVectorAddress()); + assertTrue(vectorTransfer.transfer(vectors.get(3), false)); + assertEquals(vectorAddress, vectorTransfer.getVectorAddress()); + assertFalse(vectorTransfer.transfer(vectors.get(4), false)); + assertTrue(vectorTransfer.flush(false)); + vectorTransfer.close(); } @SneakyThrows public void testByteTransfer() { - // Given - when(clusterSettings.get(KNN_VECTOR_STREAMING_MEMORY_LIMIT_PCT_SETTING)).thenReturn(new ByteSizeValue(4)); - final Map docs = Map.of(0, new byte[] { 1, 2 }, 1, new byte[] { 2, 3 }, 2, new byte[] { 3, 4 }); - DocsWithFieldSet docsWithFieldSet = new DocsWithFieldSet(); - docs.keySet().stream().sorted().forEach(docsWithFieldSet::add); - - //Transfer 1 vector - KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.BYTE, docsWithFieldSet, docs); - knnVectorValues.nextDoc(); knnVectorValues.getVector(); - VectorTransfer vectorTransfer; - - //Transfer batch, limit == batch size - knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.BYTE, docsWithFieldSet, docs); - knnVectorValues.nextDoc(); knnVectorValues.getVector(); - vectorTransfer = new OffHeapByteQuantizedVectorTransfer<>(knnVectorValues); - testTransferBatchVectors(vectorTransfer, new int[][] { { 0, 1 }, { 2 } }, 2); - - //Transfer batch, limit < batch size - knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.BYTE, docsWithFieldSet, docs); - knnVectorValues.nextDoc(); knnVectorValues.getVector(); - vectorTransfer = new OffHeapByteQuantizedVectorTransfer<>(knnVectorValues, 5L); - vectorTransfer.transferBatch(); - assertNotEquals(0, vectorTransfer.getVectorAddress()); - assertArrayEquals(new int[] {0, 1, 2}, vectorTransfer.getTransferredDocsIds()); - - //Transfer batch, limit > batch size - knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.BYTE, docsWithFieldSet, docs); - knnVectorValues.nextDoc(); knnVectorValues.getVector(); - vectorTransfer = new OffHeapByteQuantizedVectorTransfer<>(knnVectorValues, 1L); - testTransferBatchVectors(vectorTransfer, new int[][] { { 0 }, { 1 }, { 2 } }, 3); + List vectors = List.of( + new byte[] { 0, 1 }, + new byte[] { 2, 3 }, + new byte[] { 4, 5 }, + new byte[] { 6, 7 }, + new byte[] { 8, 9 } + ); + + OffHeapByteVectorTransfer vectorTransfer = new OffHeapByteVectorTransfer(2); + long vectorAddress = 0; + assertFalse(vectorTransfer.transfer(vectors.get(0), false)); + assertEquals(0, vectorTransfer.getVectorAddress()); + assertTrue(vectorTransfer.transfer(vectors.get(1), false)); + vectorAddress = vectorTransfer.getVectorAddress(); + assertFalse(vectorTransfer.transfer(vectors.get(2), false)); + assertEquals(vectorAddress, vectorTransfer.getVectorAddress()); + assertTrue(vectorTransfer.transfer(vectors.get(3), false)); + assertEquals(vectorAddress, vectorTransfer.getVectorAddress()); + assertFalse(vectorTransfer.transfer(vectors.get(4), false)); + assertTrue(vectorTransfer.flush(false)); + vectorTransfer.close(); } - // TODO: Add a unit test for binary - @SneakyThrows - private void testTransferBatchVectors(VectorTransfer vectorTransfer, int[][] expectedDocIds, int expectedIterations) { - long vectorAddress = 0L; - try { - int iteration = 0; - while (vectorTransfer.hasNext()) { - vectorTransfer.transferBatch(); - if (iteration != 0) { - assertEquals("Vector address shouldn't be different", vectorAddress, vectorTransfer.getVectorAddress()); - } else { - assertEquals(0, vectorAddress); - vectorAddress = vectorTransfer.getVectorAddress(); - } - assertArrayEquals(expectedDocIds[iteration], vectorTransfer.getTransferredDocsIds()); - iteration++; - } - assertEquals(expectedIterations, iteration); - } finally { - vectorTransfer.close(); - assertEquals(vectorTransfer.getVectorAddress(), 0); - assertNull(vectorTransfer.getTransferredDocsIds()); - } + public void testBinaryTransfer() { + List vectors = List.of( + new byte[] { 0, 1 }, + new byte[] { 2, 3 }, + new byte[] { 4, 5 }, + new byte[] { 6, 7 }, + new byte[] { 8, 9 } + ); + + OffHeapBinaryVectorTransfer vectorTransfer = new OffHeapBinaryVectorTransfer(2); + long vectorAddress = 0; + assertFalse(vectorTransfer.transfer(vectors.get(0), false)); + assertEquals(0, vectorTransfer.getVectorAddress()); + assertTrue(vectorTransfer.transfer(vectors.get(1), false)); + vectorAddress = vectorTransfer.getVectorAddress(); + assertFalse(vectorTransfer.transfer(vectors.get(2), false)); + assertEquals(vectorAddress, vectorTransfer.getVectorAddress()); + assertTrue(vectorTransfer.transfer(vectors.get(3), false)); + assertEquals(vectorAddress, vectorTransfer.getVectorAddress()); + assertFalse(vectorTransfer.transfer(vectors.get(4), false)); + assertTrue(vectorTransfer.flush(false)); + vectorTransfer.close(); } }