Skip to content

Commit

Permalink
[XLS] Finish the IR parser's handling of predicated state reads
Browse files Browse the repository at this point in the history
Previously dropped the predicate operand, due to an incomplete implementation. Now properly uses an optional keyword argument for this operand.

PiperOrigin-RevId: 719458997
  • Loading branch information
ericastor authored and copybara-github committed Jan 24, 2025
1 parent 3466ae5 commit 572e484
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 4 deletions.
1 change: 1 addition & 0 deletions xls/ir/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -868,6 +868,7 @@ cc_test(
":ir",
":ir_parser",
":op",
":state_element",
":type",
":value",
"//xls/common:casts",
Expand Down
8 changes: 7 additions & 1 deletion xls/ir/ir_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -793,7 +793,9 @@ absl::StatusOr<BValue> Parser::ParseNode(
case Op::kStateRead: {
IdentifierString* state_name =
arg_parser.AddKeywordArg<IdentifierString>("state_element");
XLS_ASSIGN_OR_RETURN(operands, arg_parser.Run(ArgParser::kVariadic));
std::optional<BValue>* predicate =
arg_parser.AddOptionalKeywordArg<BValue>("predicate");
XLS_ASSIGN_OR_RETURN(operands, arg_parser.Run(/*arity=*/0));
auto it = name_to_value->find(state_name->value);
if (it == name_to_value->end()) {
return absl::InvalidArgumentError(
Expand All @@ -802,6 +804,10 @@ absl::StatusOr<BValue> Parser::ParseNode(
state_name->value, op_token.pos().ToHumanString()));
}
bvalue = it->second;
if (predicate->has_value()) {
XLS_RETURN_IF_ERROR(
bvalue.node()->As<StateRead>()->SetPredicate((*predicate)->node()));
}
break;
}
case Op::kNext: {
Expand Down
4 changes: 4 additions & 0 deletions xls/ir/ir_parser_round_trip_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,10 @@ TEST(IrParserRoundTripTest, ParseProcWithExplicitNext) {
ParsePackageAndCheckDump(TestName());
}

TEST(IrParserRoundTripTest, ParseProcWithPredicatedStateRead) {
ParsePackageAndCheckDump(TestName());
}

TEST(IrParserRoundTripTest, ParseNewStyleProc) {
ParsePackageAndCheckDump(TestName());
}
Expand Down
46 changes: 46 additions & 0 deletions xls/ir/ir_parser_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "xls/ir/nodes.h"
#include "xls/ir/op.h"
#include "xls/ir/package.h"
#include "xls/ir/state_element.h"
#include "xls/ir/type.h"
#include "xls/ir/value.h"

Expand Down Expand Up @@ -579,6 +580,51 @@ proc foo(tok1: token, x: bits[32], tok2: token, y: (), z: bits[32], init={token,
EXPECT_EQ(proc->GetStateElementCount(), 5);
}

TEST(IrParserTest, ProcWithExplicitStateRead) {
std::string program = R"(
package test
proc foo( x: bits[32], y: (), z: bits[32], init={42, (), 123}) {
x: bits[32] = state_read(state_element=x)
sum: bits[32] = add(x, z)
next (x, y, sum)
}
)";
XLS_ASSERT_OK_AND_ASSIGN(auto package, Parser::ParsePackage(program));
XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, package->GetProc("foo"));
EXPECT_EQ(proc->GetStateElementCount(), 3);
XLS_ASSERT_OK_AND_ASSIGN(StateElement * x, proc->GetStateElement("x"));
EXPECT_THAT(proc->GetStateRead(x)->predicate(), std::nullopt);
}

TEST(IrParserTest, ProcWithPredicatedStateRead) {
std::string program = R"(
package test
proc foo( x: bits[32], y: bits[1], z: bits[32], init={42, 1, 123}) {
x: bits[32] = state_read(state_element=x, predicate=y)
z: bits[32] = state_read(state_element=z)
sum: bits[32] = add(x, z)
next (x, y, sum)
}
)";
XLS_ASSERT_OK_AND_ASSIGN(auto package, Parser::ParsePackage(program));
XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, package->GetProc("foo"));
EXPECT_EQ(proc->GetStateElementCount(), 3);

XLS_ASSERT_OK_AND_ASSIGN(StateElement * x, proc->GetStateElement("x"));
std::optional<Node*> x_predicate = proc->GetStateRead(x)->predicate();
ASSERT_TRUE(x_predicate.has_value());
ASSERT_EQ((*x_predicate)->op(), Op::kStateRead);
EXPECT_EQ((*x_predicate)->As<StateRead>()->state_element()->name(), "y");

XLS_ASSERT_OK_AND_ASSIGN(StateElement * y, proc->GetStateElement("y"));
ASSERT_FALSE(proc->GetStateRead(y)->predicate().has_value());

XLS_ASSERT_OK_AND_ASSIGN(StateElement * z, proc->GetStateElement("z"));
ASSERT_FALSE(proc->GetStateRead(z)->predicate().has_value());
}

TEST(IrParserTest, ParseSendReceiveChannel) {
Package p("my_package");
XLS_ASSERT_OK_AND_ASSIGN(Channel * ch,
Expand Down
12 changes: 9 additions & 3 deletions xls/ir/node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -523,10 +523,16 @@ std::string Node::ToStringInternal(bool include_operand_types) const {
case Op::kParam:
args.push_back(absl::StrFormat("name=%s", GetName()));
break;
case Op::kStateRead:
args.push_back(absl::StrFormat("state_element=%s",
As<StateRead>()->state_element()->name()));
case Op::kStateRead: {
const StateRead* state_read = As<StateRead>();
args = {absl::StrFormat("state_element=%s",
state_read->state_element()->name())};
if (state_read->predicate().has_value()) {
args.push_back(absl::StrFormat("predicate=%s",
(*state_read->predicate())->GetName()));
}
break;
}
case Op::kNext: {
const Next* next = As<Next>();
args = {absl::StrFormat("param=%s", next->state_read()->GetName()),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package test

chan ch(bits[32], id=0, kind=streaming, ops=send_receive, flow_control=none, strictness=proven_mutually_exclusive, metadata="""""")

proc my_proc(my_token: token, my_state: bits[32], my_predicate: bits[1], init={token, 42, 1}) {
my_predicate: bits[1] = state_read(state_element=my_predicate, id=18)
my_token: token = state_read(state_element=my_token, id=9)
my_state: bits[32] = state_read(state_element=my_state, predicate=my_predicate, id=10)
send.1: token = send(my_token, my_state, channel=ch, id=1)
literal.2: bits[1] = literal(value=1, id=2)
receive.3: (token, bits[32]) = receive(send.1, predicate=literal.2, channel=ch, id=3)
tuple_index.4: token = tuple_index(receive.3, index=0, id=4)
next_value.5: () = next_value(param=my_token, value=tuple_index.4, id=5)
next_value.6: () = next_value(param=my_state, value=my_state, id=6)
next_value.7: () = next_value(param=my_predicate, value=my_predicate, id=7)
}

0 comments on commit 572e484

Please sign in to comment.