diff --git a/CHANGELOG.md b/CHANGELOG.md index 70ec3ae510..9e7384e0d2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 3.0](https://github.com/opensearch-project/k-NN/compare/2.x...HEAD) ### Features +* [Remote Vector Index Build] Introduce Client Skeleton + basic Build Request implementation [#2560](https://github.com/opensearch-project/k-NN/pull/2560) ### Enhancements ### Bug Fixes ### Infrastructure diff --git a/build.gradle b/build.gradle index b5f7158471..4ee818c270 100644 --- a/build.gradle +++ b/build.gradle @@ -321,7 +321,11 @@ dependencies { api "net.java.dev.jna:jna-platform:5.13.0" // OpenSearch core is using slf4j 1.7.36. Therefore, we cannot change the version here. implementation 'org.slf4j:slf4j-api:1.7.36' - + api "org.apache.httpcomponents.client5:httpclient5:${versions.httpclient5}" + api "org.apache.httpcomponents.core5:httpcore5:${versions.httpcore5}" + api "org.apache.httpcomponents.core5:httpcore5-h2:${versions.httpcore5}" + api "com.fasterxml.jackson.core:jackson-databind:${versions.jackson_databind}" + api "com.fasterxml.jackson.core:jackson-annotations:${versions.jackson}" zipArchive group: 'org.opensearch.plugin', name:'opensearch-security', version: "${opensearch_build}" } diff --git a/src/main/java/org/opensearch/knn/index/KNNSettings.java b/src/main/java/org/opensearch/knn/index/KNNSettings.java index c33f3ea63c..dcfca234a2 100644 --- a/src/main/java/org/opensearch/knn/index/KNNSettings.java +++ b/src/main/java/org/opensearch/knn/index/KNNSettings.java @@ -15,10 +15,12 @@ import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.Booleans; +import org.opensearch.common.settings.SecureSetting; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.settings.SecureString; import org.opensearch.core.common.unit.ByteSizeUnit; import org.opensearch.core.common.unit.ByteSizeValue; import org.opensearch.index.IndexModule; @@ -96,6 +98,11 @@ public class KNNSettings { public static final String KNN_DERIVED_SOURCE_ENABLED = "index.knn.derived_source.enabled"; public static final String KNN_INDEX_REMOTE_VECTOR_BUILD = "index.knn.remote_index_build.enabled"; public static final String KNN_REMOTE_VECTOR_REPO = "knn.remote_index_build.vector_repo"; + public static final String KNN_REMOTE_BUILD_SERVICE_ENDPOINT = "knn.remote_build_service.endpoint"; + public static final String KNN_REMOTE_BUILD_SERVICE_POLL_INTERVAL = "knn.remote_build_service.poll_interval"; + public static final String KNN_REMOTE_BUILD_SERVICE_TIMEOUT = "knn.remote_build_service.timeout"; + public static final String KNN_REMOTE_BUILD_SERVICE_USERNAME = "knn.remote_build_service.username"; + public static final String KNN_REMOTE_BUILD_SERVICE_PASSWORD = "knn.remote_build_service.password"; /** * Default setting values @@ -127,6 +134,10 @@ public class KNNSettings { public static final Integer KNN_DEFAULT_QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES = 60; public static final boolean KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_VALUE = false; + // TODO: Tune these default values based on benchmarking + public static final Integer KNN_DEFAULT_REMOTE_BUILD_SERVICE_TIMEOUT_MINUTES = 60; + public static final Integer KNN_DEFAULT_REMOTE_BUILD_SERVICE_POLL_INTERVAL_SECONDS = 30; + /** * Settings Definition */ @@ -388,6 +399,47 @@ public class KNNSettings { */ public static final Setting KNN_REMOTE_VECTOR_REPO_SETTING = Setting.simpleString(KNN_REMOTE_VECTOR_REPO, Dynamic, NodeScope); + /** + * Remote build service endpoint to be used for remote index build. //TODO we can add String validators on these endpoint settings + */ + public static final Setting KNN_REMOTE_BUILD_SERVICE_ENDPOINT_SETTING = Setting.simpleString( + KNN_REMOTE_BUILD_SERVICE_ENDPOINT, + NodeScope, + Dynamic + ); + + /** + * Time the remote build service client will wait before falling back to CPU index build. + */ + public static final Setting KNN_REMOTE_BUILD_SERVICE_TIMEOUT_SETTING = Setting.timeSetting( + KNN_REMOTE_BUILD_SERVICE_TIMEOUT, + TimeValue.timeValueMinutes(KNN_DEFAULT_REMOTE_BUILD_SERVICE_TIMEOUT_MINUTES), + NodeScope, + Dynamic + ); + + /** + * Setting to control how often the remote build service client polls the build service for the status of the job. + */ + public static final Setting KNN_REMOTE_BUILD_SERVICE_POLL_INTERVAL_SETTING = Setting.timeSetting( + KNN_REMOTE_BUILD_SERVICE_POLL_INTERVAL, + TimeValue.timeValueSeconds(KNN_DEFAULT_REMOTE_BUILD_SERVICE_POLL_INTERVAL_SECONDS), + NodeScope, + Dynamic + ); + + /** + * Keystore settings for build service HTTP authorization + */ + public static final Setting KNN_REMOTE_BUILD_SERVICE_USERNAME_SETTING = SecureSetting.secureString( + KNN_REMOTE_BUILD_SERVICE_USERNAME, + null + ); + public static final Setting KNN_REMOTE_BUILD_SERVICE_PASSWORD_SETTING = SecureSetting.secureString( + KNN_REMOTE_BUILD_SERVICE_PASSWORD, + null + ); + /** * Dynamic settings */ @@ -550,6 +602,26 @@ private Setting getSetting(String key) { return KNN_REMOTE_VECTOR_REPO_SETTING; } + if (KNN_REMOTE_BUILD_SERVICE_ENDPOINT.equals(key)) { + return KNN_REMOTE_BUILD_SERVICE_ENDPOINT_SETTING; + } + + if (KNN_REMOTE_BUILD_SERVICE_TIMEOUT.equals(key)) { + return KNN_REMOTE_BUILD_SERVICE_TIMEOUT_SETTING; + } + + if (KNN_REMOTE_BUILD_SERVICE_POLL_INTERVAL.equals(key)) { + return KNN_REMOTE_BUILD_SERVICE_POLL_INTERVAL_SETTING; + } + + if (KNN_REMOTE_BUILD_SERVICE_USERNAME.equals(key)) { + return KNN_REMOTE_BUILD_SERVICE_USERNAME_SETTING; + } + + if (KNN_REMOTE_BUILD_SERVICE_PASSWORD.equals(key)) { + return KNN_REMOTE_BUILD_SERVICE_PASSWORD_SETTING; + } + throw new IllegalArgumentException("Cannot find setting by key [" + key + "]"); } @@ -577,7 +649,12 @@ public List> getSettings() { KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_SETTING, KNN_DERIVED_SOURCE_ENABLED_SETTING, KNN_INDEX_REMOTE_VECTOR_BUILD_SETTING, - KNN_REMOTE_VECTOR_REPO_SETTING + KNN_REMOTE_VECTOR_REPO_SETTING, + KNN_REMOTE_BUILD_SERVICE_ENDPOINT_SETTING, + KNN_REMOTE_BUILD_SERVICE_TIMEOUT_SETTING, + KNN_REMOTE_BUILD_SERVICE_POLL_INTERVAL_SETTING, + KNN_REMOTE_BUILD_SERVICE_USERNAME_SETTING, + KNN_REMOTE_BUILD_SERVICE_PASSWORD_SETTING ); return Stream.concat(settings.stream(), Stream.concat(getFeatureFlags().stream(), dynamicCacheSettings.values().stream())) .collect(Collectors.toList()); diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexBuildStrategyFactory.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexBuildStrategyFactory.java index 4b991fa507..cf997c5ab8 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexBuildStrategyFactory.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexBuildStrategyFactory.java @@ -51,7 +51,7 @@ public NativeIndexBuildStrategy getBuildStrategy(final FieldInfo fieldInfo) { && indexSettings != null && knnEngine.supportsRemoteIndexBuild() && RemoteIndexBuildStrategy.shouldBuildIndexRemotely(indexSettings)) { - return new RemoteIndexBuildStrategy(repositoriesServiceSupplier, strategy); + return new RemoteIndexBuildStrategy(repositoriesServiceSupplier, strategy, indexSettings); } else { return strategy; } diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteIndexBuildStrategy.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteIndexBuildStrategy.java index 8555e2ad68..6046a09b39 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteIndexBuildStrategy.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteIndexBuildStrategy.java @@ -8,13 +8,18 @@ import lombok.extern.log4j.Log4j2; import org.apache.commons.lang.NotImplementedException; import org.apache.lucene.index.SegmentWriteState; +import org.opensearch.cluster.ClusterName; import org.opensearch.common.StopWatch; +import org.opensearch.common.UUIDs; import org.opensearch.common.annotation.ExperimentalApi; import org.opensearch.index.IndexSettings; +import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.common.featureflags.KNNFeatureFlags; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.codec.nativeindex.NativeIndexBuildStrategy; import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; +import org.opensearch.knn.index.remote.RemoteBuildRequest; +import org.opensearch.knn.index.remote.RemoteIndexClient; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; import org.opensearch.repositories.RepositoriesService; import org.opensearch.repositories.Repository; @@ -22,6 +27,8 @@ import org.opensearch.repositories.blobstore.BlobStoreRepository; import java.io.IOException; +import java.util.HashMap; +import java.util.Map; import java.util.function.Supplier; import static org.opensearch.knn.index.KNNSettings.KNN_INDEX_REMOTE_VECTOR_BUILD_SETTING; @@ -37,17 +44,25 @@ public class RemoteIndexBuildStrategy implements NativeIndexBuildStrategy { private final Supplier repositoriesServiceSupplier; private final NativeIndexBuildStrategy fallbackStrategy; + private final IndexSettings indexSettings; + private static final String VECTOR_BLOB_FILE_EXTENSION = ".knnvec"; private static final String DOC_ID_FILE_EXTENSION = ".knndid"; + private static final String VECTORS_PATH = "_vectors"; /** * Public constructor * * @param repositoriesServiceSupplier A supplier for {@link RepositoriesService} used for interacting with repository */ - public RemoteIndexBuildStrategy(Supplier repositoriesServiceSupplier, NativeIndexBuildStrategy fallbackStrategy) { + public RemoteIndexBuildStrategy( + Supplier repositoriesServiceSupplier, + NativeIndexBuildStrategy fallbackStrategy, + IndexSettings indexSettings + ) { this.repositoriesServiceSupplier = repositoriesServiceSupplier; this.fallbackStrategy = fallbackStrategy; + this.indexSettings = indexSettings; } /** @@ -78,6 +93,7 @@ public void buildAndWriteIndex(BuildIndexParams indexInfo) throws IOException { long time_in_millis; try { stopWatch = new StopWatch().start(); + String blobName = UUIDs.base64UUID() + "_" + indexInfo.getFieldName() + "_" + indexInfo.getSegmentWriteState().segmentInfo.name; writeToRepository( indexInfo.getFieldName(), indexInfo.getKnnVectorValuesSupplier(), @@ -88,17 +104,18 @@ public void buildAndWriteIndex(BuildIndexParams indexInfo) throws IOException { log.debug("Repository write took {} ms for vector field [{}]", time_in_millis, indexInfo.getFieldName()); stopWatch = new StopWatch().start(); - submitVectorBuild(); + RemoteBuildRequest buildRequest = constructBuildRequest(indexInfo, blobName); + String jobId = RemoteIndexClient.getInstance().submitVectorBuild(buildRequest); time_in_millis = stopWatch.stop().totalTime().millis(); log.debug("Submit vector build took {} ms for vector field [{}]", time_in_millis, indexInfo.getFieldName()); stopWatch = new StopWatch().start(); - awaitVectorBuild(); + String indexPath = awaitVectorBuild(jobId); time_in_millis = stopWatch.stop().totalTime().millis(); log.debug("Await vector build took {} ms for vector field [{}]", time_in_millis, indexInfo.getFieldName()); stopWatch = new StopWatch().start(); - readFromRepository(); + readFromRepository(indexPath); time_in_millis = stopWatch.stop().totalTime().millis(); log.debug("Repository read took {} ms for vector field [{}]", time_in_millis, indexInfo.getFieldName()); } catch (Exception e) { @@ -145,6 +162,47 @@ private void writeToRepository( throw new NotImplementedException(); } + /** + * Construct the RemoteBuildRequest object for the index build request + * @return RemoteBuildRequest with parameters set + */ + protected RemoteBuildRequest constructBuildRequest(BuildIndexParams indexInfo, String blobName) throws IOException { + String repositoryType = getRepository().getMetadata().type(); + if (!repositoryType.equals("s3")) { + throw new IllegalArgumentException("Repository type " + repositoryType + " is not supported by the remote build service"); + } + String containerName = getRepository().getMetadata().settings().get("bucket"); + String vectorPath = blobName + VECTOR_BLOB_FILE_EXTENSION; + String docIdPath = blobName + DOC_ID_FILE_EXTENSION; + String tenantId = indexSettings.getSettings().get(ClusterName.CLUSTER_NAME_SETTING.getKey()); + int docCount = indexInfo.getTotalLiveDocs(); + String spaceType = indexInfo.getParameters().get(KNNConstants.SPACE_TYPE).toString(); + String engine = indexInfo.getKnnEngine().getName(); + + String dataType = indexInfo.getVectorDataType().getValue(); // TODO need to fetch encoder param to get fp16 vs fp32 + int dimension = 0; // TODO + Map algorithmParams = new HashMap<>(); // TODO fetch the below from index mapping + algorithmParams.put("ef_construction", 100); + algorithmParams.put("m", 16); + + Map indexParameters = new HashMap<>(); + indexParameters.put("algorithm", "hnsw"); + indexParameters.put("algorithm_parameters", algorithmParams); + + return RemoteBuildRequest.builder() + .repositoryType(repositoryType) + .containerName(containerName) + .vectorPath(vectorPath) + .docIdPath(docIdPath) + .tenantId(tenantId) + .dimension(dimension) + .docCount(docCount) + .dataType(dataType) + .engine(engine) + .indexParameters(indexParameters) + .build(); + } + /** * Submit vector build request to remote vector build service * @@ -156,14 +214,14 @@ private void submitVectorBuild() { /** * Wait on remote vector build to complete */ - private void awaitVectorBuild() { + private String awaitVectorBuild(String jobId) { throw new NotImplementedException(); } /** * Read constructed vector file from remote repository and write to IndexOutput */ - private void readFromRepository() { + private void readFromRepository(String indexPath) { throw new NotImplementedException(); } } diff --git a/src/main/java/org/opensearch/knn/index/remote/RemoteBuildRequest.java b/src/main/java/org/opensearch/knn/index/remote/RemoteBuildRequest.java new file mode 100644 index 0000000000..9cd19173ff --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/remote/RemoteBuildRequest.java @@ -0,0 +1,50 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.remote; + +import org.opensearch.common.xcontent.json.JsonXContent; +import lombok.Builder; +import lombok.Getter; +import org.opensearch.core.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +@Builder +@Getter +public class RemoteBuildRequest { + private final String repositoryType; + private final String containerName; + private final String vectorPath; + private final String docIdPath; + private final String tenantId; + private final int dimension; + private final int docCount; + private final String dataType; + private final String engine; + @Builder.Default + private final Map indexParameters = new HashMap<>(); + + public String toJson() throws IOException { + try (XContentBuilder builder = JsonXContent.contentBuilder()) { + builder.startObject(); + builder.field("repository_type", repositoryType); + builder.field("container_name", containerName); + builder.field("vector_path", vectorPath); + builder.field("doc_id_path", docIdPath); + builder.field("tenant_id", tenantId); + builder.field("dimension", dimension); + builder.field("doc_count", docCount); + builder.field("data_type", dataType); + builder.field("engine", engine); + builder.field("index_parameters", indexParameters); + builder.endObject(); + return builder.toString(); + } + } + +} diff --git a/src/main/java/org/opensearch/knn/index/remote/RemoteIndexClient.java b/src/main/java/org/opensearch/knn/index/remote/RemoteIndexClient.java new file mode 100644 index 0000000000..11b556d042 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/remote/RemoteIndexClient.java @@ -0,0 +1,158 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.remote; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ObjectNode; +import lombok.extern.log4j.Log4j2; +import org.apache.commons.lang.NotImplementedException; +import org.apache.hc.client5.http.classic.methods.HttpGet; +import org.apache.hc.client5.http.classic.methods.HttpPost; +import org.apache.hc.client5.http.classic.methods.HttpUriRequestBase; +import org.apache.hc.client5.http.impl.classic.BasicHttpClientResponseHandler; +import org.apache.hc.client5.http.impl.classic.CloseableHttpClient; +import org.apache.hc.client5.http.impl.classic.HttpClients; +import org.apache.hc.client5.http.utils.Base64; +import org.apache.hc.core5.http.HttpHeaders; +import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.apache.hc.core5.http.io.entity.StringEntity; +import org.opensearch.core.common.settings.SecureString; +import org.opensearch.knn.index.KNNSettings; + +import java.io.IOException; +import java.net.URI; +import java.nio.charset.StandardCharsets; + +/** + * Class to handle all interactions with the remote vector build service. + * InterruptedExceptions will cause a fallback to local CPU build. + */ +@Log4j2 +public class RemoteIndexClient { + private static RemoteIndexClient INSTANCE; + private volatile CloseableHttpClient httpClient; + protected static final int MAX_RETRIES = 1; // 2 total attempts + protected static final long BASE_DELAY_MS = 1000; + private String BUILD_ENDPOINT = "/_build"; + private String STATUS_ENDPOINT = "/_status"; + + private static final ObjectMapper objectMapper = new ObjectMapper(); + + RemoteIndexClient() { + this.httpClient = createHttpClient(); + } + + /** + * Return the Singleton instance of the node's RemoteIndexClient + * @return RemoteIndexClient instance + */ + public static synchronized RemoteIndexClient getInstance() { + if (INSTANCE == null) { + INSTANCE = new RemoteIndexClient(); + } + return INSTANCE; + } + + /** + * Initialize the httpClient to be used + * @return The HTTP Client + */ + private CloseableHttpClient createHttpClient() { + return HttpClients.custom().setRetryStrategy(new RemoteIndexClientRetryStrategy()).build(); + } + + /** + * Submit a build to the Remote Vector Build Service endpoint. + * @return job_id from the server response used to track the job + */ + public String submitVectorBuild(RemoteBuildRequest request) throws IOException { + URI endpoint = URI.create(KNNSettings.state().getSettingValue(KNNSettings.KNN_REMOTE_BUILD_SERVICE_ENDPOINT)); + HttpPost buildRequest = new HttpPost(endpoint + BUILD_ENDPOINT); + buildRequest.setHeader("Content-Type", "application/json"); + buildRequest.setEntity(new StringEntity(request.toJson())); + authenticateRequest(buildRequest); + + String response = httpClient.execute(buildRequest, body -> { + if (body.getCode() != 200) { + throw new IOException("Failed to submit build request, got status code: " + body.getCode()); + } + return EntityUtils.toString(body.getEntity()); + }); + + if (response == null) { + throw new IOException("Received 200 status code but response is null."); + } + + return getValueFromResponse(response, "job_id"); + } + + /** + * Await the completion of the index build by polling periodically and handling the returned statuses. + * @param jobId identifier from the server to track the job + * @return the path to the completed index + */ + public String awaitVectorBuild(String jobId) { + throw new NotImplementedException(); + } + + /** + * Helper method to directly get the status response for a given job ID + * @param jobId to check + * @return The entire response for the status request + */ + private String getBuildStatus(String jobId) throws IOException { + URI endpoint = URI.create(KNNSettings.state().getSettingValue(KNNSettings.KNN_REMOTE_BUILD_SERVICE_ENDPOINT)); + HttpGet request = new HttpGet(endpoint + STATUS_ENDPOINT + "/" + jobId); + authenticateRequest(request); + return httpClient.execute(request, new BasicHttpClientResponseHandler()); + } + + /** + * Given a JSON response string, get a value for a specific key. Converts json {@literal } to Java null. + * @param responseBody The response to read + * @param key The key to lookup + * @return The value for the key, or null if not found + */ + protected static String getValueFromResponse(String responseBody, String key) throws JsonProcessingException { + // TODO See if I can use OpenSearch XContent tools here to avoid Jackson dependency + ObjectNode jsonResponse = (ObjectNode) objectMapper.readTree(responseBody); + if (jsonResponse.has(key)) { + if (jsonResponse.get(key).isNull()) { + return null; + } + return jsonResponse.get(key).asText(); + } + throw new IllegalArgumentException("Key " + key + " not found in response"); + } + + /** + * Authenticate the HTTP request by manually setting the auth header. + * This is favored over setting a global auth scheme to allow for dynamic credential updates. + * @param request to be authenticated + */ + private void authenticateRequest(HttpUriRequestBase request) { + // TODO test secure setting retrieval/usage + SecureString username = KNNSettings.state().getSettingValue(KNNSettings.KNN_REMOTE_BUILD_SERVICE_USERNAME); + SecureString password = KNNSettings.state().getSettingValue(KNNSettings.KNN_REMOTE_BUILD_SERVICE_PASSWORD); + + if (password != null) { + final String auth = username + ":" + password.clone(); + final byte[] encodedAuth = Base64.encodeBase64(auth.getBytes(StandardCharsets.ISO_8859_1)); + final String authHeader = "Basic " + new String(encodedAuth); + request.setHeader(HttpHeaders.AUTHORIZATION, authHeader); + } + } + + /** + * Close the httpClient + */ + public void close() throws IOException { + if (httpClient != null) { + httpClient.close(); + } + } +} diff --git a/src/main/java/org/opensearch/knn/index/remote/RemoteIndexClientRetryStrategy.java b/src/main/java/org/opensearch/knn/index/remote/RemoteIndexClientRetryStrategy.java new file mode 100644 index 0000000000..b3f36624b9 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/remote/RemoteIndexClientRetryStrategy.java @@ -0,0 +1,62 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.remote; + +import org.apache.hc.client5.http.impl.DefaultHttpRequestRetryStrategy; +import org.apache.hc.core5.http.ConnectionClosedException; +import org.apache.hc.core5.http.HttpResponse; +import org.apache.hc.core5.http.protocol.HttpContext; +import org.apache.hc.core5.util.TimeValue; + +import javax.net.ssl.SSLException; +import java.io.InterruptedIOException; +import java.net.ConnectException; +import java.net.NoRouteToHostException; +import java.net.UnknownHostException; +import java.util.Arrays; +import java.util.List; + +/** + * The public constructors for the Apache HTTP client default retry strategies allow customization of max retries + * and retry interval, but not retryable status codes. + * In order to add the other retryable status codes from our Remote Build API Contract, we must extend this class. + * @see org.apache.hc.client5.http.impl.DefaultHttpRequestRetryStrategy + */ +public class RemoteIndexClientRetryStrategy extends DefaultHttpRequestRetryStrategy { + private static final List retryableCodes = Arrays.asList(408, 429, 500, 502, 503, 504, 509); + private static final List backoffCodes = Arrays.asList(429, 503); + + public RemoteIndexClientRetryStrategy() { + super( + RemoteIndexClient.MAX_RETRIES, + TimeValue.ofMilliseconds(RemoteIndexClient.BASE_DELAY_MS), + Arrays.asList( + InterruptedIOException.class, + UnknownHostException.class, + ConnectException.class, + ConnectionClosedException.class, + NoRouteToHostException.class, + SSLException.class + ), + retryableCodes + ); + } + + /** + * Override retry interval setting to implement backoff strategy for throttling codes. + * These codes may be returned with their own 'Retry-After' header which will take precedent over the below. + * This is only relevant for future implementations where we may increase the retry count from 1 max retry. + */ + @Override + public TimeValue getRetryInterval(HttpResponse response, int execCount, HttpContext context) { + if (backoffCodes.contains(response.getCode())) { + long delay = RemoteIndexClient.BASE_DELAY_MS; + long backoffDelay = delay * (long) Math.pow(2, execCount - 1); + return TimeValue.ofMilliseconds(Math.min(backoffDelay, TimeValue.ofMinutes(1).toMilliseconds())); + } + return super.getRetryInterval(response, execCount, context); + } +} diff --git a/src/test/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteIndexBuildStrategyTests.java b/src/test/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteIndexBuildStrategyTests.java index 1589021f67..359f04dac8 100644 --- a/src/test/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteIndexBuildStrategyTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteIndexBuildStrategyTests.java @@ -5,17 +5,27 @@ package org.opensearch.knn.index.codec.nativeindex.remote; +import org.mockito.MockedStatic; import org.mockito.Mockito; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.metadata.RepositoryMetadata; +import org.opensearch.common.SetOnce; +import org.opensearch.common.settings.Settings; +import org.opensearch.index.IndexSettings; +import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.codec.nativeindex.NativeIndexBuildStrategy; import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.remote.RemoteBuildRequest; import org.opensearch.knn.index.store.IndexOutputWithBuffer; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; import org.opensearch.knn.index.vectorvalues.TestVectorValues; import org.opensearch.repositories.RepositoriesService; import org.opensearch.repositories.RepositoryMissingException; +import org.opensearch.repositories.blobstore.BlobStoreRepository; import org.opensearch.test.OpenSearchTestCase; import java.io.IOException; @@ -25,6 +35,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import static org.opensearch.knn.index.KNNSettings.KNN_REMOTE_VECTOR_REPO_SETTING; public class RemoteIndexBuildStrategyTests extends OpenSearchTestCase { @@ -48,7 +59,11 @@ public void testFallback() throws IOException { RepositoriesService repositoriesService = mock(RepositoriesService.class); when(repositoriesService.repository(any())).thenThrow(new RepositoryMissingException("Fallback")); - RemoteIndexBuildStrategy objectUnderTest = new RemoteIndexBuildStrategy(() -> repositoriesService, new TestIndexBuildStrategy()); + RemoteIndexBuildStrategy objectUnderTest = new RemoteIndexBuildStrategy( + () -> repositoriesService, + new TestIndexBuildStrategy(), + null + ); IndexOutputWithBuffer indexOutputWithBuffer = Mockito.mock(IndexOutputWithBuffer.class); @@ -64,4 +79,62 @@ public void testFallback() throws IOException { objectUnderTest.buildAndWriteIndex(buildIndexParams); assertEquals(1, fallbackCounter); } + + public void testBuildRequest() throws IOException { + RepositoriesService repositoriesService = mock(RepositoriesService.class); + BlobStoreRepository blobStoreRepository = mock(BlobStoreRepository.class); + RepositoryMetadata metadata = mock(RepositoryMetadata.class); + Settings repoSettings = Settings.builder().put("bucket", "test-bucket").build(); + + when(metadata.type()).thenReturn("s3"); + when(metadata.settings()).thenReturn(repoSettings); + when(blobStoreRepository.getMetadata()).thenReturn(metadata); + when(repositoriesService.repository("test-repo")).thenReturn(blobStoreRepository); + + KNNSettings knnSettingsMock = mock(KNNSettings.class); + when(knnSettingsMock.getSettingValue(KNN_REMOTE_VECTOR_REPO_SETTING.getKey())).thenReturn("test-repo"); + + IndexSettings mockIndexSettings = mock(IndexSettings.class); + Settings indexSettingsSettings = Settings.builder().put(ClusterName.CLUSTER_NAME_SETTING.getKey(), "test-cluster").build(); + when(mockIndexSettings.getSettings()).thenReturn(indexSettingsSettings); + + try (MockedStatic knnSettingsStaticMock = Mockito.mockStatic(KNNSettings.class)) { + knnSettingsStaticMock.when(KNNSettings::state).thenReturn(knnSettingsMock); + + final SetOnce fallback = new SetOnce<>(); + RemoteIndexBuildStrategy objectUnderTest = new RemoteIndexBuildStrategy( + () -> repositoriesService, + new TestIndexBuildStrategy(), + mockIndexSettings + ); + + List vectorValues = List.of(new float[] { 1, 2 }, new float[] { 2, 3 }); + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + vectorValues + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues( + VectorDataType.FLOAT, + randomVectorValues + ); + + BuildIndexParams buildIndexParams = BuildIndexParams.builder() + .knnEngine(KNNEngine.FAISS) + .vectorDataType(VectorDataType.FLOAT) + .parameters(Map.of(KNNConstants.SPACE_TYPE, "l2")) + .knnVectorValuesSupplier(() -> knnVectorValues) + .totalLiveDocs(vectorValues.size()) + .build(); + + RemoteBuildRequest request = objectUnderTest.constructBuildRequest(buildIndexParams, "blob"); + + assertEquals("s3", request.getRepositoryType()); + assertEquals("test-bucket", request.getContainerName()); + assertEquals("faiss", request.getEngine()); + assertEquals("float", request.getDataType()); // TODO this will be in {fp16, fp32, byte, binary} + assertEquals("blob.knnvec", request.getVectorPath()); + assertEquals("blob.knndid", request.getDocIdPath()); + assertEquals("test-cluster", request.getTenantId()); + assertEquals(vectorValues.size(), request.getDocCount()); + } + } } diff --git a/src/test/java/org/opensearch/knn/index/remote/RemoteIndexClientTests.java b/src/test/java/org/opensearch/knn/index/remote/RemoteIndexClientTests.java new file mode 100644 index 0000000000..5f719ecc25 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/remote/RemoteIndexClientTests.java @@ -0,0 +1,111 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.remote; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.junit.Before; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.KNNSettings; + +import java.io.IOException; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class RemoteIndexClientTests extends KNNTestCase { + + @Mock + protected ClusterService clusterService; + + protected AutoCloseable openMocks; + + private ObjectMapper mapper; + + @Before + public void setup() { + this.mapper = new ObjectMapper(); + openMocks = MockitoAnnotations.openMocks(this); + clusterService = mock(ClusterService.class); + Set> defaultClusterSettings = new HashSet<>(ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); + KNNSettings.state().setClusterService(clusterService); + when(clusterService.getClusterSettings()).thenReturn(new ClusterSettings(Settings.EMPTY, defaultClusterSettings)); + } + + public void testGetHttpClient_success() throws IOException { + RemoteIndexClient client = RemoteIndexClient.getInstance(); + assertNotNull(client); + client.close(); + } + + public void testConstructBuildRequest() throws IOException { + Map algorithmParams = new HashMap<>(); + algorithmParams.put("ef_construction", 100); + algorithmParams.put("m", 16); + + Map indexParameters = new HashMap<>(); + indexParameters.put("algorithm", "hnsw"); + indexParameters.put("space_type", "l2"); + indexParameters.put("algorithm_parameters", algorithmParams); + + RemoteBuildRequest request = RemoteBuildRequest.builder() + .repositoryType("S3") + .containerName("MyVectorStore") + .vectorPath("MyVectorPath") + .docIdPath("MyDocIdPath") + .tenantId("MyTenant") + .dimension(256) + .docCount(1_000_000) + .dataType("fp32") + .engine("faiss") + .indexParameters(indexParameters) + .build(); + + String expectedJson = "{" + + "\"repository_type\":\"S3\"," + + "\"container_name\":\"MyVectorStore\"," + + "\"vector_path\":\"MyVectorPath\"," + + "\"doc_id_path\":\"MyDocIdPath\"," + + "\"tenant_id\":\"MyTenant\"," + + "\"dimension\":256," + + "\"doc_count\":1000000," + + "\"data_type\":\"fp32\"," + + "\"engine\":\"faiss\"," + + "\"index_parameters\":{" + + "\"space_type\":\"l2\"," + + "\"algorithm\":\"hnsw\"," + + "\"algorithm_parameters\":{" + + "\"ef_construction\":100," + + "\"m\":16" + + "}" + + "}" + + "}"; + assertEquals(mapper.readTree(expectedJson), mapper.readTree(request.toJson())); + } + + public void testGetValueFromResponse() throws JsonProcessingException { + String jobID = "{\"job_id\": \"job-1739930402\"}"; + assertEquals("job-1739930402", RemoteIndexClient.getValueFromResponse(jobID, "job_id")); + String failedIndexBuild = "{" + + "\"task_status\":\"FAILED_INDEX_BUILD\"," + + "\"error\":\"Index build process interrupted.\"," + + "\"index_path\": null" + + "}"; + String error = RemoteIndexClient.getValueFromResponse(failedIndexBuild, "error"); + assertEquals("Index build process interrupted.", error); + assertNull(RemoteIndexClient.getValueFromResponse(failedIndexBuild, "index_path")); + } +}