Skip to content

Commit

Permalink
Add download + indexOuput#write implementation to RemoteIndexBuildStr…
Browse files Browse the repository at this point in the history
…ategy

Signed-off-by: Jay Deng <[email protected]>
  • Loading branch information
jed326 authored and Jay Deng committed Feb 26, 2025
1 parent 5873add commit 6187dfb
Show file tree
Hide file tree
Showing 6 changed files with 157 additions and 9 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
* [Remote Vector Index Build] Introduce Remote Native Index Build feature flag, settings, and initial skeleton [#2525](https://github.com/opensearch-project/k-NN/pull/2525)
* [Remote Vector Index Build] Implement vector data upload and vector data size threshold setting [#2550](https://github.com/opensearch-project/k-NN/pull/2550)
* [Remote Vector Index Build] Implement data download and IndexOutput write functionality [#2554](https://github.com/opensearch-project/k-NN/pull/2554)
### Enhancements
* Introduce node level circuit breakers for k-NN [#2509](https://github.com/opensearch-project/k-NN/pull/2509)
### Bug Fixes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,16 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.index.IndexSettings;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.store.IndexOutputWithBuffer;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
import org.opensearch.repositories.blobstore.BlobStoreRepository;

import java.io.BufferedInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;
Expand Down Expand Up @@ -210,4 +214,30 @@ private CheckedTriFunction<Integer, Long, Long, InputStreamContainer, IOExceptio
return new InputStreamContainer(vectorValuesInputStream, size, position);
});
}

@Override
public void readFromRepository(String path, IndexOutputWithBuffer indexOutputWithBuffer) throws IOException {
if (path == null || path.isEmpty()) {
throw new IllegalArgumentException("download path is null or empty");
}
Path downloadPath = Paths.get(path);
String fileName = downloadPath.getFileName().toString();
if (!fileName.endsWith(KNNEngine.FAISS.getExtension())) {
log.error("download path [{}] does not end with extension [{}}", downloadPath, KNNEngine.FAISS.getExtension());
throw new IllegalArgumentException("download path has incorrect file extension");
}

BlobPath blobContainerPath = new BlobPath();
if (downloadPath.getParent() != null) {
for (Path p : downloadPath.getParent()) {
blobContainerPath = blobContainerPath.add(p.getFileName().toString());
}
}

BlobContainer blobContainer = repository.blobStore().blobContainer(blobContainerPath);
// TODO: We are using the sequential download API as multi-part parallel download is difficult for us to implement today and
// requires some changes in core. For more details, see: https://github.com/opensearch-project/k-NN/issues/2464
InputStream graphStream = blobContainer.readBlob(fileName);
indexOutputWithBuffer.writeFromStreamWithBuffer(graphStream);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,12 @@ public void buildAndWriteIndex(BuildIndexParams indexInfo) throws IOException {
log.debug("Submit vector build took {} ms for vector field [{}]", time_in_millis, indexInfo.getFieldName());

stopWatch = new StopWatch().start();
awaitVectorBuild();
String downloadPath = awaitVectorBuild();
time_in_millis = stopWatch.stop().totalTime().millis();
log.debug("Await vector build took {} ms for vector field [{}]", time_in_millis, indexInfo.getFieldName());

stopWatch = new StopWatch().start();
vectorRepositoryAccessor.readFromRepository();
vectorRepositoryAccessor.readFromRepository(downloadPath, indexInfo.getIndexOutputWithBuffer());
time_in_millis = stopWatch.stop().totalTime().millis();
log.debug("Repository read took {} ms for vector field [{}]", time_in_millis, indexInfo.getFieldName());
} catch (Exception e) {
Expand Down Expand Up @@ -174,8 +174,9 @@ private void submitVectorBuild() {

/**
* Wait on remote vector build to complete
* @return String The path from which we should perform download, delimited by "/"
*/
private void awaitVectorBuild() {
private String awaitVectorBuild() throws NotImplementedException {
throw new NotImplementedException();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

package org.opensearch.knn.index.codec.nativeindex.remote;

import org.apache.commons.lang.NotImplementedException;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.store.IndexOutputWithBuffer;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;

import java.io.IOException;
Expand All @@ -15,7 +15,7 @@
/**
* Interface which dictates how we use we interact with a {@link org.opensearch.repositories.blobstore.BlobStoreRepository} from {@link RemoteIndexBuildStrategy}
*/
public interface VectorRepositoryAccessor {
interface VectorRepositoryAccessor {
/**
* This method is responsible for writing both the vector blobs and doc ids provided by {@param knnVectorValuesSupplier} to the configured repository
*
Expand All @@ -35,8 +35,9 @@ void writeToRepository(

/**
* Read constructed vector file from remote repository and write to IndexOutput
* @param path File path as String
* @param indexOutputWithBuffer {@link IndexOutputWithBuffer} which will be used to write to the underlying {@link org.apache.lucene.store.IndexOutput}
* @throws IOException
*/
default void readFromRepository() {
throw new NotImplementedException();
}
void readFromRepository(String path, IndexOutputWithBuffer indexOutputWithBuffer) throws IOException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,25 @@
import org.apache.lucene.store.IndexOutput;

import java.io.IOException;
import java.io.InputStream;

/**
* Wrapper around {@link IndexOutput} to perform writes in a buffered manner. This class is created per flush/merge, and may be used twice if
* {@link org.opensearch.knn.index.codec.nativeindex.remote.RemoteIndexBuildStrategy} needs to fall back to a different build strategy.
*/
public class IndexOutputWithBuffer {
// Underlying `IndexOutput` obtained from Lucene's Directory.
private IndexOutput indexOutput;
// Write buffer. Native engine will copy bytes into this buffer.
// Allocating 64KB here since it show better performance in NMSLIB with the size. (We had slightly improvement in FAISS than having 4KB)
// NMSLIB writes an adjacent list size first, then followed by serializing the list. Since we usually have more adjacent lists, having
// 64KB to accumulate bytes as possible to reduce the times of calling `writeBytes`.
private byte[] buffer = new byte[64 * 1024];
private static final int CHUNK_SIZE = 64 * 1024;
private final byte[] buffer;

public IndexOutputWithBuffer(IndexOutput indexOutput) {
this.indexOutput = indexOutput;
this.buffer = new byte[CHUNK_SIZE];
}

// This method will be called in JNI layer which precisely knows
Expand All @@ -33,6 +40,43 @@ public void writeBytes(int length) {
}
}

/**
* Writes to the {@link IndexOutput} by buffering bytes into the existing buffer in this class.
*
* @param inputStream The stream from which we are reading bytes to write
* @throws IOException
* @see IndexOutputWithBuffer#writeFromStreamWithBuffer(InputStream, byte[])
*/
public void writeFromStreamWithBuffer(InputStream inputStream) throws IOException {
writeFromStreamWithBuffer(inputStream, this.buffer);
}

/**
* Writes to the {@link IndexOutput} by buffering bytes with @param outputBuffer. This method allows
* {@link org.opensearch.knn.index.codec.nativeindex.remote.RemoteIndexBuildStrategy} to provide a separate, larger buffer as that buffer is for buffering
* bytes downloaded from the repository, so it may be more performant to use a larger buffer.
* We do not change the size of the existing buffer in case a fallback to the existing build strategy is needed.
* TODO: Tune the size of the buffer used by RemoteIndexBuildStrategy based on benchmarking
*
* @param inputStream The stream from which we are reading bytes to write
* @param outputBuffer The buffer used to buffer bytes
* @throws IOException
* @see IndexOutputWithBuffer#writeFromStreamWithBuffer(InputStream)
*/
private void writeFromStreamWithBuffer(InputStream inputStream, byte[] outputBuffer) throws IOException {
int bytesRead = 0;
// InputStream uses -1 indicates there are no more bytes to be read
while (bytesRead != -1) {
// Try to read CHUNK_SIZE into the buffer. The actual amount read may be less.
bytesRead = inputStream.read(outputBuffer, 0, CHUNK_SIZE);
assert bytesRead <= CHUNK_SIZE;
// However many bytes we read, write it to the IndexOutput if != -1
if (bytesRead != -1) {
indexOutput.writeBytes(outputBuffer, 0, bytesRead);
}
}
}

@Override
public String toString() {
return "{indexOutput=" + indexOutput + ", len(buffer)=" + buffer.length + "}";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@

package org.opensearch.knn.index.codec.nativeindex.remote;

import org.apache.lucene.store.Directory;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.mockito.Mockito;
import org.opensearch.common.blobstore.AsyncMultiStreamBlobContainer;
import org.opensearch.common.blobstore.BlobContainer;
Expand All @@ -13,12 +17,18 @@
import org.opensearch.common.blobstore.fs.FsBlobStore;
import org.opensearch.index.IndexSettings;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.store.IndexOutputWithBuffer;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
import org.opensearch.repositories.RepositoriesService;
import org.opensearch.repositories.blobstore.BlobStoreRepository;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Path;
import java.util.List;
import java.util.Random;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
Expand Down Expand Up @@ -129,4 +139,65 @@ public void testAsyncUploadThrowsException() throws InterruptedException, IOExce
verify(mockBlobStore).blobContainer(any());
verify(mockRepository).basePath();
}

/**
* Verify the buffered read method in {@link DefaultVectorRepositoryAccessor#readFromRepository} produces the correct result
*/
public void testRepositoryRead() throws IOException {
String TEST_FILE_NAME = randomAlphaOfLength(8) + KNNEngine.FAISS.getExtension();

// Create an InputStream with random values
int TEST_ARRAY_SIZE = 64 * 1024 * 10;
byte[] byteArray = new byte[TEST_ARRAY_SIZE];
Random random = new Random();
random.nextBytes(byteArray);
InputStream randomStream = new ByteArrayInputStream(byteArray);

// Create a test segment that we will read/write from
Directory directory;
directory = newFSDirectory(createTempDir());
String TEST_SEGMENT_NAME = "test-segment-name";
IndexOutput testIndexOutput = directory.createOutput(TEST_SEGMENT_NAME, IOContext.DEFAULT);
IndexOutputWithBuffer testIndexOutputWithBuffer = new IndexOutputWithBuffer(testIndexOutput);

// Set up RemoteIndexBuildStrategy and write to IndexOutput
RepositoriesService repositoriesService = mock(RepositoriesService.class);
BlobStoreRepository mockRepository = mock(BlobStoreRepository.class);
BlobPath testBasePath = new BlobPath().add("testBasePath");
BlobStore mockBlobStore = mock(BlobStore.class);
AsyncMultiStreamBlobContainer mockBlobContainer = mock(AsyncMultiStreamBlobContainer.class);

when(repositoriesService.repository(any())).thenReturn(mockRepository);
when(mockRepository.basePath()).thenReturn(testBasePath);
when(mockRepository.blobStore()).thenReturn(mockBlobStore);
when(mockBlobStore.blobContainer(any())).thenReturn(mockBlobContainer);
when(mockBlobContainer.readBlob(TEST_FILE_NAME)).thenReturn(randomStream);

VectorRepositoryAccessor objectUnderTest = new DefaultVectorRepositoryAccessor(mockRepository, mock(IndexSettings.class));

// Verify file extension check
assertThrows(IllegalArgumentException.class, () -> objectUnderTest.readFromRepository("test_file.txt", testIndexOutputWithBuffer));

// Now test with valid file extensions
String testPath = randomFrom(
List.of(
"testBasePath/testDirectory/" + TEST_FILE_NAME, // Test with subdirectory
"testBasePath/" + TEST_FILE_NAME, // Test with only base path
TEST_FILE_NAME // test with no base path
)
);
// This should read from randomStream into testIndexOutput
objectUnderTest.readFromRepository(testPath, testIndexOutputWithBuffer);
testIndexOutput.close();

// Now try to read from the IndexOutput
IndexInput testIndexInput = directory.openInput(TEST_SEGMENT_NAME, IOContext.DEFAULT);
byte[] resultByteArray = new byte[TEST_ARRAY_SIZE];
testIndexInput.readBytes(resultByteArray, 0, TEST_ARRAY_SIZE);
assertArrayEquals(byteArray, resultByteArray);

// Test Cleanup
testIndexInput.close();
directory.close();
}
}

0 comments on commit 6187dfb

Please sign in to comment.