Skip to content

Commit

Permalink
Remove usage of cluster level setting for circuit breaker
Browse files Browse the repository at this point in the history
Simplification of circuit breaker management. Before, we were
periodically updating the circuit breaker cluster setting to have a
global circuit breaker. However, this is inconvenient and not actually
useful.

This change removes that logic and only trips based on node level
circuit breaker. For knn stats, in order to fetch the cluster level
circuit breaker, a transport call is made to check all of the nodes.
This isnt super efficient but its made for stats calls so its not on the
critical path.

Signed-off-by: John Mazanec <[email protected]>
  • Loading branch information
jmazanec15 committed Feb 27, 2025
1 parent 5873add commit 2318611
Show file tree
Hide file tree
Showing 20 changed files with 445 additions and 173 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ public class StatsIT extends AbstractRollingUpgradeTestCase {
@Before
public void setUp() throws Exception {
super.setUp();
this.knnStats = new KNNStats();
this.knnStats = new KNNStats(null);
}

// Validate if all the KNN Stats metrics from old version are present in new version
Expand Down
85 changes: 24 additions & 61 deletions src/main/java/org/opensearch/knn/index/KNNCircuitBreaker.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,34 +5,22 @@

package org.opensearch.knn.index;

import lombok.Getter;
import org.opensearch.knn.index.memory.NativeMemoryCacheManager;
import org.opensearch.knn.plugin.stats.StatNames;
import org.opensearch.knn.plugin.transport.KNNStatsAction;
import org.opensearch.knn.plugin.transport.KNNStatsNodeResponse;
import org.opensearch.knn.plugin.transport.KNNStatsRequest;
import org.opensearch.knn.plugin.transport.KNNStatsResponse;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.transport.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.threadpool.ThreadPool;

import java.util.ArrayList;
import java.util.List;

/**
* Runs the circuit breaker logic and updates the settings
*/
@Getter
public class KNNCircuitBreaker {
private static Logger logger = LogManager.getLogger(KNNCircuitBreaker.class);
public static final String KNN_CIRCUIT_BREAKER_TIER = "knn_cb_tier";
public static int CB_TIME_INTERVAL = 2 * 60; // seconds

private static KNNCircuitBreaker INSTANCE;
private ThreadPool threadPool;
private ClusterService clusterService;
private Client client;

private boolean isTripped = false;

private KNNCircuitBreaker() {}

Expand All @@ -52,60 +40,35 @@ public static synchronized void setInstance(KNNCircuitBreaker instance) {
INSTANCE = instance;
}

public void initialize(ThreadPool threadPool, ClusterService clusterService, Client client) {
this.threadPool = threadPool;
this.clusterService = clusterService;
this.client = client;
/**
* Initialize the circuit breaker
*
* @param threadPool ThreadPool instance
*/
public void initialize(ThreadPool threadPool) {
NativeMemoryCacheManager nativeMemoryCacheManager = NativeMemoryCacheManager.getInstance();
Runnable runnable = () -> {
if (nativeMemoryCacheManager.isCacheCapacityReached() && clusterService.localNode().isDataNode()) {
if (isTripped) {
long currentSizeKiloBytes = nativeMemoryCacheManager.getCacheSizeInKilobytes();
long circuitBreakerLimitSizeKiloBytes = KNNSettings.state().getCircuitBreakerLimit().getKb();
long circuitBreakerUnsetSizeKiloBytes = (long) ((KNNSettings.getCircuitBreakerUnsetPercentage() / 100)
* circuitBreakerLimitSizeKiloBytes);
/**
* Unset capacityReached flag if currentSizeBytes is less than circuitBreakerUnsetSizeBytes
*/
if (currentSizeKiloBytes <= circuitBreakerUnsetSizeKiloBytes) {
nativeMemoryCacheManager.setCacheCapacityReached(false);
}
}

// Leader node untriggers CB if all nodes have not reached their max capacity
if (KNNSettings.isCircuitBreakerTriggered() && clusterService.state().nodes().isLocalNodeElectedClusterManager()) {
KNNStatsRequest knnStatsRequest = new KNNStatsRequest();
knnStatsRequest.addStat(StatNames.CACHE_CAPACITY_REACHED.getName());
knnStatsRequest.timeout(new TimeValue(1000 * 10)); // 10 second timeout

try {
KNNStatsResponse knnStatsResponse = client.execute(KNNStatsAction.INSTANCE, knnStatsRequest).get();
List<KNNStatsNodeResponse> nodeResponses = knnStatsResponse.getNodes();

List<String> nodesAtMaxCapacity = new ArrayList<>();
for (KNNStatsNodeResponse nodeResponse : nodeResponses) {
if ((Boolean) nodeResponse.getStatsMap().get(StatNames.CACHE_CAPACITY_REACHED.getName())) {
nodesAtMaxCapacity.add(nodeResponse.getNode().getId());
}
}

if (!nodesAtMaxCapacity.isEmpty()) {
logger.info(
"[KNN] knn.circuit_breaker.triggered stays set. Nodes at max cache capacity: "
+ String.join(",", nodesAtMaxCapacity)
+ "."
);
} else {
logger.info(
"[KNN] Cache capacity below 75% of the circuit breaker limit for all nodes."
+ " Unsetting knn.circuit_breaker.triggered flag."
);
KNNSettings.state().updateCircuitBreakerSettings(false);
}
} catch (Exception e) {
logger.error("[KNN] Exception getting stats: " + e);
// Unset capacityReached flag if currentSizeBytes is less than circuitBreakerUnsetSizeBytes
if (currentSizeKiloBytes <= circuitBreakerUnsetSizeKiloBytes) {
setTripped(false);
}
}
};
this.threadPool.scheduleWithFixedDelay(runnable, TimeValue.timeValueSeconds(CB_TIME_INTERVAL), ThreadPool.Names.GENERIC);
threadPool.scheduleWithFixedDelay(runnable, TimeValue.timeValueSeconds(CB_TIME_INTERVAL), ThreadPool.Names.GENERIC);
}

/**
* Set the circuit breaker flag
*
* @param isTripped true if circuit breaker is tripped, false otherwise
*/
public synchronized void setTripped(boolean isTripped) {
this.isTripped = isTripped;
}
}
49 changes: 4 additions & 45 deletions src/main/java/org/opensearch/knn/index/KNNSettings.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,12 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchParseException;
import org.opensearch.action.admin.cluster.settings.ClusterUpdateSettingsRequest;
import org.opensearch.action.admin.cluster.settings.ClusterUpdateSettingsResponse;
import org.opensearch.cluster.metadata.IndexMetadata;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.Booleans;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.unit.ByteSizeUnit;
import org.opensearch.core.common.unit.ByteSizeValue;
import org.opensearch.index.IndexModule;
Expand All @@ -28,7 +25,6 @@
import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateCacheManager;
import org.opensearch.monitor.jvm.JvmInfo;
import org.opensearch.monitor.os.OsProbe;
import org.opensearch.transport.client.Client;

import java.security.InvalidParameterException;
import java.util.Arrays;
Expand Down Expand Up @@ -317,7 +313,8 @@ public class KNNSettings {
KNN_CIRCUIT_BREAKER_TRIGGERED,
false,
NodeScope,
Dynamic
Dynamic,
Setting.Property.Deprecated
);

public static final Setting<Double> KNN_CIRCUIT_BREAKER_UNSET_PERCENTAGE_SETTING = Setting.doubleSetting(
Expand Down Expand Up @@ -473,8 +470,8 @@ public class KNNSettings {
private final static Map<String, Setting<?>> FEATURE_FLAGS = getFeatureFlags().stream()
.collect(toUnmodifiableMap(Setting::getKey, Function.identity()));

@Setter
private ClusterService clusterService;
private Client client;
@Setter
private Optional<String> nodeCbAttribute;

Expand Down Expand Up @@ -638,10 +635,6 @@ public static boolean isKNNPluginEnabled() {
return KNNSettings.state().getSettingValue(KNNSettings.KNN_PLUGIN_ENABLED);
}

public static boolean isCircuitBreakerTriggered() {
return KNNSettings.state().getSettingValue(KNNSettings.KNN_CIRCUIT_BREAKER_TRIGGERED);
}

/**
* Retrieves the node-specific circuit breaker limit based on the existing settings.
*
Expand Down Expand Up @@ -806,8 +799,7 @@ public static boolean isShardLevelRescoringDisabledForDiskBasedVector(String ind
.getAsBoolean(KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED, false);
}

public void initialize(Client client, ClusterService clusterService) {
this.client = client;
public void initialize(ClusterService clusterService) {
this.clusterService = clusterService;
this.nodeCbAttribute = Optional.empty();
setSettingsUpdateConsumers();
Expand Down Expand Up @@ -841,35 +833,6 @@ public static ByteSizeValue parseknnMemoryCircuitBreakerValue(String sValue, Byt
}
}

/**
* Updates knn.circuit_breaker.triggered setting to true/false
* @param flag true/false
*/
public synchronized void updateCircuitBreakerSettings(boolean flag) {
ClusterUpdateSettingsRequest clusterUpdateSettingsRequest = new ClusterUpdateSettingsRequest();
Settings circuitBreakerSettings = Settings.builder().put(KNNSettings.KNN_CIRCUIT_BREAKER_TRIGGERED, flag).build();
clusterUpdateSettingsRequest.persistentSettings(circuitBreakerSettings);
client.admin().cluster().updateSettings(clusterUpdateSettingsRequest, new ActionListener<ClusterUpdateSettingsResponse>() {
@Override
public void onResponse(ClusterUpdateSettingsResponse clusterUpdateSettingsResponse) {
logger.debug(
"Cluster setting {}, acknowledged: {} ",
clusterUpdateSettingsRequest.persistentSettings(),
clusterUpdateSettingsResponse.isAcknowledged()
);
}

@Override
public void onFailure(Exception e) {
logger.info(
"Exception while updating circuit breaker setting {} to {}",
clusterUpdateSettingsRequest.persistentSettings(),
e.getMessage()
);
}
});
}

public static ByteSizeValue getVectorStreamingMemoryLimit() {
return KNNSettings.state().getSettingValue(KNN_VECTOR_STREAMING_MEMORY_LIMIT_IN_MB);
}
Expand All @@ -888,10 +851,6 @@ public static int getEfSearchParam(String index) {
);
}

public void setClusterService(ClusterService clusterService) {
this.clusterService = clusterService;
}

static class SpaceTypeValidator implements Setting.Validator<String> {

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.apache.lucene.util.BytesRef;
import org.opensearch.Version;
import org.opensearch.common.settings.Settings;
import org.opensearch.knn.index.KNNCircuitBreaker;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.KnnCircuitBreakerException;
import org.opensearch.knn.index.SpaceType;
Expand Down Expand Up @@ -114,9 +115,9 @@ public static int getExpectedVectorLength(final KNNVectorFieldType knnVectorFiel
* Validate if the circuit breaker is triggered
*/
static void validateIfCircuitBreakerIsNotTriggered() {
if (KNNSettings.isCircuitBreakerTriggered()) {
if (KNNCircuitBreaker.getInstance().isTripped()) {
throw new KnnCircuitBreakerException(
"Parsing the created knn vector fields prior to indexing has failed as the circuit breaker triggered. This indicates that the cluster is low on memory resources and cannot index more documents at the moment. Check _plugins/_knn/stats for the circuit breaker status."
"Parsing the created knn vector fields prior to indexing has failed as the circuit breaker triggered. This indicates that the node is low on memory resources and cannot index more documents at the moment. Check _plugins/_knn/stats for the circuit breaker status."
);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.opensearch.common.unit.TimeValue;
import org.opensearch.knn.common.exception.OutOfNativeMemoryException;
import org.opensearch.knn.common.featureflags.KNNFeatureFlags;
import org.opensearch.knn.index.KNNCircuitBreaker;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.plugin.stats.StatNames;
import org.opensearch.threadpool.ThreadPool;
Expand Down Expand Up @@ -413,24 +414,6 @@ public void invalidateAll() {
cache.invalidateAll();
}

/**
* Returns whether or not the capacity of the cache has been reached
*
* @return Boolean of whether cache limit has been reached
*/
public Boolean isCacheCapacityReached() {
return cacheCapacityReached.get();
}

/**
* Sets cache capacity reached
*
* @param value Boolean value to set cache Capacity Reached to
*/
public void setCacheCapacityReached(Boolean value) {
cacheCapacityReached.set(value);
}

/**
* Get the stats of all of the OpenSearch indices currently loaded into the cache
*
Expand Down Expand Up @@ -461,8 +444,7 @@ private void onRemoval(RemovalNotification<String, NativeMemoryAllocation> remov
nativeMemoryAllocation.close();

if (RemovalCause.SIZE == removalNotification.getCause()) {
KNNSettings.state().updateCircuitBreakerSettings(true);
setCacheCapacityReached(true);
KNNCircuitBreaker.getInstance().setTripped(true);
}

logger.debug("[KNN] Cache evicted. Key {}, Reason: {}", removalNotification.getKey(), removalNotification.getCause());
Expand Down
11 changes: 7 additions & 4 deletions src/main/java/org/opensearch/knn/plugin/KNNPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@
import org.opensearch.knn.plugin.transport.DeleteModelTransportAction;
import org.opensearch.knn.plugin.transport.GetModelAction;
import org.opensearch.knn.plugin.transport.GetModelTransportAction;
import org.opensearch.knn.plugin.transport.KNNCircuitBreakerTrippedAction;
import org.opensearch.knn.plugin.transport.KNNCircuitBreakerTrippedTransportAction;
import org.opensearch.knn.plugin.transport.KNNStatsAction;
import org.opensearch.knn.plugin.transport.KNNStatsTransportAction;
import org.opensearch.knn.plugin.transport.KNNWarmupAction;
Expand Down Expand Up @@ -202,22 +204,22 @@ public Collection<Object> createComponents(
VectorReader vectorReader = new VectorReader(client);
NativeMemoryLoadStrategy.TrainingLoadStrategy.initialize(vectorReader);

KNNSettings.state().initialize(client, clusterService);
KNNSettings.state().initialize(clusterService);
KNNClusterUtil.instance().initialize(clusterService);
ModelDao.OpenSearchKNNModelDao.initialize(client, clusterService, environment.settings());
ModelCache.initialize(ModelDao.OpenSearchKNNModelDao.getInstance(), clusterService);
TrainingJobRunner.initialize(threadPool, ModelDao.OpenSearchKNNModelDao.getInstance());
TrainingJobClusterStateListener.initialize(threadPool, ModelDao.OpenSearchKNNModelDao.getInstance(), clusterService);
QuantizationStateCache.setThreadPool(threadPool);
NativeMemoryCacheManager.setThreadPool(threadPool);
KNNCircuitBreaker.getInstance().initialize(threadPool, clusterService, client);
KNNCircuitBreaker.getInstance().initialize(threadPool);
KNNQueryBuilder.initialize(ModelDao.OpenSearchKNNModelDao.getInstance());
KNNWeight.initialize(ModelDao.OpenSearchKNNModelDao.getInstance());
TrainingModelRequest.initialize(ModelDao.OpenSearchKNNModelDao.getInstance(), clusterService);

clusterService.addListener(TrainingJobClusterStateListener.getInstance());

knnStats = new KNNStats();
knnStats = new KNNStats(client);
return ImmutableList.of(knnStats);
}

Expand Down Expand Up @@ -277,7 +279,8 @@ public List<RestHandler> getRestHandlers(
new ActionHandler<>(RemoveModelFromCacheAction.INSTANCE, RemoveModelFromCacheTransportAction.class),
new ActionHandler<>(SearchModelAction.INSTANCE, SearchModelTransportAction.class),
new ActionHandler<>(UpdateModelGraveyardAction.INSTANCE, UpdateModelGraveyardTransportAction.class),
new ActionHandler<>(ClearCacheAction.INSTANCE, ClearCacheTransportAction.class)
new ActionHandler<>(ClearCacheAction.INSTANCE, ClearCacheTransportAction.class),
new ActionHandler<>(KNNCircuitBreakerTrippedAction.INSTANCE, KNNCircuitBreakerTrippedTransportAction.class)
);
}

Expand Down
15 changes: 8 additions & 7 deletions src/main/java/org/opensearch/knn/plugin/stats/KNNStats.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
import org.opensearch.knn.indices.ModelCache;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.plugin.stats.suppliers.EventOccurredWithinThresholdSupplier;
import org.opensearch.knn.plugin.stats.suppliers.KNNCircuitBreakerSupplier;
import org.opensearch.knn.plugin.stats.suppliers.KNNClusterLevelCircuitBreakerSupplier;
import org.opensearch.knn.plugin.stats.suppliers.KNNNodeLevelCircuitBreakerSupplier;
import org.opensearch.knn.plugin.stats.suppliers.KNNCounterSupplier;
import org.opensearch.knn.plugin.stats.suppliers.KNNInnerCacheStatsSupplier;
import org.opensearch.knn.plugin.stats.suppliers.LibraryInitializedSupplier;
import org.opensearch.knn.plugin.stats.suppliers.ModelIndexStatusSupplier;
import org.opensearch.knn.plugin.stats.suppliers.ModelIndexingDegradingSupplier;
import org.opensearch.knn.plugin.stats.suppliers.NativeMemoryCacheManagerSupplier;
import org.opensearch.transport.client.Client;

import java.time.temporal.ChronoUnit;
import java.util.HashMap;
Expand All @@ -31,12 +33,14 @@
*/
public class KNNStats {

private final Client client;
private final Map<String, KNNStat<?>> knnStats;

/**
* Constructor
*/
public KNNStats() {
public KNNStats(Client client) {
this.client = client;
this.knnStats = buildStatsMap();
}

Expand Down Expand Up @@ -140,15 +144,12 @@ private void addNativeMemoryStats(ImmutableMap.Builder<String, KNNStat<?>> build
StatNames.INDICES_IN_CACHE.getName(),
new KNNStat<>(false, new NativeMemoryCacheManagerSupplier<>(NativeMemoryCacheManager::getIndicesCacheStats))
)
.put(
StatNames.CACHE_CAPACITY_REACHED.getName(),
new KNNStat<>(false, new NativeMemoryCacheManagerSupplier<>(NativeMemoryCacheManager::isCacheCapacityReached))
)
.put(StatNames.CACHE_CAPACITY_REACHED.getName(), new KNNStat<>(false, new KNNNodeLevelCircuitBreakerSupplier()))
.put(StatNames.GRAPH_QUERY_ERRORS.getName(), new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.GRAPH_QUERY_ERRORS)))
.put(StatNames.GRAPH_QUERY_REQUESTS.getName(), new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.GRAPH_QUERY_REQUESTS)))
.put(StatNames.GRAPH_INDEX_ERRORS.getName(), new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.GRAPH_INDEX_ERRORS)))
.put(StatNames.GRAPH_INDEX_REQUESTS.getName(), new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.GRAPH_INDEX_REQUESTS)))
.put(StatNames.CIRCUIT_BREAKER_TRIGGERED.getName(), new KNNStat<>(true, new KNNCircuitBreakerSupplier()));
.put(StatNames.CIRCUIT_BREAKER_TRIGGERED.getName(), new KNNStat<>(true, new KNNClusterLevelCircuitBreakerSupplier(client)));
}

private void addEngineStats(ImmutableMap.Builder<String, KNNStat<?>> builder) {
Expand Down
Loading

0 comments on commit 2318611

Please sign in to comment.