Skip to content

Commit

Permalink
Implement feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
kcaliban committed Jan 12, 2024
1 parent f3daf68 commit ceeeb9e
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 100 deletions.
165 changes: 89 additions & 76 deletions src/engine/GroupBy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -748,10 +748,11 @@ GroupBy::checkIfHashMapOptimizationPossible(std::vector<Aggregate>& aliases) {
}

// _____________________________________________________________________________
GroupBy::GroupedByVariableSubstitutions GroupBy::findGroupedVariable(
sparqlExpression::SparqlExpression* expr) {
std::variant<std::vector<GroupBy::ParentAndChildIndex>,
GroupBy::OccurenceAsRoot>
GroupBy::findGroupedVariable(sparqlExpression::SparqlExpression* expr) {
AD_CONTRACT_CHECK(_groupByVariables.size() == 1);
GroupBy::GroupedByVariableSubstitutions substitutions{{}, false};
std::variant<std::vector<ParentAndChildIndex>, OccurenceAsRoot> substitutions;
findGroupedVariableImpl(expr, std::nullopt, substitutions);
return substitutions;
}
Expand All @@ -760,16 +761,20 @@ GroupBy::GroupedByVariableSubstitutions GroupBy::findGroupedVariable(
void GroupBy::findGroupedVariableImpl(
sparqlExpression::SparqlExpression* expr,
std::optional<ParentAndChildIndex> parentAndChildIndex,
GroupBy::GroupedByVariableSubstitutions& substitutions) {
std::variant<std::vector<ParentAndChildIndex>, OccurenceAsRoot>&
substitutions) {
if (auto value = hasType<sparqlExpression::VariableExpression>(expr)) {
auto variable = value.value()->value();
for (const auto& groupedVariable : _groupByVariables) {
if (variable != groupedVariable) continue;

if (parentAndChildIndex.has_value()) {
substitutions.occurrences_.emplace_back(parentAndChildIndex.value());
auto vector =
std::get_if<std::vector<ParentAndChildIndex>>(&substitutions);
AD_CONTRACT_CHECK(vector != nullptr);
vector->emplace_back(parentAndChildIndex.value());
} else {
substitutions.topLevel_ = true;
substitutions = OccurenceAsRoot{};
return;
}
}
Expand Down Expand Up @@ -924,7 +929,7 @@ void GroupBy::substituteGroupVariable(
sparqlExpression::VectorWithMemoryLimit<ValueId> values(
getExecutionContext()->getAllocator());
values.resize(groupValues.size());
std::ranges::copy(groupValues.begin(), groupValues.end(), values.begin());
std::ranges::copy(groupValues, values.begin());

auto newExpression = std::make_unique<sparqlExpression::VectorIdExpression>(
std::move(values));
Expand Down Expand Up @@ -988,6 +993,80 @@ GroupBy::HashMapAggregationData::getSortedGroupColumn() const {
return sortedKeys;
}

// _____________________________________________________________________________
void GroupBy::evaluateAlias(
HashMapAliasInformation& alias, IdTable* result,
sparqlExpression::EvaluationContext& evaluationContext,
const HashMapAggregationData& aggregationData, LocalVocab* localVocab) {
auto& info = alias.aggregateInfo_;

// Check if the grouped variable occurs in this expression
auto groupByVariableSubstitutions =
findGroupedVariable(alias.expr_.getPimpl());

if (std::get_if<OccurenceAsRoot>(&groupByVariableSubstitutions)) {
// If the aggregate is at the top of the alias, e.g. `SELECT (?a as ?x)
// WHERE {...} GROUP BY ?a`, we can copy values directly from the column
// of the grouped variable
decltype(auto) groupValues = result->getColumn(0);
decltype(auto) outValues = result->getColumn(alias.outCol_);
std::ranges::copy(groupValues, outValues.begin());

// We also need to store it for possible future use
sparqlExpression::VectorWithMemoryLimit<ValueId> values(
getExecutionContext()->getAllocator());
values.resize(groupValues.size());
std::ranges::copy(groupValues, values.begin());

evaluationContext._previousResultsFromSameGroup.at(alias.outCol_) =
sparqlExpression::copyExpressionResult(
sparqlExpression::ExpressionResult{std::move(values)});
} else if (info.size() == 1 && !info.at(0).parentAndIndex_.has_value()) {
// Only one aggregate, and it is at the top of the alias, e.g.
// `(AVG(?x) as ?y)`. The grouped by variable cannot occur inside
// an aggregate, hence we don't need to substitute anything here
auto& aggregate = info.at(0);

// Get aggregate results
auto aggregateResults = getHashMapAggregationResults(
result, aggregationData, aggregate.aggregateDataIndex_,
evaluationContext._beginIndex, evaluationContext._endIndex);

// Copy to result table
decltype(auto) outValues = result->getColumn(alias.outCol_);
std::ranges::copy(aggregateResults,
outValues.begin() + evaluationContext._beginIndex);

// Copy the result so that future aliases may reuse it
evaluationContext._previousResultsFromSameGroup.at(alias.outCol_) =
sparqlExpression::copyExpressionResult(
sparqlExpression::ExpressionResult{std::move(aggregateResults)});
} else {
auto occurrences =
get<std::vector<ParentAndChildIndex>>(groupByVariableSubstitutions);
// Substitute in the values of the grouped variable
substituteGroupVariable(occurrences, result);

// Substitute in the results of all aggregates contained in the
// expression of the current alias, if `info` is non-empty.
substituteAllAggregates(info, evaluationContext._beginIndex,
evaluationContext._endIndex, aggregationData,
result);

// Evaluate top-level alias expression
sparqlExpression::ExpressionResult expressionResult =
alias.expr_.getPimpl()->evaluate(&evaluationContext);

// Copy the result so that future aliases may reuse it
evaluationContext._previousResultsFromSameGroup.at(alias.outCol_) =
sparqlExpression::copyExpressionResult(expressionResult);

// Extract values
extractValues(std::move(expressionResult), evaluationContext, result,
localVocab, alias.outCol_);
}
}

// _____________________________________________________________________________
void GroupBy::createResultFromHashMap(
IdTable* result, const HashMapAggregationData& aggregationData,
Expand Down Expand Up @@ -1020,74 +1099,8 @@ void GroupBy::createResultFromHashMap(
evaluationContext._endIndex = std::min(i + blockSize, numberOfGroups);

for (auto& alias : aggregateAliases) {
auto& info = alias.aggregateInfo_;

// Check if the grouped variable occurs in this expression
auto groupByVariableSubstitutions =
findGroupedVariable(alias.expr_.getPimpl());

if (groupByVariableSubstitutions.topLevel_) {
// If the aggregate is at the top of the alias, e.g. `SELECT (?a as ?x)
// WHERE {...} GROUP BY ?a`, we can copy values directly from the column
// of the grouped variable
decltype(auto) groupValues = result->getColumn(0);
decltype(auto) outValues = result->getColumn(alias.outCol_);
std::ranges::copy(groupValues, outValues.begin());

// We also need to store it for possible future use
sparqlExpression::VectorWithMemoryLimit<ValueId> values(
getExecutionContext()->getAllocator());
values.resize(groupValues.size());
std::ranges::copy(groupValues.begin(), groupValues.end(),
values.begin());

evaluationContext._previousResultsFromSameGroup.at(alias.outCol_) =
sparqlExpression::copyExpressionResult(
sparqlExpression::ExpressionResult{std::move(values)});
} else if (info.size() == 1 && !info.at(0).parentAndIndex_.has_value()) {
// Only one aggregate, and it is at the top of the alias, e.g.
// `(AVG(?x) as ?y)`. The grouped by variable cannot occur inside
// an aggregate, hence we don't need to substitute anything here
auto& aggregate = info.at(0);

// Get aggregate results
auto aggregateResults = getHashMapAggregationResults(
result, aggregationData, aggregate.aggregateDataIndex_,
evaluationContext._beginIndex, evaluationContext._endIndex);

// Copy to result table
decltype(auto) outValues = result->getColumn(alias.outCol_);
std::ranges::copy(aggregateResults,
outValues.begin() + evaluationContext._beginIndex);

// Copy the result so that future aliases may reuse it
evaluationContext._previousResultsFromSameGroup.at(alias.outCol_) =
sparqlExpression::copyExpressionResult(
sparqlExpression::ExpressionResult{
std::move(aggregateResults)});
} else {
// Substitute in the values of the grouped variable
substituteGroupVariable(groupByVariableSubstitutions.occurrences_,
result);

// Substitute in the results of all aggregates contained in the
// expression of the current alias, if `info` is non-empty.
substituteAllAggregates(info, evaluationContext._beginIndex,
evaluationContext._endIndex, aggregationData,
result);

// Evaluate top-level alias expression
sparqlExpression::ExpressionResult expressionResult =
alias.expr_.getPimpl()->evaluate(&evaluationContext);

// Copy the result so that future aliases may reuse it
evaluationContext._previousResultsFromSameGroup.at(alias.outCol_) =
sparqlExpression::copyExpressionResult(expressionResult);

// Extract values
extractValues(std::move(expressionResult), evaluationContext, result,
localVocab, alias.outCol_);
}
evaluateAlias(alias, result, evaluationContext, aggregationData,
localVocab);
}
}
}
Expand All @@ -1101,7 +1114,7 @@ concept SupportedAggregates =
// Visitor function to extract values from the result of an evaluation of
// the child expression of an aggregate, and subsequently processing the values
// by calling the `increment` function of the corresponding aggregate.
auto makeProcessGroupsVisitor =
static constexpr auto makeProcessGroupsVisitor =
[](size_t blockSize,
const sparqlExpression::EvaluationContext* evaluationContext,
const std::vector<size_t>& hashEntries) {
Expand Down
36 changes: 19 additions & 17 deletions src/engine/GroupBy.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,10 @@ class GroupBy : public Operation {
int64_t count_ = 0;
void increment(auto&& value,
const sparqlExpression::EvaluationContext* ctx) {
auto valueGetter = ValueGetter{};
auto val = valueGetter(AD_FWD(value), ctx);
auto val = ValueGetter{}(AD_FWD(value), ctx);

if (const int64_t* intval = std::get_if<int64_t>(&val))
sum_ += (double)*intval;
sum_ += static_cast<double>(*intval);
else if (const double* dval = std::get_if<double>(&val))
sum_ += *dval;
else
Expand All @@ -189,8 +188,7 @@ class GroupBy : public Operation {
int64_t count_ = 0;
void increment(auto&& value,
const sparqlExpression::EvaluationContext* ctx) {
auto valueGetter = ValueGetter{};
if (valueGetter(AD_FWD(value), ctx)) count_++;
if (ValueGetter{}(AD_FWD(value), ctx)) count_++;
}
[[nodiscard]] ValueId calculateResult() const {
return ValueId::makeFromInt(count_);
Expand All @@ -213,15 +211,6 @@ class GroupBy : public Operation {
}
};

// Stores information required for substitution of grouped by variable in
// an expression tree.
struct GroupedByVariableSubstitutions {
std::vector<ParentAndChildIndex> occurrences_;
// Determines whether the grouped by variable appears at the top of an
// alias, e.g. `SELECT (?a as ?x) WHERE {...} GROUP BY ?a`.
bool topLevel_;
};

// Used to store the kind of aggregate.
enum class HashMapAggregateType { AVG, COUNT };

Expand Down Expand Up @@ -341,6 +330,14 @@ class GroupBy : public Operation {
IdTable* resultTable, const HashMapAggregationData& aggregationData,
size_t dataIndex, size_t beginIndex, size_t endIndex);

// Substitute away any occurrences of the grouped variable and of aggregate
// results, if necessary, and subsequently evaluate the expression of an
// alias
void evaluateAlias(HashMapAliasInformation& alias, IdTable* result,
sparqlExpression::EvaluationContext& evaluationContext,
const HashMapAggregationData& aggregationData,
LocalVocab* localVocab);

// Sort the HashMap by key and create result table.
void createResultFromHashMap(
IdTable* result, const HashMapAggregationData& aggregationData,
Expand Down Expand Up @@ -388,15 +385,20 @@ class GroupBy : public Operation {
static std::optional<GroupBy::HashMapAggregateType> isSupportedAggregate(
sparqlExpression::SparqlExpression* expr);

// Determines whether the grouped by variable appears at the top of an
// alias, e.g. `SELECT (?a as ?x) WHERE {...} GROUP BY ?a`.
struct OccurenceAsRoot {};

// Find all occurrences of grouped by variable for expression `expr`.
GroupBy::GroupedByVariableSubstitutions findGroupedVariable(
sparqlExpression::SparqlExpression* expr);
std::variant<std::vector<ParentAndChildIndex>, OccurenceAsRoot>
findGroupedVariable(sparqlExpression::SparqlExpression* expr);

// The recursive implementation of `findGroupedVariable` (see above).
void findGroupedVariableImpl(
sparqlExpression::SparqlExpression* expr,
std::optional<ParentAndChildIndex> parentAndChildIndex,
GroupBy::GroupedByVariableSubstitutions& substitutions);
std::variant<std::vector<ParentAndChildIndex>, OccurenceAsRoot>&
substitutions);

// Find all aggregates for expression `expr`. Return `std::nullopt`
// if an unsupported aggregate is found.
Expand Down
17 changes: 10 additions & 7 deletions test/GroupByTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -528,19 +528,22 @@ TEST_F(GroupByOptimizations, findGroupedVariable) {
GroupBy groupBy{ad_utility::testing::getQec(), {Variable{"?a"}}, {}, values};

auto variableAtTop = groupBy.findGroupedVariable(expr1.get());
ASSERT_TRUE(variableAtTop.topLevel_);
ASSERT_EQ(variableAtTop.occurrences_.size(), 0);
ASSERT_TRUE(std::get_if<GroupBy::OccurenceAsRoot>(&variableAtTop));

auto variableInExpression = groupBy.findGroupedVariable(expr2.get());
ASSERT_FALSE(variableInExpression.topLevel_);
ASSERT_EQ(variableInExpression.occurrences_.size(), 1);
auto parentAndChildIndex = variableInExpression.occurrences_.at(0);
auto variableInExpressionOccurrences =
std::get_if<std::vector<GroupBy::ParentAndChildIndex>>(
&variableInExpression);
ASSERT_TRUE(variableInExpressionOccurrences);
ASSERT_EQ(variableInExpressionOccurrences->size(), 1);
auto parentAndChildIndex = variableInExpressionOccurrences->at(0);
ASSERT_EQ(parentAndChildIndex.nThChild_, 0);
ASSERT_EQ(parentAndChildIndex.parent_, expr2.get());

auto variableNotFound = groupBy.findGroupedVariable(expr3.get());
ASSERT_FALSE(variableNotFound.topLevel_);
ASSERT_EQ(variableNotFound.occurrences_.size(), 0);
auto variableNotFoundOccurrences =
std::get_if<std::vector<GroupBy::ParentAndChildIndex>>(&variableNotFound);
ASSERT_EQ(variableNotFoundOccurrences->size(), 0);
}

// _____________________________________________________________________________
Expand Down

0 comments on commit ceeeb9e

Please sign in to comment.