diff --git a/CHANGELOG.md b/CHANGELOG.md index 2c6252c46c..8fd361c48e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Refactoring * Small Refactor Post Lucene 10.0.1 upgrade [#2541](https://github.com/opensearch-project/k-NN/pull/2541) * Refactor codec to leverage backwards_codecs [#2546](https://github.com/opensearch-project/k-NN/pull/2546) +* Remove usage of cluster level setting for circuit breaker [#2567](https://github.com/opensearch-project/k-NN/pull/2567) ## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.19...2.x) ### Features diff --git a/qa/rolling-upgrade/src/test/java/org/opensearch/knn/bwc/StatsIT.java b/qa/rolling-upgrade/src/test/java/org/opensearch/knn/bwc/StatsIT.java index 50beb60d5f..5c7d1573be 100644 --- a/qa/rolling-upgrade/src/test/java/org/opensearch/knn/bwc/StatsIT.java +++ b/qa/rolling-upgrade/src/test/java/org/opensearch/knn/bwc/StatsIT.java @@ -7,6 +7,7 @@ import org.apache.hc.core5.http.io.entity.EntityUtils; import org.junit.Before; +import org.opensearch.Version; import org.opensearch.client.Response; import org.opensearch.client.ResponseException; import org.opensearch.knn.plugin.stats.KNNStats; @@ -21,7 +22,7 @@ public class StatsIT extends AbstractRollingUpgradeTestCase { @Before public void setUp() throws Exception { super.setUp(); - this.knnStats = new KNNStats(); + this.knnStats = new KNNStats(null, () -> Version.CURRENT); } // Validate if all the KNN Stats metrics from old version are present in new version diff --git a/src/main/java/org/opensearch/knn/index/KNNCircuitBreaker.java b/src/main/java/org/opensearch/knn/index/KNNCircuitBreaker.java index 4829777be0..d4dab54df0 100644 --- a/src/main/java/org/opensearch/knn/index/KNNCircuitBreaker.java +++ b/src/main/java/org/opensearch/knn/index/KNNCircuitBreaker.java @@ -5,34 +5,18 @@ package org.opensearch.knn.index; -import org.opensearch.knn.index.memory.NativeMemoryCacheManager; -import org.opensearch.knn.plugin.stats.StatNames; -import org.opensearch.knn.plugin.transport.KNNStatsAction; -import org.opensearch.knn.plugin.transport.KNNStatsNodeResponse; -import org.opensearch.knn.plugin.transport.KNNStatsRequest; -import org.opensearch.knn.plugin.transport.KNNStatsResponse; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.opensearch.transport.client.Client; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.unit.TimeValue; -import org.opensearch.threadpool.ThreadPool; - -import java.util.ArrayList; -import java.util.List; +import lombok.Getter; /** * Runs the circuit breaker logic and updates the settings */ +@Getter public class KNNCircuitBreaker { - private static Logger logger = LogManager.getLogger(KNNCircuitBreaker.class); public static final String KNN_CIRCUIT_BREAKER_TIER = "knn_cb_tier"; - public static int CB_TIME_INTERVAL = 2 * 60; // seconds private static KNNCircuitBreaker INSTANCE; - private ThreadPool threadPool; - private ClusterService clusterService; - private Client client; + + private boolean isTripped = false; private KNNCircuitBreaker() {} @@ -44,68 +28,11 @@ public static synchronized KNNCircuitBreaker getInstance() { } /** - * SetInstance of Circuit Breaker + * Set the circuit breaker flag * - * @param instance KNNCircuitBreaker instance + * @param isTripped true if circuit breaker is tripped, false otherwise */ - public static synchronized void setInstance(KNNCircuitBreaker instance) { - INSTANCE = instance; - } - - public void initialize(ThreadPool threadPool, ClusterService clusterService, Client client) { - this.threadPool = threadPool; - this.clusterService = clusterService; - this.client = client; - NativeMemoryCacheManager nativeMemoryCacheManager = NativeMemoryCacheManager.getInstance(); - Runnable runnable = () -> { - if (nativeMemoryCacheManager.isCacheCapacityReached() && clusterService.localNode().isDataNode()) { - long currentSizeKiloBytes = nativeMemoryCacheManager.getCacheSizeInKilobytes(); - long circuitBreakerLimitSizeKiloBytes = KNNSettings.state().getCircuitBreakerLimit().getKb(); - long circuitBreakerUnsetSizeKiloBytes = (long) ((KNNSettings.getCircuitBreakerUnsetPercentage() / 100) - * circuitBreakerLimitSizeKiloBytes); - /** - * Unset capacityReached flag if currentSizeBytes is less than circuitBreakerUnsetSizeBytes - */ - if (currentSizeKiloBytes <= circuitBreakerUnsetSizeKiloBytes) { - nativeMemoryCacheManager.setCacheCapacityReached(false); - } - } - - // Leader node untriggers CB if all nodes have not reached their max capacity - if (KNNSettings.isCircuitBreakerTriggered() && clusterService.state().nodes().isLocalNodeElectedClusterManager()) { - KNNStatsRequest knnStatsRequest = new KNNStatsRequest(); - knnStatsRequest.addStat(StatNames.CACHE_CAPACITY_REACHED.getName()); - knnStatsRequest.timeout(new TimeValue(1000 * 10)); // 10 second timeout - - try { - KNNStatsResponse knnStatsResponse = client.execute(KNNStatsAction.INSTANCE, knnStatsRequest).get(); - List nodeResponses = knnStatsResponse.getNodes(); - - List nodesAtMaxCapacity = new ArrayList<>(); - for (KNNStatsNodeResponse nodeResponse : nodeResponses) { - if ((Boolean) nodeResponse.getStatsMap().get(StatNames.CACHE_CAPACITY_REACHED.getName())) { - nodesAtMaxCapacity.add(nodeResponse.getNode().getId()); - } - } - - if (!nodesAtMaxCapacity.isEmpty()) { - logger.info( - "[KNN] knn.circuit_breaker.triggered stays set. Nodes at max cache capacity: " - + String.join(",", nodesAtMaxCapacity) - + "." - ); - } else { - logger.info( - "[KNN] Cache capacity below 75% of the circuit breaker limit for all nodes." - + " Unsetting knn.circuit_breaker.triggered flag." - ); - KNNSettings.state().updateCircuitBreakerSettings(false); - } - } catch (Exception e) { - logger.error("[KNN] Exception getting stats: " + e); - } - } - }; - this.threadPool.scheduleWithFixedDelay(runnable, TimeValue.timeValueSeconds(CB_TIME_INTERVAL), ThreadPool.Names.GENERIC); + public synchronized void setTripped(boolean isTripped) { + this.isTripped = isTripped; } } diff --git a/src/main/java/org/opensearch/knn/index/KNNSettings.java b/src/main/java/org/opensearch/knn/index/KNNSettings.java index c5d9a05748..f461826e53 100644 --- a/src/main/java/org/opensearch/knn/index/KNNSettings.java +++ b/src/main/java/org/opensearch/knn/index/KNNSettings.java @@ -10,15 +10,12 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.OpenSearchParseException; -import org.opensearch.action.admin.cluster.settings.ClusterUpdateSettingsRequest; -import org.opensearch.action.admin.cluster.settings.ClusterUpdateSettingsResponse; import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.Booleans; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; -import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.unit.ByteSizeUnit; import org.opensearch.core.common.unit.ByteSizeValue; import org.opensearch.index.IndexModule; @@ -28,7 +25,6 @@ import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateCacheManager; import org.opensearch.monitor.jvm.JvmInfo; import org.opensearch.monitor.os.OsProbe; -import org.opensearch.transport.client.Client; import java.security.InvalidParameterException; import java.util.Arrays; @@ -317,7 +313,8 @@ public class KNNSettings { KNN_CIRCUIT_BREAKER_TRIGGERED, false, NodeScope, - Dynamic + Dynamic, + Setting.Property.Deprecated ); public static final Setting KNN_CIRCUIT_BREAKER_UNSET_PERCENTAGE_SETTING = Setting.doubleSetting( @@ -473,8 +470,8 @@ public class KNNSettings { private final static Map> FEATURE_FLAGS = getFeatureFlags().stream() .collect(toUnmodifiableMap(Setting::getKey, Function.identity())); + @Setter private ClusterService clusterService; - private Client client; @Setter private Optional nodeCbAttribute; @@ -638,10 +635,6 @@ public static boolean isKNNPluginEnabled() { return KNNSettings.state().getSettingValue(KNNSettings.KNN_PLUGIN_ENABLED); } - public static boolean isCircuitBreakerTriggered() { - return KNNSettings.state().getSettingValue(KNNSettings.KNN_CIRCUIT_BREAKER_TRIGGERED); - } - /** * Retrieves the node-specific circuit breaker limit based on the existing settings. * @@ -806,8 +799,7 @@ public static boolean isShardLevelRescoringDisabledForDiskBasedVector(String ind .getAsBoolean(KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED, false); } - public void initialize(Client client, ClusterService clusterService) { - this.client = client; + public void initialize(ClusterService clusterService) { this.clusterService = clusterService; this.nodeCbAttribute = Optional.empty(); setSettingsUpdateConsumers(); @@ -841,35 +833,6 @@ public static ByteSizeValue parseknnMemoryCircuitBreakerValue(String sValue, Byt } } - /** - * Updates knn.circuit_breaker.triggered setting to true/false - * @param flag true/false - */ - public synchronized void updateCircuitBreakerSettings(boolean flag) { - ClusterUpdateSettingsRequest clusterUpdateSettingsRequest = new ClusterUpdateSettingsRequest(); - Settings circuitBreakerSettings = Settings.builder().put(KNNSettings.KNN_CIRCUIT_BREAKER_TRIGGERED, flag).build(); - clusterUpdateSettingsRequest.persistentSettings(circuitBreakerSettings); - client.admin().cluster().updateSettings(clusterUpdateSettingsRequest, new ActionListener() { - @Override - public void onResponse(ClusterUpdateSettingsResponse clusterUpdateSettingsResponse) { - logger.debug( - "Cluster setting {}, acknowledged: {} ", - clusterUpdateSettingsRequest.persistentSettings(), - clusterUpdateSettingsResponse.isAcknowledged() - ); - } - - @Override - public void onFailure(Exception e) { - logger.info( - "Exception while updating circuit breaker setting {} to {}", - clusterUpdateSettingsRequest.persistentSettings(), - e.getMessage() - ); - } - }); - } - public static ByteSizeValue getVectorStreamingMemoryLimit() { return KNNSettings.state().getSettingValue(KNN_VECTOR_STREAMING_MEMORY_LIMIT_IN_MB); } @@ -888,10 +851,6 @@ public static int getEfSearchParam(String index) { ); } - public void setClusterService(ClusterService clusterService) { - this.clusterService = clusterService; - } - static class SpaceTypeValidator implements Setting.Validator { @Override diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java index d2f9c21b7e..fa2a361735 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java @@ -20,6 +20,7 @@ import org.apache.lucene.util.BytesRef; import org.opensearch.Version; import org.opensearch.common.settings.Settings; +import org.opensearch.knn.index.KNNCircuitBreaker; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.KnnCircuitBreakerException; import org.opensearch.knn.index.SpaceType; @@ -114,9 +115,9 @@ public static int getExpectedVectorLength(final KNNVectorFieldType knnVectorFiel * Validate if the circuit breaker is triggered */ static void validateIfCircuitBreakerIsNotTriggered() { - if (KNNSettings.isCircuitBreakerTriggered()) { + if (KNNCircuitBreaker.getInstance().isTripped()) { throw new KnnCircuitBreakerException( - "Parsing the created knn vector fields prior to indexing has failed as the circuit breaker triggered. This indicates that the cluster is low on memory resources and cannot index more documents at the moment. Check _plugins/_knn/stats for the circuit breaker status." + "Parsing the created knn vector fields prior to indexing has failed as the circuit breaker triggered. This indicates that the node is low on memory resources and cannot index more documents at the moment. Check _plugins/_knn/stats for the circuit breaker status." ); } } diff --git a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryCacheManager.java b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryCacheManager.java index cc7edbf276..117acacbb3 100644 --- a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryCacheManager.java +++ b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryCacheManager.java @@ -24,6 +24,7 @@ import org.opensearch.common.unit.TimeValue; import org.opensearch.knn.common.exception.OutOfNativeMemoryException; import org.opensearch.knn.common.featureflags.KNNFeatureFlags; +import org.opensearch.knn.index.KNNCircuitBreaker; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.plugin.stats.StatNames; import org.opensearch.threadpool.ThreadPool; @@ -53,6 +54,7 @@ public class NativeMemoryCacheManager implements Closeable { private static NativeMemoryCacheManager INSTANCE; @Setter private static ThreadPool threadPool; + public static int CB_TIME_INTERVAL = 20; // seconds private Cache cache; private Deque accessRecencyQueue; @@ -121,9 +123,12 @@ private void initialize(NativeMemoryCacheManagerDto nativeMemoryCacheDTO) { cacheCapacityReached = new AtomicBoolean(false); accessRecencyQueue = new ConcurrentLinkedDeque<>(); cache = cacheBuilder.build(); + // Set to false when initialized. This will ensure that we dont have to wait for the maintenance job + KNNCircuitBreaker.getInstance().setTripped(false); if (threadPool != null) { startMaintenance(cache); + circuitBreakerUpdater(); } else { logger.warn("ThreadPool is null during NativeMemoryCacheManager initialization. Maintenance will not start."); } @@ -413,24 +418,6 @@ public void invalidateAll() { cache.invalidateAll(); } - /** - * Returns whether or not the capacity of the cache has been reached - * - * @return Boolean of whether cache limit has been reached - */ - public Boolean isCacheCapacityReached() { - return cacheCapacityReached.get(); - } - - /** - * Sets cache capacity reached - * - * @param value Boolean value to set cache Capacity Reached to - */ - public void setCacheCapacityReached(Boolean value) { - cacheCapacityReached.set(value); - } - /** * Get the stats of all of the OpenSearch indices currently loaded into the cache * @@ -461,8 +448,7 @@ private void onRemoval(RemovalNotification remov nativeMemoryAllocation.close(); if (RemovalCause.SIZE == removalNotification.getCause()) { - KNNSettings.state().updateCircuitBreakerSettings(true); - setCacheCapacityReached(true); + KNNCircuitBreaker.getInstance().setTripped(true); } logger.debug("[KNN] Cache evicted. Key {}, Reason: {}", removalNotification.getKey(), removalNotification.getCause()); @@ -500,4 +486,21 @@ private void startMaintenance(Cache cacheInstanc maintenanceTask = threadPool.scheduleWithFixedDelay(cleanUp, interval, ThreadPool.Names.MANAGEMENT); } + + private void circuitBreakerUpdater() { + Runnable runnable = () -> { + if (KNNCircuitBreaker.getInstance().isTripped()) { + long currentSizeKiloBytes = getCacheSizeInKilobytes(); + long circuitBreakerLimitSizeKiloBytes = KNNSettings.state().getCircuitBreakerLimit().getKb(); + long circuitBreakerUnsetSizeKiloBytes = (long) ((KNNSettings.getCircuitBreakerUnsetPercentage() / 100) + * circuitBreakerLimitSizeKiloBytes); + + // Unset capacityReached flag if currentSizeBytes is less than circuitBreakerUnsetSizeBytes + if (currentSizeKiloBytes <= circuitBreakerUnsetSizeKiloBytes) { + KNNCircuitBreaker.getInstance().setTripped(false); + } + } + }; + threadPool.scheduleWithFixedDelay(runnable, TimeValue.timeValueSeconds(CB_TIME_INTERVAL), ThreadPool.Names.GENERIC); + } } diff --git a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java index e8aa580f47..c56a617330 100644 --- a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java +++ b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java @@ -27,10 +27,8 @@ import org.opensearch.index.IndexModule; import org.opensearch.index.IndexSettings; import org.opensearch.index.codec.CodecServiceFactory; -import org.opensearch.index.engine.EngineFactory; import org.opensearch.index.mapper.Mapper; import org.opensearch.indices.SystemIndexDescriptor; -import org.opensearch.knn.index.KNNCircuitBreaker; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.codec.KNNCodecService; import org.opensearch.knn.index.codec.nativeindex.NativeIndexBuildStrategyFactory; @@ -60,6 +58,8 @@ import org.opensearch.knn.plugin.transport.DeleteModelTransportAction; import org.opensearch.knn.plugin.transport.GetModelAction; import org.opensearch.knn.plugin.transport.GetModelTransportAction; +import org.opensearch.knn.plugin.transport.KNNCircuitBreakerTrippedAction; +import org.opensearch.knn.plugin.transport.KNNCircuitBreakerTrippedTransportAction; import org.opensearch.knn.plugin.transport.KNNStatsAction; import org.opensearch.knn.plugin.transport.KNNStatsTransportAction; import org.opensearch.knn.plugin.transport.KNNWarmupAction; @@ -164,7 +164,6 @@ public class KNNPlugin extends Plugin public static final String LEGACY_KNN_BASE_URI = "/_opendistro/_knn"; public static final String KNN_BASE_URI = "/_plugins/_knn"; - private KNNStats knnStats; private ClusterService clusterService; private Supplier repositoriesServiceSupplier; @@ -202,7 +201,7 @@ public Collection createComponents( VectorReader vectorReader = new VectorReader(client); NativeMemoryLoadStrategy.TrainingLoadStrategy.initialize(vectorReader); - KNNSettings.state().initialize(client, clusterService); + KNNSettings.state().initialize(clusterService); KNNClusterUtil.instance().initialize(clusterService); ModelDao.OpenSearchKNNModelDao.initialize(client, clusterService, environment.settings()); ModelCache.initialize(ModelDao.OpenSearchKNNModelDao.getInstance(), clusterService); @@ -210,14 +209,13 @@ public Collection createComponents( TrainingJobClusterStateListener.initialize(threadPool, ModelDao.OpenSearchKNNModelDao.getInstance(), clusterService); QuantizationStateCache.setThreadPool(threadPool); NativeMemoryCacheManager.setThreadPool(threadPool); - KNNCircuitBreaker.getInstance().initialize(threadPool, clusterService, client); KNNQueryBuilder.initialize(ModelDao.OpenSearchKNNModelDao.getInstance()); KNNWeight.initialize(ModelDao.OpenSearchKNNModelDao.getInstance()); TrainingModelRequest.initialize(ModelDao.OpenSearchKNNModelDao.getInstance(), clusterService); clusterService.addListener(TrainingJobClusterStateListener.getInstance()); - knnStats = new KNNStats(); + KNNStats knnStats = new KNNStats(client, () -> clusterService.getClusterManagerService().getMinNodeVersion()); return ImmutableList.of(knnStats); } @@ -277,15 +275,11 @@ public List getRestHandlers( new ActionHandler<>(RemoveModelFromCacheAction.INSTANCE, RemoveModelFromCacheTransportAction.class), new ActionHandler<>(SearchModelAction.INSTANCE, SearchModelTransportAction.class), new ActionHandler<>(UpdateModelGraveyardAction.INSTANCE, UpdateModelGraveyardTransportAction.class), - new ActionHandler<>(ClearCacheAction.INSTANCE, ClearCacheTransportAction.class) + new ActionHandler<>(ClearCacheAction.INSTANCE, ClearCacheTransportAction.class), + new ActionHandler<>(KNNCircuitBreakerTrippedAction.INSTANCE, KNNCircuitBreakerTrippedTransportAction.class) ); } - @Override - public Optional getEngineFactory(IndexSettings indexSettings) { - return Optional.empty(); - } - @Override public Optional getCustomCodecServiceFactory(IndexSettings indexSettings) { if (indexSettings.getValue(KNNSettings.IS_KNN_INDEX_SETTING)) { diff --git a/src/main/java/org/opensearch/knn/plugin/stats/CircuitBreakerStat.java b/src/main/java/org/opensearch/knn/plugin/stats/CircuitBreakerStat.java new file mode 100644 index 0000000000..ff321558c7 --- /dev/null +++ b/src/main/java/org/opensearch/knn/plugin/stats/CircuitBreakerStat.java @@ -0,0 +1,68 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.plugin.stats; + +import org.opensearch.Version; +import org.opensearch.core.action.ActionListener; +import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.plugin.transport.KNNCircuitBreakerTrippedAction; +import org.opensearch.knn.plugin.transport.KNNCircuitBreakerTrippedRequest; +import org.opensearch.transport.client.Client; + +import java.util.Map; +import java.util.function.Function; +import java.util.function.Supplier; + +import static org.opensearch.knn.index.KNNSettings.KNN_CIRCUIT_BREAKER_TRIGGERED; + +public class CircuitBreakerStat extends KNNStat { + + public static final String CONTEXT_CB_TRIPPED = "is_cb_tripped"; + + private static final Function FETCHER = context -> { + if (context == null) { + return false; + } + return (Boolean) context.getContext(StatNames.CIRCUIT_BREAKER_TRIGGERED.getName()).get(CONTEXT_CB_TRIPPED); + }; + + private final Client client; + private final Supplier minVersionSupplier; + + public CircuitBreakerStat(Client client, Supplier minVersionSupplier) { + super(true, FETCHER); + this.client = client; + this.minVersionSupplier = minVersionSupplier; + } + + @Override + public ActionListener setupContext(KNNStatFetchContext knnStatFetchContext, ActionListener actionListener) { + // If there are any nodes in the cluster before 3.0, then we need to fall back to checking the CB + if (minVersionSupplier.get().compareTo(Version.V_3_0_0) < 0) { + return ActionListener.wrap(knnCircuitBreakerTrippedResponse -> { + knnStatFetchContext.addContext( + StatNames.CIRCUIT_BREAKER_TRIGGERED.getName(), + Map.of(CONTEXT_CB_TRIPPED, KNNSettings.state().getSettingValue(KNN_CIRCUIT_BREAKER_TRIGGERED)) + ); + actionListener.onResponse(null); + }, actionListener::onFailure); + } + return ActionListener.wrap( + response -> client.execute( + KNNCircuitBreakerTrippedAction.INSTANCE, + new KNNCircuitBreakerTrippedRequest(), + ActionListener.wrap(knnCircuitBreakerTrippedResponse -> { + knnStatFetchContext.addContext( + StatNames.CIRCUIT_BREAKER_TRIGGERED.getName(), + Map.of(CONTEXT_CB_TRIPPED, knnCircuitBreakerTrippedResponse.isTripped()) + ); + actionListener.onResponse(null); + }, actionListener::onFailure) + ), + actionListener::onFailure + ); + } +} diff --git a/src/main/java/org/opensearch/knn/plugin/stats/KNNStat.java b/src/main/java/org/opensearch/knn/plugin/stats/KNNStat.java index 230b558816..f735122a3d 100644 --- a/src/main/java/org/opensearch/knn/plugin/stats/KNNStat.java +++ b/src/main/java/org/opensearch/knn/plugin/stats/KNNStat.java @@ -5,24 +5,33 @@ package org.opensearch.knn.plugin.stats; +import lombok.Getter; +import org.opensearch.core.action.ActionListener; + +import java.util.function.Function; import java.util.function.Supplier; /** * Class represents a stat the plugin keeps track of */ public class KNNStat { - private Boolean clusterLevel; - private Supplier supplier; + @Getter + private Boolean isClusterLevel; + private final Function statFetcher; /** * Constructor * - * @param clusterLevel the scope of the stat + * @param isClusterLevel the scope of the stat * @param supplier supplier that returns the stat's value */ - public KNNStat(Boolean clusterLevel, Supplier supplier) { - this.clusterLevel = clusterLevel; - this.supplier = supplier; + public KNNStat(Boolean isClusterLevel, Supplier supplier) { + this(isClusterLevel, context -> supplier.get()); + } + + public KNNStat(Boolean isClusterLevel, Function statFetcher) { + this.isClusterLevel = isClusterLevel; + this.statFetcher = statFetcher; } /** @@ -31,7 +40,11 @@ public KNNStat(Boolean clusterLevel, Supplier supplier) { * @return boolean that is true if the stat is clusterLevel; false otherwise */ public Boolean isClusterLevel() { - return clusterLevel; + return isClusterLevel; + } + + public ActionListener setupContext(KNNStatFetchContext knnStatFetchContext, ActionListener actionListener) { + return actionListener; } /** @@ -40,6 +53,16 @@ public Boolean isClusterLevel() { * @return value of the stat */ public T getValue() { - return supplier.get(); + return getValue(null); + } + + /** + * Get the value of the statistic + * + * @param statFetchContext context for fetching the stat + * @return value of the stat + */ + public T getValue(KNNStatFetchContext statFetchContext) { + return statFetcher.apply(statFetchContext); } } diff --git a/src/main/java/org/opensearch/knn/plugin/stats/KNNStatFetchContext.java b/src/main/java/org/opensearch/knn/plugin/stats/KNNStatFetchContext.java new file mode 100644 index 0000000000..6364081250 --- /dev/null +++ b/src/main/java/org/opensearch/knn/plugin/stats/KNNStatFetchContext.java @@ -0,0 +1,28 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.plugin.stats; + +import java.util.HashMap; +import java.util.Map; + +/** + * Additional context needed for fetching KNN stats. + */ +public class KNNStatFetchContext { + private final Map> contexts; + + public KNNStatFetchContext() { + this.contexts = new HashMap<>(); + } + + public void addContext(String statName, Map context) { + this.contexts.put(statName, context); + } + + public Map getContext(String statName) { + return this.contexts.get(statName); + } +} diff --git a/src/main/java/org/opensearch/knn/plugin/stats/KNNStats.java b/src/main/java/org/opensearch/knn/plugin/stats/KNNStats.java index bcd419ea68..6e25258247 100644 --- a/src/main/java/org/opensearch/knn/plugin/stats/KNNStats.java +++ b/src/main/java/org/opensearch/knn/plugin/stats/KNNStats.java @@ -7,19 +7,21 @@ import com.google.common.cache.CacheStats; import com.google.common.collect.ImmutableMap; +import org.opensearch.Version; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.memory.NativeMemoryCacheManager; import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.knn.indices.ModelCache; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.plugin.stats.suppliers.EventOccurredWithinThresholdSupplier; -import org.opensearch.knn.plugin.stats.suppliers.KNNCircuitBreakerSupplier; +import org.opensearch.knn.plugin.stats.suppliers.KNNNodeLevelCircuitBreakerSupplier; import org.opensearch.knn.plugin.stats.suppliers.KNNCounterSupplier; import org.opensearch.knn.plugin.stats.suppliers.KNNInnerCacheStatsSupplier; import org.opensearch.knn.plugin.stats.suppliers.LibraryInitializedSupplier; import org.opensearch.knn.plugin.stats.suppliers.ModelIndexStatusSupplier; import org.opensearch.knn.plugin.stats.suppliers.ModelIndexingDegradingSupplier; import org.opensearch.knn.plugin.stats.suppliers.NativeMemoryCacheManagerSupplier; +import org.opensearch.transport.client.Client; import java.time.temporal.ChronoUnit; import java.util.HashMap; @@ -31,12 +33,16 @@ */ public class KNNStats { + private final Client client; + private final Supplier minVersionSupplier; private final Map> knnStats; /** * Constructor */ - public KNNStats() { + public KNNStats(Client client, Supplier minVersionSupplier) { + this.client = client; + this.minVersionSupplier = minVersionSupplier; this.knnStats = buildStatsMap(); } @@ -140,15 +146,12 @@ private void addNativeMemoryStats(ImmutableMap.Builder> build StatNames.INDICES_IN_CACHE.getName(), new KNNStat<>(false, new NativeMemoryCacheManagerSupplier<>(NativeMemoryCacheManager::getIndicesCacheStats)) ) - .put( - StatNames.CACHE_CAPACITY_REACHED.getName(), - new KNNStat<>(false, new NativeMemoryCacheManagerSupplier<>(NativeMemoryCacheManager::isCacheCapacityReached)) - ) + .put(StatNames.CACHE_CAPACITY_REACHED.getName(), new KNNStat<>(false, new KNNNodeLevelCircuitBreakerSupplier())) .put(StatNames.GRAPH_QUERY_ERRORS.getName(), new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.GRAPH_QUERY_ERRORS))) .put(StatNames.GRAPH_QUERY_REQUESTS.getName(), new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.GRAPH_QUERY_REQUESTS))) .put(StatNames.GRAPH_INDEX_ERRORS.getName(), new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.GRAPH_INDEX_ERRORS))) .put(StatNames.GRAPH_INDEX_REQUESTS.getName(), new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.GRAPH_INDEX_REQUESTS))) - .put(StatNames.CIRCUIT_BREAKER_TRIGGERED.getName(), new KNNStat<>(true, new KNNCircuitBreakerSupplier())); + .put(StatNames.CIRCUIT_BREAKER_TRIGGERED.getName(), new CircuitBreakerStat(client, minVersionSupplier)); } private void addEngineStats(ImmutableMap.Builder> builder) { diff --git a/src/main/java/org/opensearch/knn/plugin/stats/suppliers/KNNCircuitBreakerSupplier.java b/src/main/java/org/opensearch/knn/plugin/stats/suppliers/KNNNodeLevelCircuitBreakerSupplier.java similarity index 55% rename from src/main/java/org/opensearch/knn/plugin/stats/suppliers/KNNCircuitBreakerSupplier.java rename to src/main/java/org/opensearch/knn/plugin/stats/suppliers/KNNNodeLevelCircuitBreakerSupplier.java index 32b78e7cc3..60c217d502 100644 --- a/src/main/java/org/opensearch/knn/plugin/stats/suppliers/KNNCircuitBreakerSupplier.java +++ b/src/main/java/org/opensearch/knn/plugin/stats/suppliers/KNNNodeLevelCircuitBreakerSupplier.java @@ -5,22 +5,22 @@ package org.opensearch.knn.plugin.stats.suppliers; -import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.KNNCircuitBreaker; import java.util.function.Supplier; /** * Supplier for circuit breaker stats */ -public class KNNCircuitBreakerSupplier implements Supplier { +public class KNNNodeLevelCircuitBreakerSupplier implements Supplier { /** * Constructor */ - public KNNCircuitBreakerSupplier() {} + public KNNNodeLevelCircuitBreakerSupplier() {} @Override public Boolean get() { - return KNNSettings.isCircuitBreakerTriggered(); + return KNNCircuitBreaker.getInstance().isTripped(); } } diff --git a/src/main/java/org/opensearch/knn/plugin/transport/KNNCircuitBreakerTrippedAction.java b/src/main/java/org/opensearch/knn/plugin/transport/KNNCircuitBreakerTrippedAction.java new file mode 100644 index 0000000000..45315f9e97 --- /dev/null +++ b/src/main/java/org/opensearch/knn/plugin/transport/KNNCircuitBreakerTrippedAction.java @@ -0,0 +1,37 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.plugin.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.core.common.io.stream.Writeable; + +/** + * Action used to detect if the KNNCircuit has been tripped cluster wide + */ +public class KNNCircuitBreakerTrippedAction extends ActionType { + + public static final String NAME = "cluster:admin/knn_circuit_breaker_tripped_action"; + public static final KNNCircuitBreakerTrippedAction INSTANCE = new KNNCircuitBreakerTrippedAction( + NAME, + KNNCircuitBreakerTrippedResponse::new + ); + + /** + * Constructor + * + * @param name name of action + * @param responseReader reader for the KNNCircuitBreakerTrippedResponse + */ + public KNNCircuitBreakerTrippedAction(String name, Writeable.Reader responseReader) { + super(name, responseReader); + } +} diff --git a/src/main/java/org/opensearch/knn/plugin/transport/KNNCircuitBreakerTrippedNodeRequest.java b/src/main/java/org/opensearch/knn/plugin/transport/KNNCircuitBreakerTrippedNodeRequest.java new file mode 100644 index 0000000000..c43810aff3 --- /dev/null +++ b/src/main/java/org/opensearch/knn/plugin/transport/KNNCircuitBreakerTrippedNodeRequest.java @@ -0,0 +1,46 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.plugin.transport; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.transport.TransportRequest; + +import java.io.IOException; + +/** + * Node request to detect if the circuit breaker has been tripped + */ +public class KNNCircuitBreakerTrippedNodeRequest extends TransportRequest { + + /** + * Constructor. + */ + public KNNCircuitBreakerTrippedNodeRequest() { + super(); + } + + /** + * Constructor from stream + * + * @param in input stream + * @throws IOException thrown when reading from stream fails + */ + public KNNCircuitBreakerTrippedNodeRequest(StreamInput in) throws IOException { + super(in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + } +} diff --git a/src/main/java/org/opensearch/knn/plugin/transport/KNNCircuitBreakerTrippedNodeResponse.java b/src/main/java/org/opensearch/knn/plugin/transport/KNNCircuitBreakerTrippedNodeResponse.java new file mode 100644 index 0000000000..b6deac2b36 --- /dev/null +++ b/src/main/java/org/opensearch/knn/plugin/transport/KNNCircuitBreakerTrippedNodeResponse.java @@ -0,0 +1,56 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.plugin.transport; + +import lombok.Getter; +import org.opensearch.action.support.nodes.BaseNodeResponse; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +import java.io.IOException; + +/** + * Node response for if KNNCircuitBreaker is tripped or not + */ +@Getter +public class KNNCircuitBreakerTrippedNodeResponse extends BaseNodeResponse { + + private final boolean isTripped; + + /** + * Constructor from Stream. + * + * @param in stream input + * @throws IOException thrown when unable to read from stream + */ + public KNNCircuitBreakerTrippedNodeResponse(StreamInput in) throws IOException { + super(in); + isTripped = in.readBoolean(); + } + + /** + * Constructor + * + * @param node node + */ + public KNNCircuitBreakerTrippedNodeResponse(DiscoveryNode node, boolean isTripped) { + super(node); + this.isTripped = isTripped; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeBoolean(isTripped); + } +} diff --git a/src/main/java/org/opensearch/knn/plugin/transport/KNNCircuitBreakerTrippedRequest.java b/src/main/java/org/opensearch/knn/plugin/transport/KNNCircuitBreakerTrippedRequest.java new file mode 100644 index 0000000000..e07a31b523 --- /dev/null +++ b/src/main/java/org/opensearch/knn/plugin/transport/KNNCircuitBreakerTrippedRequest.java @@ -0,0 +1,48 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.plugin.transport; + +import org.opensearch.action.support.nodes.BaseNodesRequest; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +import java.io.IOException; + +/** + * Request to check if circuit breaker for cluster has been tripped + */ +public class KNNCircuitBreakerTrippedRequest extends BaseNodesRequest { + + /** + * Constructor. + * + * @param nodeIds Id's of nodes + */ + public KNNCircuitBreakerTrippedRequest(String... nodeIds) { + super(nodeIds); + } + + /** + * Constructor. + * + * @param in input stream + * @throws IOException thrown when reading input stream fails + */ + public KNNCircuitBreakerTrippedRequest(StreamInput in) throws IOException { + super(in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + } +} diff --git a/src/main/java/org/opensearch/knn/plugin/transport/KNNCircuitBreakerTrippedResponse.java b/src/main/java/org/opensearch/knn/plugin/transport/KNNCircuitBreakerTrippedResponse.java new file mode 100644 index 0000000000..ff2626457f --- /dev/null +++ b/src/main/java/org/opensearch/knn/plugin/transport/KNNCircuitBreakerTrippedResponse.java @@ -0,0 +1,84 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.plugin.transport; + +import lombok.Getter; +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.nodes.BaseNodesResponse; +import org.opensearch.cluster.ClusterName; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +import java.io.IOException; +import java.util.List; + +/** + * Response indicating if circuit breaker has been tripped. Circuit breaker is said to be tripped if it is tripped + * on any nodes. + */ +@Getter +public class KNNCircuitBreakerTrippedResponse extends BaseNodesResponse { + + private final boolean isTripped; + + /** + * Constructor. + * + * @param clusterName cluster's name + * @param nodes list of responses from each node + * @param failures list of failures from each node. + */ + public KNNCircuitBreakerTrippedResponse( + ClusterName clusterName, + List nodes, + List failures + ) { + super(clusterName, nodes, failures); + this.isTripped = checkIfTripped(nodes); + } + + /** + * Constructor. + * + * @param in input stream + * @throws IOException thrown when input stream cannot be read + */ + public KNNCircuitBreakerTrippedResponse(StreamInput in) throws IOException { + super(new ClusterName(in), in.readList(KNNCircuitBreakerTrippedNodeResponse::new), in.readList(FailedNodeException::new)); + this.isTripped = in.readBoolean(); + } + + @Override + protected List readNodesFrom(StreamInput in) throws IOException { + return in.readList(KNNCircuitBreakerTrippedNodeResponse::new); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeBoolean(isTripped); + } + + @Override + protected void writeNodesTo(StreamOutput out, List nodes) throws IOException { + out.writeList(nodes); + } + + private boolean checkIfTripped(List nodeResponses) { + for (KNNCircuitBreakerTrippedNodeResponse nodeResponse : nodeResponses) { + if (nodeResponse.isTripped()) { + return true; + } + } + return false; + } +} diff --git a/src/main/java/org/opensearch/knn/plugin/transport/KNNCircuitBreakerTrippedTransportAction.java b/src/main/java/org/opensearch/knn/plugin/transport/KNNCircuitBreakerTrippedTransportAction.java new file mode 100644 index 0000000000..48e7982071 --- /dev/null +++ b/src/main/java/org/opensearch/knn/plugin/transport/KNNCircuitBreakerTrippedTransportAction.java @@ -0,0 +1,82 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.plugin.transport; + +import lombok.extern.log4j.Log4j2; + +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.nodes.TransportNodesAction; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.knn.index.KNNCircuitBreaker; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +import java.io.IOException; +import java.util.List; + +/** + * Transport action to check if the KNN CB is tripped on any of the nodes in the cluster. + */ +@Log4j2 +public class KNNCircuitBreakerTrippedTransportAction extends TransportNodesAction< + KNNCircuitBreakerTrippedRequest, + KNNCircuitBreakerTrippedResponse, + KNNCircuitBreakerTrippedNodeRequest, + KNNCircuitBreakerTrippedNodeResponse> { + + @Inject + public KNNCircuitBreakerTrippedTransportAction( + ThreadPool threadPool, + ClusterService clusterService, + TransportService transportService, + ActionFilters actionFilters + ) { + super( + KNNCircuitBreakerTrippedAction.NAME, + threadPool, + clusterService, + transportService, + actionFilters, + KNNCircuitBreakerTrippedRequest::new, + KNNCircuitBreakerTrippedNodeRequest::new, + ThreadPool.Names.SAME, + KNNCircuitBreakerTrippedNodeResponse.class + ); + } + + @Override + protected KNNCircuitBreakerTrippedResponse newResponse( + KNNCircuitBreakerTrippedRequest nodesRequest, + List responses, + List failures + ) { + return new KNNCircuitBreakerTrippedResponse(clusterService.getClusterName(), responses, failures); + } + + @Override + protected KNNCircuitBreakerTrippedNodeRequest newNodeRequest(KNNCircuitBreakerTrippedRequest request) { + return new KNNCircuitBreakerTrippedNodeRequest(); + } + + @Override + protected KNNCircuitBreakerTrippedNodeResponse newNodeResponse(StreamInput in) throws IOException { + return new KNNCircuitBreakerTrippedNodeResponse(in); + } + + @Override + protected KNNCircuitBreakerTrippedNodeResponse nodeOperation(KNNCircuitBreakerTrippedNodeRequest nodeRequest) { + return new KNNCircuitBreakerTrippedNodeResponse(clusterService.localNode(), KNNCircuitBreaker.getInstance().isTripped()); + } +} diff --git a/src/main/java/org/opensearch/knn/plugin/transport/KNNStatsTransportAction.java b/src/main/java/org/opensearch/knn/plugin/transport/KNNStatsTransportAction.java index 7edc44894a..1d0d81137a 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/KNNStatsTransportAction.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/KNNStatsTransportAction.java @@ -5,6 +5,8 @@ package org.opensearch.knn.plugin.transport; +import org.opensearch.core.action.ActionListener; +import org.opensearch.knn.plugin.stats.KNNStatFetchContext; import org.opensearch.knn.plugin.stats.KNNStats; import org.opensearch.action.FailedNodeException; @@ -13,6 +15,7 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; import org.opensearch.threadpool.ThreadPool; @@ -31,7 +34,8 @@ public class KNNStatsTransportAction extends TransportNodesAction< KNNStatsNodeRequest, KNNStatsNodeResponse> { - private KNNStats knnStats; + private final KNNStats knnStats; + private final KNNStatFetchContext knnStatFetchContext; /** * Constructor @@ -62,6 +66,18 @@ public KNNStatsTransportAction( KNNStatsNodeResponse.class ); this.knnStats = knnStats; + this.knnStatFetchContext = new KNNStatFetchContext(); + } + + protected void doExecute(Task task, KNNStatsRequest request, ActionListener listener) { + ActionListener contextListener = ActionListener.wrap(none -> super.doExecute(task, request, listener), listener::onFailure); + Set statsToBeRetrieved = request.getStatsToBeRetrieved(); + for (String statName : knnStats.getClusterStats().keySet()) { + if (statsToBeRetrieved.contains(statName)) { + contextListener = knnStats.getClusterStats().get(statName).setupContext(knnStatFetchContext, contextListener); + } + } + contextListener.onResponse(null); } @Override @@ -76,7 +92,7 @@ protected KNNStatsResponse newResponse( for (String statName : knnStats.getClusterStats().keySet()) { if (statsToBeRetrieved.contains(statName)) { - clusterStats.put(statName, knnStats.getStats().get(statName).getValue()); + clusterStats.put(statName, knnStats.getStats().get(statName).getValue(knnStatFetchContext)); } } diff --git a/src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java b/src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java index 2ec9ce6b50..cb79cf28aa 100644 --- a/src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java +++ b/src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java @@ -15,6 +15,7 @@ import org.opensearch.cluster.block.ClusterBlocks; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.knn.index.KNNCircuitBreaker; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.knn.index.query.KNNQueryBuilder; @@ -91,6 +92,7 @@ public void tearDown() throws Exception { NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance().close(); NativeMemoryLoadStrategy.TrainingLoadStrategy.getInstance().close(); NativeMemoryLoadStrategy.AnonymousLoadStrategy.getInstance().close(); + KNNCircuitBreaker.getInstance().setTripped(false); super.tearDown(); } diff --git a/src/test/java/org/opensearch/knn/index/KNNCircuitBreakerIT.java b/src/test/java/org/opensearch/knn/index/KNNCircuitBreakerIT.java index b9e0118113..c169a64379 100644 --- a/src/test/java/org/opensearch/knn/index/KNNCircuitBreakerIT.java +++ b/src/test/java/org/opensearch/knn/index/KNNCircuitBreakerIT.java @@ -18,7 +18,7 @@ import java.util.List; import java.util.Map; -import static org.opensearch.knn.index.KNNCircuitBreaker.CB_TIME_INTERVAL; +import static org.opensearch.knn.index.memory.NativeMemoryCacheManager.CB_TIME_INTERVAL; /** * Integration tests to test Circuit Breaker functionality @@ -144,13 +144,6 @@ private double getGraphMemoryPercentage() throws Exception { return Double.parseDouble(nodeStatsResponse.getFirst().get(StatNames.GRAPH_MEMORY_USAGE_PERCENTAGE.getName()).toString()); } - public boolean isCbTripped() throws Exception { - Response response = getKnnStats(Collections.emptyList(), Collections.singletonList("circuit_breaker_triggered")); - String responseBody = EntityUtils.toString(response.getEntity()); - Map clusterStats = parseClusterStatsResponse(responseBody); - return Boolean.parseBoolean(clusterStats.get("circuit_breaker_triggered").toString()); - } - public void testCbTripped() throws Exception { setupIndices(); testClusterLevelCircuitBreaker(); @@ -158,12 +151,6 @@ public void testCbTripped() throws Exception { } public void verifyCbUntrips() throws Exception { - - if (!isCbTripped()) { - updateClusterSettings("knn.circuit_breaker.triggered", "true"); - - } - int backOffInterval = 5; // seconds for (int i = 0; i < CB_TIME_INTERVAL; i += backOffInterval) { if (!isCbTripped()) { diff --git a/src/test/java/org/opensearch/knn/index/OpenSearchIT.java b/src/test/java/org/opensearch/knn/index/OpenSearchIT.java index ca07cd0f34..98a030883b 100644 --- a/src/test/java/org/opensearch/knn/index/OpenSearchIT.java +++ b/src/test/java/org/opensearch/knn/index/OpenSearchIT.java @@ -185,70 +185,6 @@ public void testEndToEnd() throws Exception { fail("Graphs are not getting evicted"); } - public void testAddDoc_blockedWhenCbTrips() throws Exception { - createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); - updateClusterSettings("knn.circuit_breaker.triggered", "true"); - - Float[] vector = { 6.0f, 6.0f }; - ResponseException ex = expectThrows(ResponseException.class, () -> addKnnDoc(INDEX_NAME, "1", FIELD_NAME, vector)); - String expMessage = - "Parsing the created knn vector fields prior to indexing has failed as the circuit breaker triggered. This indicates that the cluster is low on memory resources and cannot index more documents at the moment. Check _plugins/_knn/stats for the circuit breaker status."; - assertThat(EntityUtils.toString(ex.getResponse().getEntity()), containsString(expMessage)); - - // reset - updateClusterSettings("knn.circuit_breaker.triggered", "false"); - addKnnDoc(INDEX_NAME, "1", FIELD_NAME, vector); - } - - public void testUpdateDoc_blockedWhenCbTrips() throws Exception { - createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); - Float[] vector = { 6.0f, 6.0f }; - addKnnDoc(INDEX_NAME, "1", FIELD_NAME, vector); - - // update - updateClusterSettings("knn.circuit_breaker.triggered", "true"); - Float[] updatedVector = { 8.0f, 8.0f }; - ResponseException ex = expectThrows(ResponseException.class, () -> updateKnnDoc(INDEX_NAME, "1", FIELD_NAME, vector)); - String expMessage = - "Parsing the created knn vector fields prior to indexing has failed as the circuit breaker triggered. This indicates that the cluster is low on memory resources and cannot index more documents at the moment. Check _plugins/_knn/stats for the circuit breaker status."; - assertThat(EntityUtils.toString(ex.getResponse().getEntity()), containsString(expMessage)); - - // reset - updateClusterSettings("knn.circuit_breaker.triggered", "false"); - updateKnnDoc(INDEX_NAME, "1", FIELD_NAME, vector); - } - - public void testAddAndSearchIndex_whenCBTrips() throws Exception { - createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); - for (int i = 1; i <= 4; i++) { - Float[] vector = { (float) i, (float) (i + 1) }; - addKnnDoc(INDEX_NAME, Integer.toString(i), FIELD_NAME, vector); - } - - float[] queryVector = { 1.0f, 1.0f }; // vector to be queried - int k = 10; // nearest 10 neighbor - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, k); - Response response = searchKNNIndex(INDEX_NAME, knnQueryBuilder, k); - List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); - assertEquals(4, results.size()); - - updateClusterSettings("knn.circuit_breaker.triggered", "true"); - // Try add another doc - Float[] vector = { 1.0f, 2.0f }; - ResponseException ex = expectThrows(ResponseException.class, () -> addKnnDoc(INDEX_NAME, "5", FIELD_NAME, vector)); - - // Still get 4 docs - response = searchKNNIndex(INDEX_NAME, knnQueryBuilder, k); - results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); - assertEquals(4, results.size()); - - updateClusterSettings("knn.circuit_breaker.triggered", "false"); - addKnnDoc(INDEX_NAME, "5", FIELD_NAME, vector); - response = searchKNNIndex(INDEX_NAME, knnQueryBuilder, k); - results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); - assertEquals(5, results.size()); - } - public void testIndexingVectorValidation_differentSizes() throws Exception { Settings settings = Settings.builder().put(getKNNDefaultIndexSettings()).build(); diff --git a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryCacheManagerTests.java b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryCacheManagerTests.java index 45ccf0e9d0..c43fc5ca03 100644 --- a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryCacheManagerTests.java +++ b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryCacheManagerTests.java @@ -13,7 +13,6 @@ import com.google.common.cache.CacheStats; import org.junit.Before; -import org.opensearch.action.admin.cluster.settings.ClusterUpdateSettingsRequest; import org.opensearch.common.settings.Settings; import org.opensearch.knn.common.exception.OutOfNativeMemoryException; import org.opensearch.knn.index.KNNSettings; @@ -47,11 +46,6 @@ public void setThreadPool() { @Override public void tearDown() throws Exception { - // Clear out persistent metadata - ClusterUpdateSettingsRequest clusterUpdateSettingsRequest = new ClusterUpdateSettingsRequest(); - Settings circuitBreakerSettings = Settings.builder().putNull(KNNSettings.KNN_CIRCUIT_BREAKER_TRIGGERED).build(); - clusterUpdateSettingsRequest.persistentSettings(circuitBreakerSettings); - client().admin().cluster().updateSettings(clusterUpdateSettingsRequest).get(); NativeMemoryCacheManager.getInstance().close(); terminate(threadPool); super.tearDown(); @@ -387,18 +381,6 @@ public void testInvalidateAll() throws ExecutionException { nativeMemoryCacheManager.close(); } - public void testCacheCapacity() { - NativeMemoryCacheManager nativeMemoryCacheManager = new NativeMemoryCacheManager(); - assertFalse(nativeMemoryCacheManager.isCacheCapacityReached()); - - nativeMemoryCacheManager.setCacheCapacityReached(true); - assertTrue(nativeMemoryCacheManager.isCacheCapacityReached()); - - nativeMemoryCacheManager.setCacheCapacityReached(false); - assertFalse(nativeMemoryCacheManager.isCacheCapacityReached()); - nativeMemoryCacheManager.close(); - } - public void testGetIndicesCacheStats() throws IOException, ExecutionException { NativeMemoryCacheManager nativeMemoryCacheManager = new NativeMemoryCacheManager(); Map> indicesStats = nativeMemoryCacheManager.getIndicesCacheStats(); diff --git a/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java b/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java index 1324616092..9a1be98561 100644 --- a/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java +++ b/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java @@ -7,10 +7,9 @@ import lombok.SneakyThrows; import org.apache.hc.core5.http.io.entity.EntityUtils; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; import org.junit.Before; import org.junit.rules.DisableOnDebug; +import org.opensearch.Version; import org.opensearch.client.Request; import org.opensearch.client.Response; import org.opensearch.client.ResponseException; @@ -57,7 +56,6 @@ */ public class RestKNNStatsHandlerIT extends KNNRestTestCase { - private static final Logger logger = LogManager.getLogger(RestKNNStatsHandlerIT.class); private static final String TRAINING_INDEX = "training-index"; private static final String TRAINING_FIELD = "training-field"; private static final String TEST_MODEL_ID = "model-id"; @@ -78,7 +76,7 @@ public class RestKNNStatsHandlerIT extends KNNRestTestCase { @Before public void setup() { - knnStats = new KNNStats(); + knnStats = new KNNStats(null, () -> Version.CURRENT); } /** diff --git a/src/test/java/org/opensearch/knn/plugin/action/RestLegacyKNNStatsHandlerIT.java b/src/test/java/org/opensearch/knn/plugin/action/RestLegacyKNNStatsHandlerIT.java index e00d90056d..6689467354 100644 --- a/src/test/java/org/opensearch/knn/plugin/action/RestLegacyKNNStatsHandlerIT.java +++ b/src/test/java/org/opensearch/knn/plugin/action/RestLegacyKNNStatsHandlerIT.java @@ -11,6 +11,7 @@ package org.opensearch.knn.plugin.action; +import org.opensearch.Version; import org.opensearch.knn.KNNRestTestCase; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.SpaceType; @@ -19,8 +20,6 @@ import org.opensearch.knn.plugin.stats.KNNStats; import org.opensearch.knn.plugin.stats.StatNames; import org.apache.hc.core5.http.io.entity.EntityUtils; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; import org.junit.Before; import org.junit.rules.DisableOnDebug; import org.opensearch.client.Request; @@ -44,7 +43,6 @@ */ public class RestLegacyKNNStatsHandlerIT extends KNNRestTestCase { - private static final Logger logger = LogManager.getLogger(RestLegacyKNNStatsHandlerIT.class); private boolean isDebuggingTest = new DisableOnDebug(null).isDebugging(); private boolean isDebuggingRemoteCluster = System.getProperty("cluster.debug", "false").equals("true"); @@ -52,7 +50,7 @@ public class RestLegacyKNNStatsHandlerIT extends KNNRestTestCase { @Before public void setup() { - knnStats = new KNNStats(); + knnStats = new KNNStats(null, () -> Version.CURRENT); } /** diff --git a/src/test/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheTests.java b/src/test/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheTests.java index 87cb57cdcb..f14486c19c 100644 --- a/src/test/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheTests.java +++ b/src/test/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheTests.java @@ -9,7 +9,6 @@ import lombok.SneakyThrows; import org.junit.After; import org.junit.Before; -import org.opensearch.transport.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; @@ -352,9 +351,7 @@ public void testRebuildOnCacheSizeSettingsChange() { when(clusterService.getClusterSettings()).thenReturn(clusterSettings); when(clusterService.getSettings()).thenReturn(settings); - Client client = mock(Client.class); - - KNNSettings.state().initialize(client, clusterService); + KNNSettings.state().initialize(clusterService); QuantizationStateCache cache = QuantizationStateCache.getInstance(); cache.rebuildCache(); @@ -404,10 +401,7 @@ public void testRebuildOnTimeExpirySettingsChange() { ClusterService clusterService = mock(ClusterService.class); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); when(clusterService.getSettings()).thenReturn(settings); - - Client client = mock(Client.class); - - KNNSettings.state().initialize(client, clusterService); + KNNSettings.state().initialize(clusterService); QuantizationStateCache cache = QuantizationStateCache.getInstance(); cache.addQuantizationState(fieldName, state); diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index e0992f38ee..7cdfe1181e 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -2401,4 +2401,15 @@ protected void setupSnapshotRestore(String index, String snapshot, String reposi // create snapshot createSnapshot(repository, snapshot, true); } + + protected boolean isCbTripped() throws Exception { + Response response = getKnnStats(Collections.emptyList(), Collections.singletonList("circuit_breaker_triggered")); + String responseBody = EntityUtils.toString(response.getEntity()); + Map clusterStats = parseClusterStatsResponse(responseBody); + return Boolean.parseBoolean(clusterStats.get("circuit_breaker_triggered").toString()); + } + + protected void tripCB(String index, String snapshot) { + setupSnapshotRestore(index, snapshot, "repo-" + randomLowerCaseString()); + } }