Skip to content

Commit d2cb8dc

Browse files
committed
HamaWhiteGG#128 fixed parsing multi-tier udf error bug
1 parent 68ee6f3 commit d2cb8dc

File tree

3 files changed

+97
-35
lines changed

3 files changed

+97
-35
lines changed

lineage-flink1.14.x/src/main/java/org/apache/calcite/rel/metadata/RelMdColumnOrigins.java

+10-29
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,6 @@
3030

3131
import java.util.*;
3232
import java.util.function.Function;
33-
import java.util.regex.Matcher;
34-
import java.util.regex.Pattern;
3533
import java.util.stream.Collectors;
3634

3735
import static com.hw.lineage.common.util.Constant.DELIMITER;
@@ -59,8 +57,6 @@ public class RelMdColumnOrigins implements MetadataHandler<BuiltInMetadata.Colum
5957

6058
private static final Logger LOG = LoggerFactory.getLogger(RelMdColumnOrigins.class);
6159

62-
private final Pattern pattern = Pattern.compile("\\$[\\d.]+");
63-
6460
public static final RelMetadataProvider SOURCE =
6561
ReflectiveRelMetadataProvider.reflectiveSource(
6662
BuiltInMethod.COLUMN_ORIGIN.method, new RelMdColumnOrigins());
@@ -92,7 +88,7 @@ public Set<RelColumnOrigin> getColumnOrigins(Aggregate rel, RelMetadataQuery mq,
9288

9389
if (rexNode instanceof RexLiteral) {
9490
RexLiteral literal = (RexLiteral) rexNode;
95-
transform = transform.replace("$" + iInput, literal.toString());
91+
transform = transform.replace("$" + iInput, literal.toString().replace("_UTF-16LE", ""));
9692
continue;
9793
}
9894

@@ -438,32 +434,12 @@ private Set<RelColumnOrigin> createDerivedColumnOrigins(Set<RelColumnOrigin> inp
438434
private String computeTransform(Set<RelColumnOrigin> inputSet, Object transform) {
439435
LOG.debug("origin transform: {}, class: {}", transform, transform.getClass());
440436
String finalTransform = transform.toString();
441-
442-
Matcher matcher = pattern.matcher(finalTransform);
443-
444-
Set<String> operandSet = new LinkedHashSet<>();
445-
while (matcher.find()) {
446-
operandSet.add(matcher.group());
447-
}
448-
449-
if (operandSet.isEmpty()) {
450-
return finalTransform;
451-
}
452-
/*if (inputSet.size() != operandSet.size()) {
453-
LOG.warn("The number [{}] of fields in the source tables are not equal to operands [{}]", inputSet.size(),
454-
operandSet.size());
455-
return null;
456-
}*/
457-
458437
Map<String, String> sourceColumnMap = buildSourceColumnMap(inputSet, transform);
459438

460-
matcher = pattern.matcher(finalTransform);
461-
String temp;
462-
while (matcher.find()) {
463-
temp = matcher.group();
464-
finalTransform = finalTransform.replace(temp, sourceColumnMap.get(temp));
439+
for (Map.Entry<String, String> entry : sourceColumnMap.entrySet()) {
440+
finalTransform = finalTransform.replace(entry.getKey(), entry.getValue());
465441
}
466-
// temporary special treatment
442+
467443
finalTransform = finalTransform.replace("_UTF-16LE", "");
468444
LOG.debug("final transform: {}", finalTransform);
469445
return finalTransform;
@@ -508,7 +484,12 @@ public Void visitFieldAccess(RexFieldAccess fieldAccess) {
508484
}
509485
Map<String, String> sourceColumnMap = new HashMap<>(INITIAL_CAPACITY);
510486
Iterator<String> iterator = optimizeSourceColumnSet(inputSet).iterator();
511-
traversalSet.forEach(index -> sourceColumnMap.put("$" + index, iterator.next()));
487+
traversalSet.forEach(
488+
index -> {
489+
if (iterator.hasNext()) {
490+
sourceColumnMap.put("$" + index, iterator.next());
491+
}
492+
});
512493
LOG.debug("sourceColumnMap: {}", sourceColumnMap);
513494
return sourceColumnMap;
514495
}

lineage-flink1.14.x/src/test/java/com/hw/lineage/flink/aggregatefunction/AggregateFunctionTest.java

+68-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,25 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
119
package com.hw.lineage.flink.aggregatefunction;
220

321
import com.hw.lineage.flink.basic.AbstractBasicTest;
22+
423
import org.junit.Before;
524
import org.junit.Test;
625

@@ -15,9 +34,14 @@ public void createTable() {
1534

1635
createTableOfOdsMysqlUsersDetail();
1736

18-
context.execute("create function test_aggregate as 'com.hw.lineage.flink.aggregatefunction.TestAggregateFunction'");
37+
createPrintTable();
38+
39+
createFunction();
1940
}
2041

42+
/**
43+
* #125 issue
44+
*/
2145
@Test
2246
public void testAggregateFunction() {
2347
String sql = "INSERT INTO dwd_hudi_users " +
@@ -34,7 +58,8 @@ public void testAggregateFunction() {
3458
String[][] expectedArray = {
3559
{"ods_mysql_users", "id", "dwd_hudi_users", "id"},
3660
{"ods_mysql_users", "name", "dwd_hudi_users", "name"},
37-
{"ods_mysql_users", "name", "dwd_hudi_users", "company_name", "test_aggregate(CONCAT_WS('_', name, 'test'), name, _UTF-16LE'test')"},
61+
{"ods_mysql_users", "name", "dwd_hudi_users", "company_name",
62+
"test_aggregate(CONCAT_WS('_', name, 'test'), name, 'test')"},
3863
{"ods_mysql_users", "birthday", "dwd_hudi_users", "birthday"},
3964
{"ods_mysql_users", "ts", "dwd_hudi_users", "ts"},
4065
{"ods_mysql_users", "birthday", "dwd_hudi_users", "partition", "DATE_FORMAT(birthday, 'yyyyMMdd')"}
@@ -43,6 +68,28 @@ public void testAggregateFunction() {
4368
analyzeLineage(sql, expectedArray);
4469
}
4570

71+
/**
72+
* #128 issue
73+
*/
74+
@Test
75+
public void testMultiTierUdf() {
76+
String sql = "INSERT INTO print_table " +
77+
"SELECT " +
78+
" round( COUNT(*) / COUNT( DISTINCT name ) , 2 )" +
79+
"FROM" +
80+
" ods_mysql_users group by ts ";
81+
82+
String[][] expectedArray = {
83+
{"ods_mysql_users", "name", "print_table", "num", "ROUND(/(COUNT(DISTINCT name), $2), 2)"},
84+
85+
};
86+
87+
analyzeLineage(sql, expectedArray);
88+
}
89+
90+
/**
91+
* #126 issue
92+
*/
4693
@Test
4794
public void testAggregateFunctionInputArgument() {
4895
String sql = "INSERT INTO dwd_hudi_users " +
@@ -59,17 +106,25 @@ public void testAggregateFunctionInputArgument() {
59106
String[][] expectedArray = {
60107
{"ods_mysql_user_detail", "id", "dwd_hudi_users", "id"},
61108
{"ods_mysql_user_detail", "name", "dwd_hudi_users", "name"},
62-
{"ods_mysql_user_detail", "name", "dwd_hudi_users", "company_name", "test_aggregate(CONCAT_WS('_', name, email), address, _UTF-16LE'test')"},
63-
{"ods_mysql_user_detail", "email", "dwd_hudi_users", "company_name", "test_aggregate(CONCAT_WS('_', name, email), address, _UTF-16LE'test')"},
64-
{"ods_mysql_user_detail", "address", "dwd_hudi_users", "company_name", "test_aggregate(CONCAT_WS('_', name, email), address, _UTF-16LE'test')"},
109+
{"ods_mysql_user_detail", "name", "dwd_hudi_users", "company_name",
110+
"test_aggregate(CONCAT_WS('_', name, email), address, 'test')"},
111+
{"ods_mysql_user_detail", "email", "dwd_hudi_users", "company_name",
112+
"test_aggregate(CONCAT_WS('_', name, email), address, 'test')"},
113+
{"ods_mysql_user_detail", "address", "dwd_hudi_users", "company_name",
114+
"test_aggregate(CONCAT_WS('_', name, email), address, 'test')"},
65115
{"ods_mysql_user_detail", "birthday", "dwd_hudi_users", "birthday"},
66116
{"ods_mysql_user_detail", "ts", "dwd_hudi_users", "ts"},
67-
{"ods_mysql_user_detail", "birthday", "dwd_hudi_users", "partition", "DATE_FORMAT(birthday, 'yyyyMMdd')"}
117+
{"ods_mysql_user_detail", "birthday", "dwd_hudi_users", "partition",
118+
"DATE_FORMAT(birthday, 'yyyyMMdd')"}
68119
};
69120

70121
analyzeLineage(sql, expectedArray);
71122
}
72123

124+
private void createPrintTable() {
125+
context.execute("drop table if exists print_table");
126+
context.execute("create table print_table (num double) with ('connector'='print')");
127+
}
73128
protected void createTableOfOdsMysqlUsersDetail() {
74129
context.execute("DROP TABLE IF EXISTS ods_mysql_user_detail ");
75130

@@ -92,4 +147,11 @@ protected void createTableOfOdsMysqlUsersDetail() {
92147
" 'table-name' = 'users' " +
93148
")");
94149
}
150+
151+
private void createFunction() {
152+
context.execute("drop function if exists test_aggregate");
153+
context.execute(
154+
"create function test_aggregate as 'com.hw.lineage.flink.aggregatefunction.TestAggregateFunction'");
155+
}
156+
95157
}

lineage-flink1.14.x/src/test/java/com/hw/lineage/flink/aggregatefunction/TestAggregateFunction.java

+19
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,21 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
119
package com.hw.lineage.flink.aggregatefunction;
220

321
import org.apache.flink.table.functions.AggregateFunction;
@@ -19,6 +37,7 @@ public TestAggregateAcc createAccumulator() {
1937
}
2038

2139
public static class TestAggregateAcc {
40+
2241
public String test;
2342
}
2443
}

0 commit comments

Comments
 (0)