From a0baf2ebea953e92e58512d08aa3aa87556151dc Mon Sep 17 00:00:00 2001 From: Weijie Guo Date: Wed, 17 Apr 2024 17:57:25 +0800 Subject: [PATCH] [FLINK-34549][API] Implement applyToAllPartitions for non-partitioned context --- .../context/DefaultNonPartitionedContext.java | 44 +++++- ...DefaultTwoOutputNonPartitionedContext.java | 52 ++++++- .../impl/operators/KeyedProcessOperator.java | 31 ++++ ...KeyedTwoInputBroadcastProcessOperator.java | 32 ++++ ...edTwoInputNonBroadcastProcessOperator.java | 37 +++++ .../KeyedTwoOutputProcessOperator.java | 31 ++++ .../impl/operators/ProcessOperator.java | 10 +- .../TwoInputBroadcastProcessOperator.java | 10 +- .../TwoInputNonBroadcastProcessOperator.java | 10 +- .../operators/TwoOutputProcessOperator.java | 7 +- .../impl/context/ContextTestUtils.java | 44 ++++++ .../DefaultNonPartitionedContextTest.java | 126 ++++++++++++++++ ...ultTwoOutputNonPartitionedContextTest.java | 142 ++++++++++++++++++ .../operators/KeyedProcessOperatorTest.java | 25 ++- ...dTwoInputBroadcastProcessOperatorTest.java | 43 +++++- ...oInputNonBroadcastProcessOperatorTest.java | 45 +++++- .../KeyedTwoOutputProcessOperatorTest.java | 27 +++- .../impl/operators/ProcessOperatorTest.java | 19 ++- .../TwoInputBroadcastProcessOperatorTest.java | 34 ++++- ...oInputNonBroadcastProcessOperatorTest.java | 33 +++- .../TwoOutputProcessOperatorTest.java | 22 ++- .../api/operators/AbstractStreamOperator.java | 8 + 22 files changed, 774 insertions(+), 58 deletions(-) create mode 100644 flink-datastream/src/test/java/org/apache/flink/datastream/impl/context/ContextTestUtils.java create mode 100644 flink-datastream/src/test/java/org/apache/flink/datastream/impl/context/DefaultNonPartitionedContextTest.java create mode 100644 flink-datastream/src/test/java/org/apache/flink/datastream/impl/context/DefaultTwoOutputNonPartitionedContextTest.java diff --git a/flink-datastream/src/main/java/org/apache/flink/datastream/impl/context/DefaultNonPartitionedContext.java b/flink-datastream/src/main/java/org/apache/flink/datastream/impl/context/DefaultNonPartitionedContext.java index 34e6df8b38f40d..0bbba016f40f94 100644 --- a/flink-datastream/src/main/java/org/apache/flink/datastream/impl/context/DefaultNonPartitionedContext.java +++ b/flink-datastream/src/main/java/org/apache/flink/datastream/impl/context/DefaultNonPartitionedContext.java @@ -18,23 +18,61 @@ package org.apache.flink.datastream.impl.context; +import org.apache.flink.datastream.api.common.Collector; import org.apache.flink.datastream.api.context.JobInfo; import org.apache.flink.datastream.api.context.NonPartitionedContext; import org.apache.flink.datastream.api.context.TaskInfo; import org.apache.flink.datastream.api.function.ApplyPartitionFunction; import org.apache.flink.metrics.MetricGroup; +import java.util.Set; + /** The default implementation of {@link NonPartitionedContext}. */ public class DefaultNonPartitionedContext implements NonPartitionedContext { private final DefaultRuntimeContext context; - public DefaultNonPartitionedContext(DefaultRuntimeContext context) { + private final DefaultPartitionedContext partitionedContext; + + private final Collector collector; + + private final boolean isKeyed; + + private final Set keySet; + + public DefaultNonPartitionedContext( + DefaultRuntimeContext context, + DefaultPartitionedContext partitionedContext, + Collector collector, + boolean isKeyed, + Set keySet) { this.context = context; + this.partitionedContext = partitionedContext; + this.collector = collector; + this.isKeyed = isKeyed; + this.keySet = keySet; } @Override - public void applyToAllPartitions(ApplyPartitionFunction applyPartitionFunction) { - // TODO implements this method. + public void applyToAllPartitions(ApplyPartitionFunction applyPartitionFunction) + throws Exception { + if (isKeyed) { + for (Object key : keySet) { + partitionedContext + .getStateManager() + .executeInKeyContext( + () -> { + try { + applyPartitionFunction.apply(collector, partitionedContext); + } catch (Exception e) { + throw new RuntimeException(e); + } + }, + key); + } + } else { + // non-keyed operator has only one partition. + applyPartitionFunction.apply(collector, partitionedContext); + } } @Override diff --git a/flink-datastream/src/main/java/org/apache/flink/datastream/impl/context/DefaultTwoOutputNonPartitionedContext.java b/flink-datastream/src/main/java/org/apache/flink/datastream/impl/context/DefaultTwoOutputNonPartitionedContext.java index 9b604379bbd99e..1a72476839df5f 100644 --- a/flink-datastream/src/main/java/org/apache/flink/datastream/impl/context/DefaultTwoOutputNonPartitionedContext.java +++ b/flink-datastream/src/main/java/org/apache/flink/datastream/impl/context/DefaultTwoOutputNonPartitionedContext.java @@ -18,25 +18,69 @@ package org.apache.flink.datastream.impl.context; +import org.apache.flink.datastream.api.common.Collector; import org.apache.flink.datastream.api.context.JobInfo; import org.apache.flink.datastream.api.context.TaskInfo; import org.apache.flink.datastream.api.context.TwoOutputNonPartitionedContext; import org.apache.flink.datastream.api.function.TwoOutputApplyPartitionFunction; import org.apache.flink.metrics.MetricGroup; +import java.util.Set; + /** The default implementation of {@link TwoOutputNonPartitionedContext}. */ public class DefaultTwoOutputNonPartitionedContext implements TwoOutputNonPartitionedContext { - private final DefaultRuntimeContext context; + protected final DefaultRuntimeContext context; + + private final DefaultPartitionedContext partitionedContext; + + protected final Collector firstCollector; + + protected final Collector secondCollector; + + private final boolean isKeyed; + + private final Set keySet; - public DefaultTwoOutputNonPartitionedContext(DefaultRuntimeContext context) { + public DefaultTwoOutputNonPartitionedContext( + DefaultRuntimeContext context, + DefaultPartitionedContext partitionedContext, + Collector firstCollector, + Collector secondCollector, + boolean isKeyed, + Set keySet) { this.context = context; + this.partitionedContext = partitionedContext; + this.firstCollector = firstCollector; + this.secondCollector = secondCollector; + this.isKeyed = isKeyed; + this.keySet = keySet; } @Override public void applyToAllPartitions( - TwoOutputApplyPartitionFunction applyPartitionFunction) { - // TODO implements this method. + TwoOutputApplyPartitionFunction applyPartitionFunction) throws Exception { + if (isKeyed) { + for (Object key : keySet) { + partitionedContext + .getStateManager() + .executeInKeyContext( + () -> { + try { + applyPartitionFunction.apply( + firstCollector, + secondCollector, + partitionedContext); + } catch (Exception e) { + throw new RuntimeException(e); + } + }, + key); + } + } else { + // non-keyed operator has only one partition. + applyPartitionFunction.apply(firstCollector, secondCollector, partitionedContext); + } } @Override diff --git a/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/KeyedProcessOperator.java b/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/KeyedProcessOperator.java index 1b729f8c21fe00..9f9c0bb6020ea4 100644 --- a/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/KeyedProcessOperator.java +++ b/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/KeyedProcessOperator.java @@ -19,26 +19,36 @@ package org.apache.flink.datastream.impl.operators; import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.datastream.api.context.NonPartitionedContext; import org.apache.flink.datastream.api.context.ProcessingTimeManager; import org.apache.flink.datastream.api.function.OneInputStreamProcessFunction; import org.apache.flink.datastream.api.stream.KeyedPartitionStream; import org.apache.flink.datastream.impl.common.KeyCheckedOutputCollector; import org.apache.flink.datastream.impl.common.OutputCollector; import org.apache.flink.datastream.impl.common.TimestampCollector; +import org.apache.flink.datastream.impl.context.DefaultNonPartitionedContext; import org.apache.flink.datastream.impl.context.DefaultProcessingTimeManager; import org.apache.flink.runtime.state.VoidNamespace; import org.apache.flink.runtime.state.VoidNamespaceSerializer; import org.apache.flink.streaming.api.operators.InternalTimer; import org.apache.flink.streaming.api.operators.InternalTimerService; import org.apache.flink.streaming.api.operators.Triggerable; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import javax.annotation.Nullable; +import java.util.HashSet; +import java.util.Set; + +import static org.apache.flink.util.Preconditions.checkNotNull; + /** Operator for {@link OneInputStreamProcessFunction} in {@link KeyedPartitionStream}. */ public class KeyedProcessOperator extends ProcessOperator implements Triggerable { private transient InternalTimerService timerService; + private transient Set keySet; + @Nullable private final KeySelector outKeySelector; public KeyedProcessOperator(OneInputStreamProcessFunction userFunction) { @@ -56,6 +66,7 @@ public KeyedProcessOperator( public void open() throws Exception { this.timerService = getInternalTimerService("processing timer", VoidNamespaceSerializer.INSTANCE, this); + this.keySet = new HashSet<>(); super.open(); } @@ -95,4 +106,24 @@ public void onProcessingTime(InternalTimer timer) throws Exc protected ProcessingTimeManager getProcessingTimeManager() { return new DefaultProcessingTimeManager(timerService); } + + @Override + protected NonPartitionedContext getNonPartitionedContext() { + return new DefaultNonPartitionedContext<>( + context, partitionedContext, outputCollector, true, keySet); + } + + @Override + @SuppressWarnings({"unchecked", "rawtypes"}) + public void setKeyContextElement1(StreamRecord record) throws Exception { + setKeyContextElement(record, getStateKeySelector1()); + } + + private void setKeyContextElement(StreamRecord record, KeySelector selector) + throws Exception { + checkNotNull(selector); + Object key = selector.getKey(record.getValue()); + setCurrentKey(key); + keySet.add(key); + } } diff --git a/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/KeyedTwoInputBroadcastProcessOperator.java b/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/KeyedTwoInputBroadcastProcessOperator.java index d303d0cf186c83..d46da49e26ae0a 100644 --- a/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/KeyedTwoInputBroadcastProcessOperator.java +++ b/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/KeyedTwoInputBroadcastProcessOperator.java @@ -19,27 +19,35 @@ package org.apache.flink.datastream.impl.operators; import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.datastream.api.context.NonPartitionedContext; import org.apache.flink.datastream.api.context.ProcessingTimeManager; import org.apache.flink.datastream.api.function.TwoInputBroadcastStreamProcessFunction; import org.apache.flink.datastream.api.stream.KeyedPartitionStream; import org.apache.flink.datastream.impl.common.KeyCheckedOutputCollector; import org.apache.flink.datastream.impl.common.OutputCollector; import org.apache.flink.datastream.impl.common.TimestampCollector; +import org.apache.flink.datastream.impl.context.DefaultNonPartitionedContext; import org.apache.flink.datastream.impl.context.DefaultProcessingTimeManager; import org.apache.flink.runtime.state.VoidNamespace; import org.apache.flink.runtime.state.VoidNamespaceSerializer; import org.apache.flink.streaming.api.operators.InternalTimer; import org.apache.flink.streaming.api.operators.InternalTimerService; import org.apache.flink.streaming.api.operators.Triggerable; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import javax.annotation.Nullable; +import java.util.HashSet; +import java.util.Set; + /** Operator for {@link TwoInputBroadcastStreamProcessFunction} in {@link KeyedPartitionStream}. */ public class KeyedTwoInputBroadcastProcessOperator extends TwoInputBroadcastProcessOperator implements Triggerable { private transient InternalTimerService timerService; + private transient Set keySet; + @Nullable private final KeySelector outKeySelector; public KeyedTwoInputBroadcastProcessOperator( @@ -58,6 +66,7 @@ public KeyedTwoInputBroadcastProcessOperator( public void open() throws Exception { this.timerService = getInternalTimerService("processing timer", VoidNamespaceSerializer.INSTANCE, this); + this.keySet = new HashSet<>(); super.open(); } @@ -96,4 +105,27 @@ public void onProcessingTime(InternalTimer timer) throws Exc partitionedContext), timer.getKey()); } + + @Override + protected NonPartitionedContext getNonPartitionedContext() { + return new DefaultNonPartitionedContext<>( + context, partitionedContext, collector, true, keySet); + } + + @Override + @SuppressWarnings({"unchecked", "rawtypes"}) + // Only element from input1 should be considered as the other side is broadcast input. + public void setKeyContextElement1(StreamRecord record) throws Exception { + setKeyContextElement(record, getStateKeySelector1()); + } + + private void setKeyContextElement(StreamRecord record, KeySelector selector) + throws Exception { + if (selector == null) { + return; + } + Object key = selector.getKey(record.getValue()); + setCurrentKey(key); + keySet.add(key); + } } diff --git a/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/KeyedTwoInputNonBroadcastProcessOperator.java b/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/KeyedTwoInputNonBroadcastProcessOperator.java index 36ef95835999f8..d646c2b934994a 100644 --- a/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/KeyedTwoInputNonBroadcastProcessOperator.java +++ b/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/KeyedTwoInputNonBroadcastProcessOperator.java @@ -19,21 +19,27 @@ package org.apache.flink.datastream.impl.operators; import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.datastream.api.context.NonPartitionedContext; import org.apache.flink.datastream.api.context.ProcessingTimeManager; import org.apache.flink.datastream.api.function.TwoInputNonBroadcastStreamProcessFunction; import org.apache.flink.datastream.api.stream.KeyedPartitionStream; import org.apache.flink.datastream.impl.common.KeyCheckedOutputCollector; import org.apache.flink.datastream.impl.common.OutputCollector; import org.apache.flink.datastream.impl.common.TimestampCollector; +import org.apache.flink.datastream.impl.context.DefaultNonPartitionedContext; import org.apache.flink.datastream.impl.context.DefaultProcessingTimeManager; import org.apache.flink.runtime.state.VoidNamespace; import org.apache.flink.runtime.state.VoidNamespaceSerializer; import org.apache.flink.streaming.api.operators.InternalTimer; import org.apache.flink.streaming.api.operators.InternalTimerService; import org.apache.flink.streaming.api.operators.Triggerable; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import javax.annotation.Nullable; +import java.util.HashSet; +import java.util.Set; + /** * Operator for {@link TwoInputNonBroadcastStreamProcessFunction} in {@link KeyedPartitionStream}. */ @@ -42,6 +48,8 @@ public class KeyedTwoInputNonBroadcastProcessOperator implements Triggerable { private transient InternalTimerService timerService; + private transient Set keySet; + @Nullable private final KeySelector outKeySelector; public KeyedTwoInputNonBroadcastProcessOperator( @@ -60,6 +68,7 @@ public KeyedTwoInputNonBroadcastProcessOperator( public void open() throws Exception { this.timerService = getInternalTimerService("processing timer", VoidNamespaceSerializer.INSTANCE, this); + this.keySet = new HashSet<>(); super.open(); } @@ -98,4 +107,32 @@ public void onProcessingTime(InternalTimer timer) throws Exc partitionedContext), timer.getKey()); } + + @Override + protected NonPartitionedContext getNonPartitionedContext() { + return new DefaultNonPartitionedContext<>( + context, partitionedContext, collector, true, keySet); + } + + @Override + @SuppressWarnings({"unchecked", "rawtypes"}) + public void setKeyContextElement1(StreamRecord record) throws Exception { + setKeyContextElement(record, getStateKeySelector1()); + } + + @Override + @SuppressWarnings({"unchecked", "rawtypes"}) + public void setKeyContextElement2(StreamRecord record) throws Exception { + setKeyContextElement(record, getStateKeySelector2()); + } + + private void setKeyContextElement(StreamRecord record, KeySelector selector) + throws Exception { + if (selector == null) { + return; + } + Object key = selector.getKey(record.getValue()); + setCurrentKey(key); + keySet.add(key); + } } diff --git a/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/KeyedTwoOutputProcessOperator.java b/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/KeyedTwoOutputProcessOperator.java index aa7de642bb1968..61035d1ebe4e7a 100644 --- a/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/KeyedTwoOutputProcessOperator.java +++ b/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/KeyedTwoOutputProcessOperator.java @@ -20,27 +20,37 @@ import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.datastream.api.context.ProcessingTimeManager; +import org.apache.flink.datastream.api.context.TwoOutputNonPartitionedContext; import org.apache.flink.datastream.api.function.TwoOutputStreamProcessFunction; import org.apache.flink.datastream.impl.common.KeyCheckedOutputCollector; import org.apache.flink.datastream.impl.common.OutputCollector; import org.apache.flink.datastream.impl.common.TimestampCollector; import org.apache.flink.datastream.impl.context.DefaultProcessingTimeManager; +import org.apache.flink.datastream.impl.context.DefaultTwoOutputNonPartitionedContext; import org.apache.flink.runtime.state.VoidNamespace; import org.apache.flink.runtime.state.VoidNamespaceSerializer; import org.apache.flink.streaming.api.operators.InternalTimer; import org.apache.flink.streaming.api.operators.InternalTimerService; import org.apache.flink.streaming.api.operators.Triggerable; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.util.OutputTag; import org.apache.flink.util.Preconditions; import javax.annotation.Nullable; +import java.util.HashSet; +import java.util.Set; + +import static org.apache.flink.util.Preconditions.checkNotNull; + /** */ public class KeyedTwoOutputProcessOperator extends TwoOutputProcessOperator implements Triggerable { private transient InternalTimerService timerService; + private transient Set keySet; + @Nullable private final KeySelector mainOutKeySelector; @Nullable private final KeySelector sideOutKeySelector; @@ -69,6 +79,7 @@ public KeyedTwoOutputProcessOperator( public void open() throws Exception { this.timerService = getInternalTimerService("processing timer", VoidNamespaceSerializer.INSTANCE, this); + this.keySet = new HashSet<>(); super.open(); } @@ -120,4 +131,24 @@ public void onProcessingTime(InternalTimer timer) throws Exc partitionedContext), timer.getKey()); } + + @Override + protected TwoOutputNonPartitionedContext getNonPartitionedContext() { + return new DefaultTwoOutputNonPartitionedContext<>( + context, partitionedContext, mainCollector, sideCollector, true, keySet); + } + + @Override + @SuppressWarnings({"unchecked", "rawtypes"}) + public void setKeyContextElement1(StreamRecord record) throws Exception { + setKeyContextElement(record, getStateKeySelector1()); + } + + private void setKeyContextElement(StreamRecord record, KeySelector selector) + throws Exception { + checkNotNull(selector); + Object key = selector.getKey(record.getValue()); + setCurrentKey(key); + keySet.add(key); + } } diff --git a/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/ProcessOperator.java b/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/ProcessOperator.java index be0fa04dd507e1..368b5bed247377 100644 --- a/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/ProcessOperator.java +++ b/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/ProcessOperator.java @@ -19,6 +19,7 @@ package org.apache.flink.datastream.impl.operators; import org.apache.flink.api.common.TaskInfo; +import org.apache.flink.datastream.api.context.NonPartitionedContext; import org.apache.flink.datastream.api.context.ProcessingTimeManager; import org.apache.flink.datastream.api.function.OneInputStreamProcessFunction; import org.apache.flink.datastream.impl.common.OutputCollector; @@ -43,7 +44,7 @@ public class ProcessOperator protected transient DefaultPartitionedContext partitionedContext; - protected transient DefaultNonPartitionedContext nonPartitionedContext; + protected transient NonPartitionedContext nonPartitionedContext; protected transient TimestampCollector outputCollector; @@ -69,8 +70,8 @@ public void open() throws Exception { partitionedContext = new DefaultPartitionedContext( context, this::currentKey, this::setCurrentKey, getProcessingTimeManager()); - nonPartitionedContext = new DefaultNonPartitionedContext<>(context); outputCollector = getOutputCollector(); + nonPartitionedContext = getNonPartitionedContext(); } @Override @@ -95,4 +96,9 @@ protected Object currentKey() { protected ProcessingTimeManager getProcessingTimeManager() { return UnsupportedProcessingTimeManager.INSTANCE; } + + protected NonPartitionedContext getNonPartitionedContext() { + return new DefaultNonPartitionedContext<>( + context, partitionedContext, outputCollector, false, null); + } } diff --git a/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/TwoInputBroadcastProcessOperator.java b/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/TwoInputBroadcastProcessOperator.java index 98a76edd5e0f39..4e6f663adf89ad 100644 --- a/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/TwoInputBroadcastProcessOperator.java +++ b/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/TwoInputBroadcastProcessOperator.java @@ -19,6 +19,7 @@ package org.apache.flink.datastream.impl.operators; import org.apache.flink.api.common.TaskInfo; +import org.apache.flink.datastream.api.context.NonPartitionedContext; import org.apache.flink.datastream.api.context.ProcessingTimeManager; import org.apache.flink.datastream.api.function.TwoInputBroadcastStreamProcessFunction; import org.apache.flink.datastream.impl.common.OutputCollector; @@ -48,7 +49,7 @@ public class TwoInputBroadcastProcessOperator protected transient DefaultPartitionedContext partitionedContext; - protected transient DefaultNonPartitionedContext nonPartitionedContext; + protected transient NonPartitionedContext nonPartitionedContext; public TwoInputBroadcastProcessOperator( TwoInputBroadcastStreamProcessFunction userFunction) { @@ -73,7 +74,7 @@ public void open() throws Exception { this.partitionedContext = new DefaultPartitionedContext( context, this::currentKey, this::setCurrentKey, getProcessingTimeManager()); - this.nonPartitionedContext = new DefaultNonPartitionedContext<>(context); + this.nonPartitionedContext = getNonPartitionedContext(); } @Override @@ -93,6 +94,11 @@ protected TimestampCollector getOutputCollector() { return new OutputCollector<>(output); } + protected NonPartitionedContext getNonPartitionedContext() { + return new DefaultNonPartitionedContext<>( + context, partitionedContext, collector, false, null); + } + @Override public void endInput(int inputId) throws Exception { // sanity check. diff --git a/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/TwoInputNonBroadcastProcessOperator.java b/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/TwoInputNonBroadcastProcessOperator.java index 982edcaa4dc5db..14eed8ad419938 100644 --- a/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/TwoInputNonBroadcastProcessOperator.java +++ b/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/TwoInputNonBroadcastProcessOperator.java @@ -19,6 +19,7 @@ package org.apache.flink.datastream.impl.operators; import org.apache.flink.api.common.TaskInfo; +import org.apache.flink.datastream.api.context.NonPartitionedContext; import org.apache.flink.datastream.api.context.ProcessingTimeManager; import org.apache.flink.datastream.api.function.TwoInputNonBroadcastStreamProcessFunction; import org.apache.flink.datastream.impl.common.OutputCollector; @@ -48,7 +49,7 @@ public class TwoInputNonBroadcastProcessOperator protected transient DefaultPartitionedContext partitionedContext; - protected transient DefaultNonPartitionedContext nonPartitionedContext; + protected transient NonPartitionedContext nonPartitionedContext; public TwoInputNonBroadcastProcessOperator( TwoInputNonBroadcastStreamProcessFunction userFunction) { @@ -73,7 +74,7 @@ public void open() throws Exception { this.partitionedContext = new DefaultPartitionedContext( context, this::currentKey, this::setCurrentKey, getProcessingTimeManager()); - this.nonPartitionedContext = new DefaultNonPartitionedContext<>(context); + this.nonPartitionedContext = getNonPartitionedContext(); } @Override @@ -93,6 +94,11 @@ protected TimestampCollector getOutputCollector() { return new OutputCollector<>(output); } + protected NonPartitionedContext getNonPartitionedContext() { + return new DefaultNonPartitionedContext<>( + context, partitionedContext, collector, false, null); + } + @Override public void endInput(int inputId) throws Exception { // sanity check. diff --git a/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/TwoOutputProcessOperator.java b/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/TwoOutputProcessOperator.java index 4f58db50ad6c53..43681f4a1b514e 100644 --- a/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/TwoOutputProcessOperator.java +++ b/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/TwoOutputProcessOperator.java @@ -84,7 +84,7 @@ public void open() throws Exception { this.partitionedContext = new DefaultPartitionedContext( context, this::currentKey, this::setCurrentKey, getProcessingTimeManager()); - this.nonPartitionedContext = new DefaultTwoOutputNonPartitionedContext<>(context); + this.nonPartitionedContext = getNonPartitionedContext(); } @Override @@ -112,6 +112,11 @@ protected Object currentKey() { throw new UnsupportedOperationException("The key is only defined for keyed operator"); } + protected TwoOutputNonPartitionedContext getNonPartitionedContext() { + return new DefaultTwoOutputNonPartitionedContext<>( + context, partitionedContext, mainCollector, sideCollector, false, null); + } + protected ProcessingTimeManager getProcessingTimeManager() { return UnsupportedProcessingTimeManager.INSTANCE; } diff --git a/flink-datastream/src/test/java/org/apache/flink/datastream/impl/context/ContextTestUtils.java b/flink-datastream/src/test/java/org/apache/flink/datastream/impl/context/ContextTestUtils.java new file mode 100644 index 00000000000000..5c8311e6de6e3c --- /dev/null +++ b/flink-datastream/src/test/java/org/apache/flink/datastream/impl/context/ContextTestUtils.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.datastream.impl.context; + +import org.apache.flink.runtime.jobgraph.JobType; +import org.apache.flink.runtime.memory.MemoryManager; +import org.apache.flink.runtime.operators.testutils.MockEnvironmentBuilder; +import org.apache.flink.streaming.api.operators.StreamingRuntimeContext; +import org.apache.flink.streaming.util.MockStreamingRuntimeContext; + +/** Test utils for things related to context. */ +public final class ContextTestUtils { + public static StreamingRuntimeContext createStreamingRuntimeContext() { + return new MockStreamingRuntimeContext( + false, + 2, + 1, + new MockEnvironmentBuilder() + .setTaskName("mockTask") + .setManagedMemorySize(4 * MemoryManager.DEFAULT_PAGE_SIZE) + .setParallelism(2) + .setMaxParallelism(2) + .setSubtaskIndex(1) + .setJobType(JobType.STREAMING) + .setJobName("mockJob") + .build()); + } +} diff --git a/flink-datastream/src/test/java/org/apache/flink/datastream/impl/context/DefaultNonPartitionedContextTest.java b/flink-datastream/src/test/java/org/apache/flink/datastream/impl/context/DefaultNonPartitionedContextTest.java new file mode 100644 index 00000000000000..1bd740f645323f --- /dev/null +++ b/flink-datastream/src/test/java/org/apache/flink/datastream/impl/context/DefaultNonPartitionedContextTest.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.datastream.impl.context; + +import org.apache.flink.datastream.impl.common.TestingTimestampCollector; +import org.apache.flink.streaming.api.operators.StreamingRuntimeContext; + +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.assertj.core.api.Assertions.assertThat; + +/** Tests for {@link DefaultNonPartitionedContext}. */ +class DefaultNonPartitionedContextTest { + @Test + void testApplyToAllPartitions() throws Exception { + AtomicInteger counter = new AtomicInteger(0); + List collectedData = new ArrayList<>(); + + TestingTimestampCollector collector = + TestingTimestampCollector.builder() + .setCollectConsumer(collectedData::add) + .build(); + CompletableFuture cf = new CompletableFuture<>(); + StreamingRuntimeContext operatorRuntimeContext = + ContextTestUtils.createStreamingRuntimeContext(); + DefaultRuntimeContext runtimeContext = + new DefaultRuntimeContext( + operatorRuntimeContext.getJobInfo().getJobName(), + operatorRuntimeContext.getJobType(), + 1, + 2, + "mock-task", + operatorRuntimeContext.getMetricGroup()); + DefaultNonPartitionedContext nonPartitionedContext = + new DefaultNonPartitionedContext<>( + runtimeContext, + new DefaultPartitionedContext( + runtimeContext, + Optional::empty, + (key) -> cf.complete(null), + UnsupportedProcessingTimeManager.INSTANCE), + collector, + false, + null); + nonPartitionedContext.applyToAllPartitions( + (out, ctx) -> { + counter.incrementAndGet(); + out.collect(10); + }); + assertThat(counter.get()).isEqualTo(1); + assertThat(cf).isNotCompleted(); + assertThat(collectedData).containsExactly(10); + } + + @Test + void testKeyedApplyToAllPartitions() throws Exception { + AtomicInteger counter = new AtomicInteger(0); + List collectedData = new ArrayList<>(); + + TestingTimestampCollector collector = + TestingTimestampCollector.builder() + .setCollectConsumer(collectedData::add) + .build(); + // put all keys + Set allKeys = new HashSet<>(); + allKeys.add(1); + allKeys.add(2); + allKeys.add(3); + + AtomicInteger currentKey = new AtomicInteger(-1); + StreamingRuntimeContext operatorRuntimeContext = + ContextTestUtils.createStreamingRuntimeContext(); + DefaultRuntimeContext runtimeContext = + new DefaultRuntimeContext( + operatorRuntimeContext.getJobInfo().getJobName(), + operatorRuntimeContext.getJobType(), + 1, + 2, + "mock-task", + operatorRuntimeContext.getMetricGroup()); + DefaultNonPartitionedContext nonPartitionedContext = + new DefaultNonPartitionedContext<>( + runtimeContext, + new DefaultPartitionedContext( + runtimeContext, + currentKey::get, + (key) -> currentKey.set((Integer) key), + UnsupportedProcessingTimeManager.INSTANCE), + collector, + true, + allKeys); + nonPartitionedContext.applyToAllPartitions( + (out, ctx) -> { + counter.incrementAndGet(); + Integer key = ctx.getStateManager().getCurrentKey(); + assertThat(key).isIn(allKeys); + out.collect(key); + }); + assertThat(counter.get()).isEqualTo(allKeys.size()); + assertThat(collectedData).containsExactlyInAnyOrder(1, 2, 3); + } +} diff --git a/flink-datastream/src/test/java/org/apache/flink/datastream/impl/context/DefaultTwoOutputNonPartitionedContextTest.java b/flink-datastream/src/test/java/org/apache/flink/datastream/impl/context/DefaultTwoOutputNonPartitionedContextTest.java new file mode 100644 index 00000000000000..fbad9f4bcba43c --- /dev/null +++ b/flink-datastream/src/test/java/org/apache/flink/datastream/impl/context/DefaultTwoOutputNonPartitionedContextTest.java @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.datastream.impl.context; + +import org.apache.flink.datastream.impl.common.TestingTimestampCollector; +import org.apache.flink.streaming.api.operators.StreamingRuntimeContext; + +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.assertj.core.api.Assertions.assertThat; + +/** Tests for {@link DefaultTwoOutputNonPartitionedContext}. */ +class DefaultTwoOutputNonPartitionedContextTest { + @Test + void testApplyToAllPartitions() throws Exception { + AtomicInteger counter = new AtomicInteger(0); + List collectedFromFirstOutput = new ArrayList<>(); + List collectedFromSecondOutput = new ArrayList<>(); + + TestingTimestampCollector firstCollector = + TestingTimestampCollector.builder() + .setCollectConsumer(collectedFromFirstOutput::add) + .build(); + TestingTimestampCollector secondCollector = + TestingTimestampCollector.builder() + .setCollectConsumer(collectedFromSecondOutput::add) + .build(); + CompletableFuture cf = new CompletableFuture<>(); + StreamingRuntimeContext operatorRuntimeContext = + ContextTestUtils.createStreamingRuntimeContext(); + DefaultRuntimeContext runtimeContext = + new DefaultRuntimeContext( + operatorRuntimeContext.getJobInfo().getJobName(), + operatorRuntimeContext.getJobType(), + 1, + 2, + "mock-task", + operatorRuntimeContext.getMetricGroup()); + DefaultTwoOutputNonPartitionedContext nonPartitionedContext = + new DefaultTwoOutputNonPartitionedContext<>( + runtimeContext, + new DefaultPartitionedContext( + runtimeContext, + Optional::empty, + (key) -> cf.complete(null), + UnsupportedProcessingTimeManager.INSTANCE), + firstCollector, + secondCollector, + false, + null); + nonPartitionedContext.applyToAllPartitions( + (firstOutput, secondOutput, ctx) -> { + counter.incrementAndGet(); + firstOutput.collect(10); + secondOutput.collect(20L); + }); + assertThat(counter.get()).isEqualTo(1); + assertThat(cf).isNotCompleted(); + assertThat(collectedFromFirstOutput).containsExactly(10); + assertThat(collectedFromSecondOutput).containsExactly(20L); + } + + @Test + void testKeyedApplyToAllPartitions() throws Exception { + AtomicInteger counter = new AtomicInteger(0); + List collectedFromFirstOutput = new ArrayList<>(); + List collectedFromSecondOutput = new ArrayList<>(); + + TestingTimestampCollector firstCollector = + TestingTimestampCollector.builder() + .setCollectConsumer(collectedFromFirstOutput::add) + .build(); + TestingTimestampCollector secondCollector = + TestingTimestampCollector.builder() + .setCollectConsumer(collectedFromSecondOutput::add) + .build(); + // put all keys + Set allKeys = new HashSet<>(); + allKeys.add(1); + allKeys.add(2); + allKeys.add(3); + + AtomicInteger currentKey = new AtomicInteger(-1); + StreamingRuntimeContext operatorRuntimeContext = + ContextTestUtils.createStreamingRuntimeContext(); + DefaultRuntimeContext runtimeContext = + new DefaultRuntimeContext( + operatorRuntimeContext.getJobInfo().getJobName(), + operatorRuntimeContext.getJobType(), + 1, + 2, + "mock-task", + operatorRuntimeContext.getMetricGroup()); + DefaultTwoOutputNonPartitionedContext nonPartitionedContext = + new DefaultTwoOutputNonPartitionedContext<>( + runtimeContext, + new DefaultPartitionedContext( + runtimeContext, + currentKey::get, + (key) -> currentKey.set((Integer) key), + UnsupportedProcessingTimeManager.INSTANCE), + firstCollector, + secondCollector, + true, + allKeys); + nonPartitionedContext.applyToAllPartitions( + (firstOut, secondOut, ctx) -> { + counter.incrementAndGet(); + Integer key = ctx.getStateManager().getCurrentKey(); + assertThat(key).isIn(allKeys); + firstOut.collect(key); + secondOut.collect(Long.valueOf(key)); + }); + assertThat(counter.get()).isEqualTo(allKeys.size()); + assertThat(collectedFromFirstOutput).containsExactlyInAnyOrder(1, 2, 3); + assertThat(collectedFromSecondOutput).containsExactlyInAnyOrder(1L, 2L, 3L); + } +} diff --git a/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/KeyedProcessOperatorTest.java b/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/KeyedProcessOperatorTest.java index 163ba106368b1a..d744ba003a1edf 100644 --- a/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/KeyedProcessOperatorTest.java +++ b/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/KeyedProcessOperatorTest.java @@ -30,7 +30,7 @@ import org.junit.jupiter.api.Test; import java.util.Collection; -import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicInteger; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -70,7 +70,7 @@ public void processRecord( @Test void testEndInput() throws Exception { - CompletableFuture future = new CompletableFuture<>(); + AtomicInteger counter = new AtomicInteger(); KeyedProcessOperator processOperator = new KeyedProcessOperator<>( new OneInputStreamProcessFunction() { @@ -84,7 +84,17 @@ public void processRecord( @Override public void endInput(NonPartitionedContext ctx) { - future.complete(null); + try { + ctx.applyToAllPartitions( + (out, context) -> { + counter.incrementAndGet(); + Integer currentKey = + context.getStateManager().getCurrentKey(); + out.collect(currentKey); + }); + } catch (Exception e) { + throw new RuntimeException(e); + } } }); @@ -94,8 +104,15 @@ public void endInput(NonPartitionedContext ctx) { (KeySelector) value -> value, Types.INT)) { testHarness.open(); + testHarness.processElement(new StreamRecord<>(1)); // key is 1 + testHarness.processElement(new StreamRecord<>(2)); // key is 2 + testHarness.processElement(new StreamRecord<>(3)); // key is 3 testHarness.endInput(); - assertThat(future).isCompleted(); + assertThat(counter).hasValue(3); + Collection> recordOutput = testHarness.getRecordOutput(); + assertThat(recordOutput) + .containsExactly( + new StreamRecord<>(1), new StreamRecord<>(2), new StreamRecord<>(3)); } } diff --git a/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/KeyedTwoInputBroadcastProcessOperatorTest.java b/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/KeyedTwoInputBroadcastProcessOperatorTest.java index c56a71c559dbc8..7ad6070073b3d8 100644 --- a/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/KeyedTwoInputBroadcastProcessOperatorTest.java +++ b/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/KeyedTwoInputBroadcastProcessOperatorTest.java @@ -30,8 +30,9 @@ import org.junit.jupiter.api.Test; import java.util.ArrayList; +import java.util.Collection; import java.util.List; -import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicInteger; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -79,8 +80,8 @@ public void processRecordFromBroadcastInput( @Test void testEndInput() throws Exception { - CompletableFuture nonBroadcastInputEnd = new CompletableFuture<>(); - CompletableFuture broadcastInputEnd = new CompletableFuture<>(); + AtomicInteger nonBroadcastInputCounter = new AtomicInteger(); + AtomicInteger broadcastInputCounter = new AtomicInteger(); KeyedTwoInputBroadcastProcessOperator processOperator = new KeyedTwoInputBroadcastProcessOperator<>( new TwoInputBroadcastStreamProcessFunction() { @@ -100,12 +101,32 @@ public void processRecordFromBroadcastInput( @Override public void endNonBroadcastInput(NonPartitionedContext ctx) { - nonBroadcastInputEnd.complete(null); + try { + ctx.applyToAllPartitions( + (out, context) -> { + nonBroadcastInputCounter.incrementAndGet(); + Long currentKey = + context.getStateManager().getCurrentKey(); + out.collect(currentKey); + }); + } catch (Exception e) { + throw new RuntimeException(e); + } } @Override public void endBroadcastInput(NonPartitionedContext ctx) { - broadcastInputEnd.complete(null); + try { + ctx.applyToAllPartitions( + (out, context) -> { + broadcastInputCounter.incrementAndGet(); + Long currentKey = + context.getStateManager().getCurrentKey(); + out.collect(currentKey); + }); + } catch (Exception e) { + throw new RuntimeException(e); + } } }); @@ -116,10 +137,18 @@ public void endBroadcastInput(NonPartitionedContext ctx) { (KeySelector) value -> value, Types.LONG)) { testHarness.open(); + testHarness.processElement1(new StreamRecord<>(1)); // key is 1L + testHarness.processElement2(new StreamRecord<>(2L)); // broadcast input is not keyed testHarness.endInput1(); - assertThat(nonBroadcastInputEnd).isCompleted(); + assertThat(nonBroadcastInputCounter).hasValue(1); + Collection> recordOutput = testHarness.getRecordOutput(); + assertThat(recordOutput).containsExactly(new StreamRecord<>(1L)); + testHarness.processElement2(new StreamRecord<>(3L)); // broadcast input is not keyed testHarness.endInput2(); - assertThat(broadcastInputEnd).isCompleted(); + assertThat(broadcastInputCounter).hasValue(1); + recordOutput = testHarness.getRecordOutput(); + assertThat(recordOutput) + .containsExactly(new StreamRecord<>(1L), new StreamRecord<>(1L)); } } diff --git a/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/KeyedTwoInputNonBroadcastProcessOperatorTest.java b/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/KeyedTwoInputNonBroadcastProcessOperatorTest.java index 324e91eddde555..75d9b616033281 100644 --- a/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/KeyedTwoInputNonBroadcastProcessOperatorTest.java +++ b/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/KeyedTwoInputNonBroadcastProcessOperatorTest.java @@ -30,7 +30,7 @@ import org.junit.jupiter.api.Test; import java.util.Collection; -import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicInteger; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -80,8 +80,8 @@ public void processRecordFromSecondInput( @Test void testEndInput() throws Exception { - CompletableFuture firstFuture = new CompletableFuture<>(); - CompletableFuture secondFuture = new CompletableFuture<>(); + AtomicInteger firstInputCounter = new AtomicInteger(); + AtomicInteger secondInputCounter = new AtomicInteger(); KeyedTwoInputNonBroadcastProcessOperator processOperator = new KeyedTwoInputNonBroadcastProcessOperator<>( new TwoInputNonBroadcastStreamProcessFunction() { @@ -101,12 +101,32 @@ public void processRecordFromSecondInput( @Override public void endFirstInput(NonPartitionedContext ctx) { - firstFuture.complete(null); + try { + ctx.applyToAllPartitions( + (out, context) -> { + firstInputCounter.incrementAndGet(); + Long currentKey = + context.getStateManager().getCurrentKey(); + out.collect(currentKey); + }); + } catch (Exception e) { + throw new RuntimeException(e); + } } @Override public void endSecondInput(NonPartitionedContext ctx) { - secondFuture.complete(null); + try { + ctx.applyToAllPartitions( + (out, context) -> { + secondInputCounter.incrementAndGet(); + Long currentKey = + context.getStateManager().getCurrentKey(); + out.collect(currentKey); + }); + } catch (Exception e) { + throw new RuntimeException(e); + } } }); @@ -117,10 +137,21 @@ public void endSecondInput(NonPartitionedContext ctx) { (KeySelector) value -> value, Types.LONG)) { testHarness.open(); + testHarness.processElement1(new StreamRecord<>(1)); // key is 1L + testHarness.processElement2(new StreamRecord<>(2L)); // key is 2L testHarness.endInput1(); - assertThat(firstFuture).isCompleted(); + assertThat(firstInputCounter).hasValue(2); + Collection> recordOutput = testHarness.getRecordOutput(); + assertThat(recordOutput) + .containsExactly(new StreamRecord<>(1L), new StreamRecord<>(2L)); + testHarness.processElement2(new StreamRecord<>(3L)); // key is 3L + testHarness.getOutput().clear(); testHarness.endInput2(); - assertThat(secondFuture).isCompleted(); + assertThat(secondInputCounter).hasValue(3); + recordOutput = testHarness.getRecordOutput(); + assertThat(recordOutput) + .containsExactly( + new StreamRecord<>(1L), new StreamRecord<>(2L), new StreamRecord<>(3L)); } } diff --git a/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/KeyedTwoOutputProcessOperatorTest.java b/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/KeyedTwoOutputProcessOperatorTest.java index b40a0e1289b379..875898d29e64ab 100644 --- a/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/KeyedTwoOutputProcessOperatorTest.java +++ b/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/KeyedTwoOutputProcessOperatorTest.java @@ -31,9 +31,9 @@ import org.junit.jupiter.api.Test; import java.util.Collection; -import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -82,7 +82,7 @@ public void processRecord( @Test void testEndInput() throws Exception { - CompletableFuture future = new CompletableFuture<>(); + AtomicInteger counter = new AtomicInteger(); OutputTag sideOutputTag = new OutputTag("side-output") {}; KeyedTwoOutputProcessOperator processOperator = @@ -100,7 +100,18 @@ public void processRecord( @Override public void endInput( TwoOutputNonPartitionedContext ctx) { - future.complete(null); + try { + ctx.applyToAllPartitions( + (firstOutput, secondOutput, context) -> { + counter.incrementAndGet(); + Integer currentKey = + context.getStateManager().getCurrentKey(); + firstOutput.collect(currentKey); + secondOutput.collect(Long.valueOf(currentKey)); + }); + } catch (Exception e) { + throw new RuntimeException(e); + } } }, sideOutputTag); @@ -111,8 +122,16 @@ public void endInput( (KeySelector) value -> value, Types.INT)) { testHarness.open(); + testHarness.processElement(new StreamRecord<>(1)); // key is 1 + testHarness.processElement(new StreamRecord<>(2)); // key is 2 testHarness.endInput(); - assertThat(future).isCompleted(); + assertThat(counter).hasValue(2); + Collection> firstOutput = testHarness.getRecordOutput(); + ConcurrentLinkedQueue> secondOutput = + testHarness.getSideOutput(sideOutputTag); + assertThat(firstOutput).containsExactly(new StreamRecord<>(1), new StreamRecord<>(2)); + assertThat(secondOutput) + .containsExactly(new StreamRecord<>(1L), new StreamRecord<>(2L)); } } diff --git a/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/ProcessOperatorTest.java b/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/ProcessOperatorTest.java index e33b92b913980b..fe4419b14e7ef8 100644 --- a/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/ProcessOperatorTest.java +++ b/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/ProcessOperatorTest.java @@ -28,7 +28,7 @@ import org.junit.jupiter.api.Test; import java.util.Collection; -import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicInteger; import static org.assertj.core.api.Assertions.assertThat; @@ -57,7 +57,7 @@ void testProcessRecord() throws Exception { @Test void testEndInput() throws Exception { - CompletableFuture future = new CompletableFuture<>(); + AtomicInteger counter = new AtomicInteger(); ProcessOperator processOperator = new ProcessOperator<>( new OneInputStreamProcessFunction() { @@ -71,7 +71,16 @@ public void processRecord( @Override public void endInput(NonPartitionedContext ctx) { - future.complete(null); + try { + ctx.applyToAllPartitions( + (out, context) -> { + counter.incrementAndGet(); + out.collect("end"); + }); + + } catch (Exception e) { + throw new RuntimeException(e); + } } }); @@ -79,7 +88,9 @@ public void endInput(NonPartitionedContext ctx) { new OneInputStreamOperatorTestHarness<>(processOperator)) { testHarness.open(); testHarness.endInput(); - assertThat(future).isCompleted(); + Collection> recordOutput = testHarness.getRecordOutput(); + assertThat(recordOutput).containsExactly(new StreamRecord<>("end")); + assertThat(counter).hasValue(1); } } } diff --git a/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/TwoInputBroadcastProcessOperatorTest.java b/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/TwoInputBroadcastProcessOperatorTest.java index 6bbcd33506affc..1e50d7ff86476f 100644 --- a/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/TwoInputBroadcastProcessOperatorTest.java +++ b/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/TwoInputBroadcastProcessOperatorTest.java @@ -28,8 +28,9 @@ import org.junit.jupiter.api.Test; import java.util.ArrayList; +import java.util.Collection; import java.util.List; -import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicInteger; import static org.assertj.core.api.Assertions.assertThat; @@ -73,8 +74,8 @@ public void processRecordFromBroadcastInput( @Test void testEndInput() throws Exception { - CompletableFuture nonBroadcastInputEnd = new CompletableFuture<>(); - CompletableFuture broadcastInputEnd = new CompletableFuture<>(); + AtomicInteger nonBroadcastInputCounter = new AtomicInteger(); + AtomicInteger broadcastInputCounter = new AtomicInteger(); TwoInputBroadcastProcessOperator processOperator = new TwoInputBroadcastProcessOperator<>( new TwoInputBroadcastStreamProcessFunction() { @@ -95,12 +96,28 @@ public void processRecordFromBroadcastInput( @Override public void endNonBroadcastInput(NonPartitionedContext ctx) { - nonBroadcastInputEnd.complete(null); + try { + ctx.applyToAllPartitions( + (out, context) -> { + nonBroadcastInputCounter.incrementAndGet(); + out.collect(1L); + }); + } catch (Exception e) { + throw new RuntimeException(e); + } } @Override public void endBroadcastInput(NonPartitionedContext ctx) { - broadcastInputEnd.complete(null); + try { + ctx.applyToAllPartitions( + (out, context) -> { + broadcastInputCounter.incrementAndGet(); + out.collect(2L); + }); + } catch (Exception e) { + throw new RuntimeException(e); + } } }); @@ -108,9 +125,12 @@ public void endBroadcastInput(NonPartitionedContext ctx) { new TwoInputStreamOperatorTestHarness<>(processOperator)) { testHarness.open(); testHarness.endInput1(); - assertThat(nonBroadcastInputEnd).isCompleted(); + assertThat(nonBroadcastInputCounter).hasValue(1); testHarness.endInput2(); - assertThat(broadcastInputEnd).isCompleted(); + assertThat(broadcastInputCounter).hasValue(1); + Collection> recordOutput = testHarness.getRecordOutput(); + assertThat(recordOutput) + .containsExactly(new StreamRecord<>(1L), new StreamRecord<>(2L)); } } } diff --git a/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/TwoInputNonBroadcastProcessOperatorTest.java b/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/TwoInputNonBroadcastProcessOperatorTest.java index a6c674fbb6d39f..e4774f3de8c832 100644 --- a/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/TwoInputNonBroadcastProcessOperatorTest.java +++ b/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/TwoInputNonBroadcastProcessOperatorTest.java @@ -28,7 +28,7 @@ import org.junit.jupiter.api.Test; import java.util.Collection; -import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicInteger; import static org.assertj.core.api.Assertions.assertThat; @@ -75,8 +75,8 @@ public void processRecordFromSecondInput( @Test void testEndInput() throws Exception { - CompletableFuture firstFuture = new CompletableFuture<>(); - CompletableFuture secondFuture = new CompletableFuture<>(); + AtomicInteger firstInputCounter = new AtomicInteger(); + AtomicInteger secondInputCounter = new AtomicInteger(); TwoInputNonBroadcastProcessOperator processOperator = new TwoInputNonBroadcastProcessOperator<>( new TwoInputNonBroadcastStreamProcessFunction() { @@ -96,12 +96,28 @@ public void processRecordFromSecondInput( @Override public void endFirstInput(NonPartitionedContext ctx) { - firstFuture.complete(null); + try { + ctx.applyToAllPartitions( + (out, context) -> { + firstInputCounter.incrementAndGet(); + out.collect(1L); + }); + } catch (Exception e) { + throw new RuntimeException(e); + } } @Override public void endSecondInput(NonPartitionedContext ctx) { - secondFuture.complete(null); + try { + ctx.applyToAllPartitions( + (out, context) -> { + secondInputCounter.incrementAndGet(); + out.collect(2L); + }); + } catch (Exception e) { + throw new RuntimeException(e); + } } }); @@ -109,9 +125,12 @@ public void endSecondInput(NonPartitionedContext ctx) { new TwoInputStreamOperatorTestHarness<>(processOperator)) { testHarness.open(); testHarness.endInput1(); - assertThat(firstFuture).isCompleted(); + assertThat(firstInputCounter).hasValue(1); testHarness.endInput2(); - assertThat(secondFuture).isCompleted(); + assertThat(secondInputCounter).hasValue(1); + Collection> recordOutput = testHarness.getRecordOutput(); + assertThat(recordOutput) + .containsExactly(new StreamRecord<>(1L), new StreamRecord<>(2L)); } } } diff --git a/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/TwoOutputProcessOperatorTest.java b/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/TwoOutputProcessOperatorTest.java index e3c6394990ce14..273259ae37dec8 100644 --- a/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/TwoOutputProcessOperatorTest.java +++ b/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/TwoOutputProcessOperatorTest.java @@ -29,8 +29,8 @@ import org.junit.jupiter.api.Test; import java.util.Collection; -import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.atomic.AtomicInteger; import static org.assertj.core.api.Assertions.assertThat; @@ -75,7 +75,7 @@ public void processRecord( @Test void testEndInput() throws Exception { - CompletableFuture future = new CompletableFuture<>(); + AtomicInteger counter = new AtomicInteger(); OutputTag sideOutputTag = new OutputTag("side-output") {}; TwoOutputProcessOperator processOperator = @@ -93,7 +93,16 @@ public void processRecord( @Override public void endInput( TwoOutputNonPartitionedContext ctx) { - future.complete(null); + try { + ctx.applyToAllPartitions( + (firstOutput, secondOutput, context) -> { + counter.incrementAndGet(); + firstOutput.collect(1); + secondOutput.collect(2L); + }); + } catch (Exception e) { + throw new RuntimeException(e); + } } }, sideOutputTag); @@ -102,7 +111,12 @@ public void endInput( new OneInputStreamOperatorTestHarness<>(processOperator)) { testHarness.open(); testHarness.endInput(); - assertThat(future).isCompleted(); + assertThat(counter).hasValue(1); + Collection> firstOutput = testHarness.getRecordOutput(); + ConcurrentLinkedQueue> secondOutput = + testHarness.getSideOutput(sideOutputTag); + assertThat(firstOutput).containsExactly(new StreamRecord<>(1)); + assertThat(secondOutput).containsExactly(new StreamRecord<>(2L)); } } } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java index eda7c254fd8a05..ca82984af6a70c 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java @@ -531,6 +531,14 @@ public KeyedStateStore getKeyedStateStore() { return stateHandler.getKeyedStateStore().orElse(null); } + protected KeySelector getStateKeySelector1() { + return stateKeySelector1; + } + + protected KeySelector getStateKeySelector2() { + return stateKeySelector2; + } + // ------------------------------------------------------------------------ // Context and chaining properties // ------------------------------------------------------------------------