Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for COUNT and fix a bug in the GROUP BY HashMap optimization #1222

Merged
merged 6 commits into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
240 changes: 167 additions & 73 deletions src/engine/GroupBy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -383,8 +383,8 @@ ResultTable GroupBy::computeResult() {
if (hashMapOptimizationParams.has_value()) {
computeGroupByForHashMapOptimization(
&idTable, hashMapOptimizationParams->aggregateAliases_,
hashMapOptimizationParams->numAggregates_, subresult->idTable(),
hashMapOptimizationParams->subtreeColumnIndex_, &localVocab);
subresult->idTable(), hashMapOptimizationParams->subtreeColumnIndex_,
&localVocab);

return {std::move(idTable), resultSortedOn(), std::move(localVocab)};
}
Expand Down Expand Up @@ -733,7 +733,7 @@ GroupBy::checkIfHashMapOptimizationPossible(std::vector<Aggregate>& aliases) {
if (!foundAggregates.has_value()) return std::nullopt;

for (auto& aggregate : foundAggregates.value()) {
aggregate.aggregateDataIndex = numAggregates++;
aggregate.aggregateDataIndex_ = numAggregates++;
}

aliasesWithAggregateInfo.emplace_back(alias._expression, alias._outCol,
Expand All @@ -744,8 +744,45 @@ GroupBy::checkIfHashMapOptimizationPossible(std::vector<Aggregate>& aliases) {
auto child = _subtree->getRootOperation()->getChildren().at(0);
auto columnIndex = child->getVariableColumn(groupByVariable);

return HashMapOptimizationData{columnIndex, aliasesWithAggregateInfo,
numAggregates};
return HashMapOptimizationData{columnIndex, aliasesWithAggregateInfo};
}

// _____________________________________________________________________________
GroupBy::GroupedByVariableSubstitutions GroupBy::findGroupedVariable(
sparqlExpression::SparqlExpression* expr) {
AD_CONTRACT_CHECK(_groupByVariables.size() == 1);
GroupBy::GroupedByVariableSubstitutions substitutions{{}, false};
findGroupedVariableImpl(expr, std::nullopt, substitutions);
return substitutions;
}

// _____________________________________________________________________________
void GroupBy::findGroupedVariableImpl(
sparqlExpression::SparqlExpression* expr,
std::optional<ParentAndChildIndex> parentAndChildIndex,
GroupBy::GroupedByVariableSubstitutions& substitutions) {
if (auto value = hasType<sparqlExpression::VariableExpression>(expr)) {
for (auto& groupedVariable : _groupByVariables) {
if (value.value()->value() == groupedVariable) {
if (parentAndChildIndex.has_value()) {
substitutions.occurrences.emplace_back(parentAndChildIndex.value());
} else {
substitutions.topLevel = true;
return;
}
}
}
}

auto children = expr->children();

// TODO<C++23> use views::enumerate
size_t childIndex = 0;
for (const auto& child : children) {
ParentAndChildIndex parentAndChildIndexForChild{expr, childIndex++};
findGroupedVariableImpl(child.get(), parentAndChildIndexForChild,
substitutions);
}
}

// _____________________________________________________________________________
Expand Down Expand Up @@ -775,41 +812,41 @@ bool GroupBy::hasAnyType(const auto& expr) {
}

// _____________________________________________________________________________
bool GroupBy::isUnsupportedAggregate(sparqlExpression::SparqlExpression* expr) {
std::optional<GroupBy::HashMapAggregateKind> GroupBy::isSupportedAggregate(
sparqlExpression::SparqlExpression* expr) {
using namespace sparqlExpression;

// `expr` is not an aggregate, so it cannot be an unsupported aggregate
if (!expr->isAggregate()) return false;
// `expr` is not a distinct aggregate
if (expr->isDistinct()) return std::nullopt;

// `expr` is an unsupported aggregate
if (hasAnyType<SumExpression, MinExpression, MaxExpression, CountExpression,
GroupConcatExpression>(expr))
return true;
// `expr` is not a nested aggregated
if (expr->children().front()->containsAggregate()) return std::nullopt;

if (hasType<AvgExpression>(expr)) return HashMapAggregateKind::AVG;
if (hasType<CountExpression>(expr)) return HashMapAggregateKind::COUNT;

// `expr` is a distinct aggregate
return expr->isDistinct();
// `expr` is an unsupported aggregate
return std::nullopt;
}

// _____________________________________________________________________________
bool GroupBy::findAggregatesImpl(
sparqlExpression::SparqlExpression* expr,
std::optional<ParentAndChildIndex> parentAndChildIndex,
std::vector<GroupBy::HashMapAggregateInformation>& info) {
// Unsupported aggregates
if (isUnsupportedAggregate(expr)) return false;

if (expr->isAggregate()) {
info.emplace_back(expr, 0, parentAndChildIndex);

// Make sure this is not a nested aggregate.
if (expr->children().front()->containsAggregate()) return false;

return true;
if (auto aggregateKind = isSupportedAggregate(expr)) {
info.emplace_back(expr, 0, aggregateKind.value(), parentAndChildIndex);
return true;
} else {
return false;
}
}

auto children = expr->children();

bool childrenContainOnlySupportedAggregates = true;
// TODO<C++23> use views::enumerate
size_t childIndex = 0;
for (const auto& child : children) {
ParentAndChildIndex parentAndChildIndexForChild{expr, childIndex++};
Expand Down Expand Up @@ -856,13 +893,17 @@ GroupBy::getHashMapAggregationResults(
aggregateResults.resize(endIndex - beginIndex);

decltype(auto) groupValues = resultTable->getColumn(0);
auto& aggregateDataVector =
aggregationData.getAggregationDataVector(dataIndex);
auto& aggregateDataVariant =
aggregationData.getAggregationDataVariant(dataIndex);

auto op = [&aggregationData, &aggregateDataVector](Id val) {
auto op = [&aggregationData, &aggregateDataVariant](Id val) {
auto index = aggregationData.getIndex(val);
auto& aggregateData = aggregateDataVector.at(index);
return aggregateData.calculateResult();

auto visitor = [&index](auto& aggregateDataVariant) {
return aggregateDataVariant.at(index).calculateResult();
};

return std::visit(visitor, aggregateDataVariant);
};

std::ranges::transform(groupValues.begin() + beginIndex,
Expand All @@ -872,6 +913,26 @@ GroupBy::getHashMapAggregationResults(
return aggregateResults;
}

// _____________________________________________________________________________
void GroupBy::substituteGroupVariable(
const std::vector<GroupBy::ParentAndChildIndex>& occurrences,
IdTable* resultTable) {
decltype(auto) groupValues = resultTable->getColumn(0);

for (const auto& occurrence : occurrences) {
sparqlExpression::VectorWithMemoryLimit<ValueId> values(
getExecutionContext()->getAllocator());
values.resize(groupValues.size());
std::copy(groupValues.begin(), groupValues.end(), values.begin());

auto newExpression = std::make_unique<sparqlExpression::VectorIdExpression>(
std::move(values));

occurrence.parent_->replaceChild(occurrence.nThChild_,
std::move(newExpression));
}
}

// _____________________________________________________________________________
void GroupBy::substituteAllAggregates(
std::vector<HashMapAggregateInformation>& info, size_t beginIndex,
Expand All @@ -880,7 +941,7 @@ void GroupBy::substituteAllAggregates(
// Substitute in the results of all aggregates of `info`.
for (auto& aggregate : info) {
auto aggregateResults = getHashMapAggregationResults(
resultTable, aggregationData, aggregate.aggregateDataIndex, beginIndex,
resultTable, aggregationData, aggregate.aggregateDataIndex_, beginIndex,
endIndex);

// Substitute the resulting vector as a literal
Expand All @@ -907,7 +968,8 @@ std::vector<size_t> GroupBy::HashMapAggregationData::getHashEntries(
}

for (auto& aggregation : aggregationData_)
aggregation.resize(getNumberOfGroups());
std::visit([this](auto& arg) { arg.resize(getNumberOfGroups()); },
aggregation);

return hashEntries;
}
Expand Down Expand Up @@ -965,7 +1027,7 @@ void GroupBy::createResultFromHashMap(

// Get aggregate results
auto aggregateResults = getHashMapAggregationResults(
result, aggregationData, aggregate.aggregateDataIndex,
result, aggregationData, aggregate.aggregateDataIndex_,
evaluationContext._beginIndex, evaluationContext._endIndex);

// Copy to result table
Expand All @@ -979,36 +1041,66 @@ void GroupBy::createResultFromHashMap(
sparqlExpression::ExpressionResult{
std::move(aggregateResults)});
} else {
// Substitute in the results of all aggregates contained in the
// expression of the current alias, if `info` is non-empty.
substituteAllAggregates(info, evaluationContext._beginIndex,
evaluationContext._endIndex, aggregationData,
result);

// Evaluate top-level alias expression
sparqlExpression::ExpressionResult expressionResult =
alias.expr_.getPimpl()->evaluate(&evaluationContext);

// Copy the result so that future aliases may reuse it
evaluationContext._previousResultsFromSameGroup.at(alias.outCol_) =
sparqlExpression::copyExpressionResult(expressionResult);

// Extract values
extractValues(std::move(expressionResult), evaluationContext, result,
localVocab, alias.outCol_);
// Check if the grouped variable occurs in this expression
auto groupByVariableSubstitutions =
findGroupedVariable(alias.expr_.getPimpl());

if (groupByVariableSubstitutions.topLevel) {
// If it is at the top, we can directly copy the values of the grouped
// by column
decltype(auto) groupValues = result->getColumn(0);
decltype(auto) outValues = result->getColumn(alias.outCol_);
std::ranges::copy(groupValues, outValues.begin());

// We also need to store it for possible future use
sparqlExpression::VectorWithMemoryLimit<ValueId> values(
getExecutionContext()->getAllocator());
values.resize(groupValues.size());
std::copy(groupValues.begin(), groupValues.end(), values.begin());

evaluationContext._previousResultsFromSameGroup.at(alias.outCol_) =
sparqlExpression::copyExpressionResult(
sparqlExpression::ExpressionResult{std::move(values)});
} else {
// Substitute in the values of the grouped variable
substituteGroupVariable(groupByVariableSubstitutions.occurrences,
result);

// Substitute in the results of all aggregates contained in the
// expression of the current alias, if `info` is non-empty.
substituteAllAggregates(info, evaluationContext._beginIndex,
evaluationContext._endIndex, aggregationData,
result);

// Evaluate top-level alias expression
sparqlExpression::ExpressionResult expressionResult =
alias.expr_.getPimpl()->evaluate(&evaluationContext);

// Copy the result so that future aliases may reuse it
evaluationContext._previousResultsFromSameGroup.at(alias.outCol_) =
sparqlExpression::copyExpressionResult(expressionResult);

// Extract values
extractValues(std::move(expressionResult), evaluationContext, result,
localVocab, alias.outCol_);
}
}
}
}
}

// _____________________________________________________________________________
template <typename A>
concept SupportedAggregates =
ad_utility::isTypeContainedIn<A, GroupBy::Aggregations>;

// _____________________________________________________________________________
void GroupBy::computeGroupByForHashMapOptimization(
IdTable* result, std::vector<HashMapAliasInformation>& aggregateAliases,
size_t numAggregates, const IdTable& subresult, size_t columnIndex,
LocalVocab* localVocab) {
const IdTable& subresult, size_t columnIndex, LocalVocab* localVocab) {
// Initialize aggregation data
HashMapAggregationData aggregationData(getExecutionContext()->getAllocator(),
numAggregates);
aggregateAliases);

// Initialize evaluation context
sparqlExpression::EvaluationContext evaluationContext(
Expand All @@ -1034,43 +1126,45 @@ void GroupBy::computeGroupByForHashMapOptimization(
.subspan(evaluationContext._beginIndex, currentBlockSize);
auto hashEntries = aggregationData.getHashEntries(groupValues);

// TODO<C++23> use views::enumerate
for (auto& aggregateAlias : aggregateAliases) {
for (auto& aggregate : aggregateAlias.aggregateInfo_) {
// Evaluate child expression on block
auto exprChildren = aggregate.expr_->children();
sparqlExpression::ExpressionResult expressionResult =
exprChildren[0]->evaluate(&evaluationContext);

auto& aggregationDataVector = aggregationData.getAggregationDataVector(
aggregate.aggregateDataIndex);
auto& aggregationDataVariant =
aggregationData.getAggregationDataVariant(
aggregate.aggregateDataIndex_);

auto visitor = [&currentBlockSize, &evaluationContext, &hashEntries,
&aggregationDataVector]<
sparqlExpression::SingleExpressionResult T>(
T&& singleResult) mutable {
auto generator = sparqlExpression::detail::makeGenerator(
std::forward<T>(singleResult), currentBlockSize,
&evaluationContext);
auto visitor =
[&currentBlockSize, &evaluationContext,
&hashEntries]<sparqlExpression::SingleExpressionResult T,
SupportedAggregates A>(
T&& singleResult, A& aggregationDataVector) mutable {
auto generator = sparqlExpression::detail::makeGenerator(
std::forward<T>(singleResult), currentBlockSize,
&evaluationContext);

using NVG = sparqlExpression::detail::NumericValueGetter;
using NVG = sparqlExpression::detail::NumericValueGetter;

auto hashEntryIndex = 0;
auto hashEntryIndex = 0;

for (const auto& val : generator) {
sparqlExpression::detail::NumericValue numVal =
NVG()(val, &evaluationContext);
for (const auto& val : generator) {
sparqlExpression::detail::NumericValue numVal =
NVG()(val, &evaluationContext);

auto vectorOffset = hashEntries[hashEntryIndex];
auto& aggregateData = aggregationDataVector.at(vectorOffset);
auto vectorOffset = hashEntries[hashEntryIndex];
auto& aggregateData = aggregationDataVector.at(vectorOffset);

aggregateData.increment(numVal);
aggregateData.increment(numVal);

++hashEntryIndex;
}
};
++hashEntryIndex;
}
};

std::visit(visitor, std::move(expressionResult));
std::visit(visitor, std::move(expressionResult),
aggregationDataVariant);
}
}
}
Expand Down
Loading