diff --git a/xls/passes/next_value_optimization_pass.cc b/xls/passes/next_value_optimization_pass.cc index 8cfd578184..554f89c896 100644 --- a/xls/passes/next_value_optimization_pass.cc +++ b/xls/passes/next_value_optimization_pass.cc @@ -15,6 +15,7 @@ #include "xls/passes/next_value_optimization_pass.h" #include +#include #include #include #include @@ -70,261 +71,238 @@ absl::Status RemoveNextValue(Proc* proc, Next* next) { return proc->RemoveNode(next); } -absl::StatusOr RemoveLiteralPredicates(Proc* proc) { - bool changed = false; - - std::vector next_values(proc->next_values().begin(), - proc->next_values().end()); - for (Next* next : next_values) { - if (!next->predicate().has_value()) { - continue; - } - Node* predicate = *next->predicate(); - if (!predicate->Is()) { - continue; - } - - changed = true; - - Literal* literal_predicate = predicate->As(); - if (literal_predicate->value().IsAllZeros()) { - XLS_VLOG(2) << "Identified node as dead due to zero predicate; removing: " - << *next; - XLS_RETURN_IF_ERROR(RemoveNextValue(proc, next)); - } else { - XLS_VLOG(2) << "Identified node as always live; removing predicate: " - << *next; - XLS_ASSIGN_OR_RETURN(Next * new_next, next->ReplaceUsesWithNew( - /*param=*/next->param(), - /*value=*/next->value(), - /*predicate=*/std::nullopt)); - new_next->SetLoc(next->loc()); - if (next->HasAssignedName()) { - new_next->SetName(next->GetName()); - } - XLS_RETURN_IF_ERROR(RemoveNextValue(proc, next)); - } +absl::StatusOr>> RemoveLiteralPredicate( + Proc* proc, Next* next) { + if (!next->predicate().has_value()) { + return std::nullopt; + } + Node* predicate = *next->predicate(); + if (!predicate->Is()) { + return std::nullopt; } - return changed; + Literal* literal_predicate = predicate->As(); + if (literal_predicate->value().IsAllZeros()) { + XLS_VLOG(2) << "Identified node as dead due to zero predicate; removing: " + << *next; + XLS_RETURN_IF_ERROR(RemoveNextValue(proc, next)); + return std::vector(); + } + XLS_VLOG(2) << "Identified node as always live; removing predicate: " + << *next; + XLS_ASSIGN_OR_RETURN(Next * new_next, next->ReplaceUsesWithNew( + /*param=*/next->param(), + /*value=*/next->value(), + /*predicate=*/std::nullopt)); + new_next->SetLoc(next->loc()); + if (next->HasAssignedName()) { + new_next->SetName(next->GetName()); + } + XLS_RETURN_IF_ERROR(RemoveNextValue(proc, next)); + return std::vector({new_next}); } -absl::StatusOr SplitSmallSelects(Proc* proc, - const OptimizationPassOptions& options) { +absl::StatusOr>> SplitSmallSelect( + Proc* proc, Next* next, const OptimizationPassOptions& options) { if (!options.split_next_value_selects.has_value()) { - return false; + return std::nullopt; } - bool changed = false; + if (!next->value()->Is()) { - continue; - } + Select* selected_value = next->value()->As(); - if (selected_value->cases().size() > *options.split_next_value_selects) { - continue; + std::vector new_next_values; + for (int64_t i = 0; i < selected_value->cases().size(); ++i) { + XLS_ASSIGN_OR_RETURN( + Literal * index, + proc->MakeNode( + SourceInfo(), + Value(UBits(i, selected_value->selector()->BitCountOrDie())))); + XLS_ASSIGN_OR_RETURN( + Node * predicate, + proc->MakeNode(SourceInfo(), selected_value->selector(), + index, Op::kEq)); + if (next->predicate().has_value()) { + XLS_ASSIGN_OR_RETURN( + predicate, + proc->MakeNode( + SourceInfo(), std::vector{*next->predicate(), predicate}, + Op::kAnd)); } - changed = true; - - for (int64_t i = 0; i < selected_value->cases().size(); ++i) { - XLS_ASSIGN_OR_RETURN( - Literal * index, - proc->MakeNode( - SourceInfo(), - Value(UBits(i, selected_value->selector()->BitCountOrDie())))); - XLS_ASSIGN_OR_RETURN( - Node * predicate, - proc->MakeNode(SourceInfo(), selected_value->selector(), - index, Op::kEq)); - if (next->predicate().has_value()) { - XLS_ASSIGN_OR_RETURN( - predicate, - proc->MakeNode( - SourceInfo(), std::vector{*next->predicate(), predicate}, - Op::kAnd)); - } - - std::string name; - if (next->HasAssignedName()) { - name = absl::StrCat(next->GetName(), "_case_", i); - } - XLS_RETURN_IF_ERROR( - proc->MakeNodeWithName(next->loc(), - /*param=*/next->param(), - /*value=*/selected_value->cases()[i], - predicate, name) - .status()); + std::string name; + if (next->HasAssignedName()) { + name = absl::StrCat(next->GetName(), "_case_", i); } + XLS_ASSIGN_OR_RETURN( + Next * new_next, + proc->MakeNodeWithName(next->loc(), + /*param=*/next->param(), + /*value=*/selected_value->cases()[i], + predicate, name)); + new_next_values.push_back(new_next); + } - if (selected_value->default_value().has_value()) { - XLS_ASSIGN_OR_RETURN( - Literal * max_index, - proc->MakeNode( - SourceInfo(), - Value(UBits(selected_value->cases().size() - 1, - selected_value->selector()->BitCountOrDie())))); + if (selected_value->default_value().has_value()) { + XLS_ASSIGN_OR_RETURN( + Literal * max_index, + proc->MakeNode( + SourceInfo(), + Value(UBits(selected_value->cases().size() - 1, + selected_value->selector()->BitCountOrDie())))); + XLS_ASSIGN_OR_RETURN( + Node * predicate, + proc->MakeNode(SourceInfo(), selected_value->selector(), + max_index, Op::kUGt)); + if (next->predicate().has_value()) { XLS_ASSIGN_OR_RETURN( - Node * predicate, - proc->MakeNode(SourceInfo(), selected_value->selector(), - max_index, Op::kUGt)); - if (next->predicate().has_value()) { - XLS_ASSIGN_OR_RETURN( - predicate, - proc->MakeNode( - SourceInfo(), std::vector{*next->predicate(), predicate}, - Op::kAnd)); - } - - std::string name; - if (next->HasAssignedName()) { - name = absl::StrCat(next->GetName(), "_default_case"); - } - XLS_RETURN_IF_ERROR(proc->MakeNodeWithName( - next->loc(), - /*param=*/next->param(), - /*value=*/*selected_value->default_value(), - predicate, name) - .status()); + predicate, + proc->MakeNode( + SourceInfo(), std::vector{*next->predicate(), predicate}, + Op::kAnd)); } - XLS_RETURN_IF_ERROR(RemoveNextValue(proc, next)); + std::string name; + if (next->HasAssignedName()) { + name = absl::StrCat(next->GetName(), "_default_case"); + } + XLS_ASSIGN_OR_RETURN( + Next * new_next, + proc->MakeNodeWithName(next->loc(), + /*param=*/next->param(), + /*value=*/*selected_value->default_value(), + predicate, name)); + new_next_values.push_back(new_next); } - return changed; + XLS_RETURN_IF_ERROR(RemoveNextValue(proc, next)); + return new_next_values; } -absl::StatusOr SplitPrioritySelects(Proc* proc) { - bool changed = false; - - std::vector next_values(proc->next_values().begin(), - proc->next_values().end()); - for (Next* next : next_values) { - if (!next->value()->Is()) { - continue; - } - PrioritySelect* selected_value = next->value()->As(); - - changed = true; - for (int64_t i = 0; i < selected_value->cases().size(); ++i) { - absl::InlinedVector all_clauses; - XLS_ASSIGN_OR_RETURN( - Node * case_active, - proc->MakeNode(SourceInfo(), selected_value->selector(), - /*start=*/i, /*width=*/1)); - all_clauses.push_back(case_active); - if (next->predicate().has_value()) { - all_clauses.push_back(*next->predicate()); - } - if (i > 0) { - XLS_ASSIGN_OR_RETURN(Node * higher_priority_cases_inactive, - NorReduceTrailing(selected_value->selector(), i)); - all_clauses.push_back(higher_priority_cases_inactive); - } - XLS_ASSIGN_OR_RETURN(Node * case_predicate, - NaryAndIfNeeded(proc, all_clauses)); - - std::string name; - if (next->HasAssignedName()) { - name = absl::StrCat(next->GetName(), "_case_", i); - } - XLS_RETURN_IF_ERROR( - proc->MakeNodeWithName(next->loc(), - /*param=*/next->param(), - /*value=*/selected_value->get_case(i), - /*predicate=*/case_predicate, name) - .status()); - } +absl::StatusOr>> SplitPrioritySelect( + Proc* proc, Next* next) { + if (!next->value()->Is()) { + return std::nullopt; + } + PrioritySelect* selected_value = next->value()->As(); - // Default case; if all bits of the input are zero, `priority_sel` returns - // zero. - absl::InlinedVector all_default_clauses; + std::vector new_next_values; + for (int64_t i = 0; i < selected_value->cases().size(); ++i) { + absl::InlinedVector all_clauses; XLS_ASSIGN_OR_RETURN( - Literal * zero_selector, - proc->MakeNode( - SourceInfo(), ZeroOfType(selected_value->selector()->GetType()))); - XLS_ASSIGN_OR_RETURN( - Node * all_cases_inactive, - proc->MakeNode(SourceInfo(), selected_value->selector(), - zero_selector, Op::kEq)); - all_default_clauses.push_back(all_cases_inactive); + Node * case_active, + proc->MakeNode(SourceInfo(), selected_value->selector(), + /*start=*/i, /*width=*/1)); + all_clauses.push_back(case_active); if (next->predicate().has_value()) { - all_default_clauses.push_back(*next->predicate()); + all_clauses.push_back(*next->predicate()); } - XLS_ASSIGN_OR_RETURN(Node * default_predicate, - NaryAndIfNeeded(proc, all_default_clauses)); - XLS_ASSIGN_OR_RETURN( - Node * default_value, - proc->MakeNode(SourceInfo(), - ZeroOfType(selected_value->GetType()))); + if (i > 0) { + XLS_ASSIGN_OR_RETURN(Node * higher_priority_cases_inactive, + NorReduceTrailing(selected_value->selector(), i)); + all_clauses.push_back(higher_priority_cases_inactive); + } + XLS_ASSIGN_OR_RETURN(Node * case_predicate, + NaryAndIfNeeded(proc, all_clauses)); std::string name; if (next->HasAssignedName()) { - name = absl::StrCat(next->GetName(), "_case_default"); + name = absl::StrCat(next->GetName(), "_case_", i); } - XLS_RETURN_IF_ERROR( + XLS_ASSIGN_OR_RETURN( + Next * new_next, proc->MakeNodeWithName(next->loc(), /*param=*/next->param(), - /*value=*/default_value, - /*predicate=*/default_predicate, name) - .status()); - - XLS_RETURN_IF_ERROR(RemoveNextValue(proc, next)); + /*value=*/selected_value->get_case(i), + /*predicate=*/case_predicate, name)); + new_next_values.push_back(new_next); } - return changed; + // Default case; if all bits of the input are zero, `priority_sel` returns + // zero. + absl::InlinedVector all_default_clauses; + XLS_ASSIGN_OR_RETURN( + Literal * zero_selector, + proc->MakeNode( + SourceInfo(), ZeroOfType(selected_value->selector()->GetType()))); + XLS_ASSIGN_OR_RETURN( + Node * all_cases_inactive, + proc->MakeNode(SourceInfo(), selected_value->selector(), + zero_selector, Op::kEq)); + all_default_clauses.push_back(all_cases_inactive); + if (next->predicate().has_value()) { + all_default_clauses.push_back(*next->predicate()); + } + XLS_ASSIGN_OR_RETURN(Node * default_predicate, + NaryAndIfNeeded(proc, all_default_clauses)); + XLS_ASSIGN_OR_RETURN( + Node * default_value, + proc->MakeNode(SourceInfo(), + ZeroOfType(selected_value->GetType()))); + + std::string name; + if (next->HasAssignedName()) { + name = absl::StrCat(next->GetName(), "_case_default"); + } + XLS_ASSIGN_OR_RETURN( + Next * new_next, + proc->MakeNodeWithName(next->loc(), + /*param=*/next->param(), + /*value=*/default_value, + /*predicate=*/default_predicate, name)); + new_next_values.push_back(new_next); + + XLS_RETURN_IF_ERROR(RemoveNextValue(proc, next)); + return new_next_values; } -absl::StatusOr SplitSafeOneHotSelects(Proc* proc) { - bool changed = false; +absl::StatusOr>> SplitSafeOneHotSelect( + Proc* proc, Next* next) { + if (!next->value()->Is()) { + return std::nullopt; + } + OneHotSelect* selected_value = next->value()->As(); + if (!selected_value->selector()->Is()) { + // Not safe to use for `next_value`; actual value could be the OR of + // multiple cases. + return std::nullopt; + } - std::vector next_values(proc->next_values().begin(), - proc->next_values().end()); - for (Next* next : next_values) { - if (!next->value()->Is()) { - continue; - } - OneHotSelect* selected_value = next->value()->As(); - if (!selected_value->selector()->Is()) { - // Not safe to use for `next_value`; actual value could be the OR of - // multiple cases. - continue; + std::vector new_next_values; + for (int64_t i = 0; i < selected_value->cases().size(); ++i) { + XLS_ASSIGN_OR_RETURN( + Node * case_predicate, + proc->MakeNode(SourceInfo(), selected_value->selector(), + /*start=*/i, /*width=*/1)); + if (next->predicate().has_value()) { + XLS_ASSIGN_OR_RETURN( + case_predicate, + proc->MakeNode( + SourceInfo(), + std::vector{*next->predicate(), case_predicate}, + Op::kAnd)); } - changed = true; - for (int64_t i = 0; i < selected_value->cases().size(); ++i) { - XLS_ASSIGN_OR_RETURN( - Node * case_predicate, - proc->MakeNode(SourceInfo(), selected_value->selector(), - /*start=*/i, /*width=*/1)); - if (next->predicate().has_value()) { - XLS_ASSIGN_OR_RETURN( - case_predicate, - proc->MakeNode( - SourceInfo(), - std::vector{*next->predicate(), case_predicate}, - Op::kAnd)); - } - - std::string name; - if (next->HasAssignedName()) { - name = absl::StrCat(next->GetName(), "_case_", i); - } - XLS_RETURN_IF_ERROR( - proc->MakeNodeWithName(next->loc(), - /*param=*/next->param(), - /*value=*/selected_value->get_case(i), - /*predicate=*/case_predicate, name) - .status()); + std::string name; + if (next->HasAssignedName()) { + name = absl::StrCat(next->GetName(), "_case_", i); } - XLS_RETURN_IF_ERROR(RemoveNextValue(proc, next)); + XLS_ASSIGN_OR_RETURN( + Next * new_next, + proc->MakeNodeWithName(next->loc(), + /*param=*/next->param(), + /*value=*/selected_value->get_case(i), + /*predicate=*/case_predicate, name)); + new_next_values.push_back(new_next); } - - return changed; + XLS_RETURN_IF_ERROR(RemoveNextValue(proc, next)); + return new_next_values; } } // namespace @@ -339,21 +317,53 @@ absl::StatusOr NextValueOptimizationPass::RunOnProcInternal( changed = changed || modernize_changed; } - XLS_ASSIGN_OR_RETURN(bool literal_predicate_changed, - RemoveLiteralPredicates(proc)); - changed = changed || literal_predicate_changed; + std::deque worklist(proc->next_values().begin(), + proc->next_values().end()); + while (!worklist.empty()) { + Next* next = worklist.front(); + worklist.pop_front(); + + XLS_ASSIGN_OR_RETURN( + std::optional> literal_predicate_next_values, + RemoveLiteralPredicate(proc, next)); + if (literal_predicate_next_values.has_value()) { + changed = true; + worklist.insert(worklist.end(), literal_predicate_next_values->begin(), + literal_predicate_next_values->end()); + continue; + } - XLS_ASSIGN_OR_RETURN(bool split_select_changed, - SplitSmallSelects(proc, options)); - changed = changed || split_select_changed; + XLS_ASSIGN_OR_RETURN( + std::optional> split_select_next_values, + SplitSmallSelect(proc, next, options)); + if (split_select_next_values.has_value()) { + changed = true; + worklist.insert(worklist.end(), split_select_next_values->begin(), + split_select_next_values->end()); + continue; + } - XLS_ASSIGN_OR_RETURN(bool split_priority_select_changed, - SplitPrioritySelects(proc)); - changed = changed || split_priority_select_changed; + XLS_ASSIGN_OR_RETURN( + std::optional> split_priority_select_next_values, + SplitPrioritySelect(proc, next)); + if (split_priority_select_next_values.has_value()) { + changed = true; + worklist.insert(worklist.end(), + split_priority_select_next_values->begin(), + split_priority_select_next_values->end()); + continue; + } - XLS_ASSIGN_OR_RETURN(bool split_one_hot_select_changed, - SplitSafeOneHotSelects(proc)); - changed = changed || split_one_hot_select_changed; + XLS_ASSIGN_OR_RETURN( + std::optional> split_one_hot_select_next_values, + SplitSafeOneHotSelect(proc, next)); + if (split_one_hot_select_next_values.has_value()) { + changed = true; + worklist.insert(worklist.end(), split_one_hot_select_next_values->begin(), + split_one_hot_select_next_values->end()); + continue; + } + } return changed; } diff --git a/xls/passes/next_value_optimization_pass_test.cc b/xls/passes/next_value_optimization_pass_test.cc index 4011d847de..081616251e 100644 --- a/xls/passes/next_value_optimization_pass_test.cc +++ b/xls/passes/next_value_optimization_pass_test.cc @@ -238,5 +238,36 @@ TEST_F(NextValueOptimizationPassTest, BigSelectNextValue) { IsOkAndHolds(false)); } +TEST_F(NextValueOptimizationPassTest, CascadingSmallSelectsNextValue) { + auto p = CreatePackage(); + ProcBuilder pb("p", "tkn", p.get()); + BValue x = pb.StateElement("x", Value(UBits(0, 2))); + BValue a = pb.StateElement("a", Value(UBits(0, 1))); + BValue b = pb.StateElement("b", Value(UBits(0, 1))); + BValue select_b_1 = pb.Select( + b, std::vector{pb.Literal(UBits(2, 2)), pb.Literal(UBits(1, 2))}); + BValue select_b_2 = pb.Select( + b, std::vector{pb.Literal(UBits(2, 2)), pb.Literal(UBits(3, 2))}); + BValue select_a = pb.Select(a, std::vector{select_b_1, select_b_2}); + pb.Next(/*param=*/x, /*value=*/select_a); + XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, pb.Build(pb.GetTokenParam())); + + EXPECT_THAT(Run(p.get(), /*split_next_value_selects=*/2), IsOkAndHolds(true)); + EXPECT_THAT(proc->next_values(), + UnorderedElementsAre( + m::Next(m::Param(), m::Literal(2), + m::And(m::Eq(m::Param("a"), m::Literal(0)), + m::Eq(m::Param("b"), m::Literal(0)))), + m::Next(m::Param(), m::Literal(1), + m::And(m::Eq(m::Param("a"), m::Literal(0)), + m::Eq(m::Param("b"), m::Literal(1)))), + m::Next(m::Param(), m::Literal(2), + m::And(m::Eq(m::Param("a"), m::Literal(1)), + m::Eq(m::Param("b"), m::Literal(0)))), + m::Next(m::Param(), m::Literal(3), + m::And(m::Eq(m::Param("a"), m::Literal(1)), + m::Eq(m::Param("b"), m::Literal(1)))))); +} + } // namespace } // namespace xls