Skip to content

Commit 75a553e

Browse files
authored
[FLINK-34878][cdc][transform] Flink CDC pipeline transform supports CASE WHEN (#3228)
1 parent 0108d0e commit 75a553e

File tree

5 files changed

+112
-43
lines changed

5 files changed

+112
-43
lines changed

flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/functions/SystemFunctionUtils.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,4 +501,13 @@ public static String uuid(byte[] b) {
501501
public static boolean valueEquals(Object object1, Object object2) {
502502
return (object1 != null && object2 != null) && object1.equals(object2);
503503
}
504+
505+
public static Object coalesce(Object... objects) {
506+
for (Object item : objects) {
507+
if (item != null) {
508+
return item;
509+
}
510+
}
511+
return null;
512+
}
504513
}

flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/parser/JaninoCompiler.java

Lines changed: 71 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.apache.calcite.sql.SqlLiteral;
2828
import org.apache.calcite.sql.SqlNode;
2929
import org.apache.calcite.sql.SqlNodeList;
30+
import org.apache.calcite.sql.fun.SqlCase;
3031
import org.apache.calcite.sql.type.SqlTypeName;
3132
import org.codehaus.commons.compiler.CompileException;
3233
import org.codehaus.commons.compiler.Location;
@@ -82,17 +83,51 @@ public static ExpressionEvaluator compileExpression(
8283
}
8384

8485
public static String translateSqlNodeToJaninoExpression(SqlNode transform) {
85-
if (transform instanceof SqlIdentifier) {
86-
SqlIdentifier sqlIdentifier = (SqlIdentifier) transform;
87-
return sqlIdentifier.names.get(sqlIdentifier.names.size() - 1);
88-
} else if (transform instanceof SqlBasicCall) {
89-
Java.Rvalue rvalue = translateJaninoAST((SqlBasicCall) transform);
86+
Java.Rvalue rvalue = translateSqlNodeToJaninoRvalue(transform);
87+
if (rvalue != null) {
9088
return rvalue.toString();
9189
}
9290
return "";
9391
}
9492

95-
private static Java.Rvalue translateJaninoAST(SqlBasicCall sqlBasicCall) {
93+
public static Java.Rvalue translateSqlNodeToJaninoRvalue(SqlNode transform) {
94+
if (transform instanceof SqlIdentifier) {
95+
return translateSqlIdentifier((SqlIdentifier) transform);
96+
} else if (transform instanceof SqlBasicCall) {
97+
return translateSqlBasicCall((SqlBasicCall) transform);
98+
} else if (transform instanceof SqlCase) {
99+
return translateSqlCase((SqlCase) transform);
100+
} else if (transform instanceof SqlLiteral) {
101+
return translateSqlSqlLiteral((SqlLiteral) transform);
102+
}
103+
return null;
104+
}
105+
106+
private static Java.Rvalue translateSqlIdentifier(SqlIdentifier sqlIdentifier) {
107+
String columnName = sqlIdentifier.names.get(sqlIdentifier.names.size() - 1);
108+
if (NO_OPERAND_TIMESTAMP_FUNCTIONS.contains(columnName)) {
109+
return generateNoOperandTimestampFunctionOperation(columnName);
110+
} else {
111+
return new Java.AmbiguousName(Location.NOWHERE, new String[] {columnName});
112+
}
113+
}
114+
115+
private static Java.Rvalue translateSqlSqlLiteral(SqlLiteral sqlLiteral) {
116+
if (sqlLiteral.getValue() == null) {
117+
return new Java.NullLiteral(Location.NOWHERE);
118+
}
119+
String value = sqlLiteral.getValue().toString();
120+
if (sqlLiteral instanceof SqlCharStringLiteral) {
121+
// Double quotation marks represent strings in Janino.
122+
value = "\"" + value.substring(1, value.length() - 1) + "\"";
123+
}
124+
if (SQL_TYPE_NAME_IGNORE.contains(sqlLiteral.getTypeName())) {
125+
value = "\"" + value + "\"";
126+
}
127+
return new Java.AmbiguousName(Location.NOWHERE, new String[] {value});
128+
}
129+
130+
private static Java.Rvalue translateSqlBasicCall(SqlBasicCall sqlBasicCall) {
96131
List<SqlNode> operandList = sqlBasicCall.getOperandList();
97132
List<Java.Rvalue> atoms = new ArrayList<>();
98133
for (SqlNode sqlNode : operandList) {
@@ -105,32 +140,44 @@ private static Java.Rvalue translateJaninoAST(SqlBasicCall sqlBasicCall) {
105140
return sqlBasicCallToJaninoRvalue(sqlBasicCall, atoms.toArray(new Java.Rvalue[0]));
106141
}
107142

143+
private static Java.Rvalue translateSqlCase(SqlCase sqlCase) {
144+
SqlNodeList whenOperands = sqlCase.getWhenOperands();
145+
SqlNodeList thenOperands = sqlCase.getThenOperands();
146+
SqlNode elseOperand = sqlCase.getElseOperand();
147+
List<Java.Rvalue> whenAtoms = new ArrayList<>();
148+
for (SqlNode sqlNode : whenOperands) {
149+
translateSqlNodeToAtoms(sqlNode, whenAtoms);
150+
}
151+
List<Java.Rvalue> thenAtoms = new ArrayList<>();
152+
for (SqlNode sqlNode : thenOperands) {
153+
translateSqlNodeToAtoms(sqlNode, thenAtoms);
154+
}
155+
Java.Rvalue elseAtoms = translateSqlNodeToJaninoRvalue(elseOperand);
156+
Java.Rvalue sqlCaseRvalueTemp = elseAtoms;
157+
for (int i = whenAtoms.size() - 1; i >= 0; i--) {
158+
sqlCaseRvalueTemp =
159+
new Java.ConditionalExpression(
160+
Location.NOWHERE,
161+
whenAtoms.get(i),
162+
thenAtoms.get(i),
163+
sqlCaseRvalueTemp);
164+
}
165+
return new Java.ParenthesizedExpression(Location.NOWHERE, sqlCaseRvalueTemp);
166+
}
167+
108168
private static void translateSqlNodeToAtoms(SqlNode sqlNode, List<Java.Rvalue> atoms) {
109169
if (sqlNode instanceof SqlIdentifier) {
110-
SqlIdentifier sqlIdentifier = (SqlIdentifier) sqlNode;
111-
String columnName = sqlIdentifier.names.get(sqlIdentifier.names.size() - 1);
112-
if (NO_OPERAND_TIMESTAMP_FUNCTIONS.contains(columnName)) {
113-
atoms.add(generateNoOperandTimestampFunctionOperation(columnName));
114-
} else {
115-
atoms.add(new Java.AmbiguousName(Location.NOWHERE, new String[] {columnName}));
116-
}
170+
atoms.add(translateSqlIdentifier((SqlIdentifier) sqlNode));
117171
} else if (sqlNode instanceof SqlLiteral) {
118-
SqlLiteral sqlLiteral = (SqlLiteral) sqlNode;
119-
String value = sqlLiteral.getValue().toString();
120-
if (sqlLiteral instanceof SqlCharStringLiteral) {
121-
// Double quotation marks represent strings in Janino.
122-
value = "\"" + value.substring(1, value.length() - 1) + "\"";
123-
}
124-
if (SQL_TYPE_NAME_IGNORE.contains(sqlLiteral.getTypeName())) {
125-
value = "\"" + value + "\"";
126-
}
127-
atoms.add(new Java.AmbiguousName(Location.NOWHERE, new String[] {value}));
172+
atoms.add(translateSqlSqlLiteral((SqlLiteral) sqlNode));
128173
} else if (sqlNode instanceof SqlBasicCall) {
129-
atoms.add(translateJaninoAST((SqlBasicCall) sqlNode));
174+
atoms.add(translateSqlBasicCall((SqlBasicCall) sqlNode));
130175
} else if (sqlNode instanceof SqlNodeList) {
131176
for (SqlNode node : (SqlNodeList) sqlNode) {
132177
translateSqlNodeToAtoms(node, atoms);
133178
}
179+
} else if (sqlNode instanceof SqlCase) {
180+
atoms.add(translateSqlCase((SqlCase) sqlNode));
134181
}
135182
}
136183

flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/parser/TransformParser.java

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,9 @@
4444
import org.apache.calcite.sql.SqlIdentifier;
4545
import org.apache.calcite.sql.SqlKind;
4646
import org.apache.calcite.sql.SqlNode;
47+
import org.apache.calcite.sql.SqlNodeList;
4748
import org.apache.calcite.sql.SqlSelect;
49+
import org.apache.calcite.sql.fun.SqlCase;
4850
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
4951
import org.apache.calcite.sql.parser.SqlParseException;
5052
import org.apache.calcite.sql.parser.SqlParser;
@@ -250,10 +252,7 @@ public static String translateFilterExpressionToJaninoExpression(String filterEx
250252
return "";
251253
}
252254
SqlNode where = sqlSelect.getWhere();
253-
if (!(where instanceof SqlBasicCall)) {
254-
throw new ParseException("Unrecognized where: " + where.toString());
255-
}
256-
return JaninoCompiler.translateSqlNodeToJaninoExpression((SqlBasicCall) where);
255+
return JaninoCompiler.translateSqlNodeToJaninoExpression(where);
257256
}
258257

259258
public static List<String> parseComputedColumnNames(String projection) {
@@ -307,11 +306,7 @@ public static List<String> parseFilterColumnNameList(String filterExpression) {
307306
return new ArrayList<>();
308307
}
309308
SqlNode where = sqlSelect.getWhere();
310-
if (!(where instanceof SqlBasicCall)) {
311-
throw new ParseException("Unrecognized where: " + where.toString());
312-
}
313-
SqlBasicCall sqlBasicCall = (SqlBasicCall) where;
314-
return parseColumnNameList(sqlBasicCall);
309+
return parseColumnNameList(where);
315310
}
316311

317312
private static List<String> parseColumnNameList(SqlNode sqlNode) {
@@ -323,6 +318,9 @@ private static List<String> parseColumnNameList(SqlNode sqlNode) {
323318
} else if (sqlNode instanceof SqlBasicCall) {
324319
SqlBasicCall sqlBasicCall = (SqlBasicCall) sqlNode;
325320
findSqlIdentifier(sqlBasicCall.getOperandList(), columnNameList);
321+
} else if (sqlNode instanceof SqlCase) {
322+
SqlCase sqlCase = (SqlCase) sqlNode;
323+
findSqlIdentifier(sqlCase.getWhenOperands().getList(), columnNameList);
326324
}
327325
return columnNameList;
328326
}
@@ -336,6 +334,10 @@ private static void findSqlIdentifier(List<SqlNode> sqlNodes, List<String> colum
336334
} else if (sqlNode instanceof SqlBasicCall) {
337335
SqlBasicCall sqlBasicCall = (SqlBasicCall) sqlNode;
338336
findSqlIdentifier(sqlBasicCall.getOperandList(), columnNameList);
337+
} else if (sqlNode instanceof SqlCase) {
338+
SqlCase sqlCase = (SqlCase) sqlNode;
339+
SqlNodeList whenOperands = sqlCase.getWhenOperands();
340+
findSqlIdentifier(whenOperands.getList(), columnNameList);
339341
}
340342
}
341343
}

flink-cdc-runtime/src/test/java/org/apache/flink/cdc/runtime/operators/transform/TransformDataOperatorTest.java

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -578,9 +578,23 @@ void testBuildInFunctionTransform() throws Exception {
578578
testExpressionConditionTransform("ceil(2.4) = 3.0");
579579
testExpressionConditionTransform("floor(2.5) = 2.0");
580580
testExpressionConditionTransform("round(3.1415926,2) = 3.14");
581+
testExpressionConditionTransform("IF(2>0,1,0) = 1");
582+
testExpressionConditionTransform("COALESCE(null,1,2) = 1");
583+
testExpressionConditionTransform("1 + 1 = 2");
584+
testExpressionConditionTransform("1 - 1 = 0");
585+
testExpressionConditionTransform("1 * 1 = 1");
586+
testExpressionConditionTransform("3 % 2 = 1");
587+
testExpressionConditionTransform("1 < 2");
588+
testExpressionConditionTransform("1 <= 1");
589+
testExpressionConditionTransform("1 > 0");
590+
testExpressionConditionTransform("1 >= 1");
591+
testExpressionConditionTransform(
592+
"case 1 when 1 then 'a' when 2 then 'b' else 'c' end = 'a'");
593+
testExpressionConditionTransform("case col1 when '1' then true else false end");
594+
testExpressionConditionTransform("case when col1 = '1' then true else false end");
581595
}
582596

583-
void testExpressionConditionTransform(String expression) throws Exception {
597+
private void testExpressionConditionTransform(String expression) throws Exception {
584598
TransformDataOperator transform =
585599
TransformDataOperator.newBuilder()
586600
.addTransform(

flink-cdc-runtime/src/test/java/org/apache/flink/cdc/runtime/parser/TransformParserTest.java

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@
2121
import org.apache.flink.cdc.common.types.DataTypes;
2222
import org.apache.flink.cdc.runtime.parser.metadata.TransformSchemaFactory;
2323
import org.apache.flink.cdc.runtime.parser.metadata.TransformSqlOperatorTable;
24-
import org.apache.flink.table.api.ApiExpression;
25-
import org.apache.flink.table.api.Expressions;
2624

2725
import org.apache.calcite.config.CalciteConnectionConfigImpl;
2826
import org.apache.calcite.jdbc.CalciteSchema;
@@ -260,13 +258,12 @@ public void testTranslateFilterToJaninoExpression() {
260258
testFilterExpression("upper(lower(id))", "upper(lower(id))");
261259
testFilterExpression(
262260
"abs(uniq_id) > 10 and id is not null", "abs(uniq_id) > 10 && null != id");
263-
}
264-
265-
@Test
266-
public void testSqlCall() {
267-
ApiExpression apiExpression = Expressions.concat("1", "2");
268-
ApiExpression substring = apiExpression.substring(1);
269-
System.out.println(substring);
261+
testFilterExpression(
262+
"case id when 1 then 'a' when 2 then 'b' else 'c' end",
263+
"(valueEquals(id, 1) ? \"a\" : valueEquals(id, 2) ? \"b\" : \"c\")");
264+
testFilterExpression(
265+
"case when id = 1 then 'a' when id = 2 then 'b' else 'c' end",
266+
"(valueEquals(id, 1) ? \"a\" : valueEquals(id, 2) ? \"b\" : \"c\")");
270267
}
271268

272269
private void testFilterExpression(String expression, String expressionExpect) {

0 commit comments

Comments
 (0)