Skip to content

Commit

Permalink
Fix TokenDependencyPass's handling of multiple data dependencies
Browse files Browse the repository at this point in the history
We now only drop redundant token-dependency changes to downstream nodes, rather than all changes to nodes downstream of other nodes that required changes.

PiperOrigin-RevId: 597355357
  • Loading branch information
ericastor authored and copybara-github committed Jan 10, 2024
1 parent 594ed4d commit 59b6d06
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 8 deletions.
5 changes: 5 additions & 0 deletions xls/passes/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -502,14 +502,19 @@ cc_library(
":token_provenance_analysis",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"//xls/common/logging",
"//xls/common/logging:log_lines",
"//xls/common/status:status_macros",
"//xls/data_structures:transitive_closure",
"//xls/ir",
"//xls/ir:node_util",
"//xls/ir:op",
"//xls/ir:source_location",
"//xls/ir:type",
],
)

Expand Down
39 changes: 31 additions & 8 deletions xls/passes/token_dependency_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,19 @@

#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "xls/common/logging/log_lines.h"
#include "xls/common/logging/logging.h"
#include "xls/common/status/status_macros.h"
#include "xls/data_structures/transitive_closure.h"
#include "xls/ir/function_base.h"
#include "xls/ir/node_util.h"
#include "xls/ir/op.h"
#include "xls/ir/source_location.h"
#include "xls/ir/type.h"
#include "xls/passes/optimization_pass.h"
#include "xls/passes/pass_base.h"
#include "xls/passes/token_provenance_analysis.h"
Expand Down Expand Up @@ -148,15 +153,33 @@ absl::StatusOr<bool> TokenDependencyPass::RunOnFunctionBaseInternal(
XLS_VLOG(3) << "IO to receive:";
XLS_VLOG_LINES(3, relation_to_string(io_to_receive));

// A relation similar to `io_to_receive`, except that only the earliest
// effectful nodes are included. For example, if `io_to_receive` contains
// three keys `A`, `B`, and `C`, and `C` is transitively token-dependent
// on `B`, then `minimal_io_to_receive` will only contain `A` and `B`.
// A relation similar to `io_to_receive`, except that receives are only
// included at the earliest points where they have an effect. For example, if
// `C` is token-dependent on both `A` and `B`, and `io_to_receive` contains
// all of `A`, `B`, and `C`, with
//
// - `io_to_receive[A]` containing `recv1`,
// - `io_to_receive[B]` containing `recv2`,
// - `io_to_receive[C]` containing `recv1`, `recv2`, and `recv3`,
//
// then `minimal_io_to_receive[C]` will only include `recv3`.
NodeRelation minimal_io_to_receive = io_to_receive;
for (const auto& [io, receive] : io_to_receive) {
for (const auto& [io, receives] : io_to_receive) {
for (Node* downstream_of_io : token_deps_closure.at(io)) {
if (downstream_of_io != io) {
minimal_io_to_receive.erase(downstream_of_io);
if (downstream_of_io == io) {
continue;
}

auto it = minimal_io_to_receive.find(downstream_of_io);
if (it == minimal_io_to_receive.end()) {
continue;
}
absl::flat_hash_set<Node*>& downstream_receives = it->second;
for (Node* receive : receives) {
downstream_receives.erase(receive);
}
if (downstream_receives.empty()) {
minimal_io_to_receive.erase(it);
}
}
}
Expand All @@ -165,7 +188,7 @@ absl::StatusOr<bool> TokenDependencyPass::RunOnFunctionBaseInternal(

bool changed = false;

// Before touching the IR create a determistic sort of the keys of the
// Before touching the IR create a deterministic sort of the keys of the
// relation.
std::vector<Node*> minimal_io_to_receive_keys;
minimal_io_to_receive_keys.reserve(minimal_io_to_receive.size());
Expand Down
38 changes: 38 additions & 0 deletions xls/passes/token_dependency_pass_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

#include "xls/passes/token_dependency_pass.h"

#include <memory>

#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "absl/status/statusor.h"
Expand Down Expand Up @@ -141,6 +143,42 @@ TEST_F(TokenDependencyPassTest, DependentSends) {
m::TupleIndex()));
}

TEST_F(TokenDependencyPassTest, DependentSendsMultipleReceives) {
XLS_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Package> p, ParsePackage(R"(
package test_module
chan test_channel(
bits[32], id=0, kind=streaming, ops=send_receive,
flow_control=ready_valid, metadata="""""")
top proc main(__token: token, __state: (), init={()}) {
receive.1: (token, bits[32]) = receive(__token, channel=test_channel)
tuple_index.2: token = tuple_index(receive.1, index=0)
tuple_index.3: bits[32] = tuple_index(receive.1, index=1)
receive.4: (token, bits[32]) = receive(__token, channel=test_channel)
tuple_index.5: token = tuple_index(receive.4, index=0)
tuple_index.6: bits[32] = tuple_index(receive.4, index=1)
send.7: token = send(__token, tuple_index.3, channel=test_channel)
add.8: bits[32] = add(tuple_index.3, tuple_index.6)
send.9: token = send(send.7, add.8, channel=test_channel)
after_all.10: token = after_all(send.7, send.9, tuple_index.2, tuple_index.5)
tuple.11: () = tuple()
next (after_all.10, tuple.11)
}
)"));
XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, p->GetTopAsProc());
EXPECT_THAT(Run(proc), IsOkAndHolds(true));
EXPECT_THAT(
proc->NextToken(),
m::AfterAll(
m::Send(
m::AfterAll(m::TupleIndex(m::Receive(), 0), proc->TokenParam()),
m::TupleIndex(m::Receive(), 1)),
m::Send(m::AfterAll(m::TupleIndex(m::Receive(), 0), m::Send()),
m::Add()),
m::TupleIndex(m::Receive(), 0), m::TupleIndex(m::Receive(), 0)));
}

TEST_F(TokenDependencyPassTest, SideEffectingNontokenOps) {
// Regression test for https://github.com/google/xls/issues/776
XLS_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Package> p, ParsePackage(R"(
Expand Down

0 comments on commit 59b6d06

Please sign in to comment.