From f9e00ef0b4efe15e927301c5d34a45aefae68c7e Mon Sep 17 00:00:00 2001 From: Fabian Krause <29677855+kcaliban@users.noreply.github.com> Date: Fri, 12 Jan 2024 20:09:25 +0100 Subject: [PATCH 1/2] Add support for COUNT and fix a bug in the GROUP BY HashMap optimization (#1222) Add support for COUNT and prepare adding support for the other aggregate expressions in the hash map-based implementation of GROUP BY. Also fix a critical bug when grouped variables occur in expressions. --- src/engine/GroupBy.cpp | 317 ++++++++++++++++++++++++++++------------- src/engine/GroupBy.h | 112 +++++++++++---- test/GroupByTest.cpp | 156 +++++++++++++++++--- 3 files changed, 443 insertions(+), 142 deletions(-) diff --git a/src/engine/GroupBy.cpp b/src/engine/GroupBy.cpp index 170c334022..e34f3170b8 100644 --- a/src/engine/GroupBy.cpp +++ b/src/engine/GroupBy.cpp @@ -383,8 +383,8 @@ ResultTable GroupBy::computeResult() { if (hashMapOptimizationParams.has_value()) { computeGroupByForHashMapOptimization( &idTable, hashMapOptimizationParams->aggregateAliases_, - hashMapOptimizationParams->numAggregates_, subresult->idTable(), - hashMapOptimizationParams->subtreeColumnIndex_, &localVocab); + subresult->idTable(), hashMapOptimizationParams->subtreeColumnIndex_, + &localVocab); return {std::move(idTable), resultSortedOn(), std::move(localVocab)}; } @@ -733,7 +733,7 @@ GroupBy::checkIfHashMapOptimizationPossible(std::vector& aliases) { if (!foundAggregates.has_value()) return std::nullopt; for (auto& aggregate : foundAggregates.value()) { - aggregate.aggregateDataIndex = numAggregates++; + aggregate.aggregateDataIndex_ = numAggregates++; } aliasesWithAggregateInfo.emplace_back(alias._expression, alias._outCol, @@ -744,8 +744,51 @@ GroupBy::checkIfHashMapOptimizationPossible(std::vector& aliases) { auto child = _subtree->getRootOperation()->getChildren().at(0); auto columnIndex = child->getVariableColumn(groupByVariable); - return HashMapOptimizationData{columnIndex, aliasesWithAggregateInfo, - numAggregates}; + return HashMapOptimizationData{columnIndex, aliasesWithAggregateInfo}; +} + +// _____________________________________________________________________________ +std::variant, + GroupBy::OccurenceAsRoot> +GroupBy::findGroupedVariable(sparqlExpression::SparqlExpression* expr) { + AD_CONTRACT_CHECK(_groupByVariables.size() == 1); + std::variant, OccurenceAsRoot> substitutions; + findGroupedVariableImpl(expr, std::nullopt, substitutions); + return substitutions; +} + +// _____________________________________________________________________________ +void GroupBy::findGroupedVariableImpl( + sparqlExpression::SparqlExpression* expr, + std::optional parentAndChildIndex, + std::variant, OccurenceAsRoot>& + substitutions) { + if (auto value = hasType(expr)) { + auto variable = value.value()->value(); + for (const auto& groupedVariable : _groupByVariables) { + if (variable != groupedVariable) continue; + + if (parentAndChildIndex.has_value()) { + auto vector = + std::get_if>(&substitutions); + AD_CONTRACT_CHECK(vector != nullptr); + vector->emplace_back(parentAndChildIndex.value()); + } else { + substitutions = OccurenceAsRoot{}; + return; + } + } + } + + auto children = expr->children(); + + // TODO use views::enumerate + size_t childIndex = 0; + for (const auto& child : children) { + ParentAndChildIndex parentAndChildIndexForChild{expr, childIndex++}; + findGroupedVariableImpl(child.get(), parentAndChildIndexForChild, + substitutions); + } } // _____________________________________________________________________________ @@ -775,19 +818,21 @@ bool GroupBy::hasAnyType(const auto& expr) { } // _____________________________________________________________________________ -bool GroupBy::isUnsupportedAggregate(sparqlExpression::SparqlExpression* expr) { +std::optional GroupBy::isSupportedAggregate( + sparqlExpression::SparqlExpression* expr) { using namespace sparqlExpression; - // `expr` is not an aggregate, so it cannot be an unsupported aggregate - if (!expr->isAggregate()) return false; + // `expr` is not a distinct aggregate + if (expr->isDistinct()) return std::nullopt; - // `expr` is an unsupported aggregate - if (hasAnyType(expr)) - return true; + // `expr` is not a nested aggregated + if (expr->children().front()->containsAggregate()) return std::nullopt; - // `expr` is a distinct aggregate - return expr->isDistinct(); + if (hasType(expr)) return HashMapAggregateType::AVG; + if (hasType(expr)) return HashMapAggregateType::COUNT; + + // `expr` is an unsupported aggregate + return std::nullopt; } // _____________________________________________________________________________ @@ -795,21 +840,19 @@ bool GroupBy::findAggregatesImpl( sparqlExpression::SparqlExpression* expr, std::optional parentAndChildIndex, std::vector& info) { - // Unsupported aggregates - if (isUnsupportedAggregate(expr)) return false; - if (expr->isAggregate()) { - info.emplace_back(expr, 0, parentAndChildIndex); - - // Make sure this is not a nested aggregate. - if (expr->children().front()->containsAggregate()) return false; - - return true; + if (auto aggregateType = isSupportedAggregate(expr)) { + info.emplace_back(expr, 0, aggregateType.value(), parentAndChildIndex); + return true; + } else { + return false; + } } auto children = expr->children(); bool childrenContainOnlySupportedAggregates = true; + // TODO use views::enumerate size_t childIndex = 0; for (const auto& child : children) { ParentAndChildIndex parentAndChildIndexForChild{expr, childIndex++}; @@ -856,13 +899,17 @@ GroupBy::getHashMapAggregationResults( aggregateResults.resize(endIndex - beginIndex); decltype(auto) groupValues = resultTable->getColumn(0); - auto& aggregateDataVector = - aggregationData.getAggregationDataVector(dataIndex); + auto& aggregateDataVariant = + aggregationData.getAggregationDataVariant(dataIndex); - auto op = [&aggregationData, &aggregateDataVector](Id val) { + auto op = [&aggregationData, &aggregateDataVariant](Id val) { auto index = aggregationData.getIndex(val); - auto& aggregateData = aggregateDataVector.at(index); - return aggregateData.calculateResult(); + + auto visitor = [&index](auto& aggregateDataVariant) { + return aggregateDataVariant.at(index).calculateResult(); + }; + + return std::visit(visitor, aggregateDataVariant); }; std::ranges::transform(groupValues.begin() + beginIndex, @@ -872,6 +919,26 @@ GroupBy::getHashMapAggregationResults( return aggregateResults; } +// _____________________________________________________________________________ +void GroupBy::substituteGroupVariable( + const std::vector& occurrences, + IdTable* resultTable) const { + decltype(auto) groupValues = resultTable->getColumn(0); + + for (const auto& occurrence : occurrences) { + sparqlExpression::VectorWithMemoryLimit values( + getExecutionContext()->getAllocator()); + values.resize(groupValues.size()); + std::ranges::copy(groupValues, values.begin()); + + auto newExpression = std::make_unique( + std::move(values)); + + occurrence.parent_->replaceChild(occurrence.nThChild_, + std::move(newExpression)); + } +} + // _____________________________________________________________________________ void GroupBy::substituteAllAggregates( std::vector& info, size_t beginIndex, @@ -880,7 +947,7 @@ void GroupBy::substituteAllAggregates( // Substitute in the results of all aggregates of `info`. for (auto& aggregate : info) { auto aggregateResults = getHashMapAggregationResults( - resultTable, aggregationData, aggregate.aggregateDataIndex, beginIndex, + resultTable, aggregationData, aggregate.aggregateDataIndex_, beginIndex, endIndex); // Substitute the resulting vector as a literal @@ -907,7 +974,8 @@ std::vector GroupBy::HashMapAggregationData::getHashEntries( } for (auto& aggregation : aggregationData_) - aggregation.resize(getNumberOfGroups()); + std::visit([this](auto& arg) { arg.resize(getNumberOfGroups()); }, + aggregation); return hashEntries; } @@ -925,6 +993,80 @@ GroupBy::HashMapAggregationData::getSortedGroupColumn() const { return sortedKeys; } +// _____________________________________________________________________________ +void GroupBy::evaluateAlias( + HashMapAliasInformation& alias, IdTable* result, + sparqlExpression::EvaluationContext& evaluationContext, + const HashMapAggregationData& aggregationData, LocalVocab* localVocab) { + auto& info = alias.aggregateInfo_; + + // Check if the grouped variable occurs in this expression + auto groupByVariableSubstitutions = + findGroupedVariable(alias.expr_.getPimpl()); + + if (std::get_if(&groupByVariableSubstitutions)) { + // If the aggregate is at the top of the alias, e.g. `SELECT (?a as ?x) + // WHERE {...} GROUP BY ?a`, we can copy values directly from the column + // of the grouped variable + decltype(auto) groupValues = result->getColumn(0); + decltype(auto) outValues = result->getColumn(alias.outCol_); + std::ranges::copy(groupValues, outValues.begin()); + + // We also need to store it for possible future use + sparqlExpression::VectorWithMemoryLimit values( + getExecutionContext()->getAllocator()); + values.resize(groupValues.size()); + std::ranges::copy(groupValues, values.begin()); + + evaluationContext._previousResultsFromSameGroup.at(alias.outCol_) = + sparqlExpression::copyExpressionResult( + sparqlExpression::ExpressionResult{std::move(values)}); + } else if (info.size() == 1 && !info.at(0).parentAndIndex_.has_value()) { + // Only one aggregate, and it is at the top of the alias, e.g. + // `(AVG(?x) as ?y)`. The grouped by variable cannot occur inside + // an aggregate, hence we don't need to substitute anything here + auto& aggregate = info.at(0); + + // Get aggregate results + auto aggregateResults = getHashMapAggregationResults( + result, aggregationData, aggregate.aggregateDataIndex_, + evaluationContext._beginIndex, evaluationContext._endIndex); + + // Copy to result table + decltype(auto) outValues = result->getColumn(alias.outCol_); + std::ranges::copy(aggregateResults, + outValues.begin() + evaluationContext._beginIndex); + + // Copy the result so that future aliases may reuse it + evaluationContext._previousResultsFromSameGroup.at(alias.outCol_) = + sparqlExpression::copyExpressionResult( + sparqlExpression::ExpressionResult{std::move(aggregateResults)}); + } else { + const auto& occurrences = + get>(groupByVariableSubstitutions); + // Substitute in the values of the grouped variable + substituteGroupVariable(occurrences, result); + + // Substitute in the results of all aggregates contained in the + // expression of the current alias, if `info` is non-empty. + substituteAllAggregates(info, evaluationContext._beginIndex, + evaluationContext._endIndex, aggregationData, + result); + + // Evaluate top-level alias expression + sparqlExpression::ExpressionResult expressionResult = + alias.expr_.getPimpl()->evaluate(&evaluationContext); + + // Copy the result so that future aliases may reuse it + evaluationContext._previousResultsFromSameGroup.at(alias.outCol_) = + sparqlExpression::copyExpressionResult(expressionResult); + + // Extract values + extractValues(std::move(expressionResult), evaluationContext, result, + localVocab, alias.outCol_); + } +} + // _____________________________________________________________________________ void GroupBy::createResultFromHashMap( IdTable* result, const HashMapAggregationData& aggregationData, @@ -957,58 +1099,52 @@ void GroupBy::createResultFromHashMap( evaluationContext._endIndex = std::min(i + blockSize, numberOfGroups); for (auto& alias : aggregateAliases) { - auto& info = alias.aggregateInfo_; - - // Only one aggregate, and it is at the top of the alias. - if (info.size() == 1 && !info.at(0).parentAndIndex_.has_value()) { - auto& aggregate = info.at(0); - - // Get aggregate results - auto aggregateResults = getHashMapAggregationResults( - result, aggregationData, aggregate.aggregateDataIndex, - evaluationContext._beginIndex, evaluationContext._endIndex); - - // Copy to result table - decltype(auto) outValues = result->getColumn(alias.outCol_); - std::ranges::copy(aggregateResults, - outValues.begin() + evaluationContext._beginIndex); - - // Copy the result so that future aliases may reuse it - evaluationContext._previousResultsFromSameGroup.at(alias.outCol_) = - sparqlExpression::copyExpressionResult( - sparqlExpression::ExpressionResult{ - std::move(aggregateResults)}); - } else { - // Substitute in the results of all aggregates contained in the - // expression of the current alias, if `info` is non-empty. - substituteAllAggregates(info, evaluationContext._beginIndex, - evaluationContext._endIndex, aggregationData, - result); - - // Evaluate top-level alias expression - sparqlExpression::ExpressionResult expressionResult = - alias.expr_.getPimpl()->evaluate(&evaluationContext); - - // Copy the result so that future aliases may reuse it - evaluationContext._previousResultsFromSameGroup.at(alias.outCol_) = - sparqlExpression::copyExpressionResult(expressionResult); - - // Extract values - extractValues(std::move(expressionResult), evaluationContext, result, - localVocab, alias.outCol_); - } + evaluateAlias(alias, result, evaluationContext, aggregationData, + localVocab); } } } +// _____________________________________________________________________________ +template +concept SupportedAggregates = + ad_utility::isTypeContainedIn; + +// _____________________________________________________________________________ +// Visitor function to extract values from the result of an evaluation of +// the child expression of an aggregate, and subsequently processing the values +// by calling the `increment` function of the corresponding aggregate. +static constexpr auto makeProcessGroupsVisitor = + [](size_t blockSize, + const sparqlExpression::EvaluationContext* evaluationContext, + const std::vector& hashEntries) { + return [blockSize, evaluationContext, + &hashEntries](T&& singleResult, + A& aggregationDataVector) { + auto generator = sparqlExpression::detail::makeGenerator( + std::forward(singleResult), blockSize, evaluationContext); + + auto hashEntryIndex = 0; + + for (const auto& val : generator) { + auto vectorOffset = hashEntries[hashEntryIndex]; + auto& aggregateData = aggregationDataVector.at(vectorOffset); + + aggregateData.increment(val, evaluationContext); + + ++hashEntryIndex; + } + }; + }; + // _____________________________________________________________________________ void GroupBy::computeGroupByForHashMapOptimization( IdTable* result, std::vector& aggregateAliases, - size_t numAggregates, const IdTable& subresult, size_t columnIndex, - LocalVocab* localVocab) { + const IdTable& subresult, size_t columnIndex, LocalVocab* localVocab) { // Initialize aggregation data HashMapAggregationData aggregationData(getExecutionContext()->getAllocator(), - numAggregates); + aggregateAliases); // Initialize evaluation context sparqlExpression::EvaluationContext evaluationContext( @@ -1034,7 +1170,6 @@ void GroupBy::computeGroupByForHashMapOptimization( .subspan(evaluationContext._beginIndex, currentBlockSize); auto hashEntries = aggregationData.getHashEntries(groupValues); - // TODO use views::enumerate for (auto& aggregateAlias : aggregateAliases) { for (auto& aggregate : aggregateAlias.aggregateInfo_) { // Evaluate child expression on block @@ -1042,35 +1177,13 @@ void GroupBy::computeGroupByForHashMapOptimization( sparqlExpression::ExpressionResult expressionResult = exprChildren[0]->evaluate(&evaluationContext); - auto& aggregationDataVector = aggregationData.getAggregationDataVector( - aggregate.aggregateDataIndex); - - auto visitor = [¤tBlockSize, &evaluationContext, &hashEntries, - &aggregationDataVector]< - sparqlExpression::SingleExpressionResult T>( - T&& singleResult) mutable { - auto generator = sparqlExpression::detail::makeGenerator( - std::forward(singleResult), currentBlockSize, - &evaluationContext); - - using NVG = sparqlExpression::detail::NumericValueGetter; - - auto hashEntryIndex = 0; - - for (const auto& val : generator) { - sparqlExpression::detail::NumericValue numVal = - NVG()(val, &evaluationContext); - - auto vectorOffset = hashEntries[hashEntryIndex]; - auto& aggregateData = aggregationDataVector.at(vectorOffset); - - aggregateData.increment(numVal); - - ++hashEntryIndex; - } - }; + auto& aggregationDataVariant = + aggregationData.getAggregationDataVariant( + aggregate.aggregateDataIndex_); - std::visit(visitor, std::move(expressionResult)); + std::visit(makeProcessGroupsVisitor(currentBlockSize, + &evaluationContext, hashEntries), + std::move(expressionResult), aggregationDataVariant); } } } diff --git a/src/engine/GroupBy.h b/src/engine/GroupBy.h index db6b7df10b..1a9a4f6fcd 100644 --- a/src/engine/GroupBy.h +++ b/src/engine/GroupBy.h @@ -158,13 +158,17 @@ class GroupBy : public Operation { // Data to perform the AVG aggregation using the HashMap optimization. struct AverageAggregationData { + using ValueGetter = sparqlExpression::detail::NumericValueGetter; bool error_ = false; double sum_ = 0; int64_t count_ = 0; - void increment(sparqlExpression::detail::NumericValue intermediate) { - if (const int64_t* intval = std::get_if(&intermediate)) - sum_ += (double)*intval; - else if (const double* dval = std::get_if(&intermediate)) + void increment(auto&& value, + const sparqlExpression::EvaluationContext* ctx) { + auto val = ValueGetter{}(AD_FWD(value), ctx); + + if (const int64_t* intval = std::get_if(&val)) + sum_ += static_cast(*intval); + else if (const double* dval = std::get_if(&val)) sum_ += *dval; else error_ = true; @@ -178,6 +182,19 @@ class GroupBy : public Operation { } }; + // Data to perform the COUNT aggregation using the HashMap optimization. + struct CountAggregationData { + using ValueGetter = sparqlExpression::detail::IsValidValueGetter; + int64_t count_ = 0; + void increment(auto&& value, + const sparqlExpression::EvaluationContext* ctx) { + if (ValueGetter{}(AD_FWD(value), ctx)) count_++; + } + [[nodiscard]] ValueId calculateResult() const { + return ValueId::makeFromInt(count_); + } + }; + using KeyType = ValueId; using ValueType = size_t; @@ -194,24 +211,31 @@ class GroupBy : public Operation { } }; + // Used to store the kind of aggregate. + enum class HashMapAggregateType { AVG, COUNT }; + // Stores information required for evaluation of an aggregate as well // as the alias containing it. struct HashMapAggregateInformation { // The expression of this aggregate. sparqlExpression::SparqlExpression* expr_; - // The index in the `std::array` of the Hash Map where results of this + // The index in the vector of `HashMapAggregationData` where results of this // aggregate are stored. - size_t aggregateDataIndex; + size_t aggregateDataIndex_; // The parent expression of this aggregate, and the index this expression // appears in the parents' children, so that it may be substituted away. std::optional parentAndIndex_ = std::nullopt; + // Which kind of aggregate expression. + HashMapAggregateType aggregateType_; HashMapAggregateInformation( sparqlExpression::SparqlExpression* expr, size_t aggregateDataIndex, + HashMapAggregateType aggregateType, std::optional parentAndIndex = std::nullopt) : expr_{expr}, - aggregateDataIndex{aggregateDataIndex}, - parentAndIndex_{parentAndIndex} { + aggregateDataIndex_{aggregateDataIndex}, + parentAndIndex_{parentAndIndex}, + aggregateType_{aggregateType} { AD_CONTRACT_CHECK(expr != nullptr); } }; @@ -233,24 +257,37 @@ class GroupBy : public Operation { size_t subtreeColumnIndex_; // All aliases and the aggregates they contain. std::vector aggregateAliases_; - // The total number of aggregates. - size_t numAggregates_; }; // Create result IdTable by using a HashMap mapping groups to aggregation data // and subsequently calling `createResultFromHashMap`. void computeGroupByForHashMapOptimization( IdTable* result, std::vector& aggregateAliases, - size_t numAggregates, const IdTable& subresult, size_t columnIndex, - LocalVocab* localVocab); + const IdTable& subresult, size_t columnIndex, LocalVocab* localVocab); + + using Aggregations = std::variant, + std::vector>; // Stores the map which associates Ids with vector offsets and // the vectors containing the aggregation data. class HashMapAggregationData { public: - HashMapAggregationData(ad_utility::AllocatorWithLimit alloc, - size_t numAggregates) - : map_{alloc}, aggregationData_{numAggregates} { + HashMapAggregationData( + const ad_utility::AllocatorWithLimit& alloc, + const std::vector& aggregateAliases) + : map_{alloc} { + size_t numAggregates = 0; + for (const auto& alias : aggregateAliases) { + for (const auto& aggregate : alias.aggregateInfo_) { + ++numAggregates; + + if (aggregate.aggregateType_ == HashMapAggregateType::AVG) + aggregationData_.emplace_back( + std::vector{}); + if (aggregate.aggregateType_ == HashMapAggregateType::COUNT) + aggregationData_.emplace_back(std::vector{}); + } + } AD_CONTRACT_CHECK(numAggregates > 0); } @@ -262,15 +299,14 @@ class GroupBy : public Operation { [[nodiscard]] size_t getIndex(Id id) const { return map_.at(id); } // Get vector containing the aggregation data at `aggregationDataIndex`. - std::vector& getAggregationDataVector( - size_t aggregationDataIndex) { + Aggregations& getAggregationDataVariant(size_t aggregationDataIndex) { return aggregationData_.at(aggregationDataIndex); } // Get vector containing the aggregation data at `aggregationDataIndex`, // but const. - [[nodiscard]] const std::vector& - getAggregationDataVector(size_t aggregationDataIndex) const { + [[nodiscard]] const Aggregations& getAggregationDataVariant( + size_t aggregationDataIndex) const { return aggregationData_.at(aggregationDataIndex); } @@ -284,7 +320,7 @@ class GroupBy : public Operation { // Maps `Id` to vector offsets. ad_utility::HashMapWithMemoryLimit map_; // Stores the actual aggregation data. - std::vector> aggregationData_; + std::vector aggregationData_; }; // Returns the aggregation results between `beginIndex` and `endIndex` @@ -294,6 +330,14 @@ class GroupBy : public Operation { IdTable* resultTable, const HashMapAggregationData& aggregationData, size_t dataIndex, size_t beginIndex, size_t endIndex); + // Substitute away any occurrences of the grouped variable and of aggregate + // results, if necessary, and subsequently evaluate the expression of an + // alias + void evaluateAlias(HashMapAliasInformation& alias, IdTable* result, + sparqlExpression::EvaluationContext& evaluationContext, + const HashMapAggregationData& aggregationData, + LocalVocab* localVocab); + // Sort the HashMap by key and create result table. void createResultFromHashMap( IdTable* result, const HashMapAggregationData& aggregationData, @@ -317,6 +361,11 @@ class GroupBy : public Operation { sparqlExpression::EvaluationContext& evaluationContext, IdTable* resultTable, LocalVocab* localVocab, size_t outCol); + // Substitute the group values for all occurrences of a group variable. + void substituteGroupVariable( + const std::vector& occurrences, + IdTable* resultTable) const; + // Substitute the results for all aggregates in `info`. The values of the // grouped variable should be at column 0 in `groupValues`. void substituteAllAggregates(std::vector& info, @@ -333,13 +382,28 @@ class GroupBy : public Operation { static bool hasAnyType(const auto& expr); // Check if an expression is a currently supported aggregate. - // TODO As soon as all aggregates are supported, implement and use a - // `isAggregate` function in SparqlExpressions instead. - static bool isUnsupportedAggregate(sparqlExpression::SparqlExpression* expr); + static std::optional isSupportedAggregate( + sparqlExpression::SparqlExpression* expr); + + // Determines whether the grouped by variable appears at the top of an + // alias, e.g. `SELECT (?a as ?x) WHERE {...} GROUP BY ?a`. + struct OccurenceAsRoot {}; + + // Find all occurrences of grouped by variable for expression `expr`. + std::variant, OccurenceAsRoot> + findGroupedVariable(sparqlExpression::SparqlExpression* expr); + + // The recursive implementation of `findGroupedVariable` (see above). + void findGroupedVariableImpl( + sparqlExpression::SparqlExpression* expr, + std::optional parentAndChildIndex, + std::variant, OccurenceAsRoot>& + substitutions); // Find all aggregates for expression `expr`. Return `std::nullopt` // if an unsupported aggregate is found. - // TODO Remove std::optional as soon as all aggregates are supported + // TODO Remove std::optional as soon as all aggregates are + // supported static std::optional> findAggregates( sparqlExpression::SparqlExpression* expr); diff --git a/test/GroupByTest.cpp b/test/GroupByTest.cpp index 08cc34ba6f..f1bf2a2372 100644 --- a/test/GroupByTest.cpp +++ b/test/GroupByTest.cpp @@ -497,6 +497,55 @@ TEST_F(GroupByOptimizations, findAggregates) { ASSERT_FALSE(unsupportedAggregates.has_value()); } +// _____________________________________________________________________________ +TEST_F(GroupByOptimizations, findGroupedVariable) { + Variable varA = Variable{"?a"}; + Variable varX = Variable{"?x"}; + Variable varB = Variable{"?b"}; + + using namespace sparqlExpression; + using TC = TripleComponent; + + // `(?a as ?x)`. + auto expr1 = makeVariableExpression(varA); + + // `(?a + COUNT(?b) AS ?y)`. + auto expr2 = makeAddExpression( + makeVariableExpression(varA), + std::make_unique(false, makeVariableExpression(varB))); + + // `(?x + AVG(?b) as ?z)`. + auto expr3 = makeAddExpression( + makeVariableExpression(varX), + std::make_unique(false, makeVariableExpression(varB))); + + // Set up the Group By object. + parsedQuery::SparqlValues input; + input._variables = std::vector{varA, varB}; + input._values.push_back(std::vector{TC(1.0), TC(3.0)}); + auto values = ad_utility::makeExecutionTree( + ad_utility::testing::getQec(), input); + GroupBy groupBy{ad_utility::testing::getQec(), {Variable{"?a"}}, {}, values}; + + auto variableAtTop = groupBy.findGroupedVariable(expr1.get()); + ASSERT_TRUE(std::get_if(&variableAtTop)); + + auto variableInExpression = groupBy.findGroupedVariable(expr2.get()); + auto variableInExpressionOccurrences = + std::get_if>( + &variableInExpression); + ASSERT_TRUE(variableInExpressionOccurrences); + ASSERT_EQ(variableInExpressionOccurrences->size(), 1); + auto parentAndChildIndex = variableInExpressionOccurrences->at(0); + ASSERT_EQ(parentAndChildIndex.nThChild_, 0); + ASSERT_EQ(parentAndChildIndex.parent_, expr2.get()); + + auto variableNotFound = groupBy.findGroupedVariable(expr3.get()); + auto variableNotFoundOccurrences = + std::get_if>(&variableNotFound); + ASSERT_EQ(variableNotFoundOccurrences->size(), 0); +} + // _____________________________________________________________________________ TEST_F(GroupByOptimizations, checkIfHashMapOptimizationPossible) { auto testFailure = [this](const auto& groupByVariables, const auto& aliases, @@ -558,15 +607,17 @@ TEST_F(GroupByOptimizations, checkIfHashMapOptimizationPossible) { groupBy.checkIfHashMapOptimizationPossible(avgAggregate); ASSERT_TRUE(optimizedAggregateData.has_value()); ASSERT_EQ(optimizedAggregateData->subtreeColumnIndex_, 0); - ASSERT_EQ(optimizedAggregateData->numAggregates_, 1); // Check aggregate alias is correct auto aggregateAlias = optimizedAggregateData->aggregateAliases_[0]; ASSERT_EQ(aggregateAlias.expr_.getPimpl(), avgXPimpl.getPimpl()); // Check aggregate info is correct auto aggregateInfo = aggregateAlias.aggregateInfo_[0]; - ASSERT_EQ(aggregateInfo.aggregateDataIndex, 0); + ASSERT_EQ(aggregateInfo.aggregateDataIndex_, 0); ASSERT_FALSE(aggregateInfo.parentAndIndex_.has_value()); ASSERT_EQ(aggregateInfo.expr_, avgXPimpl.getPimpl()); + + // Disable optimization for following tests + RuntimeParameters().set<"use-group-by-hash-map-optimization">(false); } // _____________________________________________________________________________ @@ -608,10 +659,80 @@ TEST_F(GroupByOptimizations, correctResultForHashMapOptimization) { } // _____________________________________________________________________________ -TEST_F(GroupByOptimizations, correctResultForHashMapOptimizationNonTrivial) { +TEST_F(GroupByOptimizations, hashMapOptimizationGroupedVariable) { + // Make sure we are calculating the correct result when a grouped variable + // occurs in an expression. + RuntimeParameters().set<"use-group-by-hash-map-optimization">(true); + + parsedQuery::SparqlValues input; + using TC = TripleComponent; + + // SELECT (?a AS ?x) (?a + COUNT(?b) AS ?y) (?x + AVG(?b) as ?z) WHERE { + // VALUES (?a ?b) { (1.0 3.0) (1.0 7.0) (5.0 4.0)} + // } GROUP BY ?a + Variable varA = Variable{"?a"}; + Variable varX = Variable{"?x"}; + Variable varB = Variable{"?b"}; + + input._variables = std::vector{varA, varB}; + input._values.push_back(std::vector{TC(1.0), TC(3.0)}); + input._values.push_back(std::vector{TC(1.0), TC(7.0)}); + input._values.push_back(std::vector{TC(5.0), TC(4.0)}); + auto values = ad_utility::makeExecutionTree( + ad_utility::testing::getQec(), input); + + using namespace sparqlExpression; + + // Create `Alias` object for `(?a as ?x)`. + auto expr1 = makeVariableExpression(varA); + auto alias1 = + Alias{SparqlExpressionPimpl{std::move(expr1), "?a"}, Variable{"?x"}}; + + // Create `Alias` object for `(?a + COUNT(?b) AS ?y)`. + auto expr2 = makeAddExpression( + makeVariableExpression(varA), + std::make_unique(false, makeVariableExpression(varB))); + auto alias2 = Alias{SparqlExpressionPimpl{std::move(expr2), "?a + COUNT(?b)"}, + Variable{"?y"}}; + + // Create `Alias` object for `(?x + AVG(?b) as ?z)`. + auto expr3 = makeAddExpression( + makeVariableExpression(varX), + std::make_unique(false, makeVariableExpression(varB))); + auto alias3 = Alias{SparqlExpressionPimpl{std::move(expr3), "?x + AVG(?b)"}, + Variable{"?z"}}; + + // Set up and evaluate the GROUP BY clause. + GroupBy groupBy{ad_utility::testing::getQec(), + {Variable{"?a"}}, + {std::move(alias1), std::move(alias2), std::move(alias3)}, + std::move(values)}; + auto result = groupBy.getResult(); + const auto& table = result->idTable(); + + // Check the result. + auto d = DoubleId; + using enum ColumnIndexAndTypeInfo::UndefStatus; + VariableToColumnMap expectedVariables{ + {Variable{"?a"}, {0, AlwaysDefined}}, + {Variable{"?x"}, {1, PossiblyUndefined}}, + {Variable{"?y"}, {2, PossiblyUndefined}}, + {Variable{"?z"}, {3, PossiblyUndefined}}}; + EXPECT_THAT(groupBy.getExternallyVisibleVariableColumns(), + ::testing::UnorderedElementsAreArray(expectedVariables)); + auto expected = makeIdTableFromVector( + {{d(1), d(1), d(3), d(6)}, {d(5), d(5), d(6), d(9)}}); + EXPECT_EQ(table, expected); + + // Disable optimization for following tests + RuntimeParameters().set<"use-group-by-hash-map-optimization">(false); +} + +// _____________________________________________________________________________ +TEST_F(GroupByOptimizations, hashMapOptimizationNonTrivial) { /* Setup query: SELECT ?x (AVG(?y) as ?avg) - (?avg + ((2 * AVG(?y)) * AVG(4 * ?y)) as ?complexAvg) + (?avg + ((2 * COUNT(?y)) * AVG(4 * ?y)) as ?complexAvg) (5.0 as ?const) (42.0 as ?const2) (13.37 as ?const3) (?const + ?const2 + ?const3 + AVG(?y) + AVG(?y) + AVG(?y) as ?sth) WHERE { @@ -634,22 +755,22 @@ TEST_F(GroupByOptimizations, correctResultForHashMapOptimizationNonTrivial) { Variable varAvg{"?avg"}; SparqlExpressionPimpl avgYPimpl = makeAvgPimpl(varY); - // (?avg + ((2 * AVG(?y)) * AVG(4 * ?y)) as ?complexAvg) + // (?avg + ((2 * COUNT(?y)) * AVG(4 * ?y)) as ?complexAvg) auto fourTimesYExpr = makeMultiplyExpression(makeLiteralDoubleExpr(4.0), makeVariableExpression(varY)); auto avgFourTimesYExpr = std::make_unique(false, std::move(fourTimesYExpr)); - auto avgYExpr = - std::make_unique(false, makeVariableExpression(varY)); - auto twoTimesAvgYExpr = - makeMultiplyExpression(makeLiteralDoubleExpr(2.0), std::move(avgYExpr)); - auto twoTimesAvgY_times_avgFourTimesYExpr = makeMultiplyExpression( - std::move(twoTimesAvgYExpr), std::move(avgFourTimesYExpr)); - auto avgY_plus_twoTimesAvgY_times_avgFourTimesYExpr = + auto countYExpr = + std::make_unique(false, makeVariableExpression(varY)); + auto twoTimesCountYExpr = + makeMultiplyExpression(makeLiteralDoubleExpr(2.0), std::move(countYExpr)); + auto twoTimesCountY_times_avgFourTimesYExpr = makeMultiplyExpression( + std::move(twoTimesCountYExpr), std::move(avgFourTimesYExpr)); + auto avgY_plus_twoTimesCountY_times_avgFourTimesYExpr = makeAddExpression(makeVariableExpression(varAvg), - std::move(twoTimesAvgY_times_avgFourTimesYExpr)); - SparqlExpressionPimpl avgY_plus_twoTimesAvgY_times_avgFourTimesYPimpl( - std::move(avgY_plus_twoTimesAvgY_times_avgFourTimesYExpr), + std::move(twoTimesCountY_times_avgFourTimesYExpr)); + SparqlExpressionPimpl avgY_plus_twoTimesCountY_times_avgFourTimesYPimpl( + std::move(avgY_plus_twoTimesCountY_times_avgFourTimesYExpr), "(?avg + ((2 * AVG(?y)) * AVG(4 * ?y)) as ?complexAvg)"); // (5.0 as ?const) (42.0 as ?const2) (13.37 as ?const3) @@ -683,7 +804,7 @@ TEST_F(GroupByOptimizations, correctResultForHashMapOptimizationNonTrivial) { std::vector aliasesAvgY{ Alias{avgYPimpl, varAvg}, - Alias{avgY_plus_twoTimesAvgY_times_avgFourTimesYPimpl, + Alias{avgY_plus_twoTimesCountY_times_avgFourTimesYPimpl, Variable{"?complexAvg"}}, Alias{constantFive, varConst}, Alias{constantFortyTwo, varConst2}, @@ -706,6 +827,9 @@ TEST_F(GroupByOptimizations, correctResultForHashMapOptimizationNonTrivial) { // Compare results, using debugString as the result only contains 2 rows ASSERT_EQ(resultWithOptimization->asDebugString(), resultWithoutOptimization->asDebugString()); + + // Disable optimization for following tests + RuntimeParameters().set<"use-group-by-hash-map-optimization">(false); } // _____________________________________________________________________________ From 34af3a7aa7ee10f1961eba856eeeb82be324dad9 Mon Sep 17 00:00:00 2001 From: Fabian Krause <29677855+kcaliban@users.noreply.github.com> Date: Sat, 13 Jan 2024 12:49:38 +0100 Subject: [PATCH 2/2] Fix build error introduced by previous commit (#1224) Note that the `check-index-version` workflow will still fail for this commit because it checks out the previous version of the master, which because of the build error we are fixing here does not compile. Everything should be fine again starting with the next commit after this one. --- src/engine/GroupBy.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/engine/GroupBy.cpp b/src/engine/GroupBy.cpp index e34f3170b8..a2b7ba720f 100644 --- a/src/engine/GroupBy.cpp +++ b/src/engine/GroupBy.cpp @@ -1108,7 +1108,7 @@ void GroupBy::createResultFromHashMap( // _____________________________________________________________________________ template concept SupportedAggregates = - ad_utility::isTypeContainedIn; + ad_utility::SameAsAnyTypeIn; // _____________________________________________________________________________ // Visitor function to extract values from the result of an evaluation of