diff --git a/xls/passes/BUILD b/xls/passes/BUILD index 926968aa4f..e49ac35946 100644 --- a/xls/passes/BUILD +++ b/xls/passes/BUILD @@ -1107,6 +1107,7 @@ cc_library( "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", "//xls/common:module_initializer", "//xls/common/status:status_macros", "//xls/ir", @@ -2421,6 +2422,7 @@ cc_test( "//xls/ir:function_builder", "//xls/ir:ir_matcher", "//xls/ir:ir_test_base", + "//xls/solvers:z3_ir_equivalence_testutils", "@com_google_googletest//:gtest", ], ) diff --git a/xls/passes/bdd_simplification_pass.cc b/xls/passes/bdd_simplification_pass.cc index c2e2fbbcdc..3be600b793 100644 --- a/xls/passes/bdd_simplification_pass.cc +++ b/xls/passes/bdd_simplification_pass.cc @@ -28,6 +28,7 @@ #include "absl/log/log.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" +#include "absl/types/span.h" #include "xls/common/module_initializer.h" #include "xls/common/status/status_macros.h" #include "xls/ir/bits.h" @@ -333,6 +334,35 @@ absl::StatusOr SimplifyNode(Node* node, const QueryEngine& query_engine, return true; } + // Simplify kPrioritySelect operations where the selector is known to have at + // least one set bit. + if (NarrowingEnabled(opt_level) && node->Is() && + query_engine.AtLeastOneBitTrue(node->As()->selector())) { + PrioritySelect* sel = node->As(); + + int64_t last_bit = 0; + std::vector trailing_bits; + for (; last_bit < sel->selector()->BitCountOrDie() - 1; ++last_bit) { + trailing_bits.push_back(TreeBitLocation(sel->selector(), last_bit)); + if (query_engine.AtLeastOneTrue(trailing_bits)) { + break; + } + } + DCHECK(last_bit < sel->selector()->BitCountOrDie() - 1 || + query_engine.AtLeastOneTrue(trailing_bits)); + + XLS_ASSIGN_OR_RETURN(Node * new_selector, + node->function_base()->MakeNode( + node->loc(), sel->selector(), + /*start=*/0, /*width=*/last_bit)); + absl::Span new_cases = sel->cases().subspan(0, last_bit); + Node* new_default = sel->get_case(last_bit); + XLS_RETURN_IF_ERROR(node->ReplaceUsesWithNew( + new_selector, new_cases, new_default) + .status()); + return true; + } + return false; } diff --git a/xls/passes/bdd_simplification_pass_test.cc b/xls/passes/bdd_simplification_pass_test.cc index 1f2cfb212c..6fdccff64d 100644 --- a/xls/passes/bdd_simplification_pass_test.cc +++ b/xls/passes/bdd_simplification_pass_test.cc @@ -31,6 +31,7 @@ #include "xls/ir/package.h" #include "xls/passes/optimization_pass.h" #include "xls/passes/pass_base.h" +#include "xls/solvers/z3_ir_equivalence_testutils.h" namespace m = ::xls::op_matchers; @@ -122,6 +123,46 @@ TEST_F(BddSimplificationPassTest, RemoveRedundantOneHot) { EXPECT_THAT(f->return_value(), m::Concat(m::Eq(), m::Concat())); } +TEST_F(BddSimplificationPassTest, RemoveRedundantPrioritySelectCases) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + BValue x = fb.Param("x", p->GetBitsType(8)); + BValue a = fb.Param("a", p->GetBitsType(24)); + BValue b = fb.Param("b", p->GetBitsType(24)); + BValue c = fb.Param("c", p->GetBitsType(24)); + BValue d = fb.Param("d", p->GetBitsType(24)); + BValue x_ge_5 = fb.UGe(x, fb.Literal(UBits(5, 8))); + BValue x_le_42 = fb.ULe(x, fb.Literal(UBits(42, 8))); + BValue x_eq_8 = fb.Eq(x, fb.Literal(UBits(8, 8))); + fb.PrioritySelect(fb.Concat({x_eq_8, x_ge_5, x_le_42}), /*cases=*/{a, b, c}, + /*default_value=*/d); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build()); + + solvers::z3::ScopedVerifyEquivalence sve{f}; + EXPECT_THAT(Run(f), IsOkAndHolds(true)); + EXPECT_THAT(f->return_value(), + m::PrioritySelect(m::BitSlice(m::Concat()), {m::Param("a")}, + m::Param("b"))); +} + +TEST_F(BddSimplificationPassTest, PreserveNonRedundantPrioritySelectCases) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + BValue x = fb.Param("x", p->GetBitsType(8)); + BValue a = fb.Param("a", p->GetBitsType(24)); + BValue b = fb.Param("b", p->GetBitsType(24)); + BValue c = fb.Param("c", p->GetBitsType(24)); + BValue d = fb.Param("d", p->GetBitsType(24)); + BValue x_ge_5 = fb.UGe(x, fb.Literal(UBits(5, 8))); + BValue x_lt_3 = fb.ULt(x, fb.Literal(UBits(3, 8))); + BValue x_eq_3 = fb.Eq(x, fb.Literal(UBits(3, 8))); + fb.PrioritySelect(fb.Concat({x_eq_3, x_ge_5, x_lt_3}), /*cases=*/{a, b, c}, + /*default_value=*/d); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build()); + + EXPECT_THAT(Run(f), IsOkAndHolds(false)); +} + TEST_F(BddSimplificationPassTest, ConvertTwoWayOneHotSelect) { auto p = CreatePackage(); XLS_ASSERT_OK_AND_ASSIGN(Function * f, ParseFunction(R"(