Skip to content

Commit 63a3b96

Browse files
committed
HamaWhiteGG#125 HamaWhiteGG#126 support parse UDTF which the number of input and argument is equal
1 parent 70c0a2c commit 63a3b96

File tree

4 files changed

+154
-4
lines changed

4 files changed

+154
-4
lines changed

lineage-flink1.14.x/src/main/java/org/apache/calcite/rel/metadata/RelColumnOrigin.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,10 @@ public String getTransform() {
104104
return transform;
105105
}
106106

107+
public void setTransform(String transform) {
108+
this.transform = transform;
109+
}
110+
107111
@Override
108112
public boolean equals(Object obj) {
109113
if (!(obj instanceof RelColumnOrigin)) {

lineage-flink1.14.x/src/main/java/org/apache/calcite/rel/metadata/RelMdColumnOrigins.java

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,37 @@ public Set<RelColumnOrigin> getColumnOrigins(Aggregate rel, RelMetadataQuery mq,
8585
// Aggregate columns are derived from input columns
8686
AggregateCall call = rel.getAggCallList().get(iOutputColumn - rel.getGroupCount());
8787
final Set<RelColumnOrigin> set = new LinkedHashSet<>();
88+
String transform = call.toString();
8889
for (Integer iInput : call.getArgList()) {
89-
set.addAll(mq.getColumnOrigins(rel.getInput(), iInput));
90+
91+
RexNode rexNode = ((Project) rel.getInput()).getProjects().get(iInput);
92+
93+
if (rexNode instanceof RexLiteral) {
94+
RexLiteral literal = (RexLiteral) rexNode;
95+
transform = transform.replace("$" + iInput, literal.toString());
96+
continue;
97+
}
98+
99+
Set<org.apache.calcite.rel.metadata.RelColumnOrigin> subSet =
100+
mq.getColumnOrigins(rel.getInput(), iInput);
101+
102+
if (!(rexNode instanceof RexCall)) {
103+
subSet = createDerivedColumnOrigins(subSet, rexNode);
104+
}
105+
106+
for (org.apache.calcite.rel.metadata.RelColumnOrigin relColumnOrigin : subSet) {
107+
if (relColumnOrigin.getTransform() != null) {
108+
transform = transform.replace("$" + iInput, relColumnOrigin.getTransform());
109+
}
110+
break;
111+
}
112+
set.addAll(subSet);
90113
}
91-
return createDerivedColumnOrigins(set, call);
114+
115+
// 替换所有的transform
116+
final String finalTransform = transform;
117+
set.forEach(s -> s.setTransform(finalTransform));
118+
return set;
92119
}
93120

94121
public Set<RelColumnOrigin> getColumnOrigins(Join rel, RelMetadataQuery mq, int iOutputColumn) {
@@ -422,11 +449,11 @@ private String computeTransform(Set<RelColumnOrigin> inputSet, Object transform)
422449
if (operandSet.isEmpty()) {
423450
return finalTransform;
424451
}
425-
if (inputSet.size() != operandSet.size()) {
452+
/*if (inputSet.size() != operandSet.size()) {
426453
LOG.warn("The number [{}] of fields in the source tables are not equal to operands [{}]", inputSet.size(),
427454
operandSet.size());
428455
return null;
429-
}
456+
}*/
430457

431458
Map<String, String> sourceColumnMap = buildSourceColumnMap(inputSet, transform);
432459

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
package com.hw.lineage.flink.aggregatefunction;
2+
3+
import com.hw.lineage.flink.basic.AbstractBasicTest;
4+
import org.junit.Before;
5+
import org.junit.Test;
6+
7+
public class AggregateFunctionTest extends AbstractBasicTest {
8+
9+
@Before
10+
public void createTable() {
11+
12+
createTableOfOdsMysqlUsers();
13+
14+
createTableOfDwdHudiUsers();
15+
16+
createTableOfOdsMysqlUsersDetail();
17+
18+
context.execute("create function test_aggregate as 'com.hw.lineage.flink.aggregatefunction.TestAggregateFunction'");
19+
}
20+
21+
@Test
22+
public void testAggregateFunction() {
23+
String sql = "INSERT INTO dwd_hudi_users " +
24+
"SELECT " +
25+
" id ," +
26+
" name ," +
27+
" test_aggregate(concat_ws('_', name, 'test'), name, 'test')," +
28+
" birthday ," +
29+
" ts ," +
30+
" DATE_FORMAT(birthday, 'yyyyMMdd') " +
31+
"FROM" +
32+
" ods_mysql_users group by id, name, birthday, ts ";
33+
34+
String[][] expectedArray = {
35+
{"ods_mysql_users", "id", "dwd_hudi_users", "id"},
36+
{"ods_mysql_users", "name", "dwd_hudi_users", "name"},
37+
{"ods_mysql_users", "name", "dwd_hudi_users", "company_name", "test_aggregate(CONCAT_WS('_', name, 'test'), name, _UTF-16LE'test')"},
38+
{"ods_mysql_users", "birthday", "dwd_hudi_users", "birthday"},
39+
{"ods_mysql_users", "ts", "dwd_hudi_users", "ts"},
40+
{"ods_mysql_users", "birthday", "dwd_hudi_users", "partition", "DATE_FORMAT(birthday, 'yyyyMMdd')"}
41+
};
42+
43+
analyzeLineage(sql, expectedArray);
44+
}
45+
46+
@Test
47+
public void testAggregateFunctionInputArgument() {
48+
String sql = "INSERT INTO dwd_hudi_users " +
49+
"SELECT " +
50+
" id ," +
51+
" name ," +
52+
" test_aggregate(concat_ws('_', name, email), address, 'test')," +
53+
" birthday ," +
54+
" ts ," +
55+
" DATE_FORMAT(birthday, 'yyyyMMdd') " +
56+
"FROM" +
57+
" ods_mysql_user_detail group by id, name, birthday, ts ";
58+
59+
String[][] expectedArray = {
60+
{"ods_mysql_user_detail", "id", "dwd_hudi_users", "id"},
61+
{"ods_mysql_user_detail", "name", "dwd_hudi_users", "name"},
62+
{"ods_mysql_user_detail", "name", "dwd_hudi_users", "company_name", "test_aggregate(CONCAT_WS('_', name, email), address, _UTF-16LE'test')"},
63+
{"ods_mysql_user_detail", "email", "dwd_hudi_users", "company_name", "test_aggregate(CONCAT_WS('_', name, email), address, _UTF-16LE'test')"},
64+
{"ods_mysql_user_detail", "address", "dwd_hudi_users", "company_name", "test_aggregate(CONCAT_WS('_', name, email), address, _UTF-16LE'test')"},
65+
{"ods_mysql_user_detail", "birthday", "dwd_hudi_users", "birthday"},
66+
{"ods_mysql_user_detail", "ts", "dwd_hudi_users", "ts"},
67+
{"ods_mysql_user_detail", "birthday", "dwd_hudi_users", "partition", "DATE_FORMAT(birthday, 'yyyyMMdd')"}
68+
};
69+
70+
analyzeLineage(sql, expectedArray);
71+
}
72+
73+
protected void createTableOfOdsMysqlUsersDetail() {
74+
context.execute("DROP TABLE IF EXISTS ods_mysql_user_detail ");
75+
76+
context.execute("CREATE TABLE IF NOT EXISTS ods_mysql_user_detail (" +
77+
" id BIGINT PRIMARY KEY NOT ENFORCED ," +
78+
" name STRING ," +
79+
" birthday TIMESTAMP(3) ," +
80+
" ts TIMESTAMP(3) ," +
81+
" email STRING ," +
82+
" address STRING ," +
83+
" proc_time as proctime() " +
84+
") WITH ( " +
85+
" 'connector' = 'mysql-cdc' ," +
86+
" 'hostname' = '127.0.0.1' ," +
87+
" 'port' = '3306' ," +
88+
" 'username' = 'root' ," +
89+
" 'password' = 'xxx' ," +
90+
" 'server-time-zone' = 'Asia/Shanghai' ," +
91+
" 'database-name' = 'demo' ," +
92+
" 'table-name' = 'users' " +
93+
")");
94+
}
95+
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
package com.hw.lineage.flink.aggregatefunction;
2+
3+
import org.apache.flink.table.functions.AggregateFunction;
4+
5+
public class TestAggregateFunction extends AggregateFunction<String, TestAggregateFunction.TestAggregateAcc> {
6+
7+
public void accumulate(TestAggregateAcc acc, String param1, String param2, String param3) {
8+
acc.test = param1 + param2 + param3;
9+
}
10+
11+
@Override
12+
public String getValue(TestAggregateAcc accumulator) {
13+
return accumulator.test;
14+
}
15+
16+
@Override
17+
public TestAggregateAcc createAccumulator() {
18+
return new TestAggregateAcc();
19+
}
20+
21+
public static class TestAggregateAcc {
22+
public String test;
23+
}
24+
}

0 commit comments

Comments
 (0)