diff --git a/src/include/mysql_filter_pushdown.hpp b/src/include/mysql_filter_pushdown.hpp index 3a3e584..fddcd1e 100644 --- a/src/include/mysql_filter_pushdown.hpp +++ b/src/include/mysql_filter_pushdown.hpp @@ -11,6 +11,7 @@ #include "duckdb/planner/table_filter.hpp" #include "duckdb/planner/filter/conjunction_filter.hpp" #include "duckdb/planner/filter/constant_filter.hpp" +#include "duckdb/function/table_function.hpp" namespace duckdb { @@ -18,6 +19,8 @@ class MySQLFilterPushdown { public: static string TransformFilters(const vector &column_ids, optional_ptr filters, const vector &names); + static void ComplexFilterPushdown(ClientContext &context, LogicalGet &get, FunctionData *bind_data_p, + vector> &filters); private: static string TransformFilter(string &column_name, TableFilter &filter); diff --git a/src/include/mysql_scanner.hpp b/src/include/mysql_scanner.hpp index 5b114c3..a8d1212 100644 --- a/src/include/mysql_scanner.hpp +++ b/src/include/mysql_scanner.hpp @@ -24,6 +24,8 @@ struct MySQLBindData : public FunctionData { vector mysql_types; vector names; vector types; + //Filter pushdown to apply + vector> filters_to_apply; public: unique_ptr Copy() const override { diff --git a/src/include/storage/mysql_table_entry.hpp b/src/include/storage/mysql_table_entry.hpp index dbbed8a..502f33e 100644 --- a/src/include/storage/mysql_table_entry.hpp +++ b/src/include/storage/mysql_table_entry.hpp @@ -10,6 +10,7 @@ #include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" #include "duckdb/parser/parsed_data/create_table_info.hpp" +#include "mysql_filter_pushdown.hpp" #include "mysql_utils.hpp" namespace duckdb { diff --git a/src/mysql_filter_pushdown.cpp b/src/mysql_filter_pushdown.cpp index 75e1bbf..407e357 100644 --- a/src/mysql_filter_pushdown.cpp +++ b/src/mysql_filter_pushdown.cpp @@ -1,5 +1,7 @@ #include "mysql_filter_pushdown.hpp" +#include "mysql_scanner.hpp" #include "mysql_utils.hpp" +#include namespace duckdb { @@ -25,6 +27,8 @@ string MySQLFilterPushdown::TransformComparision(ExpressionType type) { return "<="; case ExpressionType::COMPARE_GREATERTHANOREQUALTO: return ">="; + case ExpressionType::COMPARE_IN: + return "IN"; default: throw NotImplementedException("Unsupported expression type"); } @@ -57,8 +61,11 @@ string MySQLFilterPushdown::TransformFilter(string &column_name, TableFilter &fi string MySQLFilterPushdown::TransformFilters(const vector &column_ids, optional_ptr filters, const vector &names) { + + std::cout << "TransformFilters" << std::endl; if (!filters || filters->filters.empty()) { // no filters + std::cout << "No filter" << std::endl; return string(); } string result; @@ -68,9 +75,38 @@ string MySQLFilterPushdown::TransformFilters(const vector &column_ids, } auto column_name = MySQLUtils::WriteIdentifier(names[column_ids[entry.first]]); auto &filter = *entry.second; + std::cout << "filter: " << filter.ToString(column_name) << std::endl; result += TransformFilter(column_name, filter); } return result; } +void MySQLFilterPushdown::ComplexFilterPushdown(ClientContext &context, LogicalGet &get, FunctionData *bind_data_p, + vector> &filters) { + + auto &data = bind_data_p->Cast(); + vector> filters_to_apply; + vector> unsupported_filters; + + for (idx_t j = 0; j < filters.size(); j++) { + auto &filter = filters[j]; + std::cout << "current filter : " << filter->ToString() << std::endl; + unique_ptr filter_copy = filter->Copy(); + if (filter->expression_class == ExpressionClass::BOUND_EXPRESSION || + filter->expression_class == ExpressionClass::BOUND_CONSTANT || + filter->expression_class == ExpressionClass::BOUND_CONJUNCTION || + filter->expression_class == ExpressionClass::BOUND_COMPARISON || + filter->expression_class == ExpressionClass::BOUND_OPERATOR) { + filters_to_apply.emplace_back(std::move(filter_copy)); + std::cout << "filters_to_apply : " << filter->ToString() << std::endl; + } else { + unsupported_filters.emplace_back(std::move(filter_copy)); + std::cout << "unsupported_filters : " << filter->ToString() << " with class: " << static_cast(filter->expression_class) << std::endl; + } + } + + data.filters_to_apply = std::move(filters_to_apply); + filters = std::move(unsupported_filters); +} + } // namespace duckdb diff --git a/src/mysql_scanner.cpp b/src/mysql_scanner.cpp index 1192776..2f72183 100644 --- a/src/mysql_scanner.cpp +++ b/src/mysql_scanner.cpp @@ -79,6 +79,7 @@ static unique_ptr MySQLInitLocalState(ExecutionContext static void MySQLScan(ClientContext &context, TableFunctionInput &data, DataChunk &output) { auto &gstate = data.global_state->Cast(); + data.bind_data->Cast(); idx_t r; for (r = 0; r < STANDARD_VECTOR_SIZE; r++) { if (!gstate.result->Next()) { diff --git a/src/storage/mysql_table_entry.cpp b/src/storage/mysql_table_entry.cpp index 796291b..d86c9c3 100644 --- a/src/storage/mysql_table_entry.cpp +++ b/src/storage/mysql_table_entry.cpp @@ -35,6 +35,7 @@ TableFunction MySQLTableEntry::GetScanFunction(ClientContext &context, unique_pt Value filter_pushdown; if (context.TryGetCurrentSetting("mysql_experimental_filter_pushdown", filter_pushdown)) { function.filter_pushdown = BooleanValue::Get(filter_pushdown); + function.pushdown_complex_filter = MySQLFilterPushdown::ComplexFilterPushdown; } return function; }