From 8a9f59894d61f59f9bc7652080aa2ef94ee270d1 Mon Sep 17 00:00:00 2001 From: Eric Astor Date: Mon, 1 Jul 2024 15:50:38 -0700 Subject: [PATCH] Extend conditional specialization pass to forward through OneHotSelects & logical ops too We already supported forwarding cases through case-at-a-time selects where the context guaranteed which case would be selected. We now extend this logic to also forward inputs through OneHotSelects & ANDs/ORs/XORs where the context guarantees that one input passes identically through the operation. Applied to AND/OR/XOR, this should give us a chance to see through various forms of operations that amount to selects or gates in context. PiperOrigin-RevId: 648509568 --- xls/passes/BUILD | 2 + xls/passes/conditional_specialization_pass.cc | 135 ++++++++++++- .../conditional_specialization_pass_test.cc | 184 ++++++++++++++++-- 3 files changed, 299 insertions(+), 22 deletions(-) diff --git a/xls/passes/BUILD b/xls/passes/BUILD index e49ac35946..5a9151b63e 100644 --- a/xls/passes/BUILD +++ b/xls/passes/BUILD @@ -1545,6 +1545,7 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "//xls/common:module_initializer", + "//xls/common/status:ret_check", "//xls/common/status:status_macros", "//xls/ir", "//xls/ir:bits", @@ -1552,6 +1553,7 @@ cc_library( "//xls/ir:op", "//xls/ir:ternary", "//xls/ir:value", + "//xls/ir:value_utils", ], ) diff --git a/xls/passes/conditional_specialization_pass.cc b/xls/passes/conditional_specialization_pass.cc index 4dc76cf066..31885b1cb7 100644 --- a/xls/passes/conditional_specialization_pass.cc +++ b/xls/passes/conditional_specialization_pass.cc @@ -22,6 +22,7 @@ #include #include #include +#include #include #include "absl/container/btree_set.h" @@ -34,6 +35,7 @@ #include "absl/strings/str_join.h" #include "absl/types/span.h" #include "xls/common/module_initializer.h" +#include "xls/common/status/ret_check.h" #include "xls/common/status/status_macros.h" #include "xls/ir/bits.h" #include "xls/ir/bits_ops.h" @@ -44,6 +46,7 @@ #include "xls/ir/ternary.h" #include "xls/ir/topo_sort.h" #include "xls/ir/value.h" +#include "xls/ir/value_utils.h" #include "xls/passes/bdd_function.h" #include "xls/passes/bdd_query_engine.h" #include "xls/passes/optimization_pass.h" @@ -215,6 +218,18 @@ class ConditionMap { edge_conditions_.insert({key, std::move(condition_set)}); } + // Returns the conditions which can be assumed along the edge to `node` from + // its operand index `operand_no`. + const ConditionSet& GetEdgeConditionSet(Node* node, int64_t operand_no) { + std::pair key = {node, operand_no}; + if (!edge_conditions_.contains(key)) { + // There are no special conditions for this edge. Return the conditions on + // the target of the edge which necessarily hold on the edge as well. + return node_conditions_.at(node); + } + return edge_conditions_.at(key); + } + // Returns the conditions which can be assumed along the edge(s) from node to // user. This interface is asymmetric to SetEdgeCondition (which takes a node // and operand number) to make it easier to use because at a particular node @@ -238,13 +253,7 @@ class ConditionMap { } CHECK(operand_index.has_value()) << absl::StreamFormat( "%s is not a user of %s", user->GetName(), node->GetName()); - std::pair key = {user, operand_index.value()}; - if (!edge_conditions_.contains(key)) { - // There are no special conditions for this edge. Return the conditions on - // the target of the edge which necessarily hold on the edge as well. - return node_conditions_.at(user); - } - return edge_conditions_.at(key); + return GetEdgeConditionSet(user, *operand_index); } std::string ToString() const { @@ -368,6 +377,27 @@ std::optional GetSelectedCase(PrioritySelect* select, return select->default_value(); } +struct ZeroValue : std::monostate {}; +std::optional> GetSelectedCase( + OneHotSelect* ohs, const TernaryVector& selector_value) { + if (!ternary_ops::IsFullyKnown(selector_value)) { + // We can't be sure which case is selected. + return std::nullopt; + } + Bits selector_bits = ternary_ops::ToKnownBitsValues(selector_value); + if (selector_bits.PopCount() > 1) { + // We aren't selecting just one state. + return std::nullopt; + } + for (int64_t i = 0; i < selector_value.size(); ++i) { + if (selector_bits.Get(i)) { + return ohs->get_case(i); + } + } + // All bits of the selector are zero. + return ZeroValue{}; +} + } // namespace absl::StatusOr ConditionalSpecializationPass::RunOnFunctionBaseInternal( @@ -643,10 +673,15 @@ absl::StatusOr ConditionalSpecializationPass::RunOnFunctionBaseInternal( // It may be possible to bypass multiple selects so walk the edge up the // graph as far as possible. For example, in the diagram above `b` may // also be a select with a selector whose value is implied by `s`. - if (operand->Is() || src->Is()) { + while (src->OpIn({Op::kSel, Op::kPrioritySel, Op::kOneHotSel, Op::kAnd, + Op::kOr, Op::kXor})) { if (src->Is(); if (select->selector()->Is()) { @@ -690,6 +725,88 @@ absl::StatusOr ConditionalSpecializationPass::RunOnFunctionBaseInternal( xls::ToString(*implied_selector)); src = *implied_case; replacement = src; + } else if (src->Is()) { + XLS_RET_CHECK(src->Is()); + OneHotSelect* ohs = src->As(); + if (ohs->selector()->Is()) { + break; + } + std::optional implied_selector = + ImpliedNodeTernary(edge_set, ohs->selector(), query_engine); + if (!implied_selector.has_value()) { + break; + } + for (int64_t case_no = 0; case_no < ohs->cases().size(); + ++case_no) { + if (implied_selector.value()[case_no] == + TernaryValue::kKnownZero) { + continue; + } + + // This case could be selected - but if it's definitely zero when + // selected, then we can ignore it. + std::optional implied_case = + ImpliedNodeValue(condition_map.GetEdgeConditionSet( + ohs, /*operand_no=*/case_no + 1), + ohs->cases()[case_no], query_engine); + if (implied_case.has_value() && implied_case->IsZero()) { + implied_selector.value()[case_no] = TernaryValue::kKnownZero; + } + } + std::optional> implied_case = + GetSelectedCase(ohs, *implied_selector); + if (!implied_case.has_value()) { + break; + } + VLOG(3) << absl::StreamFormat( + "Conditions for edge (%s, %s) imply selector %s of select %s " + "has value %s", + operand->GetName(), node->GetName(), ohs->selector()->GetName(), + ohs->GetName(), xls::ToString(*implied_selector)); + if (std::holds_alternative(*implied_case)) { + src = std::get(*implied_case); + } else { + XLS_RET_CHECK(std::holds_alternative(*implied_case)); + XLS_ASSIGN_OR_RETURN( + src, + f->MakeNode(src->loc(), ZeroOfType(src->GetType()))); + } + replacement = src; + } else { + XLS_RET_CHECK(src->OpIn({Op::kAnd, Op::kOr, Op::kXor})); + auto is_identity = [&](const Bits& b) { + if (operand->op() == Op::kAnd) { + return b.IsAllOnes(); + } + return b.IsZero(); + }; + NaryOp* bitwise_op = src->As(); + std::optional nonidentity_operand = std::nullopt; + for (Node* potential_src : bitwise_op->operands()) { + XLS_RET_CHECK(potential_src->GetType()->IsBits()); + std::optional implied_src = + ImpliedNodeValue(edge_set, potential_src, query_engine); + if (implied_src.has_value() && is_identity(*implied_src)) { + continue; + } + if (nonidentity_operand.has_value()) { + // There's more than one potentially-non-zero operand; we're + // done, there's nothing to do. + nonidentity_operand = std::nullopt; + break; + } + nonidentity_operand = potential_src; + } + if (!nonidentity_operand.has_value()) { + break; + } + VLOG(3) << absl::StreamFormat( + "Conditions for edge (%s, %s) imply that bitwise operation " + "%s has only one non-identity operand: %s", + operand->GetName(), node->GetName(), bitwise_op->GetName(), + nonidentity_operand.value()->GetName()); + src = *nonidentity_operand; + replacement = src; } } if (replacement.has_value()) { diff --git a/xls/passes/conditional_specialization_pass_test.cc b/xls/passes/conditional_specialization_pass_test.cc index b2ecc048ff..6a3c7b00c5 100644 --- a/xls/passes/conditional_specialization_pass_test.cc +++ b/xls/passes/conditional_specialization_pass_test.cc @@ -528,11 +528,11 @@ TEST_F(ConditionalSpecializationPassTest, ImpliedSelectorValueUsingOr) { } TEST_F(ConditionalSpecializationPassTest, NotImpliedSelectorValueUsingAnd) { - // No transformation because r does not imply r&s is true or false. + // No transformation because r does not imply r&s&t is true or false. // // a b // \ / - // sel1 ---- r&s + // sel1 ---- r&s&t // | // c | // \ | @@ -547,25 +547,26 @@ TEST_F(ConditionalSpecializationPassTest, NotImpliedSelectorValueUsingAnd) { BValue c = fb.Param("c", u32); BValue r = fb.Param("r", p->GetBitsType(1)); BValue s = fb.Param("s", p->GetBitsType(1)); + BValue t = fb.Param("t", p->GetBitsType(1)); - BValue r_and_s = fb.And(r, s); - BValue sel1 = fb.Select(r_and_s, {a, b}); + BValue r_and_s_and_t = fb.And({r, s, t}); + BValue sel1 = fb.Select(r_and_s_and_t, {a, b}); BValue sel0 = fb.Select(r, {c, sel1}); // Keep r_and_s alive to the return value to avoid replacing r in the // expression with one (this is not what we're testing). - XLS_ASSERT_OK_AND_ASSIGN(Function * f, - fb.BuildWithReturnValue(fb.Concat({sel0, r_and_s}))); + XLS_ASSERT_OK_AND_ASSIGN( + Function * f, fb.BuildWithReturnValue(fb.Concat({sel0, r_and_s_and_t}))); EXPECT_THAT(Run(f), IsOkAndHolds(false)); } TEST_F(ConditionalSpecializationPassTest, NotImpliedSelectorValueUsingOr) { - // No transformation because !r does not imply r|s is true or false. + // No transformation because !r does not imply r|s|t is true or false. // // a b // \ / - // sel1 ---- r|s + // sel1 ---- r|s|t // | // | c // | / @@ -580,15 +581,16 @@ TEST_F(ConditionalSpecializationPassTest, NotImpliedSelectorValueUsingOr) { BValue c = fb.Param("c", u32); BValue r = fb.Param("r", p->GetBitsType(1)); BValue s = fb.Param("s", p->GetBitsType(1)); + BValue t = fb.Param("t", p->GetBitsType(1)); - BValue r_or_s = fb.Or(r, s); - BValue sel1 = fb.Select(r_or_s, {a, b}); + BValue r_or_s_or_t = fb.Or({r, s, t}); + BValue sel1 = fb.Select(r_or_s_or_t, {a, b}); BValue sel0 = fb.Select(r, {sel1, c}); - // Keep r_or_s alive to the return value to avoid replacing r in the + // Keep r_or_s_or_t alive to the return value to avoid replacing r in the // expression with zero (this is not what we're testing). - XLS_ASSERT_OK_AND_ASSIGN(Function * f, - fb.BuildWithReturnValue(fb.Concat({sel0, r_or_s}))); + XLS_ASSERT_OK_AND_ASSIGN( + Function * f, fb.BuildWithReturnValue(fb.Concat({sel0, r_or_s_or_t}))); EXPECT_THAT(Run(f), IsOkAndHolds(false)); } @@ -760,6 +762,162 @@ TEST_F(ConditionalSpecializationPassTest, /*default_value=*/m::Literal(85))); } +TEST_F(ConditionalSpecializationPassTest, + ImpliedOneHotSelectorValueWithOtherUses) { + // r implies r|s. + // + // a a ------------ + // / / \ + // sel1 ---- r|s sel1 ---- r|s \ + // | \ | | + // | ... => ... | + // | | + // sel0 ---- r sel0 ---- r + // | | + // + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + Type* u32 = p->GetBitsType(32); + BValue a = fb.Param("a", u32); + BValue r = fb.Param("r", p->GetBitsType(2)); + BValue s = fb.Param("s", p->GetBitsType(2)); + + BValue r_or_s = fb.Or(r, s); + BValue other_value = fb.Literal(UBits(42, 32)); + BValue sel1 = fb.OneHotSelect(fb.BitSlice(r, /*start=*/0, /*width=*/1), {a}); + BValue sel0 = + fb.PrioritySelect(r, {sel1, other_value}, fb.Literal(UBits(0, 32))); + + // Keep r_or_s alive to the return value to avoid replacing r in the + // expression with one (this is not what we're testing). + XLS_ASSERT_OK_AND_ASSIGN( + Function * f, fb.BuildWithReturnValue(fb.Concat({sel0, sel1, r_or_s}))); + + solvers::z3::ScopedVerifyEquivalence sve{f}; + EXPECT_THAT(Run(f), IsOkAndHolds(true)); + + EXPECT_THAT(f->return_value()->operand(0), + m::PrioritySelect(m::Param("r"), + /*cases=*/{m::Param("a"), m::Literal(42)}, + /*default_value=*/m::Literal(0))); + EXPECT_THAT( + f->return_value()->operand(1), + m::OneHotSelect(m::BitSlice(m::Param("r"), /*start=*/0, /*width=*/1), + /*cases=*/{m::Param("a")})); +} + +TEST_F(ConditionalSpecializationPassTest, + ImpliedOneHotSelectorDefaultValueWithOtherUses) { + // !r implies !(r&s). + // + // d d ------------ + // / / \ + // sel1 ---- r&s sel1 ---- r&s \ + // | \ | | + // | ... => ... | + // | | + // sel0 ---- r sel0 ---- r + // | | + // + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + Type* u32 = p->GetBitsType(32); + BValue a = fb.Param("a", u32); + BValue r = fb.Param("r", p->GetBitsType(2)); + BValue s = fb.Param("s", p->GetBitsType(2)); + + BValue r_and_s = fb.And(r, s); + BValue other_value = fb.Literal(UBits(42, 32)); + BValue sel1 = fb.OneHotSelect(r_and_s, {a, other_value}); + BValue sel0 = fb.PrioritySelect(r, {a, other_value}, + /*default_value=*/sel1); + + // Keep r_and_s alive to the return value to avoid replacing r in the + // expression (this is not what we're testing). + XLS_ASSERT_OK_AND_ASSIGN( + Function * f, fb.BuildWithReturnValue(fb.Concat({sel0, sel1, r_and_s}))); + + solvers::z3::ScopedVerifyEquivalence sve{f}; + EXPECT_THAT(Run(f), IsOkAndHolds(true)); + + EXPECT_THAT(f->return_value()->operand(0), + m::PrioritySelect(m::Param("r"), + /*cases=*/{m::Param("a"), m::Literal(42)}, + /*default_value=*/m::Literal(0))); + EXPECT_THAT(f->return_value()->operand(1), + m::OneHotSelect(m::And(m::Param("r"), m::Param("s")), + /*cases=*/{m::Param("a"), m::Literal(42)})); +} + +TEST_F(ConditionalSpecializationPassTest, ImpliedValueThroughAnd) { + // r implies r|s. + // + // r a r a ------------ + // \ / \ / \ + // and and \ + // | \ | | + // | ... => ... | + // | | + // sel0 ---- r sel0 ---- r + // | | + // + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + Type* u1 = p->GetBitsType(1); + BValue a = fb.Param("a", u1); + BValue r = fb.Param("r", u1); + + BValue r_and_a = fb.And(r, a); + BValue sel0 = fb.PrioritySelect(r, {r_and_a}, fb.Literal(UBits(0, 1))); + + XLS_ASSERT_OK_AND_ASSIGN(Function * f, + fb.BuildWithReturnValue(fb.Concat({sel0, r_and_a}))); + + solvers::z3::ScopedVerifyEquivalence sve{f}; + EXPECT_THAT(Run(f), IsOkAndHolds(true)); + + EXPECT_THAT(f->return_value()->operand(0), + m::PrioritySelect(m::Param("r"), + /*cases=*/{m::Param("a")}, + /*default_value=*/m::Literal(0))); + EXPECT_THAT(f->return_value()->operand(1), + m::And(m::Param("r"), m::Param("a"))); +} + +TEST_F(ConditionalSpecializationPassTest, ImpliedValueThroughOr) { + // + // r a r a ----------- + // \ / \ / \ + // or or \ + // | \ | | + // | ... => ... | + // | | + // sel0 ---- r sel0 ---- r + // | | + // + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + Type* u1 = p->GetBitsType(1); + BValue a = fb.Param("a", u1); + BValue r = fb.Param("r", u1); + + BValue r_or_a = fb.Or(r, a); + BValue sel0 = fb.PrioritySelect(r, {fb.Literal(UBits(0, 1))}, r_or_a); + + XLS_ASSERT_OK_AND_ASSIGN(Function * f, + fb.BuildWithReturnValue(fb.Concat({sel0, r_or_a}))); + + solvers::z3::ScopedVerifyEquivalence sve{f}; + EXPECT_THAT(Run(f), IsOkAndHolds(true)); + + EXPECT_THAT(f->return_value()->operand(0), + m::PrioritySelect(m::Param("r"), + /*cases=*/{m::Literal(0)}, + /*default_value=*/m::Param("a"))); + EXPECT_THAT(f->return_value()->operand(1), + m::Or(m::Param("r"), m::Param("a"))); +} + TEST_F(ConditionalSpecializationPassTest, SendNoChangeLiteralPred) { XLS_ASSERT_OK_AND_ASSIGN(std::unique_ptr p, ParsePackageNoVerify(R"(