Skip to content

Commit 360b97a

Browse files
committed
[FLINK-18445][table] Add pre-filtering optimization for lookup join
1 parent ee110aa commit 360b97a

File tree

24 files changed

+1917
-55
lines changed

24 files changed

+1917
-55
lines changed

flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecLookupJoin.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ public class BatchExecLookupJoin extends CommonExecLookupJoin
5050
public BatchExecLookupJoin(
5151
ReadableConfig tableConfig,
5252
FlinkJoinType joinType,
53-
@Nullable RexNode joinCondition,
53+
@Nullable RexNode preFilterCondition,
54+
@Nullable RexNode remainingJoinCondition,
5455
TemporalTableSourceSpec temporalTableSourceSpec,
5556
Map<Integer, LookupJoinUtil.LookupKey> lookupKeys,
5657
@Nullable List<RexNode> projectionOnTemporalTable,
@@ -64,7 +65,8 @@ public BatchExecLookupJoin(
6465
ExecNodeContext.newContext(BatchExecLookupJoin.class),
6566
ExecNodeContext.newPersistedConfig(BatchExecLookupJoin.class, tableConfig),
6667
joinType,
67-
joinCondition,
68+
preFilterCondition,
69+
remainingJoinCondition,
6870
temporalTableSourceSpec,
6971
lookupKeys,
7072
projectionOnTemporalTable,

flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecLookupJoin.java

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
import org.apache.flink.table.runtime.collector.ListenableCollector;
6161
import org.apache.flink.table.runtime.collector.TableFunctionResultFuture;
6262
import org.apache.flink.table.runtime.generated.GeneratedCollector;
63+
import org.apache.flink.table.runtime.generated.GeneratedFilterCondition;
6364
import org.apache.flink.table.runtime.generated.GeneratedFunction;
6465
import org.apache.flink.table.runtime.generated.GeneratedResultFuture;
6566
import org.apache.flink.table.runtime.operators.join.FlinkJoinType;
@@ -144,7 +145,8 @@ public abstract class CommonExecLookupJoin extends ExecNodeBase<RowData> {
144145
public static final String LOOKUP_JOIN_MATERIALIZE_TRANSFORMATION = "lookup-join-materialize";
145146

146147
public static final String FIELD_NAME_JOIN_TYPE = "joinType";
147-
public static final String FIELD_NAME_JOIN_CONDITION = "joinCondition";
148+
public static final String FIELD_NAME_PRE_FILTER_CONDITION = "preFilterCondition";
149+
public static final String FIELD_NAME_REMAINING_JOIN_CONDITION = "joinCondition";
148150
public static final String FIELD_NAME_TEMPORAL_TABLE = "temporalTable";
149151
public static final String FIELD_NAME_LOOKUP_KEYS = "lookupKeys";
150152
public static final String FIELD_NAME_PROJECTION_ON_TEMPORAL_TABLE =
@@ -175,9 +177,14 @@ public abstract class CommonExecLookupJoin extends ExecNodeBase<RowData> {
175177
@JsonProperty(FIELD_NAME_FILTER_ON_TEMPORAL_TABLE)
176178
private final @Nullable RexNode filterOnTemporalTable;
177179

178-
/** join condition except equi-conditions extracted as lookup keys. */
179-
@JsonProperty(FIELD_NAME_JOIN_CONDITION)
180-
private final @Nullable RexNode joinCondition;
180+
/** pre-filter condition on left input except lookup keys. */
181+
@JsonProperty(FIELD_NAME_PRE_FILTER_CONDITION)
182+
@JsonInclude(JsonInclude.Include.NON_NULL)
183+
private final @Nullable RexNode preFilterCondition;
184+
185+
/** remaining join condition except pre-filter & equi-conditions except lookup keys. */
186+
@JsonProperty(FIELD_NAME_REMAINING_JOIN_CONDITION)
187+
private final @Nullable RexNode remainingJoinCondition;
181188

182189
@JsonProperty(FIELD_NAME_INPUT_CHANGELOG_MODE)
183190
private final ChangelogMode inputChangelogMode;
@@ -195,7 +202,8 @@ protected CommonExecLookupJoin(
195202
ExecNodeContext context,
196203
ReadableConfig persistedConfig,
197204
FlinkJoinType joinType,
198-
@Nullable RexNode joinCondition,
205+
@Nullable RexNode preFilterCondition,
206+
@Nullable RexNode remainingJoinCondition,
199207
// TODO: refactor this into TableSourceTable, once legacy TableSource is removed
200208
TemporalTableSourceSpec temporalTableSourceSpec,
201209
Map<Integer, LookupJoinUtil.LookupKey> lookupKeys,
@@ -210,7 +218,8 @@ protected CommonExecLookupJoin(
210218
super(id, context, persistedConfig, inputProperties, outputType, description);
211219
checkArgument(inputProperties.size() == 1);
212220
this.joinType = checkNotNull(joinType);
213-
this.joinCondition = joinCondition;
221+
this.preFilterCondition = preFilterCondition;
222+
this.remainingJoinCondition = remainingJoinCondition;
214223
this.lookupKeys = Collections.unmodifiableMap(checkNotNull(lookupKeys));
215224
this.temporalTableSourceSpec = checkNotNull(temporalTableSourceSpec);
216225
this.projectionOnTemporalTable = projectionOnTemporalTable;
@@ -410,7 +419,11 @@ private StreamOperatorFactory<RowData> createAsyncLookupJoin(
410419
"TableFunctionResultFuture",
411420
inputRowType,
412421
rightRowType,
413-
JavaScalaConversionUtil.toScala(Optional.ofNullable(joinCondition)));
422+
JavaScalaConversionUtil.toScala(
423+
Optional.ofNullable(remainingJoinCondition)));
424+
GeneratedFilterCondition generatedPreFilterCondition =
425+
LookupJoinCodeGenerator.generatePreFilterCondition(
426+
config, classLoader, preFilterCondition, inputRowType);
414427

415428
DataStructureConverter<?, ?> fetcherConverter =
416429
DataStructureConverters.getConverter(generatedFuncWithType.dataType());
@@ -431,6 +444,7 @@ private StreamOperatorFactory<RowData> createAsyncLookupJoin(
431444
(DataStructureConverter<RowData, Object>) fetcherConverter,
432445
generatedCalc,
433446
generatedResultFuture,
447+
generatedPreFilterCondition,
434448
InternalSerializers.create(rightRowType),
435449
isLeftOuterJoin,
436450
asyncLookupOptions.asyncBufferCapacity);
@@ -441,6 +455,7 @@ private StreamOperatorFactory<RowData> createAsyncLookupJoin(
441455
generatedFuncWithType.tableFunc(),
442456
(DataStructureConverter<RowData, Object>) fetcherConverter,
443457
generatedResultFuture,
458+
generatedPreFilterCondition,
444459
InternalSerializers.create(rightRowType),
445460
isLeftOuterJoin,
446461
asyncLookupOptions.asyncBufferCapacity);
@@ -540,9 +555,14 @@ protected ProcessFunction<RowData, RowData> createSyncLookupJoinFunction(
540555
inputRowType,
541556
rightRowType,
542557
resultRowType,
543-
JavaScalaConversionUtil.toScala(Optional.ofNullable(joinCondition)),
558+
JavaScalaConversionUtil.toScala(
559+
Optional.ofNullable(remainingJoinCondition)),
544560
JavaScalaConversionUtil.toScala(Optional.empty()),
545561
true);
562+
563+
GeneratedFilterCondition generatedPreFilterCondition =
564+
LookupJoinCodeGenerator.generatePreFilterCondition(
565+
config, classLoader, preFilterCondition, inputRowType);
546566
ProcessFunction<RowData, RowData> processFunc;
547567
if (projectionOnTemporalTable != null) {
548568
// a projection or filter after table source scan
@@ -560,6 +580,7 @@ protected ProcessFunction<RowData, RowData> createSyncLookupJoinFunction(
560580
generatedFetcher,
561581
generatedCalc,
562582
generatedCollector,
583+
generatedPreFilterCondition,
563584
isLeftOuterJoin,
564585
rightRowType.getFieldCount());
565586
} else {
@@ -568,6 +589,7 @@ protected ProcessFunction<RowData, RowData> createSyncLookupJoinFunction(
568589
new LookupJoinRunner(
569590
generatedFetcher,
570591
generatedCollector,
592+
generatedPreFilterCondition,
571593
isLeftOuterJoin,
572594
rightRowType.getFieldCount());
573595
}

flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecLookupJoin.java

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,8 @@ public class StreamExecLookupJoin extends CommonExecLookupJoin
100100
public StreamExecLookupJoin(
101101
ReadableConfig tableConfig,
102102
FlinkJoinType joinType,
103-
@Nullable RexNode joinCondition,
103+
@Nullable RexNode preFilterCondition,
104+
@Nullable RexNode remainingJoinCondition,
104105
TemporalTableSourceSpec temporalTableSourceSpec,
105106
Map<Integer, LookupJoinUtil.LookupKey> lookupKeys,
106107
@Nullable List<RexNode> projectionOnTemporalTable,
@@ -118,7 +119,8 @@ public StreamExecLookupJoin(
118119
ExecNodeContext.newContext(StreamExecLookupJoin.class),
119120
ExecNodeContext.newPersistedConfig(StreamExecLookupJoin.class, tableConfig),
120121
joinType,
121-
joinCondition,
122+
preFilterCondition,
123+
remainingJoinCondition,
122124
temporalTableSourceSpec,
123125
lookupKeys,
124126
projectionOnTemporalTable,
@@ -143,7 +145,9 @@ public StreamExecLookupJoin(
143145
@JsonProperty(FIELD_NAME_TYPE) ExecNodeContext context,
144146
@JsonProperty(FIELD_NAME_CONFIGURATION) ReadableConfig persistedConfig,
145147
@JsonProperty(FIELD_NAME_JOIN_TYPE) FlinkJoinType joinType,
146-
@JsonProperty(FIELD_NAME_JOIN_CONDITION) @Nullable RexNode joinCondition,
148+
@JsonProperty(FIELD_NAME_PRE_FILTER_CONDITION) @Nullable RexNode preFilterCondition,
149+
@JsonProperty(FIELD_NAME_REMAINING_JOIN_CONDITION) @Nullable
150+
RexNode remainingJoinCondition,
147151
@JsonProperty(FIELD_NAME_TEMPORAL_TABLE)
148152
TemporalTableSourceSpec temporalTableSourceSpec,
149153
@JsonProperty(FIELD_NAME_LOOKUP_KEYS) Map<Integer, LookupJoinUtil.LookupKey> lookupKeys,
@@ -169,7 +173,8 @@ public StreamExecLookupJoin(
169173
context,
170174
persistedConfig,
171175
joinType,
172-
joinCondition,
176+
preFilterCondition,
177+
remainingJoinCondition,
173178
temporalTableSourceSpec,
174179
lookupKeys,
175180
projectionOnTemporalTable,

flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/optimize/StreamNonDeterministicUpdatePlanVisitor.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -301,8 +301,10 @@ public StreamPhysicalRel visit(
301301

302302
// required determinism cannot be satisfied even upsert materialize was enabled if:
303303
// 1. remaining join condition contains non-deterministic call
304-
JavaScalaConversionUtil.toJava(lookupJoin.remainingCondition())
305-
.ifPresent(condi -> checkNonDeterministicCondition(condi, lookupJoin));
304+
JavaScalaConversionUtil.toJava(lookupJoin.finalPreFilterCondition())
305+
.ifPresent(cond -> checkNonDeterministicCondition(cond, lookupJoin));
306+
JavaScalaConversionUtil.toJava(lookupJoin.finalRemainingCondition())
307+
.ifPresent(cond -> checkNonDeterministicCondition(cond, lookupJoin));
306308

307309
// 2. inner calc in lookJoin contains either non-deterministic condition or calls
308310
JavaScalaConversionUtil.toJava(lookupJoin.calcOnTemporalTable())

flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUpsertKeys.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,8 @@ class FlinkRelMdUpsertKeys private extends MetadataHandler[UpsertKeys] {
237237
val rightUniqueKeys = FlinkRelMdUniqueKeys.INSTANCE.getUniqueKeysOfTemporalTable(join)
238238

239239
val remainingConditionNonDeterministic =
240-
join.remainingCondition.exists(c => !RexUtil.isDeterministic(c))
240+
join.finalPreFilterCondition.exists(c => !RexUtil.isDeterministic(c)) ||
241+
join.finalRemainingCondition.exists(c => !RexUtil.isDeterministic(c))
241242
lazy val calcOnTemporalTableNonDeterministic =
242243
join.calcOnTemporalTable.exists(p => !FlinkRexUtil.isDeterministic(p))
243244

flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchPhysicalLookupJoin.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ class BatchPhysicalLookupJoin(
7676
new BatchExecLookupJoin(
7777
tableConfig,
7878
JoinTypeUtil.getFlinkJoinType(joinType),
79-
remainingCondition.orNull,
79+
finalPreFilterCondition.orNull,
80+
finalRemainingCondition.orNull,
8081
new TemporalTableSourceSpec(temporalTable),
8182
allLookupKeys.map(item => (Int.box(item._1), item._2)).asJava,
8283
projectionOnTemporalTable,

flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/common/CommonPhysicalLookupJoin.scala

Lines changed: 62 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import org.apache.flink.table.planner.calcite.FlinkTypeFactory
2424
import org.apache.flink.table.planner.plan.nodes.FlinkRelNode
2525
import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalRel
2626
import org.apache.flink.table.planner.plan.schema.{IntermediateRelTable, LegacyTableSourceTable, TableSourceTable}
27-
import org.apache.flink.table.planner.plan.utils.{ChangelogPlanUtils, ExpressionFormat, JoinTypeUtil, LookupJoinUtil, RelExplainUtil}
27+
import org.apache.flink.table.planner.plan.utils.{ChangelogPlanUtils, ExpressionFormat, InputRefVisitor, JoinTypeUtil, LookupJoinUtil, RelExplainUtil}
2828
import org.apache.flink.table.planner.plan.utils.ExpressionFormat.ExpressionFormat
2929
import org.apache.flink.table.planner.plan.utils.LookupJoinUtil._
3030
import org.apache.flink.table.planner.plan.utils.PythonUtil.containsPythonCall
@@ -98,11 +98,11 @@ abstract class CommonPhysicalLookupJoin(
9898
// all potential index keys, mapping from field index in table source to LookupKey
9999
analyzeLookupKeys(cluster.getRexBuilder, joinKeyPairs, calcOnTemporalTable)
100100
}
101-
// remaining condition used to filter the joined records (left input record X lookup-ed records)
102-
val remainingCondition: Option[RexNode] = getRemainingJoinCondition(
101+
// split join condition(except the lookup keys) into pre-filter(used to filter the left input
102+
// before lookup) and remaining parts(used to filter the joined records)
103+
val (finalPreFilterCondition, finalRemainingCondition) = splitJoinCondition(
103104
cluster.getRexBuilder,
104105
inputRel.getRowType,
105-
calcOnTemporalTable,
106106
allLookupKeys.values.toList,
107107
joinInfo)
108108

@@ -195,12 +195,9 @@ abstract class CommonPhysicalLookupJoin(
195195
.itemIf("where", whereString, whereString.nonEmpty)
196196
.itemIf(
197197
"joinCondition",
198-
joinConditionToString(
199-
resultFieldNames,
200-
remainingCondition,
201-
preferExpressionFormat(pw),
202-
pw.getDetailLevel),
203-
remainingCondition.isDefined)
198+
joinConditionToString(resultFieldNames, preferExpressionFormat(pw), pw.getDetailLevel),
199+
finalRemainingCondition.isDefined || finalPreFilterCondition.isDefined
200+
)
204201
.item("select", selection)
205202
.itemIf("upsertMaterialize", "true", upsertMaterialize)
206203
.itemIf("async", asyncOptions.getOrElse(""), asyncOptions.isDefined)
@@ -217,13 +214,15 @@ abstract class CommonPhysicalLookupJoin(
217214
case _ => ChangelogMode.insertOnly()
218215
}
219216

220-
/** Gets the remaining join condition which is used */
221-
private def getRemainingJoinCondition(
217+
/**
218+
* Splits the remaining condition in joinInfo into the pre-filter(used to filter the left input
219+
* before lookup) and remaining parts(used to filter the joined records).
220+
*/
221+
private def splitJoinCondition(
222222
rexBuilder: RexBuilder,
223223
leftRelDataType: RelDataType,
224-
calcOnTemporalTable: Option[RexProgram],
225224
leftKeys: List[LookupKey],
226-
joinInfo: JoinInfo): Option[RexNode] = {
225+
joinInfo: JoinInfo): (Option[RexNode], Option[RexNode]) = {
227226
// indexes of left key fields
228227
val leftKeyIndexes =
229228
leftKeys
@@ -244,9 +243,31 @@ abstract class CommonPhysicalLookupJoin(
244243
val rightInputRef = new RexInputRef(rightIndex, rightFieldType)
245244
rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, leftInputRef, rightInputRef)
246245
}
247-
val remainingAnds = remainingEquals ++ joinInfo.nonEquiConditions.asScala
248-
// build a new condition
249-
val condition = RexUtil.composeConjunction(rexBuilder, remainingAnds.toList.asJava)
246+
if (joinType.generatesNullsOnRight) {
247+
// only extract pre-filter for left & full outer joins(otherwise the pre-filter will always be pushed down)
248+
val (leftLocal, remaining) =
249+
joinInfo.nonEquiConditions.asScala.partition {
250+
r =>
251+
{
252+
val inputRefs = new InputRefVisitor()
253+
r.accept(inputRefs)
254+
// if all input refs belong to left
255+
inputRefs.getFields.forall(idx => idx < inputRel.getRowType.getFieldCount)
256+
}
257+
}
258+
val remainingAnds = remainingEquals ++ remaining
259+
// build final pre-filter and remaining conditions
260+
(
261+
composeCondition(rexBuilder, leftLocal.toList),
262+
composeCondition(rexBuilder, remainingAnds.toList))
263+
} else {
264+
val remainingAnds = remainingEquals ++ joinInfo.nonEquiConditions.asScala
265+
(None, composeCondition(rexBuilder, remainingAnds.toList))
266+
}
267+
}
268+
269+
private def composeCondition(rexBuilder: RexBuilder, rexNodes: List[RexNode]): Option[RexNode] = {
270+
val condition = RexUtil.composeConjunction(rexBuilder, rexNodes.asJava)
250271
if (condition.isAlwaysTrue) {
251272
None
252273
} else {
@@ -466,16 +487,30 @@ abstract class CommonPhysicalLookupJoin(
466487

467488
private def joinConditionToString(
468489
resultFieldNames: Array[String],
469-
joinCondition: Option[RexNode],
470490
expressionFormat: ExpressionFormat = ExpressionFormat.Prefix,
471-
sqlExplainLevel: SqlExplainLevel): String = joinCondition match {
472-
case Some(condition) =>
473-
getExpressionString(
474-
condition,
475-
resultFieldNames.toList,
476-
None,
477-
expressionFormat,
478-
sqlExplainLevel)
479-
case None => "N/A"
491+
sqlExplainLevel: SqlExplainLevel): String = {
492+
493+
def appendCondition(sb: StringBuilder, cond: Option[RexNode]): Unit = {
494+
cond match {
495+
case Some(condition) =>
496+
sb.append(
497+
getExpressionString(
498+
condition,
499+
resultFieldNames.toList,
500+
None,
501+
expressionFormat,
502+
sqlExplainLevel))
503+
case _ =>
504+
}
505+
}
506+
507+
if (finalPreFilterCondition.isEmpty && finalRemainingCondition.isEmpty) {
508+
"N/A"
509+
} else {
510+
val sb = new StringBuilder
511+
appendCondition(sb, finalPreFilterCondition)
512+
appendCondition(sb, finalRemainingCondition)
513+
sb.toString()
514+
}
480515
}
481516
}

flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalLookupJoin.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,8 @@ class StreamPhysicalLookupJoin(
100100
new StreamExecLookupJoin(
101101
tableConfig,
102102
JoinTypeUtil.getFlinkJoinType(joinType),
103-
remainingCondition.orNull,
103+
finalPreFilterCondition.orNull,
104+
finalRemainingCondition.orNull,
104105
new TemporalTableSourceSpec(temporalTable),
105106
allLookupKeys.map(item => (Int.box(item._1), item._2)).asJava,
106107
projectionOnTemporalTable,

flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/LookupJoinJsonPlanTest.java

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,4 +244,30 @@ public void testJoinTemporalTableWithAsyncRetryHint2() {
244244
+ "FROM MyTable AS T JOIN LookupTable "
245245
+ "FOR SYSTEM_TIME AS OF T.proctime AS D ON T.a = D.id");
246246
}
247+
248+
@Test
249+
public void testLeftJoinTemporalTableWithPreFilter() {
250+
util.verifyJsonPlan(
251+
"INSERT INTO MySink1 SELECT * "
252+
+ "FROM MyTable AS T LEFT JOIN LookupTable "
253+
+ "FOR SYSTEM_TIME AS OF T.proctime AS D ON T.a = D.id AND T.b > 'abc'");
254+
}
255+
256+
@Test
257+
public void testLeftJoinTemporalTableWithPostFilter() {
258+
util.verifyJsonPlan(
259+
"INSERT INTO MySink1 SELECT * "
260+
+ "FROM MyTable AS T LEFT JOIN LookupTable "
261+
+ "FOR SYSTEM_TIME AS OF T.proctime AS D ON T.a = D.id "
262+
+ "AND CHAR_LENGTH(D.name) > CHAR_LENGTH(T.b)");
263+
}
264+
265+
@Test
266+
public void testLeftJoinTemporalTableWithMultiJoinConditions() {
267+
util.verifyJsonPlan(
268+
"INSERT INTO MySink1 SELECT * "
269+
+ "FROM MyTable AS T LEFT JOIN LookupTable "
270+
+ "FOR SYSTEM_TIME AS OF T.proctime AS D "
271+
+ "ON T.a = D.id AND T.b > 'abc' AND T.b <> D.name AND T.c = 100");
272+
}
247273
}

0 commit comments

Comments
 (0)