From 3466ae5bb3a02fc4bddee46306a55021852860e2 Mon Sep 17 00:00:00 2001 From: Eric Astor Date: Fri, 24 Jan 2025 15:42:59 -0800 Subject: [PATCH] [XLS] Fix proc-state-legalization for state-reads that influence next_values If a next-value (state write)'s predicate depends on the read of that same state element, we would previously add the predicate to the state read - introducing a cycle in the graph, and causing codegen to fail with an error. To fix this, we automatically use a variant of the write's predicate, removing all dependency on the state read's value, but guaranteeing that the new predicate is true whenever the original expression would have been true. PiperOrigin-RevId: 719457555 --- xls/ir/node_util.cc | 100 ++++++++++++ xls/ir/node_util.h | 15 ++ xls/ir/node_util_test.cc | 140 +++++++++++++++++ .../proc_state_legalization_pass.cc | 126 +++++++++++++-- .../proc_state_legalization_pass_test.cc | 145 +++++++++++++++--- 5 files changed, 491 insertions(+), 35 deletions(-) diff --git a/xls/ir/node_util.cc b/xls/ir/node_util.cc index 92c53f47d4..5bde422429 100644 --- a/xls/ir/node_util.cc +++ b/xls/ir/node_util.cc @@ -41,6 +41,7 @@ #include "xls/data_structures/leaf_type_tree.h" #include "xls/ir/bits.h" #include "xls/ir/channel.h" +#include "xls/ir/dfs_visitor.h" #include "xls/ir/function_base.h" #include "xls/ir/node.h" #include "xls/ir/nodes.h" @@ -825,4 +826,103 @@ bool AreAllLiteral(absl::Span nodes) { return absl::c_all_of(nodes, [](Node* i) -> bool { return IsLiteral(i); }); } +namespace { + +class NodeSearch : public DfsVisitorWithDefault { + public: + explicit NodeSearch(Node* target) : target_(target) {} + + absl::Status DefaultHandler(Node* node) override { + if (node == target_) { + // We've found our target, and can cancel the search. (This causes + // Node::Accept to return early, since we've already accomplished our + // goal.) + return absl::CancelledError(); + } + return absl::OkStatus(); + } + + private: + Node* target_; +}; + +} // namespace + +bool IsAncestorOf(Node* a, Node* b) { + CHECK_NE(a, nullptr); + CHECK_NE(b, nullptr); + + if (a->function_base() != b->function_base()) { + return false; + } + if (a == b) { + return false; + } + + NodeSearch visitor(a); + absl::Status visitor_status = b->Accept(&visitor); + CHECK(visitor_status.ok() || absl::IsCancelled(visitor_status)); + return visitor.IsVisited(a); +} + +absl::StatusOr RemoveNodeFromBooleanExpression(Node* to_remove, + Node* expression, + bool favored_outcome) { + XLS_RET_CHECK(expression->GetType()->IsBits()); + XLS_RET_CHECK_EQ(expression->GetType()->AsBitsOrDie()->bit_count(), 1) + << expression->ToString(); + + if (expression == to_remove) { + return expression->function_base()->MakeNode( + expression->loc(), Value(UBits((favored_outcome ? 1 : 0), 1))); + } + + if (expression->op() == Op::kNot) { + XLS_ASSIGN_OR_RETURN( + Node * new_operand, + RemoveNodeFromBooleanExpression(to_remove, expression->operand(0), + /*favored_outcome=*/!favored_outcome)); + if (new_operand == expression->operand(0)) { + // No change was necessary; apparently `to_remove` was not present. + return expression; + } else { + return expression->function_base()->MakeNode(expression->loc(), + new_operand, Op::kNot); + } + } + + if (expression->OpIn({Op::kAnd, Op::kOr, Op::kNand, Op::kNor})) { + const bool favored_operand = + favored_outcome ^ expression->OpIn({Op::kNand, Op::kNor}); + + std::vector new_operands; + new_operands.reserve(expression->operands().size()); + bool changed = false; + for (Node* operand : expression->operands()) { + XLS_ASSIGN_OR_RETURN( + Node * new_operand, + RemoveNodeFromBooleanExpression(to_remove, operand, favored_operand)); + new_operands.push_back(new_operand); + if (new_operand != operand) { + changed = true; + } + } + if (changed) { + return expression->function_base()->MakeNode( + expression->loc(), new_operands, expression->op()); + } else { + // No change was necessary; apparently `to_remove` was not present. + return expression; + } + } + + if (IsAncestorOf(to_remove, expression)) { + // We're unable to remove `to_remove` from `expression` directly, but it is + // an ancestor; just replace the entire expression with a literal. + return expression->function_base()->MakeNode( + expression->loc(), Value(UBits((favored_outcome ? 1 : 0), 1))); + } + return expression; +} + } // namespace xls diff --git a/xls/ir/node_util.h b/xls/ir/node_util.h index e9e24f84c0..1f51b374ba 100644 --- a/xls/ir/node_util.h +++ b/xls/ir/node_util.h @@ -417,6 +417,21 @@ inline absl::StatusOr UnsignedUpperBoundLiteral(Node* v, // Check if all nodes are literals bool AreAllLiteral(absl::Span nodes); +// Returns whether `a` is an ancestor of `b`; i.e., whether `b` could possibly +// be affected by a change to `a`. Returns false if `a` and `b` are the same +// node. +bool IsAncestorOf(Node* a, Node* b); + +// Removes the given node from the given boolean expression, returning the +// result. We guarantee that the result no longer depends on `to_remove`, and +// that whenever `old_expression == favored_outcome`, `new_expression == +// favored_outcome`. Note that `new_expression` may be `favored_outcome` in more +// cases than `old_expression`; if necessary, `new_expression` may always equal +// `favored_outcome`. +absl::StatusOr RemoveNodeFromBooleanExpression(Node* to_remove, + Node* expression, + bool favored_outcome); + } // namespace xls #endif // XLS_IR_NODE_UTIL_H_ diff --git a/xls/ir/node_util_test.cc b/xls/ir/node_util_test.cc index 1be3edb8b5..4353d7c14f 100644 --- a/xls/ir/node_util_test.cc +++ b/xls/ir/node_util_test.cc @@ -656,5 +656,145 @@ TEST_F(NodeUtilTest, UnsignedBoundByLiterals) { m::ULt(m::Param("foo2"), m::Literal(UBits(10, 10))), {m::Literal(UBits(10, 10))}, m::Param("foo2"))); } + +TEST_F(NodeUtilTest, IsAncestorOf) { + std::unique_ptr p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + BValue a = fb.Param("a", p->GetBitsType(100)); + BValue b = fb.Param("b", p->GetBitsType(100)); + BValue c = fb.Param("c", p->GetBitsType(100)); + BValue d = fb.Param("d", p->GetBitsType(100)); + BValue a_or_b = fb.Or({a, b}); + BValue a_or_b_or_d = fb.Or({a_or_b, d}); + BValue c_plus_d = fb.Add(c, d); + BValue and_a_b_c = fb.And({a, b, c}); + BValue big_or = fb.Or({and_a_b_c, c_plus_d}); + + EXPECT_FALSE(IsAncestorOf(a.node(), b.node())); + EXPECT_FALSE(IsAncestorOf(c.node(), c.node())); + EXPECT_FALSE(IsAncestorOf(d.node(), c.node())); + + EXPECT_TRUE(IsAncestorOf(a.node(), a_or_b.node())); + EXPECT_TRUE(IsAncestorOf(b.node(), a_or_b.node())); + EXPECT_FALSE(IsAncestorOf(c.node(), a_or_b.node())); + EXPECT_FALSE(IsAncestorOf(d.node(), a_or_b.node())); + EXPECT_FALSE(IsAncestorOf(a_or_b.node(), a_or_b.node())); + EXPECT_FALSE(IsAncestorOf(a_or_b_or_d.node(), a_or_b.node())); + EXPECT_FALSE(IsAncestorOf(and_a_b_c.node(), a_or_b.node())); + EXPECT_FALSE(IsAncestorOf(c_plus_d.node(), a_or_b.node())); + EXPECT_FALSE(IsAncestorOf(big_or.node(), a_or_b.node())); + + EXPECT_TRUE(IsAncestorOf(a.node(), a_or_b_or_d.node())); + EXPECT_TRUE(IsAncestorOf(b.node(), a_or_b_or_d.node())); + EXPECT_FALSE(IsAncestorOf(c.node(), a_or_b_or_d.node())); + EXPECT_TRUE(IsAncestorOf(d.node(), a_or_b_or_d.node())); + EXPECT_TRUE(IsAncestorOf(a_or_b.node(), a_or_b_or_d.node())); + EXPECT_FALSE(IsAncestorOf(a_or_b_or_d.node(), a_or_b_or_d.node())); + EXPECT_FALSE(IsAncestorOf(and_a_b_c.node(), a_or_b_or_d.node())); + EXPECT_FALSE(IsAncestorOf(c_plus_d.node(), a_or_b_or_d.node())); + EXPECT_FALSE(IsAncestorOf(big_or.node(), a_or_b_or_d.node())); + + EXPECT_FALSE(IsAncestorOf(a.node(), c_plus_d.node())); + EXPECT_FALSE(IsAncestorOf(b.node(), c_plus_d.node())); + EXPECT_TRUE(IsAncestorOf(c.node(), c_plus_d.node())); + EXPECT_TRUE(IsAncestorOf(d.node(), c_plus_d.node())); + EXPECT_FALSE(IsAncestorOf(a_or_b.node(), c_plus_d.node())); + EXPECT_FALSE(IsAncestorOf(a_or_b_or_d.node(), c_plus_d.node())); + EXPECT_FALSE(IsAncestorOf(and_a_b_c.node(), c_plus_d.node())); + EXPECT_FALSE(IsAncestorOf(c_plus_d.node(), c_plus_d.node())); + EXPECT_FALSE(IsAncestorOf(big_or.node(), c_plus_d.node())); + + EXPECT_TRUE(IsAncestorOf(a.node(), and_a_b_c.node())); + EXPECT_TRUE(IsAncestorOf(b.node(), and_a_b_c.node())); + EXPECT_TRUE(IsAncestorOf(c.node(), and_a_b_c.node())); + EXPECT_FALSE(IsAncestorOf(d.node(), and_a_b_c.node())); + EXPECT_FALSE(IsAncestorOf(a_or_b.node(), and_a_b_c.node())); + EXPECT_FALSE(IsAncestorOf(a_or_b_or_d.node(), and_a_b_c.node())); + EXPECT_FALSE(IsAncestorOf(and_a_b_c.node(), and_a_b_c.node())); + EXPECT_FALSE(IsAncestorOf(c_plus_d.node(), and_a_b_c.node())); + EXPECT_FALSE(IsAncestorOf(big_or.node(), and_a_b_c.node())); + + EXPECT_TRUE(IsAncestorOf(a.node(), big_or.node())); + EXPECT_TRUE(IsAncestorOf(b.node(), big_or.node())); + EXPECT_TRUE(IsAncestorOf(c.node(), big_or.node())); + EXPECT_TRUE(IsAncestorOf(d.node(), big_or.node())); + EXPECT_FALSE(IsAncestorOf(a_or_b.node(), big_or.node())); + EXPECT_FALSE(IsAncestorOf(a_or_b_or_d.node(), big_or.node())); + EXPECT_TRUE(IsAncestorOf(and_a_b_c.node(), big_or.node())); + EXPECT_TRUE(IsAncestorOf(c_plus_d.node(), big_or.node())); + EXPECT_FALSE(IsAncestorOf(big_or.node(), big_or.node())); +} + +TEST_F(NodeUtilTest, RemoveNodeFromBooleanExpression) { + std::unique_ptr p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + BValue a = fb.Param("a", p->GetBitsType(1)); + BValue b = fb.Param("b", p->GetBitsType(1)); + BValue c = fb.Param("c", p->GetBitsType(1)); + BValue d = fb.Param("d", p->GetBitsType(1)); + BValue a_or_b = fb.Or({a, b}); + BValue a_or_b_or_d = fb.Or({a_or_b, d}); + BValue c_plus_d = fb.Add(c, d); + BValue and_a_b_c = fb.And({a, b, c}); + BValue big_nor = fb.Nor({and_a_b_c, c_plus_d}); + + EXPECT_THAT(RemoveNodeFromBooleanExpression(a.node(), a_or_b.node(), true), + IsOkAndHolds(m::Or(m::Literal(1), m::Param("b")))); + EXPECT_THAT(RemoveNodeFromBooleanExpression(a.node(), a_or_b.node(), false), + IsOkAndHolds(m::Or(m::Literal(0), m::Param("b")))); + EXPECT_THAT(RemoveNodeFromBooleanExpression(c.node(), a_or_b.node(), true), + IsOkAndHolds(a_or_b.node())); + EXPECT_THAT(RemoveNodeFromBooleanExpression(c.node(), a_or_b.node(), false), + IsOkAndHolds(a_or_b.node())); + + EXPECT_THAT( + RemoveNodeFromBooleanExpression(b.node(), a_or_b_or_d.node(), true), + IsOkAndHolds(m::Or(m::Or(m::Param("a"), m::Literal(1)), m::Param("d")))); + EXPECT_THAT( + RemoveNodeFromBooleanExpression(b.node(), a_or_b_or_d.node(), false), + IsOkAndHolds(m::Or(m::Or(m::Param("a"), m::Literal(0)), m::Param("d")))); + + EXPECT_THAT(RemoveNodeFromBooleanExpression(c.node(), c_plus_d.node(), true), + IsOkAndHolds(m::Literal(1))); + EXPECT_THAT(RemoveNodeFromBooleanExpression(c.node(), c_plus_d.node(), false), + IsOkAndHolds(m::Literal(0))); + + EXPECT_THAT( + RemoveNodeFromBooleanExpression(a.node(), big_nor.node(), true), + IsOkAndHolds(m::Nor(m::And(m::Literal(0), m::Param("b"), m::Param("c")), + m::Add(m::Param("c"), m::Param("d"))))); + EXPECT_THAT( + RemoveNodeFromBooleanExpression(a.node(), big_nor.node(), false), + IsOkAndHolds(m::Nor(m::And(m::Literal(1), m::Param("b"), m::Param("c")), + m::Add(m::Param("c"), m::Param("d"))))); + + EXPECT_THAT( + RemoveNodeFromBooleanExpression(b.node(), big_nor.node(), true), + IsOkAndHolds(m::Nor(m::And(m::Param("a"), m::Literal(0), m::Param("c")), + m::Add(m::Param("c"), m::Param("d"))))); + EXPECT_THAT( + RemoveNodeFromBooleanExpression(b.node(), big_nor.node(), false), + IsOkAndHolds(m::Nor(m::And(m::Param("a"), m::Literal(1), m::Param("c")), + m::Add(m::Param("c"), m::Param("d"))))); + + EXPECT_THAT( + RemoveNodeFromBooleanExpression(c.node(), big_nor.node(), true), + IsOkAndHolds(m::Nor(m::And(m::Param("a"), m::Param("b"), m::Literal(0)), + m::Literal(0)))); + EXPECT_THAT( + RemoveNodeFromBooleanExpression(c.node(), big_nor.node(), false), + IsOkAndHolds(m::Nor(m::And(m::Param("a"), m::Param("b"), m::Literal(1)), + m::Literal(1)))); + + EXPECT_THAT( + RemoveNodeFromBooleanExpression(d.node(), big_nor.node(), true), + IsOkAndHolds(m::Nor(m::And(m::Param("a"), m::Param("b"), m::Param("c")), + m::Literal(0)))); + EXPECT_THAT( + RemoveNodeFromBooleanExpression(d.node(), big_nor.node(), false), + IsOkAndHolds(m::Nor(m::And(m::Param("a"), m::Param("b"), m::Param("c")), + m::Literal(1)))); +} + } // namespace } // namespace xls diff --git a/xls/scheduling/proc_state_legalization_pass.cc b/xls/scheduling/proc_state_legalization_pass.cc index 6a8a2caeb7..01fe26e100 100644 --- a/xls/scheduling/proc_state_legalization_pass.cc +++ b/xls/scheduling/proc_state_legalization_pass.cc @@ -68,7 +68,7 @@ absl::StatusOr LegalizeStateReadPredicate( predicates_set.reserve(next_values.size()); for (Next* next : next_values) { if (next->state_read() == next->value()) { - // This is a no-op next_value; we can narrow it to the case where the + // This is a no-op next_value; we will narrow it to the case where the // state read is active instead. continue; } @@ -78,8 +78,16 @@ absl::StatusOr LegalizeStateReadPredicate( return true; } - predicates.push_back(*next->predicate()); - predicates_set.insert(*next->predicate()); + // We can't just add the next-value predicate to the state read predicate + // directly, because it may create a cycle; the state read may be part of + // what decides whether the next-value is active. To fix this, we remove the + // state read from the predicate (erring on the side of reading too often). + XLS_ASSIGN_OR_RETURN( + Node * safe_predicate, + RemoveNodeFromBooleanExpression(state_read, *next->predicate(), + /*favored_outcome=*/true)); + predicates.push_back(safe_predicate); + predicates_set.insert(safe_predicate); } if (predicates.empty()) { // There are no non-trivial next_value nodes; any predicate is fine. @@ -143,13 +151,24 @@ absl::StatusOr LegalizeNoOpNextPredicate( Node* new_predicate = *state_read->predicate(); if (next->predicate().has_value()) { Node* old_predicate = *next->predicate(); - if (old_predicate == *state_read->predicate() || - (old_predicate->op() == Op::kAnd && - absl::c_contains(old_predicate->operands(), - *state_read->predicate()))) { - // Already restricted to the case where the state read is active. + + // Check if we're already trivially restricted to the case where the state + // read is active; i.e., taking A = `old_predicate` and B = `new_predicate`, + // we're already fine if we can verify that A -> B. + if (old_predicate == new_predicate) { + return false; + } + if (old_predicate->op() == Op::kAnd && + absl::c_contains(old_predicate->operands(), new_predicate)) { + // A && X -> A, so we're already restricted. + return false; + } + if (new_predicate->op() == Op::kOr && + absl::c_contains(new_predicate->operands(), old_predicate)) { + // A -> A || X, so we're already restricted. return false; } + std::vector predicates; if (old_predicate->op() == Op::kAnd) { predicates.reserve(1 + old_predicate->operands().size()); @@ -254,9 +273,9 @@ absl::StatusOr AddMutualExclusionAsserts( bool changed = false; for (StateElement* state_element : proc->StateElements()) { - XLS_ASSIGN_OR_RETURN(bool asserts_added, AddMutualExclusionAssert( - proc, state_element, options)); - if (asserts_added) { + XLS_ASSIGN_OR_RETURN(bool assert_added, AddMutualExclusionAssert( + proc, state_element, options)); + if (assert_added) { VLOG(4) << "Added mutual exclusion assert for state element: " << state_element->name(); changed = true; @@ -266,6 +285,81 @@ absl::StatusOr AddMutualExclusionAsserts( return changed; } +absl::StatusOr AddWriteWithoutReadAsserts( + Proc* proc, StateElement* state_element, + const SchedulingPassOptions& options) { + StateRead* state_read = proc->GetStateRead(state_element); + if (!state_read->predicate().has_value()) { + return false; + } + + const absl::btree_set& next_values = + proc->next_values(proc->GetStateRead(state_element)); + if (next_values.empty()) { + return false; + } + + std::vector predicate_list; + for (Next* next : next_values) { + XLS_RET_CHECK(next->predicate().has_value()); + XLS_ASSIGN_OR_RETURN( + Node * next_not_triggered, + proc->MakeNodeWithName( + SourceInfo(), *next->predicate(), Op::kNot, + absl::StrCat("__", state_element->name(), "__next_", next->id(), + "_not_triggered"))); + XLS_ASSIGN_OR_RETURN( + Node * no_write_without_read, + proc->MakeNodeWithName( + SourceInfo(), + absl::MakeConstSpan({*state_read->predicate(), next_not_triggered}), + Op::kOr, + absl::StrCat("__", state_element->name(), "__no_next_", next->id(), + "_without_read"))); + std::string label = absl::StrCat("__", state_element->name(), "__next_", + next->id(), "_without_read_assert"); + if (proc->HasNode(label)) { + return absl::InternalError(absl::StrFormat( + "Write-without-read assert already exists for next_value node '%s'; " + "was this pass run twice? assert label: %s", + next->GetName(), label)); + } + + XLS_ASSIGN_OR_RETURN(Node * tkn, + proc->MakeNode(SourceInfo(), Value::Token())); + XLS_RETURN_IF_ERROR( + proc->MakeNodeWithName( + SourceInfo(), tkn, + /*condition=*/no_write_without_read, + /*message=*/ + absl::StrCat(next->GetName(), + " fired while read disabled for state element: ", + state_element->name()), + /*label=*/label, + /*original_label=*/std::nullopt, + /*name=*/label) + .status()); + } + return true; +} + +absl::StatusOr AddWriteWithoutReadAsserts( + Proc* proc, const SchedulingPassOptions& options) { + bool changed = false; + + for (StateElement* state_element : proc->StateElements()) { + XLS_ASSIGN_OR_RETURN(bool assert_added, AddWriteWithoutReadAsserts( + proc, state_element, options)); + if (assert_added) { + VLOG(4) << "Added write-without-read assert for state element: " + << state_element->name(); + changed = true; + } + } + + return changed; +} + absl::StatusOr AddDefaultNextValue(Proc* proc, StateElement* state_element, const SchedulingPassOptions& options) { @@ -425,9 +519,15 @@ absl::StatusOr ProcStateLegalizationPass::RunOnFunctionBaseInternal( changed = true; } - XLS_ASSIGN_OR_RETURN(bool asserts_added, + XLS_ASSIGN_OR_RETURN(bool mutex_asserts_added, AddMutualExclusionAsserts(proc, options)); - if (asserts_added) { + if (mutex_asserts_added) { + changed = true; + } + + XLS_ASSIGN_OR_RETURN(bool write_without_read_asserts_added, + AddWriteWithoutReadAsserts(proc, options)); + if (write_without_read_asserts_added) { changed = true; } diff --git a/xls/scheduling/proc_state_legalization_pass_test.cc b/xls/scheduling/proc_state_legalization_pass_test.cc index 6c3278fd82..90bc1b2904 100644 --- a/xls/scheduling/proc_state_legalization_pass_test.cc +++ b/xls/scheduling/proc_state_legalization_pass_test.cc @@ -357,6 +357,7 @@ TEST_F(ProcStateLegalizationPassTest, ProcWithPredicatedStateRead) { pb.Next(y, pb.Add(y, pb.Literal(UBits(1, 32))), x_multiple_of_3); XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, pb.Build()); + ScopedRecordIr sri(p.get()); ASSERT_THAT(Run(proc), IsOkAndHolds(true)); EXPECT_EQ(proc->GetStateRead(*proc->GetStateElement("x"))->predicate(), @@ -380,7 +381,87 @@ TEST_F(ProcStateLegalizationPassTest, ProcWithPredicatedStateRead) { m::Not(m::Eq(m::UMod(m::StateRead("x"), m::Literal(3)), m::Literal(0))))))); - EXPECT_THAT(proc->nodes(), Each(Not(m::Assert()))); + std::vector asserts; + absl::c_copy_if(proc->nodes(), std::back_inserter(asserts), + [](Node* node) { return node->Is(); }); + EXPECT_THAT( + asserts, + UnorderedElementsAre(m::Assert( + _, m::Or(m::Or(m::Eq(m::UMod(m::StateRead("x"), m::Literal(2)), + m::Literal(0)), + m::Eq(m::UMod(m::StateRead("x"), m::Literal(3)), + m::Literal(0))), + m::Not(m::Eq(m::UMod(m::StateRead("x"), m::Literal(3)), + m::Literal(0))))))); +} + +TEST_F(ProcStateLegalizationPassTest, + ProcWithPredicatedStateReadAndPotentialCycle) { + auto p = CreatePackage(); + ProcBuilder pb("p", p.get()); + BValue x = pb.StateElement("x", Value(UBits(0, 32))); + BValue x_even = + pb.Eq(pb.UMod(x, pb.Literal(UBits(2, 32))), pb.Literal(UBits(0, 32))); + BValue x_multiple_of_3 = + pb.Eq(pb.UMod(x, pb.Literal(UBits(3, 32))), pb.Literal(UBits(0, 32))); + BValue y = pb.StateElement("y", Value(UBits(0, 32)), + /*read_predicate=*/x_even); + BValue y_even = + pb.Eq(pb.UMod(y, pb.Literal(UBits(2, 32))), pb.Literal(UBits(0, 32))); + pb.Next(x, pb.Add(x, pb.Literal(UBits(1, 32)))); + pb.Next(y, pb.Add(y, pb.Literal(UBits(1, 32))), + pb.And(x_multiple_of_3, y_even)); + XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, pb.Build()); + + ScopedRecordIr sri(p.get()); + ASSERT_THAT(Run(proc), IsOkAndHolds(true)); + + EXPECT_EQ(proc->GetStateRead(*proc->GetStateElement("x"))->predicate(), + std::nullopt); + EXPECT_THAT( + proc->GetStateRead(*proc->GetStateElement("y"))->predicate(), + Optional(m::Or( + m::Eq(m::UMod(m::StateRead("x"), m::Literal(2)), m::Literal(0)), + m::And( + m::Eq(m::UMod(m::StateRead("x"), m::Literal(3)), m::Literal(0)), + m::Literal(1))))); + EXPECT_THAT( + proc->next_values(proc->GetStateRead(*proc->GetStateElement("y"))), + UnorderedElementsAre( + m::Next(m::StateRead("y"), m::Add(m::StateRead("y"), m::Literal(1)), + m::And(m::Eq(m::UMod(m::StateRead("x"), m::Literal(3)), + m::Literal(0)), + m::Eq(m::UMod(m::StateRead("y"), m::Literal(2)), + m::Literal(0)))), + m::Next( + m::StateRead("y"), m::StateRead("y"), + m::And( + m::Or(m::Eq(m::UMod(m::StateRead("x"), m::Literal(2)), + m::Literal(0)), + m::And(m::Eq(m::UMod(m::StateRead("x"), m::Literal(3)), + m::Literal(0)), + m::Literal(1))), + m::Not(m::And(m::Eq(m::UMod(m::StateRead("x"), m::Literal(3)), + m::Literal(0)), + m::Eq(m::UMod(m::StateRead("y"), m::Literal(2)), + m::Literal(0)))))))); + + std::vector asserts; + absl::c_copy_if(proc->nodes(), std::back_inserter(asserts), + [](Node* node) { return node->Is(); }); + EXPECT_THAT( + asserts, + UnorderedElementsAre(m::Assert( + _, + m::Or(m::Or(m::Eq(m::UMod(m::StateRead("x"), m::Literal(2)), + m::Literal(0)), + m::And(m::Eq(m::UMod(m::StateRead("x"), m::Literal(3)), + m::Literal(0)), + m::Literal(1))), + m::Not(m::And(m::Eq(m::UMod(m::StateRead("x"), m::Literal(3)), + m::Literal(0)), + m::Eq(m::UMod(m::StateRead("y"), m::Literal(2)), + m::Literal(0)))))))); } TEST_F(ProcStateLegalizationPassTest, @@ -399,6 +480,7 @@ TEST_F(ProcStateLegalizationPassTest, pb.Next(y, y, x_not_multiple_of_3); XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, pb.Build()); + ScopedRecordIr sri(p.get()); ASSERT_THAT(Run(proc), IsOkAndHolds(true)); const testing::Matcher expected_read_predicate = m::Or( @@ -413,19 +495,26 @@ TEST_F(ProcStateLegalizationPassTest, m::StateRead("y"), m::Add(m::StateRead("y"), m::Literal(1)), m::Eq(m::UMod(m::StateRead("x"), m::Literal(3)), m::Literal(0))), m::Next(m::StateRead("y"), m::StateRead("y"), - m::And(expected_read_predicate, - m::Not(m::Eq(m::UMod(m::StateRead("x"), m::Literal(3)), - m::Literal(0))))))); + m::Not(m::Eq(m::UMod(m::StateRead("x"), m::Literal(3)), + m::Literal(0)))))); std::vector asserts; absl::c_copy_if(proc->nodes(), std::back_inserter(asserts), [](Node* node) { return node->Is(); }); - EXPECT_THAT(asserts, - UnorderedElementsAre(m::Assert( - _, m::Eq(m::Concat(x_multiple_of_3.node(), - m::And(expected_read_predicate, - x_not_multiple_of_3.node())), - m::BitSlice(m::OneHot(m::Concat())))))); + EXPECT_THAT( + asserts, + UnorderedElementsAre( + m::Assert(_, m::Eq(m::Concat(x_multiple_of_3.node(), + x_not_multiple_of_3.node()), + m::BitSlice(m::OneHot(m::Concat())))), + m::Assert( + _, m::Or(expected_read_predicate, + m::Not(m::Eq(m::UMod(m::StateRead("x"), m::Literal(3)), + m::Literal(0))))), + m::Assert(_, m::Or(expected_read_predicate, + m::Not(m::Not(m::Eq( + m::UMod(m::StateRead("x"), m::Literal(3)), + m::Literal(0)))))))); } TEST_F(ProcStateLegalizationPassTest, @@ -445,13 +534,20 @@ TEST_F(ProcStateLegalizationPassTest, pb.Next(y, y, x_not_multiple_of_3); XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, pb.Build()); + ScopedRecordIr sri(p.get()); ASSERT_THAT(Run(proc), IsOkAndHolds(true)); - const testing::Matcher expected_read_predicate = + const ::testing::Matcher match_x_even_or_threeven = m::Or(m::Eq(m::UMod(m::StateRead("x"), m::Literal(2)), m::Literal(0)), m::Eq(m::UMod(m::StateRead("x"), m::Literal(3)), m::Literal(0))); EXPECT_THAT(proc->GetStateRead(*proc->GetStateElement("y"))->predicate(), - Optional(expected_read_predicate)); + Optional(match_x_even_or_threeven)); + + const ::testing::Matcher + match_x_even_or_threeven_and_not_threeven = + m::And(match_x_even_or_threeven, + m::Not(m::Eq(m::UMod(m::StateRead("x"), m::Literal(3)), + m::Literal(0)))); EXPECT_THAT( proc->next_values(proc->GetStateRead(*proc->GetStateElement("y"))), UnorderedElementsAre( @@ -459,20 +555,25 @@ TEST_F(ProcStateLegalizationPassTest, m::StateRead("y"), m::Add(m::StateRead("y"), m::Literal(1)), m::Eq(m::UMod(m::StateRead("x"), m::Literal(3)), m::Literal(0))), m::Next(m::StateRead("y"), m::StateRead("y"), - m::And(expected_read_predicate, - m::Not(m::Eq(m::UMod(m::StateRead("x"), m::Literal(3)), - m::Literal(0))))))) - << p->DumpIr(); + match_x_even_or_threeven_and_not_threeven))); std::vector asserts; absl::c_copy_if(proc->nodes(), std::back_inserter(asserts), [](Node* node) { return node->Is(); }); - EXPECT_THAT(asserts, - UnorderedElementsAre(m::Assert( - _, m::Eq(m::Concat(x_multiple_of_3.node(), - m::And(expected_read_predicate, - x_not_multiple_of_3.node())), - m::BitSlice(m::OneHot(m::Concat())))))); + EXPECT_THAT( + asserts, + UnorderedElementsAre( + m::Assert(_, + m::Eq(m::Concat(x_multiple_of_3.node(), + match_x_even_or_threeven_and_not_threeven), + m::BitSlice(m::OneHot(m::Concat())))), + m::Assert( + _, m::Or(match_x_even_or_threeven, + m::Not(m::Eq(m::UMod(m::StateRead("x"), m::Literal(3)), + m::Literal(0))))), + m::Assert(_, + m::Or(match_x_even_or_threeven, + m::Not(match_x_even_or_threeven_and_not_threeven))))); } } // namespace