Skip to content

Commit

Permalink
Fix SqlCase in DDL parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
yuxiqian committed May 11, 2024
1 parent ef2c654 commit 9b7c313
Showing 1 changed file with 20 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlNodeList;
import org.apache.calcite.sql.SqlSelect;
import org.apache.calcite.sql.fun.SqlCase;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParseException;
import org.apache.calcite.sql.parser.SqlParser;
Expand Down Expand Up @@ -413,12 +412,12 @@ private static List<String> parseColumnNameList(SqlNode sqlNode) {
SqlIdentifier sqlIdentifier = (SqlIdentifier) sqlNode;
String columnName = sqlIdentifier.names.get(sqlIdentifier.names.size() - 1);
columnNameList.add(columnName);
} else if (sqlNode instanceof SqlBasicCall) {
SqlBasicCall sqlBasicCall = (SqlBasicCall) sqlNode;
findSqlIdentifier(sqlBasicCall.getOperandList(), columnNameList);
} else if (sqlNode instanceof SqlCase) {
SqlCase sqlCase = (SqlCase) sqlNode;
findSqlIdentifier(sqlCase.getWhenOperands().getList(), columnNameList);
} else if (sqlNode instanceof SqlCall) {
SqlCall sqlCall = (SqlCall) sqlNode;
findSqlIdentifier(sqlCall.getOperandList(), columnNameList);
} else if (sqlNode instanceof SqlNodeList) {
SqlNodeList sqlNodeList = (SqlNodeList) sqlNode;
findSqlIdentifier(sqlNodeList.getList(), columnNameList);
}
return columnNameList;
}
Expand All @@ -429,13 +428,12 @@ private static void findSqlIdentifier(List<SqlNode> sqlNodes, List<String> colum
SqlIdentifier sqlIdentifier = (SqlIdentifier) sqlNode;
String columnName = sqlIdentifier.names.get(sqlIdentifier.names.size() - 1);
columnNameList.add(columnName);
} else if (sqlNode instanceof SqlBasicCall) {
SqlBasicCall sqlBasicCall = (SqlBasicCall) sqlNode;
findSqlIdentifier(sqlBasicCall.getOperandList(), columnNameList);
} else if (sqlNode instanceof SqlCase) {
SqlCase sqlCase = (SqlCase) sqlNode;
SqlNodeList whenOperands = sqlCase.getWhenOperands();
findSqlIdentifier(whenOperands.getList(), columnNameList);
} else if (sqlNode instanceof SqlCall) {
SqlCall sqlCall = (SqlCall) sqlNode;
findSqlIdentifier(sqlCall.getOperandList(), columnNameList);
} else if (sqlNode instanceof SqlNodeList) {
SqlNodeList sqlNodeList = (SqlNodeList) sqlNode;
findSqlIdentifier(sqlNodeList.getList(), columnNameList);
}
}
}
Expand Down Expand Up @@ -509,6 +507,14 @@ public static SqlNode rewriteExpression(SqlNode sqlNode, Map<String, SqlNode> re
}
}
return sqlIdentifier;
} else if (sqlNode instanceof SqlNodeList) {
SqlNodeList sqlNodeList = (SqlNodeList) sqlNode;
IntStream.range(0, sqlNodeList.size())
.forEach(
i ->
sqlNodeList.set(
i, rewriteExpression(sqlNodeList.get(i), replaceMap)));
return sqlNodeList;
} else {
return sqlNode;
}
Expand Down

0 comments on commit 9b7c313

Please sign in to comment.