Skip to content

Commit d8baa7a

Browse files
authored
[FLINK-30687][table] Fix wrong result of agg with fiter which references first input column
This closes #26303.
1 parent 7d5609f commit d8baa7a

File tree

4 files changed

+35
-1
lines changed

4 files changed

+35
-1
lines changed

flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/AggsHandlerCodeGenerator.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ class AggsHandlerCodeGenerator(
330330
aggIndex: Int,
331331
aggName: String): Option[Expression] = {
332332

333-
if (filterArg > 0) {
333+
if (filterArg >= 0) {
334334
val filterType = inputFieldTypes(filterArg)
335335
if (!filterType.isInstanceOf[BooleanType]) {
336336
throw new TableException(

flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/codegen/agg/AggTestBase.scala

+9
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ abstract class AggTestBase(isBatchMode: Boolean) {
6767
val aggInfo1: AggregateInfo = {
6868
val aggInfo = mock(classOf[AggregateInfo])
6969
val call = mock(classOf[AggregateCall])
70+
updateFilter(call, -1)
7071
when(aggInfo.agg).thenReturn(call)
7172
when(call.getName).thenReturn("avg1")
7273
when(call.hasFilter).thenReturn(false)
@@ -82,6 +83,7 @@ abstract class AggTestBase(isBatchMode: Boolean) {
8283
val aggInfo2: AggregateInfo = {
8384
val aggInfo = mock(classOf[AggregateInfo])
8485
val call = mock(classOf[AggregateCall])
86+
updateFilter(call, -1)
8587
when(aggInfo.agg).thenReturn(call)
8688
when(call.getName).thenReturn("avg2")
8789
when(call.hasFilter).thenReturn(false)
@@ -98,6 +100,7 @@ abstract class AggTestBase(isBatchMode: Boolean) {
98100
val aggInfo3: AggregateInfo = {
99101
val aggInfo = mock(classOf[AggregateInfo])
100102
val call = mock(classOf[AggregateCall])
103+
updateFilter(call, -1)
101104
when(aggInfo.agg).thenReturn(call)
102105
when(call.getName).thenReturn("avg3")
103106
when(call.hasFilter).thenReturn(false)
@@ -118,4 +121,10 @@ abstract class AggTestBase(isBatchMode: Boolean) {
118121
val classLoader: ClassLoader = Thread.currentThread().getContextClassLoader
119122
val context: ExecutionContext = mock(classOf[ExecutionContext])
120123
when(context.getRuntimeContext).thenReturn(mock(classOf[RuntimeContext]))
124+
125+
private def updateFilter(call: AggregateCall, v: Int): Unit = {
126+
val field = call.getClass.getField("filterArg")
127+
field.setAccessible(true)
128+
field.set(call, v)
129+
}
121130
}

flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/agg/AggregateITCaseBase.scala

+5
Original file line numberDiff line numberDiff line change
@@ -1224,6 +1224,11 @@ abstract class AggregateITCaseBase(testName: String) extends BatchTestBase {
12241224
checkResult(sql, Seq(row("11, 11"), row("12, 12"), row("null, null")))
12251225
}
12261226

1227+
@Test
1228+
def testAggFilterReferenceFirstColumn(): Unit = {
1229+
checkResult("select count(*) filter (where a < 10) from Table3", Seq(row(9)))
1230+
}
1231+
12271232
// TODO support csv
12281233
// @Test
12291234
// def testMultiGroupBys(): Unit = {

flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/AggregateITCase.scala

+20
Original file line numberDiff line numberDiff line change
@@ -1644,6 +1644,26 @@ class AggregateITCase(aggMode: AggMode, miniBatch: MiniBatchMode, backend: State
16441644
assertThat(sink.getRetractResults.sorted).isEqualTo(expected.sorted)
16451645
}
16461646

1647+
@TestTemplate
1648+
def testAggFilterReferenceFirstColumn(): Unit = {
1649+
val t = failingDataSource(TestData.tupleData3).toTable(tEnv).as("a", "b", "c")
1650+
tEnv.createTemporaryView("MyTable", t)
1651+
1652+
val sqlQuery =
1653+
s"""
1654+
|SELECT
1655+
| COUNT(*) filter (where a < 10)
1656+
|FROM MyTable
1657+
""".stripMargin
1658+
1659+
val sink = new TestingRetractSink
1660+
val result = tEnv.sqlQuery(sqlQuery).toRetractStream[Row]
1661+
result.addSink(sink).setParallelism(1)
1662+
env.execute()
1663+
val expected = List("9")
1664+
assertThat(sink.getRetractResults.sorted).isEqualTo(expected.sorted)
1665+
}
1666+
16471667
@TestTemplate
16481668
def testPruneUselessAggCall(): Unit = {
16491669
val data = new mutable.MutableList[(Int, Long, String)]

0 commit comments

Comments
 (0)