Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for COUNT and fix a bug in the GROUP BY HashMap optimization #1222

Merged
merged 6 commits into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
230 changes: 165 additions & 65 deletions src/engine/GroupBy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -383,8 +383,8 @@ ResultTable GroupBy::computeResult() {
if (hashMapOptimizationParams.has_value()) {
computeGroupByForHashMapOptimization(
&idTable, hashMapOptimizationParams->aggregateAliases_,
hashMapOptimizationParams->numAggregates_, subresult->idTable(),
hashMapOptimizationParams->subtreeColumnIndex_, &localVocab);
subresult->idTable(), hashMapOptimizationParams->subtreeColumnIndex_,
&localVocab);

return {std::move(idTable), resultSortedOn(), std::move(localVocab)};
}
Expand Down Expand Up @@ -733,7 +733,7 @@ GroupBy::checkIfHashMapOptimizationPossible(std::vector<Aggregate>& aliases) {
if (!foundAggregates.has_value()) return std::nullopt;

for (auto& aggregate : foundAggregates.value()) {
aggregate.aggregateDataIndex = numAggregates++;
aggregate.aggregateDataIndex_ = numAggregates++;
}

aliasesWithAggregateInfo.emplace_back(alias._expression, alias._outCol,
Expand All @@ -744,8 +744,46 @@ GroupBy::checkIfHashMapOptimizationPossible(std::vector<Aggregate>& aliases) {
auto child = _subtree->getRootOperation()->getChildren().at(0);
auto columnIndex = child->getVariableColumn(groupByVariable);

return HashMapOptimizationData{columnIndex, aliasesWithAggregateInfo,
numAggregates};
return HashMapOptimizationData{columnIndex, aliasesWithAggregateInfo};
}

// _____________________________________________________________________________
GroupBy::GroupedByVariableSubstitutions GroupBy::findGroupedVariable(
sparqlExpression::SparqlExpression* expr) {
AD_CONTRACT_CHECK(_groupByVariables.size() == 1);
GroupBy::GroupedByVariableSubstitutions substitutions{{}, false};
findGroupedVariableImpl(expr, std::nullopt, substitutions);
return substitutions;
}

// _____________________________________________________________________________
void GroupBy::findGroupedVariableImpl(
sparqlExpression::SparqlExpression* expr,
std::optional<ParentAndChildIndex> parentAndChildIndex,
GroupBy::GroupedByVariableSubstitutions& substitutions) {
if (auto value = hasType<sparqlExpression::VariableExpression>(expr)) {
auto variable = value.value()->value();
for (const auto& groupedVariable : _groupByVariables) {
if (variable != groupedVariable) continue;

if (parentAndChildIndex.has_value()) {
substitutions.occurrences_.emplace_back(parentAndChildIndex.value());
} else {
substitutions.topLevel_ = true;
return;
}
}
}

auto children = expr->children();

// TODO<C++23> use views::enumerate
size_t childIndex = 0;
for (const auto& child : children) {
ParentAndChildIndex parentAndChildIndexForChild{expr, childIndex++};
findGroupedVariableImpl(child.get(), parentAndChildIndexForChild,
substitutions);
}
}

// _____________________________________________________________________________
Expand Down Expand Up @@ -775,41 +813,41 @@ bool GroupBy::hasAnyType(const auto& expr) {
}

// _____________________________________________________________________________
bool GroupBy::isUnsupportedAggregate(sparqlExpression::SparqlExpression* expr) {
std::optional<GroupBy::HashMapAggregateType> GroupBy::isSupportedAggregate(
sparqlExpression::SparqlExpression* expr) {
using namespace sparqlExpression;

// `expr` is not an aggregate, so it cannot be an unsupported aggregate
if (!expr->isAggregate()) return false;
// `expr` is not a distinct aggregate
if (expr->isDistinct()) return std::nullopt;

// `expr` is an unsupported aggregate
if (hasAnyType<SumExpression, MinExpression, MaxExpression, CountExpression,
GroupConcatExpression>(expr))
return true;
// `expr` is not a nested aggregated
if (expr->children().front()->containsAggregate()) return std::nullopt;

if (hasType<AvgExpression>(expr)) return HashMapAggregateType::AVG;
if (hasType<CountExpression>(expr)) return HashMapAggregateType::COUNT;

// `expr` is a distinct aggregate
return expr->isDistinct();
// `expr` is an unsupported aggregate
return std::nullopt;
}

// _____________________________________________________________________________
bool GroupBy::findAggregatesImpl(
sparqlExpression::SparqlExpression* expr,
std::optional<ParentAndChildIndex> parentAndChildIndex,
std::vector<GroupBy::HashMapAggregateInformation>& info) {
// Unsupported aggregates
if (isUnsupportedAggregate(expr)) return false;

if (expr->isAggregate()) {
info.emplace_back(expr, 0, parentAndChildIndex);

// Make sure this is not a nested aggregate.
if (expr->children().front()->containsAggregate()) return false;

return true;
if (auto aggregateType = isSupportedAggregate(expr)) {
info.emplace_back(expr, 0, aggregateType.value(), parentAndChildIndex);
return true;
} else {
return false;
}
}

auto children = expr->children();

bool childrenContainOnlySupportedAggregates = true;
// TODO<C++23> use views::enumerate
size_t childIndex = 0;
for (const auto& child : children) {
ParentAndChildIndex parentAndChildIndexForChild{expr, childIndex++};
Expand Down Expand Up @@ -856,13 +894,17 @@ GroupBy::getHashMapAggregationResults(
aggregateResults.resize(endIndex - beginIndex);

decltype(auto) groupValues = resultTable->getColumn(0);
auto& aggregateDataVector =
aggregationData.getAggregationDataVector(dataIndex);
auto& aggregateDataVariant =
aggregationData.getAggregationDataVariant(dataIndex);

auto op = [&aggregationData, &aggregateDataVector](Id val) {
auto op = [&aggregationData, &aggregateDataVariant](Id val) {
auto index = aggregationData.getIndex(val);
auto& aggregateData = aggregateDataVector.at(index);
return aggregateData.calculateResult();

auto visitor = [&index](auto& aggregateDataVariant) {
return aggregateDataVariant.at(index).calculateResult();
};

return std::visit(visitor, aggregateDataVariant);
};

std::ranges::transform(groupValues.begin() + beginIndex,
Expand All @@ -872,6 +914,26 @@ GroupBy::getHashMapAggregationResults(
return aggregateResults;
}

// _____________________________________________________________________________
void GroupBy::substituteGroupVariable(
const std::vector<GroupBy::ParentAndChildIndex>& occurrences,
IdTable* resultTable) const {
decltype(auto) groupValues = resultTable->getColumn(0);

for (const auto& occurrence : occurrences) {
sparqlExpression::VectorWithMemoryLimit<ValueId> values(
getExecutionContext()->getAllocator());
values.resize(groupValues.size());
std::ranges::copy(groupValues.begin(), groupValues.end(), values.begin());

auto newExpression = std::make_unique<sparqlExpression::VectorIdExpression>(
std::move(values));

occurrence.parent_->replaceChild(occurrence.nThChild_,
std::move(newExpression));
}
}

// _____________________________________________________________________________
void GroupBy::substituteAllAggregates(
std::vector<HashMapAggregateInformation>& info, size_t beginIndex,
Expand All @@ -880,7 +942,7 @@ void GroupBy::substituteAllAggregates(
// Substitute in the results of all aggregates of `info`.
for (auto& aggregate : info) {
auto aggregateResults = getHashMapAggregationResults(
resultTable, aggregationData, aggregate.aggregateDataIndex, beginIndex,
resultTable, aggregationData, aggregate.aggregateDataIndex_, beginIndex,
endIndex);

// Substitute the resulting vector as a literal
Expand All @@ -907,7 +969,8 @@ std::vector<size_t> GroupBy::HashMapAggregationData::getHashEntries(
}

for (auto& aggregation : aggregationData_)
aggregation.resize(getNumberOfGroups());
std::visit([this](auto& arg) { arg.resize(getNumberOfGroups()); },
aggregation);

return hashEntries;
}
Expand Down Expand Up @@ -959,13 +1022,37 @@ void GroupBy::createResultFromHashMap(
for (auto& alias : aggregateAliases) {
auto& info = alias.aggregateInfo_;

// Only one aggregate, and it is at the top of the alias.
if (info.size() == 1 && !info.at(0).parentAndIndex_.has_value()) {
// Check if the grouped variable occurs in this expression
auto groupByVariableSubstitutions =
findGroupedVariable(alias.expr_.getPimpl());

if (groupByVariableSubstitutions.topLevel_) {
// If the aggregate is at the top of the alias, e.g. `SELECT (?a as ?x)
// WHERE {...} GROUP BY ?a`, we can copy values directly from the column
// of the grouped variable
decltype(auto) groupValues = result->getColumn(0);
decltype(auto) outValues = result->getColumn(alias.outCol_);
std::ranges::copy(groupValues, outValues.begin());

// We also need to store it for possible future use
sparqlExpression::VectorWithMemoryLimit<ValueId> values(
getExecutionContext()->getAllocator());
values.resize(groupValues.size());
std::ranges::copy(groupValues.begin(), groupValues.end(),
values.begin());

evaluationContext._previousResultsFromSameGroup.at(alias.outCol_) =
sparqlExpression::copyExpressionResult(
sparqlExpression::ExpressionResult{std::move(values)});
} else if (info.size() == 1 && !info.at(0).parentAndIndex_.has_value()) {
// Only one aggregate, and it is at the top of the alias, e.g.
// `(AVG(?x) as ?y)`. The grouped by variable cannot occur inside
// an aggregate, hence we don't need to substitute anything here
auto& aggregate = info.at(0);

// Get aggregate results
auto aggregateResults = getHashMapAggregationResults(
result, aggregationData, aggregate.aggregateDataIndex,
result, aggregationData, aggregate.aggregateDataIndex_,
evaluationContext._beginIndex, evaluationContext._endIndex);

// Copy to result table
Expand All @@ -979,6 +1066,10 @@ void GroupBy::createResultFromHashMap(
sparqlExpression::ExpressionResult{
std::move(aggregateResults)});
} else {
// Substitute in the values of the grouped variable
substituteGroupVariable(groupByVariableSubstitutions.occurrences_,
result);

// Substitute in the results of all aggregates contained in the
// expression of the current alias, if `info` is non-empty.
substituteAllAggregates(info, evaluationContext._beginIndex,
Expand All @@ -1001,14 +1092,46 @@ void GroupBy::createResultFromHashMap(
}
}

// _____________________________________________________________________________
template <typename A>
concept SupportedAggregates =
ad_utility::isTypeContainedIn<A, GroupBy::Aggregations>;

// _____________________________________________________________________________
// Visitor function to extract values from the result of an evaluation of
// the child expression of an aggregate, and subsequently processing the values
// by calling the `increment` function of the corresponding aggregate.
auto makeProcessGroupsVisitor =
[](size_t blockSize,
const sparqlExpression::EvaluationContext* evaluationContext,
const std::vector<size_t>& hashEntries) {
return [blockSize, evaluationContext,
&hashEntries]<sparqlExpression::SingleExpressionResult T,
SupportedAggregates A>(T&& singleResult,
A& aggregationDataVector) {
auto generator = sparqlExpression::detail::makeGenerator(
std::forward<T>(singleResult), blockSize, evaluationContext);

auto hashEntryIndex = 0;

for (const auto& val : generator) {
auto vectorOffset = hashEntries[hashEntryIndex];
auto& aggregateData = aggregationDataVector.at(vectorOffset);

aggregateData.increment(val, evaluationContext);

++hashEntryIndex;
}
};
};

// _____________________________________________________________________________
void GroupBy::computeGroupByForHashMapOptimization(
IdTable* result, std::vector<HashMapAliasInformation>& aggregateAliases,
size_t numAggregates, const IdTable& subresult, size_t columnIndex,
LocalVocab* localVocab) {
const IdTable& subresult, size_t columnIndex, LocalVocab* localVocab) {
// Initialize aggregation data
HashMapAggregationData aggregationData(getExecutionContext()->getAllocator(),
numAggregates);
aggregateAliases);

// Initialize evaluation context
sparqlExpression::EvaluationContext evaluationContext(
Expand All @@ -1034,43 +1157,20 @@ void GroupBy::computeGroupByForHashMapOptimization(
.subspan(evaluationContext._beginIndex, currentBlockSize);
auto hashEntries = aggregationData.getHashEntries(groupValues);

// TODO<C++23> use views::enumerate
for (auto& aggregateAlias : aggregateAliases) {
for (auto& aggregate : aggregateAlias.aggregateInfo_) {
// Evaluate child expression on block
auto exprChildren = aggregate.expr_->children();
sparqlExpression::ExpressionResult expressionResult =
exprChildren[0]->evaluate(&evaluationContext);

auto& aggregationDataVector = aggregationData.getAggregationDataVector(
aggregate.aggregateDataIndex);

auto visitor = [&currentBlockSize, &evaluationContext, &hashEntries,
&aggregationDataVector]<
sparqlExpression::SingleExpressionResult T>(
T&& singleResult) mutable {
auto generator = sparqlExpression::detail::makeGenerator(
std::forward<T>(singleResult), currentBlockSize,
&evaluationContext);

using NVG = sparqlExpression::detail::NumericValueGetter;

auto hashEntryIndex = 0;

for (const auto& val : generator) {
sparqlExpression::detail::NumericValue numVal =
NVG()(val, &evaluationContext);

auto vectorOffset = hashEntries[hashEntryIndex];
auto& aggregateData = aggregationDataVector.at(vectorOffset);

aggregateData.increment(numVal);

++hashEntryIndex;
}
};
auto& aggregationDataVariant =
aggregationData.getAggregationDataVariant(
aggregate.aggregateDataIndex_);

std::visit(visitor, std::move(expressionResult));
std::visit(makeProcessGroupsVisitor(currentBlockSize,
&evaluationContext, hashEntries),
std::move(expressionResult), aggregationDataVariant);
}
}
}
Expand Down
Loading