From d41fd8628a8287ad6ed2156a2c33fa76a277e607 Mon Sep 17 00:00:00 2001 From: Eric Astor Date: Mon, 30 Dec 2024 15:57:28 -0800 Subject: [PATCH] [XLS] Allow conditional specialization to recognize basic implications By recognizing certain special cases (equality, NOT, etc.), we can lift some conditions to apply to earlier nodes, letting conditional specialization infer more about the context in which each node is used. PiperOrigin-RevId: 710811273 --- xls/passes/BUILD | 1 + xls/passes/conditional_specialization_pass.cc | 179 +++++++++++++++--- .../conditional_specialization_pass_test.cc | 80 ++++++++ 3 files changed, 229 insertions(+), 31 deletions(-) diff --git a/xls/passes/BUILD b/xls/passes/BUILD index c3dd4a4e14..b01cf149c8 100644 --- a/xls/passes/BUILD +++ b/xls/passes/BUILD @@ -1764,6 +1764,7 @@ cc_library( "//xls/common:module_initializer", "//xls/common/status:ret_check", "//xls/common/status:status_macros", + "//xls/data_structures:leaf_type_tree", "//xls/ir", "//xls/ir:bits", "//xls/ir:bits_ops", diff --git a/xls/passes/conditional_specialization_pass.cc b/xls/passes/conditional_specialization_pass.cc index bbf6d74fb4..d62e55eddc 100644 --- a/xls/passes/conditional_specialization_pass.cc +++ b/xls/passes/conditional_specialization_pass.cc @@ -40,6 +40,7 @@ #include "xls/common/module_initializer.h" #include "xls/common/status/ret_check.h" #include "xls/common/status/status_macros.h" +#include "xls/data_structures/leaf_type_tree.h" #include "xls/ir/bits.h" #include "xls/ir/bits_ops.h" #include "xls/ir/function_base.h" @@ -187,6 +188,110 @@ class ConditionSet { CHECK_LE(conditions_.size(), kMaxConditions); } + void AddImpliedConditions(const Condition& condition, + QueryEngine& query_engine) { + AddCondition(condition); + + if (condition.node->op() == Op::kNot && + !ternary_ops::AllUnknown(condition.value) && + !condition.node->operand(0)->Is()) { + Node* operand = condition.node->operand(0); + + VLOG(4) << "Lifting a known negated value: not(" << operand->GetName() + << ") == " << xls::ToString(condition.value); + + TernaryVector negated = condition.value; + for (int64_t i = 0; i < negated.size(); ++i) { + if (negated[i] == TernaryValue::kKnownOne) { + negated[i] = TernaryValue::kKnownZero; + } else if (negated[i] == TernaryValue::kKnownZero) { + negated[i] = TernaryValue::kKnownOne; + } + } + AddCondition(Condition{.node = operand, .value = negated}); + } + + if (condition.node->op() == Op::kOr && + absl::c_any_of(condition.value, [](TernaryValue v) { + return v == TernaryValue::kKnownZero; + })) { + VLOG(4) << "Lifting known bits through an OR; or(" + << absl::StrJoin(condition.node->operands(), ", ", + [](std::string* out, Node* node) { + absl::StrAppend(out, node->GetName()); + }) + << ") == " << xls::ToString(condition.value); + TernaryVector lifted = condition.value; + for (int64_t i = 0; i < lifted.size(); ++i) { + if (lifted[i] == TernaryValue::kKnownOne) { + lifted[i] = TernaryValue::kUnknown; + } + } + for (Node* operand : condition.node->operands()) { + if (operand->Is()) { + continue; + } + AddImpliedConditions(Condition{.node = operand, .value = lifted}, + query_engine); + } + } + + if (condition.node->op() == Op::kAnd && + absl::c_any_of(condition.value, [](TernaryValue v) { + return v == TernaryValue::kKnownOne; + })) { + VLOG(4) << "Lifting known bits through an AND; and(" + << absl::StrJoin(condition.node->operands(), ", ", + [](std::string* out, Node* node) { + absl::StrAppend(out, node->GetName()); + }) + << ") == " << xls::ToString(condition.value); + TernaryVector lifted = condition.value; + for (int64_t i = 0; i < lifted.size(); ++i) { + if (lifted[i] == TernaryValue::kKnownZero) { + lifted[i] = TernaryValue::kUnknown; + } + } + for (Node* operand : condition.node->operands()) { + if (operand->Is()) { + continue; + } + AddImpliedConditions(Condition{.node = operand, .value = lifted}, + query_engine); + } + } + + if ((condition.node->op() == Op::kEq && + ternary_ops::IsKnownOne(condition.value)) || + (condition.node->op() == Op::kNe && + ternary_ops::IsKnownZero(condition.value))) { + Node* lhs = condition.node->operand(0); + Node* rhs = condition.node->operand(1); + + VLOG(4) << "Converting a known equality to direct conditions: " + << lhs->GetName() << " == " << rhs->GetName(); + + if (std::optional> lhs_ternary = + query_engine.GetTernary(lhs); + !rhs->Is() && rhs->GetType()->IsBits() && + lhs_ternary.has_value() && + !ternary_ops::AllUnknown(lhs_ternary->Get({}))) { + AddImpliedConditions( + Condition{.node = rhs, .value = lhs_ternary->Get({})}, + query_engine); + } + if (std::optional> rhs_ternary = + query_engine.GetTernary(rhs); + !lhs->Is() && lhs->GetType()->IsBits() && + rhs_ternary.has_value() && + !ternary_ops::AllUnknown(rhs_ternary->Get({}))) { + AddImpliedConditions( + Condition{.node = lhs, .value = rhs_ternary->Get({})}, + query_engine); + } + } + } + absl::Span conditions() const { return conditions_; } std::string ToString() const { @@ -571,11 +676,13 @@ absl::StatusOr ConditionalSpecializationPass::RunOnFunctionBaseInternal( ConditionSet edge_set = set; // If this case is selected, we know the selector is exactly // `case_no`. - edge_set.AddCondition(Condition{ - .node = select->selector(), - .value = ternary_ops::BitsToTernary( - UBits(case_no, select->selector()->BitCountOrDie())), - }); + edge_set.AddImpliedConditions( + Condition{ + .node = select->selector(), + .value = ternary_ops::BitsToTernary( + UBits(case_no, select->selector()->BitCountOrDie())), + }, + query_engine); condition_map.SetEdgeConditionSet(node, case_no + 1, std::move(edge_set)); } @@ -591,10 +698,12 @@ absl::StatusOr ConditionalSpecializationPass::RunOnFunctionBaseInternal( TernaryVector selector_value(select->selector()->BitCountOrDie(), TernaryValue::kUnknown); selector_value[case_no] = TernaryValue::kKnownOne; - edge_set.AddCondition(Condition{ - .node = select->selector(), - .value = selector_value, - }); + edge_set.AddImpliedConditions( + Condition{ + .node = select->selector(), + .value = selector_value, + }, + query_engine); condition_map.SetEdgeConditionSet(node, case_no + 1, std::move(edge_set)); } @@ -612,22 +721,26 @@ absl::StatusOr ConditionalSpecializationPass::RunOnFunctionBaseInternal( known_bits.SetRange(0, case_no + 1); Bits known_bits_values = Bits::PowerOfTwo(case_no, select->selector()->BitCountOrDie()); - edge_set.AddCondition(Condition{ - .node = select->selector(), - .value = - ternary_ops::FromKnownBits(known_bits, known_bits_values), - }); + edge_set.AddImpliedConditions( + Condition{ + .node = select->selector(), + .value = + ternary_ops::FromKnownBits(known_bits, known_bits_values), + }, + query_engine); condition_map.SetEdgeConditionSet(node, case_no + 1, std::move(edge_set)); } ConditionSet edge_set = set; // If the default value is selected, we know all the bits of the // selector are zero. - edge_set.AddCondition(Condition{ - .node = select->selector(), - .value = TernaryVector(select->selector()->BitCountOrDie(), - TernaryValue::kKnownZero), - }); + edge_set.AddImpliedConditions( + Condition{ + .node = select->selector(), + .value = TernaryVector(select->selector()->BitCountOrDie(), + TernaryValue::kKnownZero), + }, + query_engine); condition_map.SetEdgeConditionSet(node, select->cases().size() + 1, std::move(edge_set)); } @@ -649,14 +762,16 @@ absl::StatusOr ConditionalSpecializationPass::RunOnFunctionBaseInternal( // ArrayUpdate is a no-op if any index is out of range; as such, it only // cares about the update value if all indices are in range. - edge_set.AddCondition(Condition{ - .node = index, - .value = - TernaryVector(index->BitCountOrDie(), TernaryValue::kUnknown), - .range = IntervalSet::Of( - {Interval::RightOpen(UBits(0, index->BitCountOrDie()), - UBits(array_type->AsArrayOrDie()->size(), - index->BitCountOrDie()))})}); + edge_set.AddImpliedConditions( + Condition{ + .node = index, + .value = TernaryVector(index->BitCountOrDie(), + TernaryValue::kUnknown), + .range = IntervalSet::Of({Interval::RightOpen( + UBits(0, index->BitCountOrDie()), + UBits(array_size, index->BitCountOrDie()))}), + }, + query_engine); array_type = array_type->AsArrayOrDie()->element_type(); } @@ -680,8 +795,9 @@ absl::StatusOr ConditionalSpecializationPass::RunOnFunctionBaseInternal( node->GetName(), predicate->GetName(), send->data()->GetName()); ConditionSet edge_set = set; - edge_set.AddCondition( - Condition{.node = predicate, .value = {TernaryValue::kKnownOne}}); + edge_set.AddImpliedConditions( + Condition{.node = predicate, .value = {TernaryValue::kKnownOne}}, + query_engine); condition_map.SetEdgeConditionSet(node, Send::kDataOperand, std::move(edge_set)); } @@ -703,8 +819,9 @@ absl::StatusOr ConditionalSpecializationPass::RunOnFunctionBaseInternal( node->GetName(), predicate->GetName(), next->value()->GetName()); ConditionSet edge_set = set; - edge_set.AddCondition( - Condition{.node = predicate, .value = {TernaryValue::kKnownOne}}); + edge_set.AddImpliedConditions( + Condition{.node = predicate, .value = {TernaryValue::kKnownOne}}, + query_engine); condition_map.SetEdgeConditionSet(node, Next::kValueOperand, std::move(edge_set)); } diff --git a/xls/passes/conditional_specialization_pass_test.cc b/xls/passes/conditional_specialization_pass_test.cc index ac320dff97..647f5b0359 100644 --- a/xls/passes/conditional_specialization_pass_test.cc +++ b/xls/passes/conditional_specialization_pass_test.cc @@ -1072,5 +1072,85 @@ TEST_F(ConditionalSpecializationPassTest, NextValueChange) { m::StateRead("value2"), m::Eq()))); } +TEST_F(ConditionalSpecializationPassTest, ImpliedConditionThroughNot) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + Type* u1 = p->GetBitsType(1); + BValue a = fb.Param("a", u1); + BValue b = fb.Param("b", u1); + BValue s = fb.Not(a); + BValue result = fb.Select(s, {a, b}); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(result)); + + solvers::z3::ScopedVerifyEquivalence sve{f}; + EXPECT_THAT(Run(f, /*use_bdd=*/false), IsOkAndHolds(true)); + EXPECT_THAT(f->return_value(), + m::Select(m::Not(m::Param("a")), {m::Literal(1), m::Param("b")})); +} + +TEST_F(ConditionalSpecializationPassTest, ImpliedConditionThroughEq) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + Type* u1 = p->GetBitsType(1); + BValue a = fb.Param("a", u1); + BValue b = fb.Param("b", u1); + BValue s = fb.Eq(b, fb.Literal(UBits(1, 1))); + BValue result = fb.Select(s, {a, b}); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(result)); + + solvers::z3::ScopedVerifyEquivalence sve{f}; + EXPECT_THAT(Run(f, /*use_bdd=*/false), IsOkAndHolds(true)); + EXPECT_THAT(f->return_value(), m::Select(m::Eq(m::Param("b"), m::Literal(1)), + {m::Param("a"), m::Literal(1)})); +} + +TEST_F(ConditionalSpecializationPassTest, ImpliedConditionThroughNe) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + Type* u1 = p->GetBitsType(1); + BValue a = fb.Param("a", u1); + BValue b = fb.Param("b", u1); + BValue s = fb.Ne(a, fb.Literal(UBits(1, 1))); + BValue result = fb.Select(s, {a, b}); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(result)); + + solvers::z3::ScopedVerifyEquivalence sve{f}; + EXPECT_THAT(Run(f, /*use_bdd=*/false), IsOkAndHolds(true)); + EXPECT_THAT(f->return_value(), m::Select(m::Ne(m::Param("a"), m::Literal(1)), + {m::Literal(1), m::Param("b")})); +} + +TEST_F(ConditionalSpecializationPassTest, ImpliedConditionThroughOr) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + Type* u1 = p->GetBitsType(1); + BValue a = fb.Param("a", u1); + BValue b = fb.Param("b", u1); + BValue s = fb.Or(a, b); + BValue result = fb.Select(s, {a, b}); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(result)); + + solvers::z3::ScopedVerifyEquivalence sve{f}; + EXPECT_THAT(Run(f, /*use_bdd=*/false), IsOkAndHolds(true)); + EXPECT_THAT(f->return_value(), m::Select(m::Or(m::Param("a"), m::Param("b")), + {m::Literal(0), m::Param("b")})); +} + +TEST_F(ConditionalSpecializationPassTest, ImpliedConditionThroughAnd) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + Type* u1 = p->GetBitsType(1); + BValue a = fb.Param("a", u1); + BValue b = fb.Param("b", u1); + BValue s = fb.And(a, b); + BValue result = fb.Select(s, {a, b}); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(result)); + + solvers::z3::ScopedVerifyEquivalence sve{f}; + EXPECT_THAT(Run(f, /*use_bdd=*/false), IsOkAndHolds(true)); + EXPECT_THAT(f->return_value(), m::Select(m::And(m::Param("a"), m::Param("b")), + {m::Param("a"), m::Literal(1)})); +} + } // namespace } // namespace xls