From caf39d95b4c3cf7ac7b09ebee9d4057b24ed44af Mon Sep 17 00:00:00 2001 From: Rui Mo Date: Wed, 10 May 2023 09:39:29 +0800 Subject: [PATCH] Add back not node (#228) * Add back not node. * Fix NOT handling in metadata filter. --------- Co-authored-by: Jimmy Lu --- velox/dwio/common/MetadataFilter.cpp | 81 +++++++------ velox/dwio/common/MetadataFilter.h | 1 - velox/dwio/common/tests/E2EFilterTestBase.cpp | 49 ++++++-- velox/expression/ExprToSubfieldFilter.cpp | 112 ++++++++++++------ velox/expression/ExprToSubfieldFilter.h | 3 +- .../tests/ExprToSubfieldFilterTest.cpp | 3 +- 6 files changed, 159 insertions(+), 90 deletions(-) diff --git a/velox/dwio/common/MetadataFilter.cpp b/velox/dwio/common/MetadataFilter.cpp index d89717f0f852..c9e492da3745 100644 --- a/velox/dwio/common/MetadataFilter.cpp +++ b/velox/dwio/common/MetadataFilter.cpp @@ -28,9 +28,8 @@ using LeafResults = } struct MetadataFilter::Node { - static std::unique_ptr fromExpression( - ScanSpec&, - const core::ITypedExpr&); + static std::unique_ptr + fromExpression(ScanSpec&, const core::ITypedExpr&, bool negated); virtual ~Node() = default; virtual uint64_t* eval(LeafResults&, int size) const = 0; }; @@ -59,6 +58,18 @@ class MetadataFilter::LeafNode : public Node { }; struct MetadataFilter::AndNode : Node { + static std::unique_ptr create( + std::unique_ptr lhs, + std::unique_ptr rhs) { + if (!lhs) { + return rhs; + } + if (!rhs) { + return lhs; + } + return std::make_unique(std::move(lhs), std::move(rhs)); + } + AndNode(std::unique_ptr lhs, std::unique_ptr rhs) : lhs_(std::move(lhs)), rhs_(std::move(rhs)) {} @@ -81,6 +92,15 @@ struct MetadataFilter::AndNode : Node { }; struct MetadataFilter::OrNode : Node { + static std::unique_ptr create( + std::unique_ptr lhs, + std::unique_ptr rhs) { + if (!lhs || !rhs) { + return nullptr; + } + return std::make_unique(std::move(lhs), std::move(rhs)); + } + OrNode(std::unique_ptr lhs, std::unique_ptr rhs) : lhs_(std::move(lhs)), rhs_(std::move(rhs)) {} @@ -99,23 +119,6 @@ struct MetadataFilter::OrNode : Node { std::unique_ptr rhs_; }; -struct MetadataFilter::NotNode : Node { - explicit NotNode(std::unique_ptr negated) - : negated_(std::move(negated)) {} - - uint64_t* eval(LeafResults& leafResults, int size) const override { - auto* bits = negated_->eval(leafResults, size); - if (!bits) { - return nullptr; - } - bits::negate(reinterpret_cast(bits), size); - return bits; - } - - private: - std::unique_ptr negated_; -}; - namespace { const core::FieldAccessTypedExpr* asField( @@ -133,40 +136,36 @@ const core::CallTypedExpr* asCall(const core::ITypedExpr* expr) { std::unique_ptr MetadataFilter::Node::fromExpression( ScanSpec& scanSpec, - const core::ITypedExpr& expr) { + const core::ITypedExpr& expr, + bool negated) { auto* call = asCall(&expr); if (!call) { return nullptr; } if (call->name() == "and") { - auto lhs = fromExpression(scanSpec, *call->inputs()[0]); - auto rhs = fromExpression(scanSpec, *call->inputs()[1]); - if (!lhs) { - return rhs; - } - if (!rhs) { - return lhs; - } - return std::make_unique(std::move(lhs), std::move(rhs)); + auto lhs = fromExpression(scanSpec, *call->inputs()[0], negated); + auto rhs = fromExpression(scanSpec, *call->inputs()[1], negated); + return negated ? OrNode::create(std::move(lhs), std::move(rhs)) + : AndNode::create(std::move(lhs), std::move(rhs)); } if (call->name() == "or") { - auto lhs = fromExpression(scanSpec, *call->inputs()[0]); - auto rhs = fromExpression(scanSpec, *call->inputs()[1]); - if (!lhs || !rhs) { - return nullptr; - } - return std::make_unique(std::move(lhs), std::move(rhs)); + auto lhs = fromExpression(scanSpec, *call->inputs()[0], negated); + auto rhs = fromExpression(scanSpec, *call->inputs()[1], negated); + return negated ? AndNode::create(std::move(lhs), std::move(rhs)) + : OrNode::create(std::move(lhs), std::move(rhs)); + } + if (call->name() == "not") { + return fromExpression(scanSpec, *call->inputs()[0], !negated); } if (call->name() == "endswith" || call->name() == "contains" || call->name() == "like" || call->name() == "startswith" || - call->name() == "in" || call->name() == "rlike" || - call->name() == "isnotnull" || call->name() == "coalesce" || - call->name() == "might_contain") { + call->name() == "rlike" || call->name() == "isnotnull" || + call->name() == "coalesce" || call->name() == "might_contain") { return nullptr; } try { Subfield subfield; - auto filter = exec::leafCallToSubfieldFilter(*call, subfield); + auto filter = exec::leafCallToSubfieldFilter(*call, subfield, negated); if (!filter) { return nullptr; } @@ -180,7 +179,7 @@ std::unique_ptr MetadataFilter::Node::fromExpression( } MetadataFilter::MetadataFilter(ScanSpec& scanSpec, const core::ITypedExpr& expr) - : root_(Node::fromExpression(scanSpec, expr)) {} + : root_(Node::fromExpression(scanSpec, expr, false)) {} void MetadataFilter::eval( std::vector>>& leafNodeResults, diff --git a/velox/dwio/common/MetadataFilter.h b/velox/dwio/common/MetadataFilter.h index 5eaa1597c4a1..02c33f8ab791 100644 --- a/velox/dwio/common/MetadataFilter.h +++ b/velox/dwio/common/MetadataFilter.h @@ -45,7 +45,6 @@ class MetadataFilter { class Node; class AndNode; class OrNode; - class NotNode; std::shared_ptr root_; }; diff --git a/velox/dwio/common/tests/E2EFilterTestBase.cpp b/velox/dwio/common/tests/E2EFilterTestBase.cpp index 79a2336f6d75..25331ed4e5ee 100644 --- a/velox/dwio/common/tests/E2EFilterTestBase.cpp +++ b/velox/dwio/common/tests/E2EFilterTestBase.cpp @@ -429,17 +429,17 @@ void E2EFilterTestBase::testMetadataFilterImpl( int64_t originalIndex = 0; auto nextExpectedIndex = [&]() -> int64_t { for (;;) { - if (originalIndex >= batches.size() * kRowsInGroup) { + if (originalIndex >= batches.size() * batchSize_) { return -1; } - auto& batch = batches[originalIndex / kRowsInGroup]; + auto& batch = batches[originalIndex / batchSize_]; auto vecA = batch->as()->childAt(0)->asFlatVector(); auto vecC = batch->as() ->childAt(1) ->as() ->childAt(0) ->asFlatVector(); - auto j = originalIndex++ % kRowsInGroup; + auto j = originalIndex++ % batchSize_; auto a = vecA->valueAt(j); auto c = vecC->valueAt(j); if (validationFilter(a, c)) { @@ -451,8 +451,8 @@ void E2EFilterTestBase::testMetadataFilterImpl( for (int i = 0; i < result->size(); ++i) { auto totalIndex = nextExpectedIndex(); ASSERT_GE(totalIndex, 0); - auto& expected = batches[totalIndex / kRowsInGroup]; - vector_size_t j = totalIndex % kRowsInGroup; + auto& expected = batches[totalIndex / batchSize_]; + vector_size_t j = totalIndex % batchSize_; ASSERT_TRUE(result->equalValueAt(expected.get(), i, j)) << result->toString(i) << " vs " << expected->toString(j); } @@ -461,14 +461,20 @@ void E2EFilterTestBase::testMetadataFilterImpl( } void E2EFilterTestBase::testMetadataFilter() { + flushEveryNBatches_ = 1; + batchSize_ = 10; + test::VectorMaker vectorMaker(leafPool_.get()); + functions::prestosql::registerAllScalarFunctions(); + parse::registerTypeResolver(); + // a: bigint, b: struct std::vector batches; for (int i = 0; i < 10; ++i) { auto a = BaseVector::create>( - BIGINT(), kRowsInGroup, leafPool_.get()); + BIGINT(), batchSize_, leafPool_.get()); auto c = BaseVector::create>( - BIGINT(), kRowsInGroup, leafPool_.get()); - for (int j = 0; j < kRowsInGroup; ++j) { + BIGINT(), batchSize_, leafPool_.get()); + for (int j = 0; j < batchSize_; ++j) { a->set(j, i); c->set(j, i); } @@ -485,10 +491,8 @@ void E2EFilterTestBase::testMetadataFilter() { a->size(), std::vector({a, b}))); } - writeToMemory(batches[0]->type(), batches, true); + writeToMemory(batches[0]->type(), batches, false); - functions::prestosql::registerAllScalarFunctions(); - parse::registerTypeResolver(); testMetadataFilterImpl( batches, common::Subfield("a"), @@ -509,6 +513,29 @@ void E2EFilterTestBase::testMetadataFilter() { nullptr, "a in (1, 3, 8) or a >= 9", [](int64_t a, int64_t) { return a == 1 || a == 3 || a == 8 || a >= 9; }); + testMetadataFilterImpl( + batches, + common::Subfield("a"), + nullptr, + "not (a not in (2, 3, 5, 7))", + [](int64_t a, int64_t) { + return !!(a == 2 || a == 3 || a == 5 || a == 7); + }); + + { + SCOPED_TRACE("Values not unique in row group"); + auto a = vectorMaker.flatVector(batchSize_, folly::identity); + auto c = vectorMaker.flatVector(batchSize_, folly::identity); + auto b = vectorMaker.rowVector({"c"}, {c}); + batches = {vectorMaker.rowVector({"a", "b"}, {a, b})}; + writeToMemory(batches[0]->type(), batches, false); + testMetadataFilterImpl( + batches, + common::Subfield("a"), + nullptr, + "not (a = 1 and b.c = 2)", + [](int64_t a, int64_t c) { return !(a == 1 && c == 2); }); + } } void E2EFilterTestBase::testSubfieldsPruning() { diff --git a/velox/expression/ExprToSubfieldFilter.cpp b/velox/expression/ExprToSubfieldFilter.cpp index 4c0939101b38..1ea492744fc3 100644 --- a/velox/expression/ExprToSubfieldFilter.cpp +++ b/velox/expression/ExprToSubfieldFilter.cpp @@ -351,7 +351,9 @@ toInt64List(const VectorPtr& vector, vector_size_t start, vector_size_t size) { return values; } -std::unique_ptr makeInFilter(const core::TypedExprPtr& expr) { +std::unique_ptr makeInFilter( + const core::TypedExprPtr& expr, + bool negated) { auto queryCtx = std::make_shared(); auto vector = toConstant(expr, queryCtx); if (!(vector && vector->type()->isArray())) { @@ -366,20 +368,31 @@ std::unique_ptr makeInFilter(const core::TypedExprPtr& expr) { auto elementType = arrayVector->type()->asArray().elementType(); switch (elementType->kind()) { - case TypeKind::TINYINT: - return in(toInt64List(elements, offset, size)); - case TypeKind::SMALLINT: - return in(toInt64List(elements, offset, size)); - case TypeKind::INTEGER: - return in(toInt64List(elements, offset, size)); - case TypeKind::BIGINT: - return in(toInt64List(elements, offset, size)); + case TypeKind::TINYINT: { + auto values = toInt64List(elements, offset, size); + return negated ? notIn(values) : in(values); + } + case TypeKind::SMALLINT: { + auto values = toInt64List(elements, offset, size); + return negated ? notIn(values) : in(values); + } + case TypeKind::INTEGER: { + auto values = toInt64List(elements, offset, size); + return negated ? notIn(values) : in(values); + } + case TypeKind::BIGINT: { + auto values = toInt64List(elements, offset, size); + return negated ? notIn(values) : in(values); + } case TypeKind::VARCHAR: { auto stringElements = elements->as>(); std::vector values; for (auto i = 0; i < size; i++) { values.push_back(stringElements->valueAt(offset + i).str()); } + if (negated) { + return notIn(values); + } return in(values); } default: @@ -389,7 +402,8 @@ std::unique_ptr makeInFilter(const core::TypedExprPtr& expr) { std::unique_ptr makeBetweenFilter( const core::TypedExprPtr& lowerExpr, - const core::TypedExprPtr& upperExpr) { + const core::TypedExprPtr& upperExpr, + bool negated) { auto queryCtx = std::make_shared(); auto lower = toConstant(lowerExpr, queryCtx); if (!lower) { @@ -401,19 +415,40 @@ std::unique_ptr makeBetweenFilter( } switch (lower->typeKind()) { case TypeKind::BIGINT: + if (negated) { + return notBetween( + singleValue(lower), singleValue(upper)); + } return between(singleValue(lower), singleValue(upper)); case TypeKind::DOUBLE: - return betweenDouble( - singleValue(lower), singleValue(upper)); + return negated + ? nullptr + : betweenDouble( + singleValue(lower), singleValue(upper)); case TypeKind::REAL: - return betweenFloat(singleValue(lower), singleValue(upper)); + return negated + ? nullptr + : betweenFloat(singleValue(lower), singleValue(upper)); case TypeKind::DATE: + if (negated) { + return notBetween( + singleValue(lower).days(), singleValue(upper).days()); + } return between( singleValue(lower).days(), singleValue(upper).days()); case TypeKind::VARCHAR: + if (negated) { + return notBetween( + singleValue(lower), singleValue(upper)); + } return between( singleValue(lower), singleValue(upper)); case TypeKind::SHORT_DECIMAL: + if (negated) { + notBetween( + singleValue(lower).unscaledValue(), + singleValue(upper).unscaledValue()); + } return between( singleValue(lower).unscaledValue(), singleValue(upper).unscaledValue()); @@ -421,73 +456,74 @@ std::unique_ptr makeBetweenFilter( return nullptr; } } + } // namespace std::unique_ptr leafCallToSubfieldFilter( const core::CallTypedExpr& call, - common::Subfield& subfield) { + common::Subfield& subfield, + bool negated) { if (call.name() == "eq" || call.name() == "equalto") { if (auto field = asField(&call, 0)) { if (toSubfield(field, subfield)) { - return makeEqualFilter(call.inputs()[1]); + return negated ? makeNotEqualFilter(call.inputs()[1]) + : makeEqualFilter(call.inputs()[1]); } } - } else if (call.name() == "neq") { + } else if (call.name() == "neq" || call.name() == "notequalto") { if (auto field = asField(&call, 0)) { if (toSubfield(field, subfield)) { - return makeNotEqualFilter(call.inputs()[1]); + return negated ? makeEqualFilter(call.inputs()[1]) + : makeNotEqualFilter(call.inputs()[1]); } } } else if (call.name() == "lte" || call.name() == "lessthanorequal") { if (auto field = asField(&call, 0)) { if (toSubfield(field, subfield)) { - return makeLessThanOrEqualFilter(call.inputs()[1]); + return negated ? makeGreaterThanFilter(call.inputs()[1]) + : makeLessThanOrEqualFilter(call.inputs()[1]); } } } else if (call.name() == "lt" || call.name() == "lessthan") { if (auto field = asField(&call, 0)) { if (toSubfield(field, subfield)) { - return makeLessThanFilter(call.inputs()[1]); + return negated ? makeGreaterThanOrEqualFilter(call.inputs()[1]) + : makeLessThanFilter(call.inputs()[1]); } } } else if (call.name() == "gte" || call.name() == "greaterthanorequal") { if (auto field = asField(&call, 0)) { if (toSubfield(field, subfield)) { - return makeGreaterThanOrEqualFilter(call.inputs()[1]); + return negated ? makeLessThanFilter(call.inputs()[1]) + : makeGreaterThanOrEqualFilter(call.inputs()[1]); } } } else if (call.name() == "gt" || call.name() == "greaterthan") { if (auto field = asField(&call, 0)) { if (toSubfield(field, subfield)) { - return makeGreaterThanFilter(call.inputs()[1]); + return negated ? makeLessThanOrEqualFilter(call.inputs()[1]) + : makeGreaterThanFilter(call.inputs()[1]); } } } else if (call.name() == "between") { if (auto field = asField(&call, 0)) { if (toSubfield(field, subfield)) { - return makeBetweenFilter(call.inputs()[1], call.inputs()[2]); + return makeBetweenFilter(call.inputs()[1], call.inputs()[2], negated); } } } else if (call.name() == "in") { if (auto field = asField(&call, 0)) { if (toSubfield(field, subfield)) { - return makeInFilter(call.inputs()[1]); + return makeInFilter(call.inputs()[1], negated); } } } else if (call.name() == "is_null" || call.name() == "isnull") { if (auto field = asField(&call, 0)) { if (toSubfield(field, subfield)) { - return isNull(); - } - } - } else if (call.name() == "not") { - if (auto nestedCall = asCall(call.inputs()[0].get())) { - if (nestedCall->name() == "is_null") { - if (auto field = asField(nestedCall, 0)) { - if (toSubfield(field, subfield)) { - return isNotNull(); - } + if (negated) { + return isNotNull(); } + return isNull(); } } } @@ -506,7 +542,15 @@ std::pair> toSubfieldFilter( makeOrFilter(std::move(left.second), std::move(right.second))}; } common::Subfield subfield; - if (auto filter = leafCallToSubfieldFilter(*call, subfield)) { + std::unique_ptr filter; + if (call->name() == "not") { + if (auto* inner = asCall(call->inputs()[0].get())) { + filter = leafCallToSubfieldFilter(*inner, subfield, true); + } + } else { + filter = leafCallToSubfieldFilter(*call, subfield, false); + } + if (filter) { return std::make_pair(std::move(subfield), std::move(filter)); } } diff --git a/velox/expression/ExprToSubfieldFilter.h b/velox/expression/ExprToSubfieldFilter.h index 66030a105936..effe45a5c963 100644 --- a/velox/expression/ExprToSubfieldFilter.h +++ b/velox/expression/ExprToSubfieldFilter.h @@ -380,6 +380,7 @@ std::pair> toSubfieldFilter( /// execution. std::unique_ptr leafCallToSubfieldFilter( const core::CallTypedExpr&, - common::Subfield&); + common::Subfield&, + bool negated = false); } // namespace facebook::velox::exec diff --git a/velox/expression/tests/ExprToSubfieldFilterTest.cpp b/velox/expression/tests/ExprToSubfieldFilterTest.cpp index f2fe7322748d..b92546766bc4 100644 --- a/velox/expression/tests/ExprToSubfieldFilterTest.cpp +++ b/velox/expression/tests/ExprToSubfieldFilterTest.cpp @@ -193,8 +193,7 @@ TEST_F(ExprToSubfieldFilterTest, isNull) { TEST_F(ExprToSubfieldFilterTest, isNotNull) { auto call = parseCallExpr("a is not null", ROW({{"a", BIGINT()}})); - Subfield subfield; - auto filter = leafCallToSubfieldFilter(*call, subfield); + auto [subfield, filter] = toSubfieldFilter(call); ASSERT_TRUE(filter); validateSubfield(subfield, {"a"}); ASSERT_TRUE(filter->testInt64(0));