diff --git a/samza-core/src/main/java/org/apache/samza/config/TaskConfig.java b/samza-core/src/main/java/org/apache/samza/config/TaskConfig.java index 0f168be18d..276f3812ff 100644 --- a/samza-core/src/main/java/org/apache/samza/config/TaskConfig.java +++ b/samza-core/src/main/java/org/apache/samza/config/TaskConfig.java @@ -65,6 +65,21 @@ public class TaskConfig extends MapConfig { public static final String COMMIT_TIMEOUT_MS = "task.commit.timeout.ms"; static final long DEFAULT_COMMIT_TIMEOUT_MS = Duration.ofMinutes(30).toMillis(); + // Flag to indicate whether to skip commit during failures (exceptions or timeouts) + // The number of allowed successive commit exceptions and timeouts are controlled by the following two configs. + public static final String SKIP_COMMIT_DURING_FAILURES_ENABLED = "task.commit.skip.commit.during.failures.enabled"; + private static final boolean DEFAULT_SKIP_COMMIT_DURING_FAILURES_ENABLED = false; + + // Maximum number of allowed successive commit exceptions. + // If the number of successive commit exceptions exceeds this limit, the task will be shut down. + public static final String SKIP_COMMIT_EXCEPTION_MAX_LIMIT = "task.commit.skip.commit.exception.max.limit"; + private static final int DEFAULT_SKIP_COMMIT_EXCEPTION_MAX_LIMIT = 5; + + // Maximum number of allowed successive commit timeouts. + // If the number of successive commit timeout exceeds this limit, the task will be shut down. + public static final String SKIP_COMMIT_TIMEOUT_MAX_LIMIT = "task.commit.skip.commit.timeout.max.limit"; + private static final int DEFAULT_SKIP_COMMIT_TIMEOUT_MAX_LIMIT = 2; + // how long to wait for a clean shutdown public static final String TASK_SHUTDOWN_MS = "task.shutdown.ms"; static final long DEFAULT_TASK_SHUTDOWN_MS = 30000L; @@ -418,4 +433,16 @@ public long getWatermarkIdleTimeoutMs() { public double getWatermarkQuorumSizePercentage() { return getDouble(WATERMARK_QUORUM_SIZE_PERCENTAGE, DEFAULT_WATERMARK_QUORUM_SIZE_PERCENTAGE); } + + public boolean getSkipCommitDuringFailuresEnabled() { + return getBoolean(SKIP_COMMIT_DURING_FAILURES_ENABLED, DEFAULT_SKIP_COMMIT_DURING_FAILURES_ENABLED); + } + + public int getSkipCommitExceptionMaxLimit() { + return getInt(SKIP_COMMIT_EXCEPTION_MAX_LIMIT, DEFAULT_SKIP_COMMIT_EXCEPTION_MAX_LIMIT); + } + + public int getSkipCommitTimeoutMaxLimit() { + return getInt(SKIP_COMMIT_TIMEOUT_MAX_LIMIT, DEFAULT_SKIP_COMMIT_TIMEOUT_MAX_LIMIT); + } } diff --git a/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala b/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala index 70d9ca3800..f5d13106f3 100644 --- a/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala +++ b/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala @@ -38,7 +38,7 @@ import org.apache.samza.util.ScalaJavaUtil.JavaOptionals.toRichOptional import org.apache.samza.util.{Logging, ReflectionUtil, ScalaJavaUtil} import java.util -import java.util.concurrent.atomic.AtomicReference +import java.util.concurrent.atomic.{AtomicInteger, AtomicReference} import java.util.function.BiConsumer import java.util.function.Function import scala.collection.JavaConversions._ @@ -133,8 +133,13 @@ class TaskInstance( val checkpointWriteVersions = new TaskConfig(config).getCheckpointWriteVersions @volatile var lastCommitStartTimeMs = System.currentTimeMillis() + val commitExceptionCounter = new AtomicInteger(0) + val commitTimeoutCounter = new AtomicInteger(0) val commitMaxDelayMs = taskConfig.getCommitMaxDelayMs val commitTimeoutMs = taskConfig.getCommitTimeoutMs + val skipCommitDuringFailureEnabled = taskConfig.getSkipCommitDuringFailuresEnabled + val skipCommitExceptionMaxLimit = taskConfig.getSkipCommitExceptionMaxLimit + val skipCommitTimeoutMaxLimit = taskConfig.getSkipCommitTimeoutMaxLimit val commitInProgress = new Semaphore(1) val commitException = new AtomicReference[Exception]() @@ -312,10 +317,22 @@ class TaskInstance( val commitStartNs = System.nanoTime() // first check if there were any unrecoverable errors during the async stage of the pending commit - // and if so, shut down the container. + // If there is unrecoverable error, increment the metric and the counter. + // Shutdown the container in the following scenarios: + // 1. skipCommitDuringFailureEnabled is not enabled + // 2. skipCommitDuringFailureEnabled is enabled but the number of exceptions exceeded the max count + // Otherwise, ignore the exception. if (commitException.get() != null) { - throw new SamzaException("Unrecoverable error during pending commit for taskName: %s." format taskName, - commitException.get()) + metrics.commitExceptions.inc() + commitExceptionCounter.incrementAndGet() + if (!skipCommitDuringFailureEnabled || commitExceptionCounter.get() > skipCommitExceptionMaxLimit) { + throw new SamzaException("Unrecoverable error during pending commit for taskName: %s. Exception Counter: %s" + format (taskName, commitExceptionCounter.get()), commitException.get()) + } else { + warn("Ignored the commit failure for taskName %s. Exception Counter: %s." + format (taskName, commitExceptionCounter.get()), commitException.get()) + commitException.set(null) + } } // if no commit is in progress for this task, continue with this commit. @@ -328,7 +345,7 @@ class TaskInstance( if (timeSinceLastCommit < commitMaxDelayMs) { info("Skipping commit for taskName: %s since another commit is in progress. " + "%s ms have elapsed since the pending commit started." format (taskName, timeSinceLastCommit)) - metrics.commitsSkipped.set(metrics.commitsSkipped.getValue + 1) + metrics.commitsSkipped.inc() return } else { warn("Blocking processing for taskName: %s until in-flight commit is complete. " + @@ -336,13 +353,28 @@ class TaskInstance( "which is greater than the max allowed commit delay: %s." format (taskName, timeSinceLastCommit, commitMaxDelayMs)) + // Wait for the previous commit to complete within the timeout. + // If it doesn't complete within the timeout, increment metric and the counter. + // Shutdown the container in the following scenarios: + // 1. skipCommitDuringFailureEnabled is not enabled + // 2. skipCommitDuringFailureEnabled is enabled but the number of timeouts exceeded the max count + // Otherwise, ignore the timeout. if (!commitInProgress.tryAcquire(commitTimeoutMs, TimeUnit.MILLISECONDS)) { val timeSinceLastCommit = System.currentTimeMillis() - lastCommitStartTimeMs - metrics.commitsTimedOut.set(metrics.commitsTimedOut.getValue + 1) - throw new SamzaException("Timeout waiting for pending commit for taskName: %s to finish. " + - "%s ms have elapsed since the pending commit started. Max allowed commit delay is %s ms " + - "and commit timeout beyond that is %s ms" format (taskName, timeSinceLastCommit, - commitMaxDelayMs, commitTimeoutMs)) + metrics.commitsTimedOut.inc() + commitTimeoutCounter.incrementAndGet() + if (!skipCommitDuringFailureEnabled || commitTimeoutCounter.get() > skipCommitTimeoutMaxLimit) { + throw new SamzaException("Timeout waiting for pending commit for taskName: %s to finish. " + + "%s ms have elapsed since the pending commit started. Max allowed commit delay is %s ms " + + "and commit timeout beyond that is %s ms. Timeout Counter: %s" format (taskName, timeSinceLastCommit, + commitMaxDelayMs, commitTimeoutMs, commitTimeoutCounter.get())) + } else { + warn("Ignoring commit timeout for taskName: %s. %s ms have elapsed since another commit started. " + + "Max allowed commit delay is %s ms and commit timeout beyond that is %s ms. Timeout Counter: %s." + format (taskName, timeSinceLastCommit, commitMaxDelayMs, commitTimeoutMs, commitTimeoutCounter.get())) + commitInProgress.release() + return + } } } } @@ -426,7 +458,7 @@ class TaskInstance( } }) - metrics.lastCommitNs.set(System.nanoTime() - commitStartNs) + metrics.lastCommitNs.set(System.nanoTime()) metrics.commitSyncNs.update(System.nanoTime() - commitStartNs) debug("Finishing sync stage of commit for taskName: %s checkpointId: %s" format (taskName, checkpointId)) } @@ -531,8 +563,11 @@ class TaskInstance( "Saved exception under Caused By.", commitException.get()) } } else { + commitExceptionCounter.set(0) + commitTimeoutCounter.set(0) metrics.commitAsyncNs.update(System.nanoTime() - asyncStageStartNs) metrics.commitNs.update(System.nanoTime() - commitStartNs) + metrics.lastCommitAsyncTimestamp.set(System.nanoTime()) } } finally { // release the permit indicating that previous commit is complete. diff --git a/samza-core/src/main/scala/org/apache/samza/container/TaskInstanceMetrics.scala b/samza-core/src/main/scala/org/apache/samza/container/TaskInstanceMetrics.scala index 54d3665253..02674fb7eb 100644 --- a/samza-core/src/main/scala/org/apache/samza/container/TaskInstanceMetrics.scala +++ b/samza-core/src/main/scala/org/apache/samza/container/TaskInstanceMetrics.scala @@ -38,10 +38,12 @@ class TaskInstanceMetrics( val pendingMessages = newGauge("pending-messages", 0) val messagesInFlight = newGauge("messages-in-flight", 0) val asyncCallbackCompleted = newCounter("async-callback-complete-calls") - val commitsTimedOut = newGauge("commits-timed-out", 0) - val commitsSkipped = newGauge("commits-skipped", 0) + val commitsTimedOut = newCounter("commits-timed-out") + val commitsSkipped = newCounter("commits-skipped") + val commitExceptions = newCounter("commit-exceptions") val commitNs = newTimer("commit-ns") val lastCommitNs = newGauge("last-commit-ns", 0L) + val lastCommitAsyncTimestamp = newGauge("last-async-commit-timestamp", 0L) val commitSyncNs = newTimer("commit-sync-ns") val commitAsyncNs = newTimer("commit-async-ns") val snapshotNs = newTimer("snapshot-ns") diff --git a/samza-core/src/test/scala/org/apache/samza/container/TestTaskInstance.scala b/samza-core/src/test/scala/org/apache/samza/container/TestTaskInstance.scala index 6afec52e72..ff52b40062 100644 --- a/samza-core/src/test/scala/org/apache/samza/container/TestTaskInstance.scala +++ b/samza-core/src/test/scala/org/apache/samza/container/TestTaskInstance.scala @@ -277,7 +277,7 @@ class TestTaskInstance extends AssertionsForJUnit with MockitoSugar { when(this.metrics.asyncUploadNs).thenReturn(uploadTimer) val cleanUpTimer = mock[Timer] when(this.metrics.asyncCleanupNs).thenReturn(cleanUpTimer) - val skippedCounter = mock[Gauge[Int]] + val skippedCounter = mock[Counter] when(this.metrics.commitsSkipped).thenReturn(skippedCounter) val inputOffsets = new util.HashMap[SystemStreamPartition, String]() inputOffsets.put(SYSTEM_STREAM_PARTITION,"4") @@ -370,7 +370,7 @@ class TestTaskInstance extends AssertionsForJUnit with MockitoSugar { when(this.metrics.commitAsyncNs).thenReturn(commitAsyncTimer) val uploadTimer = mock[Timer] when(this.metrics.asyncUploadNs).thenReturn(uploadTimer) - val skippedCounter = mock[Gauge[Int]] + val skippedCounter = mock[Counter] when(this.metrics.commitsSkipped).thenReturn(skippedCounter) val inputOffsets = Map(SYSTEM_STREAM_PARTITION -> "4").asJava @@ -431,7 +431,7 @@ class TestTaskInstance extends AssertionsForJUnit with MockitoSugar { when(this.metrics.commitAsyncNs).thenReturn(commitAsyncTimer) val uploadTimer = mock[Timer] when(this.metrics.asyncUploadNs).thenReturn(uploadTimer) - val skippedCounter = mock[Gauge[Int]] + val skippedCounter = mock[Counter] when(this.metrics.commitsSkipped).thenReturn(skippedCounter) val lastCommitGauge = mock[Gauge[Long]] when(this.metrics.lastCommitNs).thenReturn(lastCommitGauge) @@ -504,10 +504,12 @@ class TestTaskInstance extends AssertionsForJUnit with MockitoSugar { when(this.metrics.asyncUploadNs).thenReturn(uploadTimer) val cleanUpTimer = mock[Timer] when(this.metrics.asyncCleanupNs).thenReturn(cleanUpTimer) - val skippedCounter = mock[Gauge[Int]] + val skippedCounter = mock[Counter] when(this.metrics.commitsSkipped).thenReturn(skippedCounter) val lastCommitGauge = mock[Gauge[Long]] when(this.metrics.lastCommitNs).thenReturn(lastCommitGauge) + val commitExceptionsGauge = mock[Counter] + when(this.metrics.commitExceptions).thenReturn(commitExceptionsGauge) val inputOffsets = new util.HashMap[SystemStreamPartition, String]() inputOffsets.put(SYSTEM_STREAM_PARTITION,"4") @@ -556,10 +558,12 @@ class TestTaskInstance extends AssertionsForJUnit with MockitoSugar { when(this.metrics.asyncUploadNs).thenReturn(uploadTimer) val cleanUpTimer = mock[Timer] when(this.metrics.asyncCleanupNs).thenReturn(cleanUpTimer) - val skippedCounter = mock[Gauge[Int]] + val skippedCounter = mock[Counter] when(this.metrics.commitsSkipped).thenReturn(skippedCounter) val lastCommitGauge = mock[Gauge[Long]] when(this.metrics.lastCommitNs).thenReturn(lastCommitGauge) + val commitExceptionsGauge = mock[Counter] + when(this.metrics.commitExceptions).thenReturn(commitExceptionsGauge) val inputOffsets = new util.HashMap[SystemStreamPartition, String]() inputOffsets.put(SYSTEM_STREAM_PARTITION,"4") @@ -608,10 +612,12 @@ class TestTaskInstance extends AssertionsForJUnit with MockitoSugar { when(this.metrics.asyncUploadNs).thenReturn(uploadTimer) val cleanUpTimer = mock[Timer] when(this.metrics.asyncCleanupNs).thenReturn(cleanUpTimer) - val skippedCounter = mock[Gauge[Int]] + val skippedCounter = mock[Counter] when(this.metrics.commitsSkipped).thenReturn(skippedCounter) val lastCommitGauge = mock[Gauge[Long]] when(this.metrics.lastCommitNs).thenReturn(lastCommitGauge) + val commitExceptionsGauge = mock[Counter] + when(this.metrics.commitExceptions).thenReturn(commitExceptionsGauge) val inputOffsets = new util.HashMap[SystemStreamPartition, String]() inputOffsets.put(SYSTEM_STREAM_PARTITION,"4") @@ -661,10 +667,12 @@ class TestTaskInstance extends AssertionsForJUnit with MockitoSugar { when(this.metrics.asyncUploadNs).thenReturn(uploadTimer) val cleanUpTimer = mock[Timer] when(this.metrics.asyncCleanupNs).thenReturn(cleanUpTimer) - val skippedCounter = mock[Gauge[Int]] + val skippedCounter = mock[Counter] when(this.metrics.commitsSkipped).thenReturn(skippedCounter) val lastCommitGauge = mock[Gauge[Long]] when(this.metrics.lastCommitNs).thenReturn(lastCommitGauge) + val commitExceptionsGauge = mock[Counter] + when(this.metrics.commitExceptions).thenReturn(commitExceptionsGauge) val inputOffsets = new util.HashMap[SystemStreamPartition, String]() inputOffsets.put(SYSTEM_STREAM_PARTITION,"4") @@ -714,10 +722,12 @@ class TestTaskInstance extends AssertionsForJUnit with MockitoSugar { when(this.metrics.asyncUploadNs).thenReturn(uploadTimer) val cleanUpTimer = mock[Timer] when(this.metrics.asyncCleanupNs).thenReturn(cleanUpTimer) - val skippedCounter = mock[Gauge[Int]] + val skippedCounter = mock[Counter] when(this.metrics.commitsSkipped).thenReturn(skippedCounter) val lastCommitGauge = mock[Gauge[Long]] when(this.metrics.lastCommitNs).thenReturn(lastCommitGauge) + val commitExceptionsGauge = mock[Counter] + when(this.metrics.commitExceptions).thenReturn(commitExceptionsGauge) val inputOffsets = new util.HashMap[SystemStreamPartition, String]() inputOffsets.put(SYSTEM_STREAM_PARTITION,"4") @@ -768,7 +778,7 @@ class TestTaskInstance extends AssertionsForJUnit with MockitoSugar { when(this.metrics.asyncUploadNs).thenReturn(uploadTimer) val cleanUpTimer = mock[Timer] when(this.metrics.asyncCleanupNs).thenReturn(cleanUpTimer) - val skippedCounter = mock[Gauge[Int]] + val skippedCounter = mock[Counter] when(this.metrics.commitsSkipped).thenReturn(skippedCounter) val lastCommitGauge = mock[Gauge[Long]] when(this.metrics.lastCommitNs).thenReturn(lastCommitGauge) @@ -828,7 +838,7 @@ class TestTaskInstance extends AssertionsForJUnit with MockitoSugar { when(this.metrics.asyncUploadNs).thenReturn(uploadTimer) val cleanUpTimer = mock[Timer] when(this.metrics.asyncCleanupNs).thenReturn(cleanUpTimer) - val skippedCounter = mock[Gauge[Int]] + val skippedCounter = mock[Counter] when(this.metrics.commitsSkipped).thenReturn(skippedCounter) val lastCommitGauge = mock[Gauge[Long]] when(this.metrics.lastCommitNs).thenReturn(lastCommitGauge) @@ -859,7 +869,7 @@ class TestTaskInstance extends AssertionsForJUnit with MockitoSugar { taskInstance.commit - verify(skippedCounter).set(1) + verify(skippedCounter, times(1)).inc() verify(commitsCounter, times(1)).inc() // should only have been incremented once on the initial commit verify(snapshotTimer).update(anyLong()) @@ -884,7 +894,7 @@ class TestTaskInstance extends AssertionsForJUnit with MockitoSugar { when(this.metrics.asyncUploadNs).thenReturn(uploadTimer) val cleanUpTimer = mock[Timer] when(this.metrics.asyncCleanupNs).thenReturn(cleanUpTimer) - val skippedCounter = mock[Gauge[Int]] + val skippedCounter = mock[Counter] when(this.metrics.commitsSkipped).thenReturn(skippedCounter) val lastCommitGauge = mock[Gauge[Long]] when(this.metrics.lastCommitNs).thenReturn(lastCommitGauge) @@ -947,7 +957,7 @@ class TestTaskInstance extends AssertionsForJUnit with MockitoSugar { when(this.metrics.asyncUploadNs).thenReturn(uploadTimer) val cleanUpTimer = mock[Timer] when(this.metrics.asyncCleanupNs).thenReturn(cleanUpTimer) - val skippedCounter = mock[Gauge[Int]] + val skippedCounter = mock[Counter] when(this.metrics.commitsSkipped).thenReturn(skippedCounter) val lastCommitGauge = mock[Gauge[Long]] when(this.metrics.lastCommitNs).thenReturn(lastCommitGauge) @@ -1004,6 +1014,208 @@ class TestTaskInstance extends AssertionsForJUnit with MockitoSugar { verify(snapshotTimer, times(2)).update(anyLong()) } + @Test + def testSkipExceptionFromFirstCommitAndContinueSecondCommit(): Unit = { + val commitsCounter = mock[Counter] + when(this.metrics.commits).thenReturn(commitsCounter) + val snapshotTimer = mock[Timer] + when(this.metrics.snapshotNs).thenReturn(snapshotTimer) + val uploadTimer = mock[Timer] + when(this.metrics.asyncUploadNs).thenReturn(uploadTimer) + val commitTimer = mock[Timer] + when(this.metrics.commitNs).thenReturn(commitTimer) + val commitSyncTimer = mock[Timer] + when(this.metrics.commitSyncNs).thenReturn(commitSyncTimer) + val commitAsyncTimer = mock[Timer] + when(this.metrics.commitAsyncNs).thenReturn(commitAsyncTimer) + val cleanUpTimer = mock[Timer] + when(this.metrics.asyncCleanupNs).thenReturn(cleanUpTimer) + val skippedCounter = mock[Counter] + when(this.metrics.commitsSkipped).thenReturn(skippedCounter) + val lastCommitGauge = mock[Gauge[Long]] + when(this.metrics.lastCommitNs).thenReturn(lastCommitGauge) + val commitExceptionCounter = mock[Counter] + when(this.metrics.commitExceptions).thenReturn(commitExceptionCounter) + + val taskConfigsMap = new util.HashMap[String, String]() + taskConfigsMap.put("task.commit.ms", "-1") + taskConfigsMap.put("task.commit.max.delay.ms", "-1") + taskConfigsMap.put("task.commit.timeout.ms", "2000000") + // skip commit if exception occurs during the commit + taskConfigsMap.put("task.commit.skip.commit.during.failures.enabled", "true") + // should throw exception if second commit exception occurs + taskConfigsMap.put("task.commit.skip.commit.exception.max.limit", "1") + when(this.jobContext.getConfig).thenReturn(new MapConfig(taskConfigsMap)) + setupTaskInstance(None, ForkJoinPool.commonPool()) + + val inputOffsets = new util.HashMap[SystemStreamPartition, String]() + inputOffsets.put(SYSTEM_STREAM_PARTITION, "4") + val stateCheckpointMarkers: util.Map[String, String] = new util.HashMap[String, String]() + when(this.offsetManager.getLastProcessedOffsets(TASK_NAME)).thenReturn(inputOffsets) + // Ensure the second commit proceeds without exceptions + when(this.taskCommitManager.upload(any(), any())) + .thenReturn(CompletableFuture.completedFuture( + Collections.singletonMap(KafkaStateCheckpointMarker.KAFKA_STATE_BACKEND_FACTORY_NAME, stateCheckpointMarkers))) + // exception during the first commit + when(this.taskCommitManager.upload(any(), any())) + .thenReturn(FutureUtil.failedFuture[util.Map[String, util.Map[String, String]]](new RuntimeException)) + + // First commit fails but should not throw exception + taskInstance.commit + verify(commitsCounter).inc() + verify(snapshotTimer).update(anyLong()) + verifyZeroInteractions(uploadTimer) + verifyZeroInteractions(commitTimer) + verifyZeroInteractions(skippedCounter) + waitForCommitExceptionIsSet(100, 5) + // Second commit should succeed + taskInstance.commit + verify(commitsCounter, times(2)).inc() // should only have been incremented twice - once for each commit + verify(commitExceptionCounter).inc() + } + + @Test + def testCommitThrowsIfAllowSkipCommitButExceptionCountReachMaxLimit(): Unit = { + val commitsCounter = mock[Counter] + when(this.metrics.commits).thenReturn(commitsCounter) + val snapshotTimer = mock[Timer] + when(this.metrics.snapshotNs).thenReturn(snapshotTimer) + val uploadTimer = mock[Timer] + when(this.metrics.asyncUploadNs).thenReturn(uploadTimer) + val commitTimer = mock[Timer] + when(this.metrics.commitNs).thenReturn(commitTimer) + val commitSyncTimer = mock[Timer] + when(this.metrics.commitSyncNs).thenReturn(commitSyncTimer) + val commitAsyncTimer = mock[Timer] + when(this.metrics.commitAsyncNs).thenReturn(commitAsyncTimer) + val cleanUpTimer = mock[Timer] + when(this.metrics.asyncCleanupNs).thenReturn(cleanUpTimer) + val skippedCounter = mock[Counter] + when(this.metrics.commitsSkipped).thenReturn(skippedCounter) + val lastCommitGauge = mock[Gauge[Long]] + when(this.metrics.lastCommitNs).thenReturn(lastCommitGauge) + val commitExceptionCounter = mock[Counter] + when(this.metrics.commitExceptions).thenReturn(commitExceptionCounter) + + val taskConfigsMap = new util.HashMap[String, String]() + taskConfigsMap.put("task.commit.ms", "-1") + taskConfigsMap.put("task.commit.max.delay.ms", "-1") + taskConfigsMap.put("task.commit.timeout.ms", "2000000") + // skip commit if exception occurs during the commit + taskConfigsMap.put("task.commit.skip.commit.during.failures.enabled", "true") + // should throw exception if second commit exception occurs + taskConfigsMap.put("task.commit.skip.commit.exception.max.limit", "1") + when(this.jobContext.getConfig).thenReturn(new MapConfig(taskConfigsMap)) + setupTaskInstance(None, ForkJoinPool.commonPool()) + + val inputOffsets = new util.HashMap[SystemStreamPartition, String]() + inputOffsets.put(SYSTEM_STREAM_PARTITION, "4") + when(this.offsetManager.getLastProcessedOffsets(TASK_NAME)).thenReturn(inputOffsets) + // exception for commits + when(this.taskCommitManager.upload(any(), any())) + .thenReturn(FutureUtil.failedFuture[util.Map[String, util.Map[String, String]]](new RuntimeException)) + + // First commit fails but should not throw exception + taskInstance.commit + waitForCommitExceptionIsSet(100, 5) + // Second commit fails but should not throw exception + taskInstance.commit + verify(commitExceptionCounter).inc() + verify(commitsCounter, times(2)).inc() + verify(snapshotTimer, times(2)).update(anyLong()) + verifyZeroInteractions(uploadTimer) + verifyZeroInteractions(commitTimer) + verifyZeroInteractions(skippedCounter) + waitForCommitExceptionIsSet(100, 5) + // third commit should fail as the the commit exception counter is greater than the max limit + try { + taskInstance.commit + fail("Should have thrown an exception if exception count reached the max limit.") + } catch { + case e: Exception => + // expected + } + verify(commitExceptionCounter, times(2)).inc() + verify(commitsCounter, times(2)).inc() + } + + @Test + def testCommitThrowsIfAllowSkipTimeoutButTimeoutCountReachMaxLimit(): Unit = { + val commitsCounter = mock[Counter] + when(this.metrics.commits).thenReturn(commitsCounter) + val snapshotTimer = mock[Timer] + when(this.metrics.snapshotNs).thenReturn(snapshotTimer) + val commitTimer = mock[Timer] + when(this.metrics.commitNs).thenReturn(commitTimer) + val commitSyncTimer = mock[Timer] + when(this.metrics.commitSyncNs).thenReturn(commitSyncTimer) + val commitAsyncTimer = mock[Timer] + when(this.metrics.commitAsyncNs).thenReturn(commitAsyncTimer) + val uploadTimer = mock[Timer] + when(this.metrics.asyncUploadNs).thenReturn(uploadTimer) + val cleanUpTimer = mock[Timer] + when(this.metrics.asyncCleanupNs).thenReturn(cleanUpTimer) + val skippedCounter = mock[Counter] + when(this.metrics.commitsSkipped).thenReturn(skippedCounter) + val commitsTimedOutCounter = mock[Counter] + when(this.metrics.commitsTimedOut).thenReturn(commitsTimedOutCounter) + val lastCommitGauge = mock[Gauge[Long]] + when(this.metrics.lastCommitNs).thenReturn(lastCommitGauge) + val commitExceptionCounter = mock[Counter] + when(this.metrics.commitExceptions).thenReturn(commitExceptionCounter) + + val inputOffsets = new util.HashMap[SystemStreamPartition, String]() + inputOffsets.put(SYSTEM_STREAM_PARTITION,"4") + val changelogSSP = new SystemStreamPartition(new SystemStream(SYSTEM_NAME, "test-changelog-stream"), new Partition(0)) + + val stateCheckpointMarkers: util.Map[String, String] = new util.HashMap[String, String]() + val stateCheckpointMarker = KafkaStateCheckpointMarker.serialize(new KafkaStateCheckpointMarker(changelogSSP, "5")) + stateCheckpointMarkers.put("storeName", stateCheckpointMarker) + when(this.offsetManager.getLastProcessedOffsets(TASK_NAME)).thenReturn(inputOffsets) + + val snapshotSCMs = ImmutableMap.of(KafkaStateCheckpointMarker.KAFKA_STATE_BACKEND_FACTORY_NAME, stateCheckpointMarkers) + when(this.taskCommitManager.snapshot(any())).thenReturn(snapshotSCMs) + val snapshotSCMFuture: CompletableFuture[util.Map[String, util.Map[String, String]]] = + CompletableFuture.completedFuture(snapshotSCMs) + + when(this.taskCommitManager.upload(any(), Matchers.eq(snapshotSCMs))).thenReturn(snapshotSCMFuture) // kafka is no-op + + val cleanUpFuture = new CompletableFuture[Void]() + when(this.taskCommitManager.cleanUp(any(), any())).thenReturn(cleanUpFuture) + + // use a separate executor to perform async operations on to test caller thread blocking behavior + val taskConfigsMap = new util.HashMap[String, String]() + taskConfigsMap.put("task.commit.ms", "-1") + // "block" immediately if previous commit async stage not complete + taskConfigsMap.put("task.commit.max.delay.ms", "-1") + taskConfigsMap.put("task.commit.timeout.ms", "0") // throw exception immediately if blocked + taskConfigsMap.put("task.commit.skip.commit.during.failures.enabled", "true") + // should throw exception if second commit timeout occurs + taskConfigsMap.put("task.commit.skip.commit.timeout.max.limit", "1") + when(this.jobContext.getConfig).thenReturn(new MapConfig(taskConfigsMap)) // override default behavior + + setupTaskInstance(None, ForkJoinPool.commonPool()) + + taskInstance.commit // async stage will not complete until cleanUpFuture is completed + taskInstance.commit // second commit found commit timeout and release the semaphore + + verifyZeroInteractions(commitExceptionCounter) + verifyZeroInteractions(skippedCounter) + verify(commitsTimedOutCounter).inc() + verify(commitsCounter, times(1)).inc() // should only have been incremented once now - second commit was skipped + taskInstance.commit // third commit should proceed without any issues and acquire the semaphore + try { + taskInstance.commit // fourth commit should throw exception as the timeout count reached the max limit + fail("Should have thrown an exception due to exceeding timeout limit.") + } catch { + case e: Exception => + // expected + } + verify(commitsTimedOutCounter, times(2)).inc() // incremented twice (second and fourth commit) + verify(commitsCounter, times(2)).inc() // incremented twice (first and third commit) + cleanUpFuture.complete(null) // just to unblock shared executor + } + /** * Given that no application task context factory is provided, then no lifecycle calls should be made. @@ -1091,6 +1303,17 @@ class TestTaskInstance extends AssertionsForJUnit with MockitoSugar { externalContextOption = Some(this.externalContext), elasticityFactor = elasticityFactor) } + private def waitForCommitExceptionIsSet(sleepTimeInMs: Int, maxRetry: Int): Unit = { + var retries = 0 + while (taskInstance.commitException.get() == null && retries < maxRetry) { + retries += 1 + Thread.sleep(sleepTimeInMs) + } + if (taskInstance.commitException.get() == null) { + fail("Should have set the commit exception.") + } + } + /** * Task type which has all task traits, which can be mocked. */