diff --git a/src/engine/GroupBy.cpp b/src/engine/GroupBy.cpp index 53878d6eda..1eec96f8c5 100644 --- a/src/engine/GroupBy.cpp +++ b/src/engine/GroupBy.cpp @@ -748,10 +748,11 @@ GroupBy::checkIfHashMapOptimizationPossible(std::vector& aliases) { } // _____________________________________________________________________________ -GroupBy::GroupedByVariableSubstitutions GroupBy::findGroupedVariable( - sparqlExpression::SparqlExpression* expr) { +std::variant, + GroupBy::OccurenceAsRoot> +GroupBy::findGroupedVariable(sparqlExpression::SparqlExpression* expr) { AD_CONTRACT_CHECK(_groupByVariables.size() == 1); - GroupBy::GroupedByVariableSubstitutions substitutions{{}, false}; + std::variant, OccurenceAsRoot> substitutions; findGroupedVariableImpl(expr, std::nullopt, substitutions); return substitutions; } @@ -760,16 +761,20 @@ GroupBy::GroupedByVariableSubstitutions GroupBy::findGroupedVariable( void GroupBy::findGroupedVariableImpl( sparqlExpression::SparqlExpression* expr, std::optional parentAndChildIndex, - GroupBy::GroupedByVariableSubstitutions& substitutions) { + std::variant, OccurenceAsRoot>& + substitutions) { if (auto value = hasType(expr)) { auto variable = value.value()->value(); for (const auto& groupedVariable : _groupByVariables) { if (variable != groupedVariable) continue; if (parentAndChildIndex.has_value()) { - substitutions.occurrences_.emplace_back(parentAndChildIndex.value()); + auto vector = + std::get_if>(&substitutions); + AD_CONTRACT_CHECK(vector != nullptr); + vector->emplace_back(parentAndChildIndex.value()); } else { - substitutions.topLevel_ = true; + substitutions = OccurenceAsRoot{}; return; } } @@ -924,7 +929,7 @@ void GroupBy::substituteGroupVariable( sparqlExpression::VectorWithMemoryLimit values( getExecutionContext()->getAllocator()); values.resize(groupValues.size()); - std::ranges::copy(groupValues.begin(), groupValues.end(), values.begin()); + std::ranges::copy(groupValues, values.begin()); auto newExpression = std::make_unique( std::move(values)); @@ -988,6 +993,80 @@ GroupBy::HashMapAggregationData::getSortedGroupColumn() const { return sortedKeys; } +// _____________________________________________________________________________ +void GroupBy::evaluateAlias( + HashMapAliasInformation& alias, IdTable* result, + sparqlExpression::EvaluationContext& evaluationContext, + const HashMapAggregationData& aggregationData, LocalVocab* localVocab) { + auto& info = alias.aggregateInfo_; + + // Check if the grouped variable occurs in this expression + auto groupByVariableSubstitutions = + findGroupedVariable(alias.expr_.getPimpl()); + + if (std::get_if(&groupByVariableSubstitutions)) { + // If the aggregate is at the top of the alias, e.g. `SELECT (?a as ?x) + // WHERE {...} GROUP BY ?a`, we can copy values directly from the column + // of the grouped variable + decltype(auto) groupValues = result->getColumn(0); + decltype(auto) outValues = result->getColumn(alias.outCol_); + std::ranges::copy(groupValues, outValues.begin()); + + // We also need to store it for possible future use + sparqlExpression::VectorWithMemoryLimit values( + getExecutionContext()->getAllocator()); + values.resize(groupValues.size()); + std::ranges::copy(groupValues, values.begin()); + + evaluationContext._previousResultsFromSameGroup.at(alias.outCol_) = + sparqlExpression::copyExpressionResult( + sparqlExpression::ExpressionResult{std::move(values)}); + } else if (info.size() == 1 && !info.at(0).parentAndIndex_.has_value()) { + // Only one aggregate, and it is at the top of the alias, e.g. + // `(AVG(?x) as ?y)`. The grouped by variable cannot occur inside + // an aggregate, hence we don't need to substitute anything here + auto& aggregate = info.at(0); + + // Get aggregate results + auto aggregateResults = getHashMapAggregationResults( + result, aggregationData, aggregate.aggregateDataIndex_, + evaluationContext._beginIndex, evaluationContext._endIndex); + + // Copy to result table + decltype(auto) outValues = result->getColumn(alias.outCol_); + std::ranges::copy(aggregateResults, + outValues.begin() + evaluationContext._beginIndex); + + // Copy the result so that future aliases may reuse it + evaluationContext._previousResultsFromSameGroup.at(alias.outCol_) = + sparqlExpression::copyExpressionResult( + sparqlExpression::ExpressionResult{std::move(aggregateResults)}); + } else { + auto occurrences = + get>(groupByVariableSubstitutions); + // Substitute in the values of the grouped variable + substituteGroupVariable(occurrences, result); + + // Substitute in the results of all aggregates contained in the + // expression of the current alias, if `info` is non-empty. + substituteAllAggregates(info, evaluationContext._beginIndex, + evaluationContext._endIndex, aggregationData, + result); + + // Evaluate top-level alias expression + sparqlExpression::ExpressionResult expressionResult = + alias.expr_.getPimpl()->evaluate(&evaluationContext); + + // Copy the result so that future aliases may reuse it + evaluationContext._previousResultsFromSameGroup.at(alias.outCol_) = + sparqlExpression::copyExpressionResult(expressionResult); + + // Extract values + extractValues(std::move(expressionResult), evaluationContext, result, + localVocab, alias.outCol_); + } +} + // _____________________________________________________________________________ void GroupBy::createResultFromHashMap( IdTable* result, const HashMapAggregationData& aggregationData, @@ -1020,74 +1099,8 @@ void GroupBy::createResultFromHashMap( evaluationContext._endIndex = std::min(i + blockSize, numberOfGroups); for (auto& alias : aggregateAliases) { - auto& info = alias.aggregateInfo_; - - // Check if the grouped variable occurs in this expression - auto groupByVariableSubstitutions = - findGroupedVariable(alias.expr_.getPimpl()); - - if (groupByVariableSubstitutions.topLevel_) { - // If the aggregate is at the top of the alias, e.g. `SELECT (?a as ?x) - // WHERE {...} GROUP BY ?a`, we can copy values directly from the column - // of the grouped variable - decltype(auto) groupValues = result->getColumn(0); - decltype(auto) outValues = result->getColumn(alias.outCol_); - std::ranges::copy(groupValues, outValues.begin()); - - // We also need to store it for possible future use - sparqlExpression::VectorWithMemoryLimit values( - getExecutionContext()->getAllocator()); - values.resize(groupValues.size()); - std::ranges::copy(groupValues.begin(), groupValues.end(), - values.begin()); - - evaluationContext._previousResultsFromSameGroup.at(alias.outCol_) = - sparqlExpression::copyExpressionResult( - sparqlExpression::ExpressionResult{std::move(values)}); - } else if (info.size() == 1 && !info.at(0).parentAndIndex_.has_value()) { - // Only one aggregate, and it is at the top of the alias, e.g. - // `(AVG(?x) as ?y)`. The grouped by variable cannot occur inside - // an aggregate, hence we don't need to substitute anything here - auto& aggregate = info.at(0); - - // Get aggregate results - auto aggregateResults = getHashMapAggregationResults( - result, aggregationData, aggregate.aggregateDataIndex_, - evaluationContext._beginIndex, evaluationContext._endIndex); - - // Copy to result table - decltype(auto) outValues = result->getColumn(alias.outCol_); - std::ranges::copy(aggregateResults, - outValues.begin() + evaluationContext._beginIndex); - - // Copy the result so that future aliases may reuse it - evaluationContext._previousResultsFromSameGroup.at(alias.outCol_) = - sparqlExpression::copyExpressionResult( - sparqlExpression::ExpressionResult{ - std::move(aggregateResults)}); - } else { - // Substitute in the values of the grouped variable - substituteGroupVariable(groupByVariableSubstitutions.occurrences_, - result); - - // Substitute in the results of all aggregates contained in the - // expression of the current alias, if `info` is non-empty. - substituteAllAggregates(info, evaluationContext._beginIndex, - evaluationContext._endIndex, aggregationData, - result); - - // Evaluate top-level alias expression - sparqlExpression::ExpressionResult expressionResult = - alias.expr_.getPimpl()->evaluate(&evaluationContext); - - // Copy the result so that future aliases may reuse it - evaluationContext._previousResultsFromSameGroup.at(alias.outCol_) = - sparqlExpression::copyExpressionResult(expressionResult); - - // Extract values - extractValues(std::move(expressionResult), evaluationContext, result, - localVocab, alias.outCol_); - } + evaluateAlias(alias, result, evaluationContext, aggregationData, + localVocab); } } } @@ -1101,7 +1114,7 @@ concept SupportedAggregates = // Visitor function to extract values from the result of an evaluation of // the child expression of an aggregate, and subsequently processing the values // by calling the `increment` function of the corresponding aggregate. -auto makeProcessGroupsVisitor = +static constexpr auto makeProcessGroupsVisitor = [](size_t blockSize, const sparqlExpression::EvaluationContext* evaluationContext, const std::vector& hashEntries) { diff --git a/src/engine/GroupBy.h b/src/engine/GroupBy.h index 69503d0a34..1a9a4f6fcd 100644 --- a/src/engine/GroupBy.h +++ b/src/engine/GroupBy.h @@ -164,11 +164,10 @@ class GroupBy : public Operation { int64_t count_ = 0; void increment(auto&& value, const sparqlExpression::EvaluationContext* ctx) { - auto valueGetter = ValueGetter{}; - auto val = valueGetter(AD_FWD(value), ctx); + auto val = ValueGetter{}(AD_FWD(value), ctx); if (const int64_t* intval = std::get_if(&val)) - sum_ += (double)*intval; + sum_ += static_cast(*intval); else if (const double* dval = std::get_if(&val)) sum_ += *dval; else @@ -189,8 +188,7 @@ class GroupBy : public Operation { int64_t count_ = 0; void increment(auto&& value, const sparqlExpression::EvaluationContext* ctx) { - auto valueGetter = ValueGetter{}; - if (valueGetter(AD_FWD(value), ctx)) count_++; + if (ValueGetter{}(AD_FWD(value), ctx)) count_++; } [[nodiscard]] ValueId calculateResult() const { return ValueId::makeFromInt(count_); @@ -213,15 +211,6 @@ class GroupBy : public Operation { } }; - // Stores information required for substitution of grouped by variable in - // an expression tree. - struct GroupedByVariableSubstitutions { - std::vector occurrences_; - // Determines whether the grouped by variable appears at the top of an - // alias, e.g. `SELECT (?a as ?x) WHERE {...} GROUP BY ?a`. - bool topLevel_; - }; - // Used to store the kind of aggregate. enum class HashMapAggregateType { AVG, COUNT }; @@ -341,6 +330,14 @@ class GroupBy : public Operation { IdTable* resultTable, const HashMapAggregationData& aggregationData, size_t dataIndex, size_t beginIndex, size_t endIndex); + // Substitute away any occurrences of the grouped variable and of aggregate + // results, if necessary, and subsequently evaluate the expression of an + // alias + void evaluateAlias(HashMapAliasInformation& alias, IdTable* result, + sparqlExpression::EvaluationContext& evaluationContext, + const HashMapAggregationData& aggregationData, + LocalVocab* localVocab); + // Sort the HashMap by key and create result table. void createResultFromHashMap( IdTable* result, const HashMapAggregationData& aggregationData, @@ -388,15 +385,20 @@ class GroupBy : public Operation { static std::optional isSupportedAggregate( sparqlExpression::SparqlExpression* expr); + // Determines whether the grouped by variable appears at the top of an + // alias, e.g. `SELECT (?a as ?x) WHERE {...} GROUP BY ?a`. + struct OccurenceAsRoot {}; + // Find all occurrences of grouped by variable for expression `expr`. - GroupBy::GroupedByVariableSubstitutions findGroupedVariable( - sparqlExpression::SparqlExpression* expr); + std::variant, OccurenceAsRoot> + findGroupedVariable(sparqlExpression::SparqlExpression* expr); // The recursive implementation of `findGroupedVariable` (see above). void findGroupedVariableImpl( sparqlExpression::SparqlExpression* expr, std::optional parentAndChildIndex, - GroupBy::GroupedByVariableSubstitutions& substitutions); + std::variant, OccurenceAsRoot>& + substitutions); // Find all aggregates for expression `expr`. Return `std::nullopt` // if an unsupported aggregate is found. diff --git a/test/GroupByTest.cpp b/test/GroupByTest.cpp index b07750cff7..843a1f8f77 100644 --- a/test/GroupByTest.cpp +++ b/test/GroupByTest.cpp @@ -528,19 +528,22 @@ TEST_F(GroupByOptimizations, findGroupedVariable) { GroupBy groupBy{ad_utility::testing::getQec(), {Variable{"?a"}}, {}, values}; auto variableAtTop = groupBy.findGroupedVariable(expr1.get()); - ASSERT_TRUE(variableAtTop.topLevel_); - ASSERT_EQ(variableAtTop.occurrences_.size(), 0); + ASSERT_TRUE(std::get_if(&variableAtTop)); auto variableInExpression = groupBy.findGroupedVariable(expr2.get()); - ASSERT_FALSE(variableInExpression.topLevel_); - ASSERT_EQ(variableInExpression.occurrences_.size(), 1); - auto parentAndChildIndex = variableInExpression.occurrences_.at(0); + auto variableInExpressionOccurrences = + std::get_if>( + &variableInExpression); + ASSERT_TRUE(variableInExpressionOccurrences); + ASSERT_EQ(variableInExpressionOccurrences->size(), 1); + auto parentAndChildIndex = variableInExpressionOccurrences->at(0); ASSERT_EQ(parentAndChildIndex.nThChild_, 0); ASSERT_EQ(parentAndChildIndex.parent_, expr2.get()); auto variableNotFound = groupBy.findGroupedVariable(expr3.get()); - ASSERT_FALSE(variableNotFound.topLevel_); - ASSERT_EQ(variableNotFound.occurrences_.size(), 0); + auto variableNotFoundOccurrences = + std::get_if>(&variableNotFound); + ASSERT_EQ(variableNotFoundOccurrences->size(), 0); } // _____________________________________________________________________________