diff --git a/plugins/ingestion-kafka/src/main/java/org/opensearch/plugin/kafka/KafkaPartitionConsumer.java b/plugins/ingestion-kafka/src/main/java/org/opensearch/plugin/kafka/KafkaPartitionConsumer.java index 9461cfbc2de98..c749a887a2ccb 100644 --- a/plugins/ingestion-kafka/src/main/java/org/opensearch/plugin/kafka/KafkaPartitionConsumer.java +++ b/plugins/ingestion-kafka/src/main/java/org/opensearch/plugin/kafka/KafkaPartitionConsumer.java @@ -132,6 +132,11 @@ public KafkaOffset nextPointer() { return new KafkaOffset(lastFetchedOffset + 1); } + @Override + public KafkaOffset nextPointer(KafkaOffset pointer) { + return new KafkaOffset(pointer.getOffset() + 1); + } + @Override public IngestionShardPointer earliestPointer() { long startOffset = AccessController.doPrivileged( diff --git a/server/src/main/java/org/opensearch/cluster/metadata/IndexMetadata.java b/server/src/main/java/org/opensearch/cluster/metadata/IndexMetadata.java index d4fcadc4ac56d..bc718cf206a46 100644 --- a/server/src/main/java/org/opensearch/cluster/metadata/IndexMetadata.java +++ b/server/src/main/java/org/opensearch/cluster/metadata/IndexMetadata.java @@ -71,6 +71,7 @@ import org.opensearch.index.IndexModule; import org.opensearch.index.mapper.MapperService; import org.opensearch.index.seqno.SequenceNumbers; +import org.opensearch.indices.pollingingest.IngestionErrorStrategy; import org.opensearch.indices.pollingingest.StreamPoller; import org.opensearch.indices.replication.SegmentReplicationSource; import org.opensearch.indices.replication.common.ReplicationType; @@ -770,6 +771,30 @@ public Iterator> settings() { Property.Final ); + public static final String SETTING_INGESTION_SOURCE_ERROR_STRATEGY = "index.ingestion_source.error.strategy"; + public static final Setting INGESTION_SOURCE_ERROR_STRATEGY_SETTING = Setting.simpleString( + SETTING_INGESTION_SOURCE_ERROR_STRATEGY, + IngestionErrorStrategy.ErrorStrategy.DROP.name(), + new Setting.Validator<>() { + + @Override + public void validate(final String value) { + try { + IngestionErrorStrategy.ErrorStrategy.valueOf(value.toUpperCase(Locale.ROOT)); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException("Invalid value for " + SETTING_INGESTION_SOURCE_ERROR_STRATEGY + " [" + value + "]"); + } + } + + @Override + public void validate(final String value, final Map, Object> settings) { + validate(value); + } + }, + Property.IndexScope, + Property.Dynamic + ); + public static final Setting.AffixSetting INGESTION_SOURCE_PARAMS_SETTING = Setting.prefixKeySetting( "index.ingestion_source.param.", key -> new Setting<>(key, "", (value) -> { @@ -1001,8 +1026,13 @@ public IngestionSource getIngestionSource() { pointerInitResetType, pointerInitResetValue ); + + final String errorStrategyString = INGESTION_SOURCE_ERROR_STRATEGY_SETTING.get(settings); + IngestionErrorStrategy.ErrorStrategy errorStrategy = IngestionErrorStrategy.ErrorStrategy.valueOf( + errorStrategyString.toUpperCase(Locale.ROOT) + ); final Map ingestionSourceParams = INGESTION_SOURCE_PARAMS_SETTING.getAsMap(settings); - return new IngestionSource(ingestionSourceType, pointerInitReset, ingestionSourceParams); + return new IngestionSource(ingestionSourceType, pointerInitReset, errorStrategy, ingestionSourceParams); } return null; } diff --git a/server/src/main/java/org/opensearch/cluster/metadata/IngestionSource.java b/server/src/main/java/org/opensearch/cluster/metadata/IngestionSource.java index 9849c0a5f2ba9..fd28acf3246ad 100644 --- a/server/src/main/java/org/opensearch/cluster/metadata/IngestionSource.java +++ b/server/src/main/java/org/opensearch/cluster/metadata/IngestionSource.java @@ -9,6 +9,7 @@ package org.opensearch.cluster.metadata; import org.opensearch.common.annotation.ExperimentalApi; +import org.opensearch.indices.pollingingest.IngestionErrorStrategy; import org.opensearch.indices.pollingingest.StreamPoller; import java.util.Map; @@ -21,12 +22,19 @@ public class IngestionSource { private String type; private PointerInitReset pointerInitReset; + private IngestionErrorStrategy.ErrorStrategy errorStrategy; private Map params; - public IngestionSource(String type, PointerInitReset pointerInitReset, Map params) { + public IngestionSource( + String type, + PointerInitReset pointerInitReset, + IngestionErrorStrategy.ErrorStrategy errorStrategy, + Map params + ) { this.type = type; this.pointerInitReset = pointerInitReset; this.params = params; + this.errorStrategy = errorStrategy; } public String getType() { @@ -37,6 +45,10 @@ public PointerInitReset getPointerInitReset() { return pointerInitReset; } + public IngestionErrorStrategy.ErrorStrategy getErrorStrategy() { + return errorStrategy; + } + public Map params() { return params; } @@ -48,17 +60,30 @@ public boolean equals(Object o) { IngestionSource ingestionSource = (IngestionSource) o; return Objects.equals(type, ingestionSource.type) && Objects.equals(pointerInitReset, ingestionSource.pointerInitReset) + && Objects.equals(errorStrategy, ingestionSource.errorStrategy) && Objects.equals(params, ingestionSource.params); } @Override public int hashCode() { - return Objects.hash(type, pointerInitReset, params); + return Objects.hash(type, pointerInitReset, params, errorStrategy); } @Override public String toString() { - return "IngestionSource{" + "type='" + type + '\'' + ",pointer_init_reset='" + pointerInitReset + '\'' + ", params=" + params + '}'; + return "IngestionSource{" + + "type='" + + type + + '\'' + + ",pointer_init_reset='" + + pointerInitReset + + '\'' + + ",error_strategy='" + + errorStrategy + + '\'' + + ", params=" + + params + + '}'; } /** diff --git a/server/src/main/java/org/opensearch/common/settings/IndexScopedSettings.java b/server/src/main/java/org/opensearch/common/settings/IndexScopedSettings.java index 946d7fe734deb..6a926cdcf9119 100644 --- a/server/src/main/java/org/opensearch/common/settings/IndexScopedSettings.java +++ b/server/src/main/java/org/opensearch/common/settings/IndexScopedSettings.java @@ -265,6 +265,7 @@ public final class IndexScopedSettings extends AbstractScopedSettings { IndexMetadata.INGESTION_SOURCE_POINTER_INIT_RESET_SETTING, IndexMetadata.INGESTION_SOURCE_POINTER_INIT_RESET_VALUE_SETTING, IndexMetadata.INGESTION_SOURCE_PARAMS_SETTING, + IndexMetadata.INGESTION_SOURCE_ERROR_STRATEGY_SETTING, // validate that built-in similarities don't get redefined Setting.groupSetting("index.similarity.", (s) -> { diff --git a/server/src/main/java/org/opensearch/index/IngestionShardConsumer.java b/server/src/main/java/org/opensearch/index/IngestionShardConsumer.java index 41e659196a612..a9ffcaca850f2 100644 --- a/server/src/main/java/org/opensearch/index/IngestionShardConsumer.java +++ b/server/src/main/java/org/opensearch/index/IngestionShardConsumer.java @@ -72,6 +72,11 @@ public M getMessage() { */ T nextPointer(); + /** + * @return the immediate next pointer from the provided start pointer + */ + T nextPointer(T startPointer); + /** * @return the earliest pointer in the shard */ diff --git a/server/src/main/java/org/opensearch/index/engine/IngestionEngine.java b/server/src/main/java/org/opensearch/index/engine/IngestionEngine.java index b37281b9d1582..a6a1905bd39e0 100644 --- a/server/src/main/java/org/opensearch/index/engine/IngestionEngine.java +++ b/server/src/main/java/org/opensearch/index/engine/IngestionEngine.java @@ -55,6 +55,7 @@ import org.opensearch.index.translog.TranslogManager; import org.opensearch.index.translog.TranslogStats; import org.opensearch.indices.pollingingest.DefaultStreamPoller; +import org.opensearch.indices.pollingingest.IngestionErrorStrategy; import org.opensearch.indices.pollingingest.StreamPoller; import org.opensearch.search.suggest.completion.CompletionStats; import org.opensearch.threadpool.ThreadPool; @@ -189,7 +190,20 @@ public void start() { } String resetValue = ingestionSource.getPointerInitReset().getValue(); - streamPoller = new DefaultStreamPoller(startPointer, persistedPointers, ingestionShardConsumer, this, resetState, resetValue); + IngestionErrorStrategy ingestionErrorStrategy = IngestionErrorStrategy.create( + ingestionSource.getErrorStrategy(), + ingestionSource.getType() + ); + + streamPoller = new DefaultStreamPoller( + startPointer, + persistedPointers, + ingestionShardConsumer, + this, + resetState, + resetValue, + ingestionErrorStrategy + ); streamPoller.start(); } diff --git a/server/src/main/java/org/opensearch/indices/pollingingest/BlockIngestionErrorStrategy.java b/server/src/main/java/org/opensearch/indices/pollingingest/BlockIngestionErrorStrategy.java new file mode 100644 index 0000000000000..d0febd0909be2 --- /dev/null +++ b/server/src/main/java/org/opensearch/indices/pollingingest/BlockIngestionErrorStrategy.java @@ -0,0 +1,36 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.indices.pollingingest; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +/** + * This error handling strategy blocks on failures preventing processing of remaining updates in the ingestion source. + */ +public class BlockIngestionErrorStrategy implements IngestionErrorStrategy { + private static final Logger logger = LogManager.getLogger(BlockIngestionErrorStrategy.class); + private final String ingestionSource; + + public BlockIngestionErrorStrategy(String ingestionSource) { + this.ingestionSource = ingestionSource; + } + + @Override + public void handleError(Throwable e, ErrorStage stage) { + logger.error("Error processing update from {}: {}", ingestionSource, e); + + // todo: record blocking update and emit metrics + } + + @Override + public boolean shouldPauseIngestion(Throwable e, ErrorStage stage) { + return true; + } +} diff --git a/server/src/main/java/org/opensearch/indices/pollingingest/DefaultStreamPoller.java b/server/src/main/java/org/opensearch/indices/pollingingest/DefaultStreamPoller.java index 884cffec4aad5..f1ae00bdc98db 100644 --- a/server/src/main/java/org/opensearch/indices/pollingingest/DefaultStreamPoller.java +++ b/server/src/main/java/org/opensearch/indices/pollingingest/DefaultStreamPoller.java @@ -64,21 +64,25 @@ public class DefaultStreamPoller implements StreamPoller { @Nullable private IngestionShardPointer maxPersistedPointer; + private IngestionErrorStrategy errorStrategy; + public DefaultStreamPoller( IngestionShardPointer startPointer, Set persistedPointers, IngestionShardConsumer consumer, IngestionEngine ingestionEngine, ResetState resetState, - String resetValue + String resetValue, + IngestionErrorStrategy errorStrategy ) { this( startPointer, persistedPointers, consumer, - new MessageProcessorRunnable(new ArrayBlockingQueue<>(100), ingestionEngine), + new MessageProcessorRunnable(new ArrayBlockingQueue<>(100), ingestionEngine, errorStrategy), resetState, - resetValue + resetValue, + errorStrategy ); } @@ -88,7 +92,8 @@ public DefaultStreamPoller( IngestionShardConsumer consumer, MessageProcessorRunnable processorRunnable, ResetState resetState, - String resetValue + String resetValue, + IngestionErrorStrategy errorStrategy ) { this.consumer = Objects.requireNonNull(consumer); this.resetState = resetState; @@ -114,6 +119,7 @@ public DefaultStreamPoller( String.format(Locale.ROOT, "stream-poller-processor-%d-%d", consumer.getShardId(), System.currentTimeMillis()) ) ); + this.errorStrategy = errorStrategy; } @Override @@ -138,6 +144,9 @@ protected void startPoll() { } logger.info("Starting poller for shard {}", consumer.getShardId()); + // track the last record successfully written to the blocking queue + IngestionShardPointer lastSuccessfulPointer = null; + while (true) { try { if (closed) { @@ -205,6 +214,7 @@ protected void startPoll() { continue; } blockingQueue.put(result); + lastSuccessfulPointer = result.getPointer(); logger.debug( "Put message {} with pointer {} to the blocking queue", String.valueOf(result.getMessage().getPayload()), @@ -214,8 +224,18 @@ protected void startPoll() { // update the batch start pointer to the next batch batchStartPointer = consumer.nextPointer(); } catch (Throwable e) { - // TODO better error handling logger.error("Error in polling the shard {}: {}", consumer.getShardId(), e); + errorStrategy.handleError(e, IngestionErrorStrategy.ErrorStage.POLLING); + + if (errorStrategy.shouldPauseIngestion(e, IngestionErrorStrategy.ErrorStage.POLLING)) { + // Blocking error encountered. Pause poller to stop processing remaining updates. + pause(); + } else { + // Advance the batch start pointer to ignore the error and continue from next record + batchStartPointer = lastSuccessfulPointer == null + ? consumer.nextPointer(batchStartPointer) + : consumer.nextPointer(lastSuccessfulPointer); + } } } } diff --git a/server/src/main/java/org/opensearch/indices/pollingingest/DropIngestionErrorStrategy.java b/server/src/main/java/org/opensearch/indices/pollingingest/DropIngestionErrorStrategy.java new file mode 100644 index 0000000000000..4598bf1248cfd --- /dev/null +++ b/server/src/main/java/org/opensearch/indices/pollingingest/DropIngestionErrorStrategy.java @@ -0,0 +1,37 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.indices.pollingingest; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +/** + * This error handling strategy drops failures and proceeds with remaining updates in the ingestion source. + */ +public class DropIngestionErrorStrategy implements IngestionErrorStrategy { + private static final Logger logger = LogManager.getLogger(DropIngestionErrorStrategy.class); + private final String ingestionSource; + + public DropIngestionErrorStrategy(String ingestionSource) { + this.ingestionSource = ingestionSource; + } + + @Override + public void handleError(Throwable e, ErrorStage stage) { + logger.error("Error processing update from {}: {}", ingestionSource, e); + + // todo: record failed update stats and emit metrics + } + + @Override + public boolean shouldPauseIngestion(Throwable e, ErrorStage stage) { + return false; + } + +} diff --git a/server/src/main/java/org/opensearch/indices/pollingingest/IngestionErrorStrategy.java b/server/src/main/java/org/opensearch/indices/pollingingest/IngestionErrorStrategy.java new file mode 100644 index 0000000000000..79da988280c74 --- /dev/null +++ b/server/src/main/java/org/opensearch/indices/pollingingest/IngestionErrorStrategy.java @@ -0,0 +1,58 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.indices.pollingingest; + +import org.opensearch.common.annotation.ExperimentalApi; + +/** + * Defines the error handling strategy when an error is encountered either during polling records from ingestion source + * or during processing the polled records. + */ +@ExperimentalApi +public interface IngestionErrorStrategy { + + /** + * Process and record the error. + */ + void handleError(Throwable e, ErrorStage stage); + + /** + * Indicates if ingestion must be paused, blocking further writes. + */ + boolean shouldPauseIngestion(Throwable e, ErrorStage stage); + + static IngestionErrorStrategy create(ErrorStrategy errorStrategy, String ingestionSource) { + switch (errorStrategy) { + case BLOCK: + return new BlockIngestionErrorStrategy(ingestionSource); + case DROP: + default: + return new DropIngestionErrorStrategy(ingestionSource); + } + } + + /** + * Indicates available error handling strategies + */ + @ExperimentalApi + enum ErrorStrategy { + DROP, + BLOCK + } + + /** + * Indicates different stages of encountered errors + */ + @ExperimentalApi + enum ErrorStage { + POLLING, + PROCESSING + } + +} diff --git a/server/src/main/java/org/opensearch/indices/pollingingest/MessageProcessorRunnable.java b/server/src/main/java/org/opensearch/indices/pollingingest/MessageProcessorRunnable.java index 53f9353477869..f783e210cd80f 100644 --- a/server/src/main/java/org/opensearch/indices/pollingingest/MessageProcessorRunnable.java +++ b/server/src/main/java/org/opensearch/indices/pollingingest/MessageProcessorRunnable.java @@ -48,6 +48,7 @@ public class MessageProcessorRunnable implements Runnable { private final BlockingQueue> blockingQueue; private final MessageProcessor messageProcessor; + private IngestionErrorStrategy errorStrategy; private static final String ID = "_id"; private static final String OP_TYPE = "_op_type"; @@ -61,9 +62,10 @@ public class MessageProcessorRunnable implements Runnable { */ public MessageProcessorRunnable( BlockingQueue> blockingQueue, - IngestionEngine engine + IngestionEngine engine, + IngestionErrorStrategy errorStrategy ) { - this(blockingQueue, new MessageProcessor(engine)); + this(blockingQueue, new MessageProcessor(engine), errorStrategy); } /** @@ -73,10 +75,12 @@ public MessageProcessorRunnable( */ MessageProcessorRunnable( BlockingQueue> blockingQueue, - MessageProcessor messageProcessor + MessageProcessor messageProcessor, + IngestionErrorStrategy errorStrategy ) { this.blockingQueue = Objects.requireNonNull(blockingQueue); this.messageProcessor = messageProcessor; + this.errorStrategy = errorStrategy; } static class MessageProcessor { @@ -229,7 +233,14 @@ public void run() { Thread.currentThread().interrupt(); // Restore interrupt status } if (result != null) { - messageProcessor.process(result.getMessage(), result.getPointer()); + try { + messageProcessor.process(result.getMessage(), result.getPointer()); + } catch (Exception e) { + errorStrategy.handleError(e, IngestionErrorStrategy.ErrorStage.PROCESSING); + if (errorStrategy.shouldPauseIngestion(e, IngestionErrorStrategy.ErrorStage.PROCESSING)) { + Thread.currentThread().interrupt(); + } + } } } } diff --git a/server/src/test/java/org/opensearch/cluster/metadata/IngestionSourceTests.java b/server/src/test/java/org/opensearch/cluster/metadata/IngestionSourceTests.java index 0afe67002517b..05037f33c3965 100644 --- a/server/src/test/java/org/opensearch/cluster/metadata/IngestionSourceTests.java +++ b/server/src/test/java/org/opensearch/cluster/metadata/IngestionSourceTests.java @@ -14,6 +14,8 @@ import java.util.HashMap; import java.util.Map; +import static org.opensearch.indices.pollingingest.IngestionErrorStrategy.ErrorStrategy.DROP; + public class IngestionSourceTests extends OpenSearchTestCase { private final IngestionSource.PointerInitReset pointerInitReset = new IngestionSource.PointerInitReset( @@ -24,52 +26,50 @@ public class IngestionSourceTests extends OpenSearchTestCase { public void testConstructorAndGetters() { Map params = new HashMap<>(); params.put("key", "value"); - IngestionSource source = new IngestionSource("type", pointerInitReset, params); + IngestionSource source = new IngestionSource("type", pointerInitReset, DROP, params); assertEquals("type", source.getType()); assertEquals(StreamPoller.ResetState.REWIND_BY_OFFSET, source.getPointerInitReset().getType()); assertEquals("1000", source.getPointerInitReset().getValue()); + assertEquals(DROP, source.getErrorStrategy()); assertEquals(params, source.params()); } public void testEquals() { Map params1 = new HashMap<>(); params1.put("key", "value"); - IngestionSource source1 = new IngestionSource("type", pointerInitReset, params1); + IngestionSource source1 = new IngestionSource("type", pointerInitReset, DROP, params1); Map params2 = new HashMap<>(); params2.put("key", "value"); - IngestionSource source2 = new IngestionSource("type", pointerInitReset, params2); - + IngestionSource source2 = new IngestionSource("type", pointerInitReset, DROP, params2); assertTrue(source1.equals(source2)); assertTrue(source2.equals(source1)); - IngestionSource source3 = new IngestionSource("differentType", pointerInitReset, params1); + IngestionSource source3 = new IngestionSource("differentType", pointerInitReset, DROP, params1); assertFalse(source1.equals(source3)); } public void testHashCode() { Map params1 = new HashMap<>(); params1.put("key", "value"); - IngestionSource source1 = new IngestionSource("type", pointerInitReset, params1); + IngestionSource source1 = new IngestionSource("type", pointerInitReset, DROP, params1); Map params2 = new HashMap<>(); params2.put("key", "value"); - IngestionSource source2 = new IngestionSource("type", pointerInitReset, params2); - + IngestionSource source2 = new IngestionSource("type", pointerInitReset, DROP, params2); assertEquals(source1.hashCode(), source2.hashCode()); - IngestionSource source3 = new IngestionSource("differentType", pointerInitReset, params1); + IngestionSource source3 = new IngestionSource("differentType", pointerInitReset, DROP, params1); assertNotEquals(source1.hashCode(), source3.hashCode()); } public void testToString() { Map params = new HashMap<>(); params.put("key", "value"); - IngestionSource source = new IngestionSource("type", pointerInitReset, params); - + IngestionSource source = new IngestionSource("type", pointerInitReset, DROP, params); String expected = - "IngestionSource{type='type',pointer_init_reset='PointerInitReset{type='REWIND_BY_OFFSET', value=1000}', params={key=value}}"; + "IngestionSource{type='type',pointer_init_reset='PointerInitReset{type='REWIND_BY_OFFSET', value=1000}',error_strategy='DROP', params={key=value}}"; assertEquals(expected, source.toString()); } } diff --git a/server/src/test/java/org/opensearch/index/engine/FakeIngestionSource.java b/server/src/test/java/org/opensearch/index/engine/FakeIngestionSource.java index 1d81a22e94e9c..6233a65664d0b 100644 --- a/server/src/test/java/org/opensearch/index/engine/FakeIngestionSource.java +++ b/server/src/test/java/org/opensearch/index/engine/FakeIngestionSource.java @@ -83,6 +83,11 @@ public FakeIngestionShardPointer nextPointer() { return new FakeIngestionShardPointer(lastFetchedOffset + 1); } + @Override + public FakeIngestionShardPointer nextPointer(FakeIngestionShardPointer startPointer) { + return new FakeIngestionShardPointer(startPointer.offset + 1); + } + @Override public FakeIngestionShardPointer earliestPointer() { return new FakeIngestionShardPointer(0); diff --git a/server/src/test/java/org/opensearch/indices/pollingingest/DefaultStreamPollerTests.java b/server/src/test/java/org/opensearch/indices/pollingingest/DefaultStreamPollerTests.java index c17b11791af09..0f0f90f392242 100644 --- a/server/src/test/java/org/opensearch/indices/pollingingest/DefaultStreamPollerTests.java +++ b/server/src/test/java/org/opensearch/indices/pollingingest/DefaultStreamPollerTests.java @@ -8,6 +8,7 @@ package org.opensearch.indices.pollingingest; +import org.opensearch.index.IngestionShardConsumer; import org.opensearch.index.IngestionShardPointer; import org.opensearch.index.engine.FakeIngestionSource; import org.opensearch.test.OpenSearchTestCase; @@ -16,19 +17,27 @@ import java.nio.charset.StandardCharsets; import java.util.ArrayList; +import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.Set; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; public class DefaultStreamPollerTests extends OpenSearchTestCase { private DefaultStreamPoller poller; @@ -38,6 +47,8 @@ public class DefaultStreamPollerTests extends OpenSearchTestCase { private List messages; private Set persistedPointers; private final int awaitTime = 300; + private final int sleepTime = 300; + private DropIngestionErrorStrategy errorStrategy; @Before public void setUp() throws Exception { @@ -48,7 +59,8 @@ public void setUp() throws Exception { messages.add("{\"_id\":\"2\",\"_source\":{\"name\":\"alice\", \"age\": 21}}".getBytes(StandardCharsets.UTF_8)); fakeConsumer = new FakeIngestionSource.FakeIngestionConsumer(messages, 0); processor = mock(MessageProcessorRunnable.MessageProcessor.class); - processorRunnable = new MessageProcessorRunnable(new ArrayBlockingQueue<>(5), processor); + errorStrategy = new DropIngestionErrorStrategy("ingestion_source"); + processorRunnable = new MessageProcessorRunnable(new ArrayBlockingQueue<>(5), processor, errorStrategy); persistedPointers = new HashSet<>(); poller = new DefaultStreamPoller( new FakeIngestionSource.FakeIngestionShardPointer(0), @@ -56,7 +68,8 @@ public void setUp() throws Exception { fakeConsumer, processorRunnable, StreamPoller.ResetState.NONE, - "" + "", + errorStrategy ); } @@ -111,7 +124,8 @@ public void testSkipProcessed() throws InterruptedException { fakeConsumer, processorRunnable, StreamPoller.ResetState.NONE, - "" + "", + errorStrategy ); CountDownLatch latch = new CountDownLatch(2); @@ -147,7 +161,8 @@ public void testResetStateEarliest() throws InterruptedException { fakeConsumer, processorRunnable, StreamPoller.ResetState.EARLIEST, - "" + "", + errorStrategy ); CountDownLatch latch = new CountDownLatch(2); doAnswer(invocation -> { @@ -169,7 +184,8 @@ public void testResetStateLatest() throws InterruptedException { fakeConsumer, processorRunnable, StreamPoller.ResetState.LATEST, - "" + "", + errorStrategy ); poller.start(); @@ -187,7 +203,8 @@ public void testResetStateRewindByOffset() throws InterruptedException { fakeConsumer, processorRunnable, StreamPoller.ResetState.REWIND_BY_OFFSET, - "1" + "1", + errorStrategy ); CountDownLatch latch = new CountDownLatch(1); doAnswer(invocation -> { @@ -221,4 +238,112 @@ public void testStartClosedPoller() throws InterruptedException { assertEquals("poller is closed!", e.getMessage()); } } + + public void testDropErrorIngestionStrategy() throws TimeoutException, InterruptedException { + messages.add("{\"_id\":\"3\",\"_source\":{\"name\":\"bob\", \"age\": 24}}".getBytes(StandardCharsets.UTF_8)); + messages.add("{\"_id\":\"4\",\"_source\":{\"name\":\"alice\", \"age\": 21}}".getBytes(StandardCharsets.UTF_8)); + List< + IngestionShardConsumer.ReadResult< + FakeIngestionSource.FakeIngestionShardPointer, + FakeIngestionSource.FakeIngestionMessage>> readResultsBatch1 = fakeConsumer.readNext( + fakeConsumer.earliestPointer(), + 2, + 100 + ); + List< + IngestionShardConsumer.ReadResult< + FakeIngestionSource.FakeIngestionShardPointer, + FakeIngestionSource.FakeIngestionMessage>> readResultsBatch2 = fakeConsumer.readNext(fakeConsumer.nextPointer(), 2, 100); + IngestionShardConsumer mockConsumer = mock(IngestionShardConsumer.class); + when(mockConsumer.getShardId()).thenReturn(0); + when(mockConsumer.readNext(any(), anyLong(), anyInt())).thenThrow(new RuntimeException("message1 poll failed")) + .thenReturn(readResultsBatch1) + .thenThrow(new RuntimeException("message3 poll failed")) + .thenReturn(readResultsBatch2) + .thenReturn(Collections.emptyList()); + + IngestionErrorStrategy errorStrategy = spy(new DropIngestionErrorStrategy("ingestion_source")); + poller = new DefaultStreamPoller( + new FakeIngestionSource.FakeIngestionShardPointer(0), + persistedPointers, + mockConsumer, + processorRunnable, + StreamPoller.ResetState.NONE, + "", + errorStrategy + ); + poller.start(); + Thread.sleep(sleepTime); + + verify(errorStrategy, times(2)).handleError(any(), eq(IngestionErrorStrategy.ErrorStage.POLLING)); + verify(processor, times(4)).process(any(), any()); + } + + public void testBlockErrorIngestionStrategy() throws TimeoutException, InterruptedException { + messages.add("{\"_id\":\"3\",\"_source\":{\"name\":\"bob\", \"age\": 24}}".getBytes(StandardCharsets.UTF_8)); + messages.add("{\"_id\":\"4\",\"_source\":{\"name\":\"alice\", \"age\": 21}}".getBytes(StandardCharsets.UTF_8)); + List< + IngestionShardConsumer.ReadResult< + FakeIngestionSource.FakeIngestionShardPointer, + FakeIngestionSource.FakeIngestionMessage>> readResultsBatch1 = fakeConsumer.readNext( + fakeConsumer.earliestPointer(), + 2, + 100 + ); + List< + IngestionShardConsumer.ReadResult< + FakeIngestionSource.FakeIngestionShardPointer, + FakeIngestionSource.FakeIngestionMessage>> readResultsBatch2 = fakeConsumer.readNext(fakeConsumer.nextPointer(), 2, 100); + IngestionShardConsumer mockConsumer = mock(IngestionShardConsumer.class); + when(mockConsumer.getShardId()).thenReturn(0); + when(mockConsumer.readNext(any(), anyLong(), anyInt())).thenThrow(new RuntimeException("message1 poll failed")) + .thenReturn(readResultsBatch1) + .thenReturn(readResultsBatch2) + .thenReturn(Collections.emptyList()); + + IngestionErrorStrategy errorStrategy = spy(new BlockIngestionErrorStrategy("ingestion_source")); + poller = new DefaultStreamPoller( + new FakeIngestionSource.FakeIngestionShardPointer(0), + persistedPointers, + mockConsumer, + processorRunnable, + StreamPoller.ResetState.NONE, + "", + errorStrategy + ); + poller.start(); + Thread.sleep(sleepTime); + + verify(errorStrategy, times(1)).handleError(any(), eq(IngestionErrorStrategy.ErrorStage.POLLING)); + verify(processor, never()).process(any(), any()); + assertEquals(DefaultStreamPoller.State.PAUSED, poller.getState()); + assertTrue(poller.isPaused()); + } + + public void testProcessingErrorWithBlockErrorIngestionStrategy() throws TimeoutException, InterruptedException { + messages.add("{\"_id\":\"3\",\"_source\":{\"name\":\"bob\", \"age\": 24}}".getBytes(StandardCharsets.UTF_8)); + messages.add("{\"_id\":\"4\",\"_source\":{\"name\":\"alice\", \"age\": 21}}".getBytes(StandardCharsets.UTF_8)); + + doThrow(new RuntimeException("Error processing update")).when(processor).process(any(), any()); + BlockIngestionErrorStrategy mockErrorStrategy = spy(new BlockIngestionErrorStrategy("ingestion_source")); + processorRunnable = new MessageProcessorRunnable(new ArrayBlockingQueue<>(5), processor, mockErrorStrategy); + + poller = new DefaultStreamPoller( + new FakeIngestionSource.FakeIngestionShardPointer(0), + persistedPointers, + fakeConsumer, + processorRunnable, + StreamPoller.ResetState.NONE, + "", + mockErrorStrategy + ); + poller.start(); + Thread.sleep(sleepTime); + + verify(mockErrorStrategy, times(1)).handleError(any(), eq(IngestionErrorStrategy.ErrorStage.PROCESSING)); + verify(processor, times(1)).process(any(), any()); + // poller will continue to poll if an error is encountered during message processing but will be blocked by + // the write to blockingQueue + assertEquals(DefaultStreamPoller.State.POLLING, poller.getState()); + } }