diff --git a/services/venice-server/src/main/java/com/linkedin/venice/stats/ServerReadQuotaUsageStats.java b/services/venice-server/src/main/java/com/linkedin/venice/stats/ServerReadQuotaUsageStats.java index e2857ebdb9..b5e2b0b689 100644 --- a/services/venice-server/src/main/java/com/linkedin/venice/stats/ServerReadQuotaUsageStats.java +++ b/services/venice-server/src/main/java/com/linkedin/venice/stats/ServerReadQuotaUsageStats.java @@ -2,14 +2,13 @@ import com.linkedin.venice.utils.SystemTime; import com.linkedin.venice.utils.Time; +import com.linkedin.venice.utils.concurrent.VeniceConcurrentHashMap; import io.tehuti.metrics.MetricConfig; import io.tehuti.metrics.MetricsRepository; import io.tehuti.metrics.Sensor; import io.tehuti.metrics.stats.AsyncGauge; import io.tehuti.metrics.stats.Count; import io.tehuti.metrics.stats.Rate; -import it.unimi.dsi.fastutil.ints.Int2ObjectMap; -import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; @@ -31,7 +30,8 @@ public class ServerReadQuotaUsageStats extends AbstractVeniceStats { private final Sensor rejectedKPS; // rejected key per second private final Sensor allowedUnintentionallyKPS; // allowed KPS unintentionally due to error or insufficient info private final Sensor usageRatioSensor; // requested kps divided by nodes quota responsibility - private final Int2ObjectMap versionedStats = new Int2ObjectOpenHashMap<>(); + private final VeniceConcurrentHashMap versionedStats = + new VeniceConcurrentHashMap<>(); private final AtomicInteger currentVersion = new AtomicInteger(0); private final AtomicInteger backupVersion = new AtomicInteger(0); private final Time time; @@ -144,10 +144,10 @@ final Double getVersionedRequestedKPS(int version) { */ final Double getReadQuotaUsageRatio() { int version = currentVersion.get(); - if (version < 1 || !versionedStats.containsKey(version)) { + ServerReadQuotaVersionedStats stats = versionedStats.get(version); + if (version < 1 || stats == null) { return Double.NaN; } - ServerReadQuotaVersionedStats stats = versionedStats.get(version); long nodeKpsResponsibility = stats.getNodeKpsResponsibility(); if (nodeKpsResponsibility < 1) { return Double.NaN; diff --git a/services/venice-server/src/test/java/com/linkedin/venice/stats/ServerReadQuotaUsageStatsTest.java b/services/venice-server/src/test/java/com/linkedin/venice/stats/ServerReadQuotaUsageStatsTest.java index 1d5a91a526..b2c6c39ca3 100644 --- a/services/venice-server/src/test/java/com/linkedin/venice/stats/ServerReadQuotaUsageStatsTest.java +++ b/services/venice-server/src/test/java/com/linkedin/venice/stats/ServerReadQuotaUsageStatsTest.java @@ -1,6 +1,13 @@ package com.linkedin.venice.stats; +import com.linkedin.venice.utils.Time; import io.tehuti.metrics.MetricsRepository; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import org.testng.Assert; import org.testng.annotations.Test; @@ -50,4 +57,25 @@ public void testGetReadQuotaMetricsWithNoVersionOrRecordings() { Assert.assertTrue(stats.getVersionedRequestedQPS(currentVersion) > 0); Assert.assertTrue(stats.getVersionedRequestedKPS(currentVersion) > 0); } + + /** + * A non-thread safe map like Int2ObjectOpenHashMap could cause the threads to busy loop inside the internal find + * method when a race condition happens + */ + @Test(timeOut = 10 * Time.MS_PER_SECOND) + public void testVersionedStatsThreadSafe() throws ExecutionException, InterruptedException, TimeoutException { + MetricsRepository metricsRepository = new MetricsRepository(); + String storeName = "test-store"; + ServerReadQuotaUsageStats stats = new ServerReadQuotaUsageStats(metricsRepository, storeName); + ExecutorService service = Executors.newFixedThreadPool(100); + CompletableFuture[] completableFutures = new CompletableFuture[100]; + for (int j = 0; j < 100; j++) { + completableFutures[j] = CompletableFuture.runAsync(() -> { + for (int i = 0; i < 100000; i++) { + stats.recordAllowed(i, 1); + } + }, service); + } + CompletableFuture.allOf(completableFutures).get(10, TimeUnit.SECONDS); + } }