Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for COUNT and fix a bug in the GROUP BY HashMap optimization #1222

Merged
merged 6 commits into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 94 additions & 88 deletions src/engine/GroupBy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -762,14 +762,15 @@ void GroupBy::findGroupedVariableImpl(
std::optional<ParentAndChildIndex> parentAndChildIndex,
GroupBy::GroupedByVariableSubstitutions& substitutions) {
if (auto value = hasType<sparqlExpression::VariableExpression>(expr)) {
for (auto& groupedVariable : _groupByVariables) {
if (value.value()->value() == groupedVariable) {
if (parentAndChildIndex.has_value()) {
substitutions.occurrences.emplace_back(parentAndChildIndex.value());
} else {
substitutions.topLevel = true;
return;
}
auto variable = value.value()->value();
for (const auto& groupedVariable : _groupByVariables) {
if (variable != groupedVariable) continue;

if (parentAndChildIndex.has_value()) {
substitutions.occurrences_.emplace_back(parentAndChildIndex.value());
} else {
substitutions.topLevel_ = true;
return;
}
}
}
Expand Down Expand Up @@ -812,7 +813,7 @@ bool GroupBy::hasAnyType(const auto& expr) {
}

// _____________________________________________________________________________
std::optional<GroupBy::HashMapAggregateKind> GroupBy::isSupportedAggregate(
std::optional<GroupBy::HashMapAggregateType> GroupBy::isSupportedAggregate(
sparqlExpression::SparqlExpression* expr) {
using namespace sparqlExpression;

Expand All @@ -822,8 +823,8 @@ std::optional<GroupBy::HashMapAggregateKind> GroupBy::isSupportedAggregate(
// `expr` is not a nested aggregated
if (expr->children().front()->containsAggregate()) return std::nullopt;

if (hasType<AvgExpression>(expr)) return HashMapAggregateKind::AVG;
if (hasType<CountExpression>(expr)) return HashMapAggregateKind::COUNT;
if (hasType<AvgExpression>(expr)) return HashMapAggregateType::AVG;
if (hasType<CountExpression>(expr)) return HashMapAggregateType::COUNT;

// `expr` is an unsupported aggregate
return std::nullopt;
Expand All @@ -835,8 +836,8 @@ bool GroupBy::findAggregatesImpl(
std::optional<ParentAndChildIndex> parentAndChildIndex,
std::vector<GroupBy::HashMapAggregateInformation>& info) {
if (expr->isAggregate()) {
if (auto aggregateKind = isSupportedAggregate(expr)) {
info.emplace_back(expr, 0, aggregateKind.value(), parentAndChildIndex);
if (auto aggregateType = isSupportedAggregate(expr)) {
info.emplace_back(expr, 0, aggregateType.value(), parentAndChildIndex);
return true;
} else {
return false;
Expand Down Expand Up @@ -916,14 +917,14 @@ GroupBy::getHashMapAggregationResults(
// _____________________________________________________________________________
void GroupBy::substituteGroupVariable(
const std::vector<GroupBy::ParentAndChildIndex>& occurrences,
IdTable* resultTable) {
IdTable* resultTable) const {
decltype(auto) groupValues = resultTable->getColumn(0);

for (const auto& occurrence : occurrences) {
sparqlExpression::VectorWithMemoryLimit<ValueId> values(
getExecutionContext()->getAllocator());
values.resize(groupValues.size());
std::copy(groupValues.begin(), groupValues.end(), values.begin());
std::ranges::copy(groupValues.begin(), groupValues.end(), values.begin());

auto newExpression = std::make_unique<sparqlExpression::VectorIdExpression>(
std::move(values));
Expand Down Expand Up @@ -1021,8 +1022,32 @@ void GroupBy::createResultFromHashMap(
for (auto& alias : aggregateAliases) {
auto& info = alias.aggregateInfo_;

// Only one aggregate, and it is at the top of the alias.
if (info.size() == 1 && !info.at(0).parentAndIndex_.has_value()) {
// Check if the grouped variable occurs in this expression
auto groupByVariableSubstitutions =
findGroupedVariable(alias.expr_.getPimpl());

if (groupByVariableSubstitutions.topLevel_) {
// If the aggregate is at the top of the alias, e.g. `SELECT (?a as ?x)
// WHERE {...} GROUP BY ?a`, we can copy values directly from the column
// of the grouped variable
decltype(auto) groupValues = result->getColumn(0);
decltype(auto) outValues = result->getColumn(alias.outCol_);
std::ranges::copy(groupValues, outValues.begin());

// We also need to store it for possible future use
sparqlExpression::VectorWithMemoryLimit<ValueId> values(
getExecutionContext()->getAllocator());
values.resize(groupValues.size());
std::ranges::copy(groupValues.begin(), groupValues.end(),
values.begin());

evaluationContext._previousResultsFromSameGroup.at(alias.outCol_) =
sparqlExpression::copyExpressionResult(
sparqlExpression::ExpressionResult{std::move(values)});
} else if (info.size() == 1 && !info.at(0).parentAndIndex_.has_value()) {
// Only one aggregate, and it is at the top of the alias, e.g.
// `(AVG(?x) as ?y)`. The grouped by variable cannot occur inside
// an aggregate, hence we don't need to substitute anything here
auto& aggregate = info.at(0);

// Get aggregate results
Expand All @@ -1041,49 +1066,27 @@ void GroupBy::createResultFromHashMap(
sparqlExpression::ExpressionResult{
std::move(aggregateResults)});
} else {
// Check if the grouped variable occurs in this expression
auto groupByVariableSubstitutions =
findGroupedVariable(alias.expr_.getPimpl());

if (groupByVariableSubstitutions.topLevel) {
// If it is at the top, we can directly copy the values of the grouped
// by column
decltype(auto) groupValues = result->getColumn(0);
decltype(auto) outValues = result->getColumn(alias.outCol_);
std::ranges::copy(groupValues, outValues.begin());

// We also need to store it for possible future use
sparqlExpression::VectorWithMemoryLimit<ValueId> values(
getExecutionContext()->getAllocator());
values.resize(groupValues.size());
std::copy(groupValues.begin(), groupValues.end(), values.begin());

evaluationContext._previousResultsFromSameGroup.at(alias.outCol_) =
sparqlExpression::copyExpressionResult(
sparqlExpression::ExpressionResult{std::move(values)});
} else {
// Substitute in the values of the grouped variable
substituteGroupVariable(groupByVariableSubstitutions.occurrences,
result);

// Substitute in the results of all aggregates contained in the
// expression of the current alias, if `info` is non-empty.
substituteAllAggregates(info, evaluationContext._beginIndex,
evaluationContext._endIndex, aggregationData,
result);

// Evaluate top-level alias expression
sparqlExpression::ExpressionResult expressionResult =
alias.expr_.getPimpl()->evaluate(&evaluationContext);

// Copy the result so that future aliases may reuse it
evaluationContext._previousResultsFromSameGroup.at(alias.outCol_) =
sparqlExpression::copyExpressionResult(expressionResult);

// Extract values
extractValues(std::move(expressionResult), evaluationContext, result,
localVocab, alias.outCol_);
}
// Substitute in the values of the grouped variable
substituteGroupVariable(groupByVariableSubstitutions.occurrences_,
result);

// Substitute in the results of all aggregates contained in the
// expression of the current alias, if `info` is non-empty.
substituteAllAggregates(info, evaluationContext._beginIndex,
evaluationContext._endIndex, aggregationData,
result);

// Evaluate top-level alias expression
sparqlExpression::ExpressionResult expressionResult =
alias.expr_.getPimpl()->evaluate(&evaluationContext);

// Copy the result so that future aliases may reuse it
evaluationContext._previousResultsFromSameGroup.at(alias.outCol_) =
sparqlExpression::copyExpressionResult(expressionResult);

// Extract values
extractValues(std::move(expressionResult), evaluationContext, result,
localVocab, alias.outCol_);
}
}
}
Expand All @@ -1094,6 +1097,34 @@ template <typename A>
concept SupportedAggregates =
ad_utility::isTypeContainedIn<A, GroupBy::Aggregations>;

// _____________________________________________________________________________
// Visitor function to extract values from the result of an evaluation of
// the child expression of an aggregate, and subsequently processing the values
// by calling the `increment` function of the corresponding aggregate.
auto makeProcessGroupsVisitor =
[](size_t blockSize,
const sparqlExpression::EvaluationContext* evaluationContext,
const std::vector<size_t>& hashEntries) {
return [blockSize, evaluationContext,
&hashEntries]<sparqlExpression::SingleExpressionResult T,
SupportedAggregates A>(T&& singleResult,
A& aggregationDataVector) {
auto generator = sparqlExpression::detail::makeGenerator(
std::forward<T>(singleResult), blockSize, evaluationContext);

auto hashEntryIndex = 0;

for (const auto& val : generator) {
auto vectorOffset = hashEntries[hashEntryIndex];
auto& aggregateData = aggregationDataVector.at(vectorOffset);

aggregateData.increment(val, evaluationContext);

++hashEntryIndex;
}
};
};

// _____________________________________________________________________________
void GroupBy::computeGroupByForHashMapOptimization(
IdTable* result, std::vector<HashMapAliasInformation>& aggregateAliases,
Expand Down Expand Up @@ -1137,34 +1168,9 @@ void GroupBy::computeGroupByForHashMapOptimization(
aggregationData.getAggregationDataVariant(
aggregate.aggregateDataIndex_);

auto visitor =
[&currentBlockSize, &evaluationContext,
&hashEntries]<sparqlExpression::SingleExpressionResult T,
SupportedAggregates A>(
T&& singleResult, A& aggregationDataVector) mutable {
auto generator = sparqlExpression::detail::makeGenerator(
std::forward<T>(singleResult), currentBlockSize,
&evaluationContext);

using NVG = sparqlExpression::detail::NumericValueGetter;

auto hashEntryIndex = 0;

for (const auto& val : generator) {
sparqlExpression::detail::NumericValue numVal =
NVG()(val, &evaluationContext);

auto vectorOffset = hashEntries[hashEntryIndex];
auto& aggregateData = aggregationDataVector.at(vectorOffset);

aggregateData.increment(numVal);

++hashEntryIndex;
}
};

std::visit(visitor, std::move(expressionResult),
aggregationDataVariant);
std::visit(makeProcessGroupsVisitor(currentBlockSize,
&evaluationContext, hashEntries),
std::move(expressionResult), aggregationDataVariant);
}
}
}
Expand Down
49 changes: 29 additions & 20 deletions src/engine/GroupBy.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,13 +158,18 @@ class GroupBy : public Operation {

// Data to perform the AVG aggregation using the HashMap optimization.
struct AverageAggregationData {
using ValueGetter = sparqlExpression::detail::NumericValueGetter;
bool error_ = false;
double sum_ = 0;
int64_t count_ = 0;
void increment(sparqlExpression::detail::NumericValue intermediate) {
if (const int64_t* intval = std::get_if<int64_t>(&intermediate))
void increment(auto&& value,
const sparqlExpression::EvaluationContext* ctx) {
auto valueGetter = ValueGetter{};
auto val = valueGetter(AD_FWD(value), ctx);

if (const int64_t* intval = std::get_if<int64_t>(&val))
sum_ += (double)*intval;
else if (const double* dval = std::get_if<double>(&intermediate))
else if (const double* dval = std::get_if<double>(&val))
sum_ += *dval;
else
error_ = true;
Expand All @@ -180,10 +185,12 @@ class GroupBy : public Operation {

// Data to perform the COUNT aggregation using the HashMap optimization.
struct CountAggregationData {
using ValueGetter = sparqlExpression::detail::IsValidValueGetter;
int64_t count_ = 0;
void increment(sparqlExpression::detail::NumericValue intermediate) {
(void)intermediate;
count_++;
void increment(auto&& value,
const sparqlExpression::EvaluationContext* ctx) {
auto valueGetter = ValueGetter{};
if (valueGetter(AD_FWD(value), ctx)) count_++;
}
[[nodiscard]] ValueId calculateResult() const {
return ValueId::makeFromInt(count_);
Expand All @@ -209,12 +216,14 @@ class GroupBy : public Operation {
// Stores information required for substitution of grouped by variable in
// an expression tree.
struct GroupedByVariableSubstitutions {
std::vector<ParentAndChildIndex> occurrences;
bool topLevel;
std::vector<ParentAndChildIndex> occurrences_;
// Determines whether the grouped by variable appears at the top of an
// alias, e.g. `SELECT (?a as ?x) WHERE {...} GROUP BY ?a`.
bool topLevel_;
};

// Used to store the kind of aggregate.
enum class HashMapAggregateKind { AVG, COUNT };
enum class HashMapAggregateType { AVG, COUNT };

// Stores information required for evaluation of an aggregate as well
// as the alias containing it.
Expand All @@ -228,16 +237,16 @@ class GroupBy : public Operation {
// appears in the parents' children, so that it may be substituted away.
std::optional<ParentAndChildIndex> parentAndIndex_ = std::nullopt;
// Which kind of aggregate expression.
HashMapAggregateKind aggregateKind_;
HashMapAggregateType aggregateType_;

HashMapAggregateInformation(
sparqlExpression::SparqlExpression* expr, size_t aggregateDataIndex,
HashMapAggregateKind aggregateKind,
HashMapAggregateType aggregateType,
std::optional<ParentAndChildIndex> parentAndIndex = std::nullopt)
: expr_{expr},
aggregateDataIndex_{aggregateDataIndex},
parentAndIndex_{parentAndIndex},
aggregateKind_{aggregateKind} {
aggregateType_{aggregateType} {
AD_CONTRACT_CHECK(expr != nullptr);
}
};
Expand Down Expand Up @@ -275,18 +284,18 @@ class GroupBy : public Operation {
class HashMapAggregationData {
public:
HashMapAggregationData(
ad_utility::AllocatorWithLimit<Id> alloc,
std::vector<HashMapAliasInformation>& aggregateAliases)
const ad_utility::AllocatorWithLimit<Id>& alloc,
const std::vector<HashMapAliasInformation>& aggregateAliases)
: map_{alloc} {
size_t numAggregates = 0;
for (auto& alias : aggregateAliases) {
for (auto& aggregate : alias.aggregateInfo_) {
for (const auto& alias : aggregateAliases) {
for (const auto& aggregate : alias.aggregateInfo_) {
++numAggregates;

if (aggregate.aggregateKind_ == HashMapAggregateKind::AVG)
if (aggregate.aggregateType_ == HashMapAggregateType::AVG)
aggregationData_.emplace_back(
std::vector<AverageAggregationData>{});
if (aggregate.aggregateKind_ == HashMapAggregateKind::COUNT)
if (aggregate.aggregateType_ == HashMapAggregateType::COUNT)
aggregationData_.emplace_back(std::vector<CountAggregationData>{});
}
}
Expand Down Expand Up @@ -358,7 +367,7 @@ class GroupBy : public Operation {
// Substitute the group values for all occurrences of a group variable.
void substituteGroupVariable(
const std::vector<GroupBy::ParentAndChildIndex>& occurrences,
IdTable* resultTable);
IdTable* resultTable) const;

// Substitute the results for all aggregates in `info`. The values of the
// grouped variable should be at column 0 in `groupValues`.
Expand All @@ -376,7 +385,7 @@ class GroupBy : public Operation {
static bool hasAnyType(const auto& expr);

// Check if an expression is a currently supported aggregate.
static std::optional<GroupBy::HashMapAggregateKind> isSupportedAggregate(
static std::optional<GroupBy::HashMapAggregateType> isSupportedAggregate(
sparqlExpression::SparqlExpression* expr);

// Find all occurrences of grouped by variable for expression `expr`.
Expand Down
14 changes: 7 additions & 7 deletions test/GroupByTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -528,19 +528,19 @@ TEST_F(GroupByOptimizations, findGroupedVariable) {
GroupBy groupBy{ad_utility::testing::getQec(), {Variable{"?a"}}, {}, values};

auto variableAtTop = groupBy.findGroupedVariable(expr1.get());
ASSERT_TRUE(variableAtTop.topLevel);
ASSERT_EQ(variableAtTop.occurrences.size(), 0);
ASSERT_TRUE(variableAtTop.topLevel_);
ASSERT_EQ(variableAtTop.occurrences_.size(), 0);

auto variableInExpression = groupBy.findGroupedVariable(expr2.get());
ASSERT_FALSE(variableInExpression.topLevel);
ASSERT_EQ(variableInExpression.occurrences.size(), 1);
auto parentAndChildIndex = variableInExpression.occurrences.at(0);
ASSERT_FALSE(variableInExpression.topLevel_);
ASSERT_EQ(variableInExpression.occurrences_.size(), 1);
auto parentAndChildIndex = variableInExpression.occurrences_.at(0);
ASSERT_EQ(parentAndChildIndex.nThChild_, 0);
ASSERT_EQ(parentAndChildIndex.parent_, expr2.get());

auto variableNotFound = groupBy.findGroupedVariable(expr3.get());
ASSERT_FALSE(variableNotFound.topLevel);
ASSERT_EQ(variableNotFound.occurrences.size(), 0);
ASSERT_FALSE(variableNotFound.topLevel_);
ASSERT_EQ(variableNotFound.occurrences_.size(), 0);
}

// _____________________________________________________________________________
Expand Down