From 23186115d911de63037a961ee7dc6c5814b3e6f0 Mon Sep 17 00:00:00 2001 From: John Mazanec Date: Thu, 27 Feb 2025 08:24:25 -0800 Subject: [PATCH] Remove usage of cluster level setting for circuit breaker Simplification of circuit breaker management. Before, we were periodically updating the circuit breaker cluster setting to have a global circuit breaker. However, this is inconvenient and not actually useful. This change removes that logic and only trips based on node level circuit breaker. For knn stats, in order to fetch the cluster level circuit breaker, a transport call is made to check all of the nodes. This isnt super efficient but its made for stats calls so its not on the critical path. Signed-off-by: John Mazanec --- .../java/org/opensearch/knn/bwc/StatsIT.java | 2 +- .../knn/index/KNNCircuitBreaker.java | 85 ++++++------------- .../org/opensearch/knn/index/KNNSettings.java | 49 +---------- .../mapper/KNNVectorFieldMapperUtil.java | 5 +- .../memory/NativeMemoryCacheManager.java | 22 +---- .../org/opensearch/knn/plugin/KNNPlugin.java | 11 ++- .../opensearch/knn/plugin/stats/KNNStats.java | 15 ++-- ...KNNClusterLevelCircuitBreakerSupplier.java | 34 ++++++++ ...> KNNNodeLevelCircuitBreakerSupplier.java} | 8 +- .../KNNCircuitBreakerTrippedAction.java | 37 ++++++++ .../KNNCircuitBreakerTrippedNodeRequest.java | 46 ++++++++++ .../KNNCircuitBreakerTrippedNodeResponse.java | 56 ++++++++++++ .../KNNCircuitBreakerTrippedRequest.java | 48 +++++++++++ .../KNNCircuitBreakerTrippedResponse.java | 84 ++++++++++++++++++ ...NCircuitBreakerTrippedTransportAction.java | 82 ++++++++++++++++++ .../transport/KNNStatsTransportAction.java | 2 +- .../memory/NativeMemoryCacheManagerTests.java | 12 --- .../plugin/action/RestKNNStatsHandlerIT.java | 5 +- .../action/RestLegacyKNNStatsHandlerIT.java | 5 +- .../QuantizationStateCacheTests.java | 10 +-- 20 files changed, 445 insertions(+), 173 deletions(-) create mode 100644 src/main/java/org/opensearch/knn/plugin/stats/suppliers/KNNClusterLevelCircuitBreakerSupplier.java rename src/main/java/org/opensearch/knn/plugin/stats/suppliers/{KNNCircuitBreakerSupplier.java => KNNNodeLevelCircuitBreakerSupplier.java} (55%) create mode 100644 src/main/java/org/opensearch/knn/plugin/transport/KNNCircuitBreakerTrippedAction.java create mode 100644 src/main/java/org/opensearch/knn/plugin/transport/KNNCircuitBreakerTrippedNodeRequest.java create mode 100644 src/main/java/org/opensearch/knn/plugin/transport/KNNCircuitBreakerTrippedNodeResponse.java create mode 100644 src/main/java/org/opensearch/knn/plugin/transport/KNNCircuitBreakerTrippedRequest.java create mode 100644 src/main/java/org/opensearch/knn/plugin/transport/KNNCircuitBreakerTrippedResponse.java create mode 100644 src/main/java/org/opensearch/knn/plugin/transport/KNNCircuitBreakerTrippedTransportAction.java diff --git a/qa/rolling-upgrade/src/test/java/org/opensearch/knn/bwc/StatsIT.java b/qa/rolling-upgrade/src/test/java/org/opensearch/knn/bwc/StatsIT.java index 50beb60d5f..c41b2a6e16 100644 --- a/qa/rolling-upgrade/src/test/java/org/opensearch/knn/bwc/StatsIT.java +++ b/qa/rolling-upgrade/src/test/java/org/opensearch/knn/bwc/StatsIT.java @@ -21,7 +21,7 @@ public class StatsIT extends AbstractRollingUpgradeTestCase { @Before public void setUp() throws Exception { super.setUp(); - this.knnStats = new KNNStats(); + this.knnStats = new KNNStats(null); } // Validate if all the KNN Stats metrics from old version are present in new version diff --git a/src/main/java/org/opensearch/knn/index/KNNCircuitBreaker.java b/src/main/java/org/opensearch/knn/index/KNNCircuitBreaker.java index 4829777be0..56c4f96ca4 100644 --- a/src/main/java/org/opensearch/knn/index/KNNCircuitBreaker.java +++ b/src/main/java/org/opensearch/knn/index/KNNCircuitBreaker.java @@ -5,34 +5,22 @@ package org.opensearch.knn.index; +import lombok.Getter; import org.opensearch.knn.index.memory.NativeMemoryCacheManager; -import org.opensearch.knn.plugin.stats.StatNames; -import org.opensearch.knn.plugin.transport.KNNStatsAction; -import org.opensearch.knn.plugin.transport.KNNStatsNodeResponse; -import org.opensearch.knn.plugin.transport.KNNStatsRequest; -import org.opensearch.knn.plugin.transport.KNNStatsResponse; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.opensearch.transport.client.Client; -import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.unit.TimeValue; import org.opensearch.threadpool.ThreadPool; -import java.util.ArrayList; -import java.util.List; - /** * Runs the circuit breaker logic and updates the settings */ +@Getter public class KNNCircuitBreaker { - private static Logger logger = LogManager.getLogger(KNNCircuitBreaker.class); public static final String KNN_CIRCUIT_BREAKER_TIER = "knn_cb_tier"; public static int CB_TIME_INTERVAL = 2 * 60; // seconds private static KNNCircuitBreaker INSTANCE; - private ThreadPool threadPool; - private ClusterService clusterService; - private Client client; + + private boolean isTripped = false; private KNNCircuitBreaker() {} @@ -52,60 +40,35 @@ public static synchronized void setInstance(KNNCircuitBreaker instance) { INSTANCE = instance; } - public void initialize(ThreadPool threadPool, ClusterService clusterService, Client client) { - this.threadPool = threadPool; - this.clusterService = clusterService; - this.client = client; + /** + * Initialize the circuit breaker + * + * @param threadPool ThreadPool instance + */ + public void initialize(ThreadPool threadPool) { NativeMemoryCacheManager nativeMemoryCacheManager = NativeMemoryCacheManager.getInstance(); Runnable runnable = () -> { - if (nativeMemoryCacheManager.isCacheCapacityReached() && clusterService.localNode().isDataNode()) { + if (isTripped) { long currentSizeKiloBytes = nativeMemoryCacheManager.getCacheSizeInKilobytes(); long circuitBreakerLimitSizeKiloBytes = KNNSettings.state().getCircuitBreakerLimit().getKb(); long circuitBreakerUnsetSizeKiloBytes = (long) ((KNNSettings.getCircuitBreakerUnsetPercentage() / 100) * circuitBreakerLimitSizeKiloBytes); - /** - * Unset capacityReached flag if currentSizeBytes is less than circuitBreakerUnsetSizeBytes - */ - if (currentSizeKiloBytes <= circuitBreakerUnsetSizeKiloBytes) { - nativeMemoryCacheManager.setCacheCapacityReached(false); - } - } - - // Leader node untriggers CB if all nodes have not reached their max capacity - if (KNNSettings.isCircuitBreakerTriggered() && clusterService.state().nodes().isLocalNodeElectedClusterManager()) { - KNNStatsRequest knnStatsRequest = new KNNStatsRequest(); - knnStatsRequest.addStat(StatNames.CACHE_CAPACITY_REACHED.getName()); - knnStatsRequest.timeout(new TimeValue(1000 * 10)); // 10 second timeout - try { - KNNStatsResponse knnStatsResponse = client.execute(KNNStatsAction.INSTANCE, knnStatsRequest).get(); - List nodeResponses = knnStatsResponse.getNodes(); - - List nodesAtMaxCapacity = new ArrayList<>(); - for (KNNStatsNodeResponse nodeResponse : nodeResponses) { - if ((Boolean) nodeResponse.getStatsMap().get(StatNames.CACHE_CAPACITY_REACHED.getName())) { - nodesAtMaxCapacity.add(nodeResponse.getNode().getId()); - } - } - - if (!nodesAtMaxCapacity.isEmpty()) { - logger.info( - "[KNN] knn.circuit_breaker.triggered stays set. Nodes at max cache capacity: " - + String.join(",", nodesAtMaxCapacity) - + "." - ); - } else { - logger.info( - "[KNN] Cache capacity below 75% of the circuit breaker limit for all nodes." - + " Unsetting knn.circuit_breaker.triggered flag." - ); - KNNSettings.state().updateCircuitBreakerSettings(false); - } - } catch (Exception e) { - logger.error("[KNN] Exception getting stats: " + e); + // Unset capacityReached flag if currentSizeBytes is less than circuitBreakerUnsetSizeBytes + if (currentSizeKiloBytes <= circuitBreakerUnsetSizeKiloBytes) { + setTripped(false); } } }; - this.threadPool.scheduleWithFixedDelay(runnable, TimeValue.timeValueSeconds(CB_TIME_INTERVAL), ThreadPool.Names.GENERIC); + threadPool.scheduleWithFixedDelay(runnable, TimeValue.timeValueSeconds(CB_TIME_INTERVAL), ThreadPool.Names.GENERIC); + } + + /** + * Set the circuit breaker flag + * + * @param isTripped true if circuit breaker is tripped, false otherwise + */ + public synchronized void setTripped(boolean isTripped) { + this.isTripped = isTripped; } } diff --git a/src/main/java/org/opensearch/knn/index/KNNSettings.java b/src/main/java/org/opensearch/knn/index/KNNSettings.java index c5d9a05748..f461826e53 100644 --- a/src/main/java/org/opensearch/knn/index/KNNSettings.java +++ b/src/main/java/org/opensearch/knn/index/KNNSettings.java @@ -10,15 +10,12 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.OpenSearchParseException; -import org.opensearch.action.admin.cluster.settings.ClusterUpdateSettingsRequest; -import org.opensearch.action.admin.cluster.settings.ClusterUpdateSettingsResponse; import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.Booleans; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; -import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.unit.ByteSizeUnit; import org.opensearch.core.common.unit.ByteSizeValue; import org.opensearch.index.IndexModule; @@ -28,7 +25,6 @@ import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateCacheManager; import org.opensearch.monitor.jvm.JvmInfo; import org.opensearch.monitor.os.OsProbe; -import org.opensearch.transport.client.Client; import java.security.InvalidParameterException; import java.util.Arrays; @@ -317,7 +313,8 @@ public class KNNSettings { KNN_CIRCUIT_BREAKER_TRIGGERED, false, NodeScope, - Dynamic + Dynamic, + Setting.Property.Deprecated ); public static final Setting KNN_CIRCUIT_BREAKER_UNSET_PERCENTAGE_SETTING = Setting.doubleSetting( @@ -473,8 +470,8 @@ public class KNNSettings { private final static Map> FEATURE_FLAGS = getFeatureFlags().stream() .collect(toUnmodifiableMap(Setting::getKey, Function.identity())); + @Setter private ClusterService clusterService; - private Client client; @Setter private Optional nodeCbAttribute; @@ -638,10 +635,6 @@ public static boolean isKNNPluginEnabled() { return KNNSettings.state().getSettingValue(KNNSettings.KNN_PLUGIN_ENABLED); } - public static boolean isCircuitBreakerTriggered() { - return KNNSettings.state().getSettingValue(KNNSettings.KNN_CIRCUIT_BREAKER_TRIGGERED); - } - /** * Retrieves the node-specific circuit breaker limit based on the existing settings. * @@ -806,8 +799,7 @@ public static boolean isShardLevelRescoringDisabledForDiskBasedVector(String ind .getAsBoolean(KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED, false); } - public void initialize(Client client, ClusterService clusterService) { - this.client = client; + public void initialize(ClusterService clusterService) { this.clusterService = clusterService; this.nodeCbAttribute = Optional.empty(); setSettingsUpdateConsumers(); @@ -841,35 +833,6 @@ public static ByteSizeValue parseknnMemoryCircuitBreakerValue(String sValue, Byt } } - /** - * Updates knn.circuit_breaker.triggered setting to true/false - * @param flag true/false - */ - public synchronized void updateCircuitBreakerSettings(boolean flag) { - ClusterUpdateSettingsRequest clusterUpdateSettingsRequest = new ClusterUpdateSettingsRequest(); - Settings circuitBreakerSettings = Settings.builder().put(KNNSettings.KNN_CIRCUIT_BREAKER_TRIGGERED, flag).build(); - clusterUpdateSettingsRequest.persistentSettings(circuitBreakerSettings); - client.admin().cluster().updateSettings(clusterUpdateSettingsRequest, new ActionListener() { - @Override - public void onResponse(ClusterUpdateSettingsResponse clusterUpdateSettingsResponse) { - logger.debug( - "Cluster setting {}, acknowledged: {} ", - clusterUpdateSettingsRequest.persistentSettings(), - clusterUpdateSettingsResponse.isAcknowledged() - ); - } - - @Override - public void onFailure(Exception e) { - logger.info( - "Exception while updating circuit breaker setting {} to {}", - clusterUpdateSettingsRequest.persistentSettings(), - e.getMessage() - ); - } - }); - } - public static ByteSizeValue getVectorStreamingMemoryLimit() { return KNNSettings.state().getSettingValue(KNN_VECTOR_STREAMING_MEMORY_LIMIT_IN_MB); } @@ -888,10 +851,6 @@ public static int getEfSearchParam(String index) { ); } - public void setClusterService(ClusterService clusterService) { - this.clusterService = clusterService; - } - static class SpaceTypeValidator implements Setting.Validator { @Override diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java index d2f9c21b7e..fa2a361735 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java @@ -20,6 +20,7 @@ import org.apache.lucene.util.BytesRef; import org.opensearch.Version; import org.opensearch.common.settings.Settings; +import org.opensearch.knn.index.KNNCircuitBreaker; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.KnnCircuitBreakerException; import org.opensearch.knn.index.SpaceType; @@ -114,9 +115,9 @@ public static int getExpectedVectorLength(final KNNVectorFieldType knnVectorFiel * Validate if the circuit breaker is triggered */ static void validateIfCircuitBreakerIsNotTriggered() { - if (KNNSettings.isCircuitBreakerTriggered()) { + if (KNNCircuitBreaker.getInstance().isTripped()) { throw new KnnCircuitBreakerException( - "Parsing the created knn vector fields prior to indexing has failed as the circuit breaker triggered. This indicates that the cluster is low on memory resources and cannot index more documents at the moment. Check _plugins/_knn/stats for the circuit breaker status." + "Parsing the created knn vector fields prior to indexing has failed as the circuit breaker triggered. This indicates that the node is low on memory resources and cannot index more documents at the moment. Check _plugins/_knn/stats for the circuit breaker status." ); } } diff --git a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryCacheManager.java b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryCacheManager.java index cc7edbf276..37c039debf 100644 --- a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryCacheManager.java +++ b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryCacheManager.java @@ -24,6 +24,7 @@ import org.opensearch.common.unit.TimeValue; import org.opensearch.knn.common.exception.OutOfNativeMemoryException; import org.opensearch.knn.common.featureflags.KNNFeatureFlags; +import org.opensearch.knn.index.KNNCircuitBreaker; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.plugin.stats.StatNames; import org.opensearch.threadpool.ThreadPool; @@ -413,24 +414,6 @@ public void invalidateAll() { cache.invalidateAll(); } - /** - * Returns whether or not the capacity of the cache has been reached - * - * @return Boolean of whether cache limit has been reached - */ - public Boolean isCacheCapacityReached() { - return cacheCapacityReached.get(); - } - - /** - * Sets cache capacity reached - * - * @param value Boolean value to set cache Capacity Reached to - */ - public void setCacheCapacityReached(Boolean value) { - cacheCapacityReached.set(value); - } - /** * Get the stats of all of the OpenSearch indices currently loaded into the cache * @@ -461,8 +444,7 @@ private void onRemoval(RemovalNotification remov nativeMemoryAllocation.close(); if (RemovalCause.SIZE == removalNotification.getCause()) { - KNNSettings.state().updateCircuitBreakerSettings(true); - setCacheCapacityReached(true); + KNNCircuitBreaker.getInstance().setTripped(true); } logger.debug("[KNN] Cache evicted. Key {}, Reason: {}", removalNotification.getKey(), removalNotification.getCause()); diff --git a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java index e8aa580f47..c5a1edc331 100644 --- a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java +++ b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java @@ -60,6 +60,8 @@ import org.opensearch.knn.plugin.transport.DeleteModelTransportAction; import org.opensearch.knn.plugin.transport.GetModelAction; import org.opensearch.knn.plugin.transport.GetModelTransportAction; +import org.opensearch.knn.plugin.transport.KNNCircuitBreakerTrippedAction; +import org.opensearch.knn.plugin.transport.KNNCircuitBreakerTrippedTransportAction; import org.opensearch.knn.plugin.transport.KNNStatsAction; import org.opensearch.knn.plugin.transport.KNNStatsTransportAction; import org.opensearch.knn.plugin.transport.KNNWarmupAction; @@ -202,7 +204,7 @@ public Collection createComponents( VectorReader vectorReader = new VectorReader(client); NativeMemoryLoadStrategy.TrainingLoadStrategy.initialize(vectorReader); - KNNSettings.state().initialize(client, clusterService); + KNNSettings.state().initialize(clusterService); KNNClusterUtil.instance().initialize(clusterService); ModelDao.OpenSearchKNNModelDao.initialize(client, clusterService, environment.settings()); ModelCache.initialize(ModelDao.OpenSearchKNNModelDao.getInstance(), clusterService); @@ -210,14 +212,14 @@ public Collection createComponents( TrainingJobClusterStateListener.initialize(threadPool, ModelDao.OpenSearchKNNModelDao.getInstance(), clusterService); QuantizationStateCache.setThreadPool(threadPool); NativeMemoryCacheManager.setThreadPool(threadPool); - KNNCircuitBreaker.getInstance().initialize(threadPool, clusterService, client); + KNNCircuitBreaker.getInstance().initialize(threadPool); KNNQueryBuilder.initialize(ModelDao.OpenSearchKNNModelDao.getInstance()); KNNWeight.initialize(ModelDao.OpenSearchKNNModelDao.getInstance()); TrainingModelRequest.initialize(ModelDao.OpenSearchKNNModelDao.getInstance(), clusterService); clusterService.addListener(TrainingJobClusterStateListener.getInstance()); - knnStats = new KNNStats(); + knnStats = new KNNStats(client); return ImmutableList.of(knnStats); } @@ -277,7 +279,8 @@ public List getRestHandlers( new ActionHandler<>(RemoveModelFromCacheAction.INSTANCE, RemoveModelFromCacheTransportAction.class), new ActionHandler<>(SearchModelAction.INSTANCE, SearchModelTransportAction.class), new ActionHandler<>(UpdateModelGraveyardAction.INSTANCE, UpdateModelGraveyardTransportAction.class), - new ActionHandler<>(ClearCacheAction.INSTANCE, ClearCacheTransportAction.class) + new ActionHandler<>(ClearCacheAction.INSTANCE, ClearCacheTransportAction.class), + new ActionHandler<>(KNNCircuitBreakerTrippedAction.INSTANCE, KNNCircuitBreakerTrippedTransportAction.class) ); } diff --git a/src/main/java/org/opensearch/knn/plugin/stats/KNNStats.java b/src/main/java/org/opensearch/knn/plugin/stats/KNNStats.java index bcd419ea68..8f8a4a39eb 100644 --- a/src/main/java/org/opensearch/knn/plugin/stats/KNNStats.java +++ b/src/main/java/org/opensearch/knn/plugin/stats/KNNStats.java @@ -13,13 +13,15 @@ import org.opensearch.knn.indices.ModelCache; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.plugin.stats.suppliers.EventOccurredWithinThresholdSupplier; -import org.opensearch.knn.plugin.stats.suppliers.KNNCircuitBreakerSupplier; +import org.opensearch.knn.plugin.stats.suppliers.KNNClusterLevelCircuitBreakerSupplier; +import org.opensearch.knn.plugin.stats.suppliers.KNNNodeLevelCircuitBreakerSupplier; import org.opensearch.knn.plugin.stats.suppliers.KNNCounterSupplier; import org.opensearch.knn.plugin.stats.suppliers.KNNInnerCacheStatsSupplier; import org.opensearch.knn.plugin.stats.suppliers.LibraryInitializedSupplier; import org.opensearch.knn.plugin.stats.suppliers.ModelIndexStatusSupplier; import org.opensearch.knn.plugin.stats.suppliers.ModelIndexingDegradingSupplier; import org.opensearch.knn.plugin.stats.suppliers.NativeMemoryCacheManagerSupplier; +import org.opensearch.transport.client.Client; import java.time.temporal.ChronoUnit; import java.util.HashMap; @@ -31,12 +33,14 @@ */ public class KNNStats { + private final Client client; private final Map> knnStats; /** * Constructor */ - public KNNStats() { + public KNNStats(Client client) { + this.client = client; this.knnStats = buildStatsMap(); } @@ -140,15 +144,12 @@ private void addNativeMemoryStats(ImmutableMap.Builder> build StatNames.INDICES_IN_CACHE.getName(), new KNNStat<>(false, new NativeMemoryCacheManagerSupplier<>(NativeMemoryCacheManager::getIndicesCacheStats)) ) - .put( - StatNames.CACHE_CAPACITY_REACHED.getName(), - new KNNStat<>(false, new NativeMemoryCacheManagerSupplier<>(NativeMemoryCacheManager::isCacheCapacityReached)) - ) + .put(StatNames.CACHE_CAPACITY_REACHED.getName(), new KNNStat<>(false, new KNNNodeLevelCircuitBreakerSupplier())) .put(StatNames.GRAPH_QUERY_ERRORS.getName(), new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.GRAPH_QUERY_ERRORS))) .put(StatNames.GRAPH_QUERY_REQUESTS.getName(), new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.GRAPH_QUERY_REQUESTS))) .put(StatNames.GRAPH_INDEX_ERRORS.getName(), new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.GRAPH_INDEX_ERRORS))) .put(StatNames.GRAPH_INDEX_REQUESTS.getName(), new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.GRAPH_INDEX_REQUESTS))) - .put(StatNames.CIRCUIT_BREAKER_TRIGGERED.getName(), new KNNStat<>(true, new KNNCircuitBreakerSupplier())); + .put(StatNames.CIRCUIT_BREAKER_TRIGGERED.getName(), new KNNStat<>(true, new KNNClusterLevelCircuitBreakerSupplier(client))); } private void addEngineStats(ImmutableMap.Builder> builder) { diff --git a/src/main/java/org/opensearch/knn/plugin/stats/suppliers/KNNClusterLevelCircuitBreakerSupplier.java b/src/main/java/org/opensearch/knn/plugin/stats/suppliers/KNNClusterLevelCircuitBreakerSupplier.java new file mode 100644 index 0000000000..2767aa3a82 --- /dev/null +++ b/src/main/java/org/opensearch/knn/plugin/stats/suppliers/KNNClusterLevelCircuitBreakerSupplier.java @@ -0,0 +1,34 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.plugin.stats.suppliers; + +import lombok.AllArgsConstructor; +import org.opensearch.knn.plugin.transport.KNNCircuitBreakerTrippedAction; +import org.opensearch.knn.plugin.transport.KNNCircuitBreakerTrippedRequest; +import org.opensearch.knn.plugin.transport.KNNCircuitBreakerTrippedResponse; +import org.opensearch.transport.client.Client; + +import java.util.concurrent.ExecutionException; +import java.util.function.Supplier; + +@AllArgsConstructor +public class KNNClusterLevelCircuitBreakerSupplier implements Supplier { + + private final Client client; + + @Override + public Boolean get() { + try { + KNNCircuitBreakerTrippedResponse knnCircuitBreakerTrippedResponse = client.execute( + KNNCircuitBreakerTrippedAction.INSTANCE, + new KNNCircuitBreakerTrippedRequest() + ).get(); + return knnCircuitBreakerTrippedResponse.isTripped(); + } catch (InterruptedException | ExecutionException e) { + throw new RuntimeException(e); + } + } +} diff --git a/src/main/java/org/opensearch/knn/plugin/stats/suppliers/KNNCircuitBreakerSupplier.java b/src/main/java/org/opensearch/knn/plugin/stats/suppliers/KNNNodeLevelCircuitBreakerSupplier.java similarity index 55% rename from src/main/java/org/opensearch/knn/plugin/stats/suppliers/KNNCircuitBreakerSupplier.java rename to src/main/java/org/opensearch/knn/plugin/stats/suppliers/KNNNodeLevelCircuitBreakerSupplier.java index 32b78e7cc3..60c217d502 100644 --- a/src/main/java/org/opensearch/knn/plugin/stats/suppliers/KNNCircuitBreakerSupplier.java +++ b/src/main/java/org/opensearch/knn/plugin/stats/suppliers/KNNNodeLevelCircuitBreakerSupplier.java @@ -5,22 +5,22 @@ package org.opensearch.knn.plugin.stats.suppliers; -import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.KNNCircuitBreaker; import java.util.function.Supplier; /** * Supplier for circuit breaker stats */ -public class KNNCircuitBreakerSupplier implements Supplier { +public class KNNNodeLevelCircuitBreakerSupplier implements Supplier { /** * Constructor */ - public KNNCircuitBreakerSupplier() {} + public KNNNodeLevelCircuitBreakerSupplier() {} @Override public Boolean get() { - return KNNSettings.isCircuitBreakerTriggered(); + return KNNCircuitBreaker.getInstance().isTripped(); } } diff --git a/src/main/java/org/opensearch/knn/plugin/transport/KNNCircuitBreakerTrippedAction.java b/src/main/java/org/opensearch/knn/plugin/transport/KNNCircuitBreakerTrippedAction.java new file mode 100644 index 0000000000..45315f9e97 --- /dev/null +++ b/src/main/java/org/opensearch/knn/plugin/transport/KNNCircuitBreakerTrippedAction.java @@ -0,0 +1,37 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.plugin.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.core.common.io.stream.Writeable; + +/** + * Action used to detect if the KNNCircuit has been tripped cluster wide + */ +public class KNNCircuitBreakerTrippedAction extends ActionType { + + public static final String NAME = "cluster:admin/knn_circuit_breaker_tripped_action"; + public static final KNNCircuitBreakerTrippedAction INSTANCE = new KNNCircuitBreakerTrippedAction( + NAME, + KNNCircuitBreakerTrippedResponse::new + ); + + /** + * Constructor + * + * @param name name of action + * @param responseReader reader for the KNNCircuitBreakerTrippedResponse + */ + public KNNCircuitBreakerTrippedAction(String name, Writeable.Reader responseReader) { + super(name, responseReader); + } +} diff --git a/src/main/java/org/opensearch/knn/plugin/transport/KNNCircuitBreakerTrippedNodeRequest.java b/src/main/java/org/opensearch/knn/plugin/transport/KNNCircuitBreakerTrippedNodeRequest.java new file mode 100644 index 0000000000..c43810aff3 --- /dev/null +++ b/src/main/java/org/opensearch/knn/plugin/transport/KNNCircuitBreakerTrippedNodeRequest.java @@ -0,0 +1,46 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.plugin.transport; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.transport.TransportRequest; + +import java.io.IOException; + +/** + * Node request to detect if the circuit breaker has been tripped + */ +public class KNNCircuitBreakerTrippedNodeRequest extends TransportRequest { + + /** + * Constructor. + */ + public KNNCircuitBreakerTrippedNodeRequest() { + super(); + } + + /** + * Constructor from stream + * + * @param in input stream + * @throws IOException thrown when reading from stream fails + */ + public KNNCircuitBreakerTrippedNodeRequest(StreamInput in) throws IOException { + super(in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + } +} diff --git a/src/main/java/org/opensearch/knn/plugin/transport/KNNCircuitBreakerTrippedNodeResponse.java b/src/main/java/org/opensearch/knn/plugin/transport/KNNCircuitBreakerTrippedNodeResponse.java new file mode 100644 index 0000000000..b6deac2b36 --- /dev/null +++ b/src/main/java/org/opensearch/knn/plugin/transport/KNNCircuitBreakerTrippedNodeResponse.java @@ -0,0 +1,56 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.plugin.transport; + +import lombok.Getter; +import org.opensearch.action.support.nodes.BaseNodeResponse; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +import java.io.IOException; + +/** + * Node response for if KNNCircuitBreaker is tripped or not + */ +@Getter +public class KNNCircuitBreakerTrippedNodeResponse extends BaseNodeResponse { + + private final boolean isTripped; + + /** + * Constructor from Stream. + * + * @param in stream input + * @throws IOException thrown when unable to read from stream + */ + public KNNCircuitBreakerTrippedNodeResponse(StreamInput in) throws IOException { + super(in); + isTripped = in.readBoolean(); + } + + /** + * Constructor + * + * @param node node + */ + public KNNCircuitBreakerTrippedNodeResponse(DiscoveryNode node, boolean isTripped) { + super(node); + this.isTripped = isTripped; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeBoolean(isTripped); + } +} diff --git a/src/main/java/org/opensearch/knn/plugin/transport/KNNCircuitBreakerTrippedRequest.java b/src/main/java/org/opensearch/knn/plugin/transport/KNNCircuitBreakerTrippedRequest.java new file mode 100644 index 0000000000..e07a31b523 --- /dev/null +++ b/src/main/java/org/opensearch/knn/plugin/transport/KNNCircuitBreakerTrippedRequest.java @@ -0,0 +1,48 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.plugin.transport; + +import org.opensearch.action.support.nodes.BaseNodesRequest; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +import java.io.IOException; + +/** + * Request to check if circuit breaker for cluster has been tripped + */ +public class KNNCircuitBreakerTrippedRequest extends BaseNodesRequest { + + /** + * Constructor. + * + * @param nodeIds Id's of nodes + */ + public KNNCircuitBreakerTrippedRequest(String... nodeIds) { + super(nodeIds); + } + + /** + * Constructor. + * + * @param in input stream + * @throws IOException thrown when reading input stream fails + */ + public KNNCircuitBreakerTrippedRequest(StreamInput in) throws IOException { + super(in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + } +} diff --git a/src/main/java/org/opensearch/knn/plugin/transport/KNNCircuitBreakerTrippedResponse.java b/src/main/java/org/opensearch/knn/plugin/transport/KNNCircuitBreakerTrippedResponse.java new file mode 100644 index 0000000000..ff2626457f --- /dev/null +++ b/src/main/java/org/opensearch/knn/plugin/transport/KNNCircuitBreakerTrippedResponse.java @@ -0,0 +1,84 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.plugin.transport; + +import lombok.Getter; +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.nodes.BaseNodesResponse; +import org.opensearch.cluster.ClusterName; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +import java.io.IOException; +import java.util.List; + +/** + * Response indicating if circuit breaker has been tripped. Circuit breaker is said to be tripped if it is tripped + * on any nodes. + */ +@Getter +public class KNNCircuitBreakerTrippedResponse extends BaseNodesResponse { + + private final boolean isTripped; + + /** + * Constructor. + * + * @param clusterName cluster's name + * @param nodes list of responses from each node + * @param failures list of failures from each node. + */ + public KNNCircuitBreakerTrippedResponse( + ClusterName clusterName, + List nodes, + List failures + ) { + super(clusterName, nodes, failures); + this.isTripped = checkIfTripped(nodes); + } + + /** + * Constructor. + * + * @param in input stream + * @throws IOException thrown when input stream cannot be read + */ + public KNNCircuitBreakerTrippedResponse(StreamInput in) throws IOException { + super(new ClusterName(in), in.readList(KNNCircuitBreakerTrippedNodeResponse::new), in.readList(FailedNodeException::new)); + this.isTripped = in.readBoolean(); + } + + @Override + protected List readNodesFrom(StreamInput in) throws IOException { + return in.readList(KNNCircuitBreakerTrippedNodeResponse::new); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeBoolean(isTripped); + } + + @Override + protected void writeNodesTo(StreamOutput out, List nodes) throws IOException { + out.writeList(nodes); + } + + private boolean checkIfTripped(List nodeResponses) { + for (KNNCircuitBreakerTrippedNodeResponse nodeResponse : nodeResponses) { + if (nodeResponse.isTripped()) { + return true; + } + } + return false; + } +} diff --git a/src/main/java/org/opensearch/knn/plugin/transport/KNNCircuitBreakerTrippedTransportAction.java b/src/main/java/org/opensearch/knn/plugin/transport/KNNCircuitBreakerTrippedTransportAction.java new file mode 100644 index 0000000000..48e7982071 --- /dev/null +++ b/src/main/java/org/opensearch/knn/plugin/transport/KNNCircuitBreakerTrippedTransportAction.java @@ -0,0 +1,82 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.plugin.transport; + +import lombok.extern.log4j.Log4j2; + +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.nodes.TransportNodesAction; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.knn.index.KNNCircuitBreaker; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +import java.io.IOException; +import java.util.List; + +/** + * Transport action to check if the KNN CB is tripped on any of the nodes in the cluster. + */ +@Log4j2 +public class KNNCircuitBreakerTrippedTransportAction extends TransportNodesAction< + KNNCircuitBreakerTrippedRequest, + KNNCircuitBreakerTrippedResponse, + KNNCircuitBreakerTrippedNodeRequest, + KNNCircuitBreakerTrippedNodeResponse> { + + @Inject + public KNNCircuitBreakerTrippedTransportAction( + ThreadPool threadPool, + ClusterService clusterService, + TransportService transportService, + ActionFilters actionFilters + ) { + super( + KNNCircuitBreakerTrippedAction.NAME, + threadPool, + clusterService, + transportService, + actionFilters, + KNNCircuitBreakerTrippedRequest::new, + KNNCircuitBreakerTrippedNodeRequest::new, + ThreadPool.Names.SAME, + KNNCircuitBreakerTrippedNodeResponse.class + ); + } + + @Override + protected KNNCircuitBreakerTrippedResponse newResponse( + KNNCircuitBreakerTrippedRequest nodesRequest, + List responses, + List failures + ) { + return new KNNCircuitBreakerTrippedResponse(clusterService.getClusterName(), responses, failures); + } + + @Override + protected KNNCircuitBreakerTrippedNodeRequest newNodeRequest(KNNCircuitBreakerTrippedRequest request) { + return new KNNCircuitBreakerTrippedNodeRequest(); + } + + @Override + protected KNNCircuitBreakerTrippedNodeResponse newNodeResponse(StreamInput in) throws IOException { + return new KNNCircuitBreakerTrippedNodeResponse(in); + } + + @Override + protected KNNCircuitBreakerTrippedNodeResponse nodeOperation(KNNCircuitBreakerTrippedNodeRequest nodeRequest) { + return new KNNCircuitBreakerTrippedNodeResponse(clusterService.localNode(), KNNCircuitBreaker.getInstance().isTripped()); + } +} diff --git a/src/main/java/org/opensearch/knn/plugin/transport/KNNStatsTransportAction.java b/src/main/java/org/opensearch/knn/plugin/transport/KNNStatsTransportAction.java index 7edc44894a..f97f21ff1c 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/KNNStatsTransportAction.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/KNNStatsTransportAction.java @@ -31,7 +31,7 @@ public class KNNStatsTransportAction extends TransportNodesAction< KNNStatsNodeRequest, KNNStatsNodeResponse> { - private KNNStats knnStats; + private final KNNStats knnStats; /** * Constructor diff --git a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryCacheManagerTests.java b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryCacheManagerTests.java index 45ccf0e9d0..a17dcdacd2 100644 --- a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryCacheManagerTests.java +++ b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryCacheManagerTests.java @@ -387,18 +387,6 @@ public void testInvalidateAll() throws ExecutionException { nativeMemoryCacheManager.close(); } - public void testCacheCapacity() { - NativeMemoryCacheManager nativeMemoryCacheManager = new NativeMemoryCacheManager(); - assertFalse(nativeMemoryCacheManager.isCacheCapacityReached()); - - nativeMemoryCacheManager.setCacheCapacityReached(true); - assertTrue(nativeMemoryCacheManager.isCacheCapacityReached()); - - nativeMemoryCacheManager.setCacheCapacityReached(false); - assertFalse(nativeMemoryCacheManager.isCacheCapacityReached()); - nativeMemoryCacheManager.close(); - } - public void testGetIndicesCacheStats() throws IOException, ExecutionException { NativeMemoryCacheManager nativeMemoryCacheManager = new NativeMemoryCacheManager(); Map> indicesStats = nativeMemoryCacheManager.getIndicesCacheStats(); diff --git a/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java b/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java index 1324616092..e620df5c8b 100644 --- a/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java +++ b/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java @@ -7,8 +7,6 @@ import lombok.SneakyThrows; import org.apache.hc.core5.http.io.entity.EntityUtils; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; import org.junit.Before; import org.junit.rules.DisableOnDebug; import org.opensearch.client.Request; @@ -57,7 +55,6 @@ */ public class RestKNNStatsHandlerIT extends KNNRestTestCase { - private static final Logger logger = LogManager.getLogger(RestKNNStatsHandlerIT.class); private static final String TRAINING_INDEX = "training-index"; private static final String TRAINING_FIELD = "training-field"; private static final String TEST_MODEL_ID = "model-id"; @@ -78,7 +75,7 @@ public class RestKNNStatsHandlerIT extends KNNRestTestCase { @Before public void setup() { - knnStats = new KNNStats(); + knnStats = new KNNStats(null); } /** diff --git a/src/test/java/org/opensearch/knn/plugin/action/RestLegacyKNNStatsHandlerIT.java b/src/test/java/org/opensearch/knn/plugin/action/RestLegacyKNNStatsHandlerIT.java index e00d90056d..568df26f30 100644 --- a/src/test/java/org/opensearch/knn/plugin/action/RestLegacyKNNStatsHandlerIT.java +++ b/src/test/java/org/opensearch/knn/plugin/action/RestLegacyKNNStatsHandlerIT.java @@ -19,8 +19,6 @@ import org.opensearch.knn.plugin.stats.KNNStats; import org.opensearch.knn.plugin.stats.StatNames; import org.apache.hc.core5.http.io.entity.EntityUtils; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; import org.junit.Before; import org.junit.rules.DisableOnDebug; import org.opensearch.client.Request; @@ -44,7 +42,6 @@ */ public class RestLegacyKNNStatsHandlerIT extends KNNRestTestCase { - private static final Logger logger = LogManager.getLogger(RestLegacyKNNStatsHandlerIT.class); private boolean isDebuggingTest = new DisableOnDebug(null).isDebugging(); private boolean isDebuggingRemoteCluster = System.getProperty("cluster.debug", "false").equals("true"); @@ -52,7 +49,7 @@ public class RestLegacyKNNStatsHandlerIT extends KNNRestTestCase { @Before public void setup() { - knnStats = new KNNStats(); + knnStats = new KNNStats(null); } /** diff --git a/src/test/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheTests.java b/src/test/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheTests.java index 87cb57cdcb..f14486c19c 100644 --- a/src/test/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheTests.java +++ b/src/test/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheTests.java @@ -9,7 +9,6 @@ import lombok.SneakyThrows; import org.junit.After; import org.junit.Before; -import org.opensearch.transport.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; @@ -352,9 +351,7 @@ public void testRebuildOnCacheSizeSettingsChange() { when(clusterService.getClusterSettings()).thenReturn(clusterSettings); when(clusterService.getSettings()).thenReturn(settings); - Client client = mock(Client.class); - - KNNSettings.state().initialize(client, clusterService); + KNNSettings.state().initialize(clusterService); QuantizationStateCache cache = QuantizationStateCache.getInstance(); cache.rebuildCache(); @@ -404,10 +401,7 @@ public void testRebuildOnTimeExpirySettingsChange() { ClusterService clusterService = mock(ClusterService.class); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); when(clusterService.getSettings()).thenReturn(settings); - - Client client = mock(Client.class); - - KNNSettings.state().initialize(client, clusterService); + KNNSettings.state().initialize(clusterService); QuantizationStateCache cache = QuantizationStateCache.getInstance(); cache.addQuantizationState(fieldName, state);