diff --git a/xls/passes/BUILD b/xls/passes/BUILD index 2efb6021d0..9818295c49 100644 --- a/xls/passes/BUILD +++ b/xls/passes/BUILD @@ -422,6 +422,7 @@ cc_library( "//xls/data_structures:union_find", "//xls/ir", "//xls/ir:channel", + "//xls/ir:elaboration", "//xls/ir:node_util", "//xls/ir:op", ], @@ -2081,11 +2082,14 @@ cc_test( "//xls/common:xls_gunit", "//xls/common:xls_gunit_main", "//xls/common/status:matchers", + "//xls/common/status:status_macros", "//xls/ir", "//xls/ir:bits", + "//xls/ir:channel", "//xls/ir:function_builder", "//xls/ir:ir_matcher", "//xls/ir:ir_test_base", + "//xls/ir:value", ], ) diff --git a/xls/passes/dfe_pass.cc b/xls/passes/dfe_pass.cc index 7299b08836..160f74bec6 100644 --- a/xls/passes/dfe_pass.cc +++ b/xls/passes/dfe_pass.cc @@ -18,19 +18,20 @@ #include #include #include +#include #include #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" -#include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/str_format.h" #include "xls/common/logging/logging.h" #include "xls/common/status/status_macros.h" #include "xls/data_structures/union_find.h" #include "xls/ir/block.h" #include "xls/ir/channel.h" +#include "xls/ir/elaboration.h" #include "xls/ir/function_base.h" +#include "xls/ir/instantiation.h" #include "xls/ir/node_util.h" #include "xls/ir/op.h" #include "xls/ir/package.h" @@ -79,63 +80,104 @@ void MarkReachedFunctions(FunctionBase* func, } } } -} // namespace -// Starting from the return_value(s), DFS over all nodes. Unvisited -// nodes, or parameters, are dead. -absl::StatusOr DeadFunctionEliminationPass::RunInternal( - Package* p, const OptimizationPassOptions& options, - PassResults* results) const { - std::optional top = p->GetTop(); - if (!top.has_value()) { - return false; +// Data structure describing the liveness of global constructs in a package. +struct FunctionBaseLiveness { + // The live roots of the package. This does not include FunctionBases which + // are live because they are called/instantiated from other FunctionBases. + std::vector live_roots; + + // Set of the live global channels. Only set for old-style procs. + absl::flat_hash_set live_global_channels; +}; + +absl::StatusOr LivenessFromTopProc(Proc* top) { + if (top->is_new_style_proc()) { + XLS_ASSIGN_OR_RETURN(Elaboration elab, Elaboration::Elaborate(top)); + return FunctionBaseLiveness{.live_roots = std::vector( + elab.procs().begin(), elab.procs().end()), + .live_global_channels = {}}; } - // Mapping from proc->channel, where channel is a representative value - // for all the channel names in the UnionFind. - absl::flat_hash_map representative_channels; + Package* p = top->package(); + + // Mapping from proc to channel, where channel is a representative value for + // all the channel names in the UnionFind. If the proc uses no channels then + // the value will be nullopt. + absl::flat_hash_map> + representative_channels; representative_channels.reserve(p->procs().size()); // Channels in the same proc will be union'd. - UnionFind channel_union; - for (std::unique_ptr& proc : p->procs()) { - std::optional representative_proc_channel; + UnionFind channel_union; + for (const std::unique_ptr& proc : p->procs()) { + std::optional representative_proc_channel; for (Node* node : proc->nodes()) { if (IsChannelNode(node)) { - std::string channel; - if (node->Is()) { - channel = node->As()->channel_name(); - } else if (node->Is()) { - channel = node->As()->channel_name(); - } else { - return absl::NotFoundError(absl::StrFormat( - "No channel associated with node %s", node->GetName())); - } - channel_union.Insert(channel); + XLS_ASSIGN_OR_RETURN(Channel * channel, GetChannelUsedByNode(node)); + channel_union.Insert(channel->name()); if (representative_proc_channel.has_value()) { - channel_union.Union(representative_proc_channel.value(), channel); + channel_union.Union(representative_proc_channel.value(), + channel->name()); } else { - representative_proc_channel = channel; - representative_channels.insert({proc.get(), channel}); + representative_proc_channel = channel->name(); } } } + representative_channels[proc.get()] = representative_proc_channel; } - absl::flat_hash_set reached; - MarkReachedFunctions(top.value(), &reached); - std::optional top_proc_representative_channel; - if ((*top)->IsProc()) { - auto itr = representative_channels.find(top.value()->AsProcOrDie()); - if (itr != representative_channels.end()) { - top_proc_representative_channel = channel_union.Find(itr->second); - for (auto [proc, representative_channel] : representative_channels) { - if (channel_union.Find(representative_channel) == - *top_proc_representative_channel) { - MarkReachedFunctions(proc, &reached); - } + FunctionBaseLiveness liveness; + + // Add procs to the live set if they are connnected to `top` via channels. + for (const std::unique_ptr& proc : p->procs()) { + if (proc.get() == top) { + liveness.live_roots.push_back(proc.get()); + continue; + } + if (representative_channels.at(top).has_value() && + representative_channels.at(proc.get()) && + channel_union.Find(representative_channels.at(top).value()) == + channel_union.Find( + representative_channels.at(proc.get()).value())) { + liveness.live_roots.push_back(proc.get()); + } + } + + // Add channels to the live set if they are connnected to `top`. + if (representative_channels.at(top).has_value()) { + for (Channel* channel : p->channels()) { + if (channel_union.Find(channel->name()) == + channel_union.Find(representative_channels.at(top).value())) { + liveness.live_global_channels.insert(channel); } } } + return liveness; +} + +} // namespace + +// Starting from the return_value(s), DFS over all nodes. Unvisited +// nodes, or parameters, are dead. +absl::StatusOr DeadFunctionEliminationPass::RunInternal( + Package* p, const OptimizationPassOptions& options, + PassResults* results) const { + std::optional top = p->GetTop(); + if (!top.has_value()) { + return false; + } + + FunctionBaseLiveness liveness; + if ((*top)->IsProc()) { + XLS_ASSIGN_OR_RETURN(liveness, LivenessFromTopProc((*top)->AsProcOrDie())); + } else { + liveness.live_roots = {*top}; + } + + absl::flat_hash_set reached; + for (FunctionBase* fb : liveness.live_roots) { + MarkReachedFunctions(fb, &reached); + } // Accumulate a list of FunctionBases to unlink. bool changed = false; @@ -147,19 +189,15 @@ absl::StatusOr DeadFunctionEliminationPass::RunInternal( } } - // Find any channels which are only used by now-removed procs. - std::vector channels_to_remove; + // Remove dead channels. + std::vector channels_to_remove; channels_to_remove.reserve(p->channels().size()); for (Channel* channel : p->channels()) { - if (!top_proc_representative_channel.has_value() || - channel_union.Find(std::string{channel->name()}) != - *top_proc_representative_channel) { - channels_to_remove.push_back(std::string{channel->name()}); + if (!liveness.live_global_channels.contains(channel)) { + channels_to_remove.push_back(channel); } } - // Now remove any channels which are only used by now-removed procs. - for (const std::string& channel_name : channels_to_remove) { - XLS_ASSIGN_OR_RETURN(Channel * channel, p->GetChannel(channel_name)); + for (Channel* channel : channels_to_remove) { XLS_VLOG(2) << "Removing channel: " << channel->name(); XLS_RETURN_IF_ERROR(p->RemoveChannel(channel)); changed = true; diff --git a/xls/passes/dfe_pass.h b/xls/passes/dfe_pass.h index fc71697cfe..3cebfd9494 100644 --- a/xls/passes/dfe_pass.h +++ b/xls/passes/dfe_pass.h @@ -23,9 +23,8 @@ namespace xls { -// class DeadCodeEliminationPass iterates up from a functions result -// nodes and marks all visited node. After that, all unvisited nodes -// are considered dead. +// This pass removes unreachable procs/blocks/functions from the package. The +// pass requires `top` be set in order remove any constructs. class DeadFunctionEliminationPass : public OptimizationPass { public: explicit DeadFunctionEliminationPass() diff --git a/xls/passes/dfe_pass_test.cc b/xls/passes/dfe_pass_test.cc index ff15b8278a..cb40b9ac5e 100644 --- a/xls/passes/dfe_pass_test.cc +++ b/xls/passes/dfe_pass_test.cc @@ -16,18 +16,22 @@ #include #include +#include #include "gmock/gmock.h" #include "gtest/gtest.h" #include "absl/status/statusor.h" #include "xls/common/status/matchers.h" +#include "xls/common/status/status_macros.h" #include "xls/ir/bits.h" +#include "xls/ir/channel.h" #include "xls/ir/function.h" #include "xls/ir/function_base.h" #include "xls/ir/function_builder.h" #include "xls/ir/ir_matcher.h" #include "xls/ir/ir_test_base.h" #include "xls/ir/package.h" +#include "xls/ir/value.h" #include "xls/passes/optimization_pass.h" #include "xls/passes/pass_base.h" @@ -55,6 +59,21 @@ class DeadFunctionEliminationPassTest : public IrTestBase { fb.Param("arg", p->GetBitsType(32)); return fb.Build(); } + + absl::StatusOr CreateNewStyleAccumProc(std::string_view proc_name, + Package* package) { + TokenlessProcBuilder pb(NewStyleProc(), proc_name, "tkn", package); + BValue accum = pb.StateElement("accum", Value(UBits(0, 32))); + XLS_ASSIGN_OR_RETURN(ReceiveChannelReference * in_channel, + pb.AddInputChannel("in_ch", package->GetBitsType(32))); + BValue input = pb.Receive(in_channel); + BValue next_accum = pb.Add(accum, input); + XLS_ASSIGN_OR_RETURN( + SendChannelReference * out_channel, + pb.AddOutputChannel("out_ch", package->GetBitsType(32))); + pb.Send(out_channel, next_accum); + return pb.Build({next_accum}); + } }; TEST_F(DeadFunctionEliminationPassTest, NoDeadFunctions) { @@ -430,5 +449,66 @@ proc test_proc3(tkn: token, state:(), init={()}) { EXPECT_THAT(p->channels(), IsEmpty()); } +TEST_F(DeadFunctionEliminationPassTest, SingleNewStyleProc) { + auto p = CreatePackage(); + XLS_ASSERT_OK(CreateNewStyleAccumProc("my_proc", p.get()).status()); + + EXPECT_THAT(p->GetFunctionBases(), UnorderedElementsAre(m::Proc("my_proc"))); + + EXPECT_THAT(Run(p.get()), IsOkAndHolds(false)); + + EXPECT_THAT(p->GetFunctionBases(), UnorderedElementsAre(m::Proc("my_proc"))); +} + +TEST_F(DeadFunctionEliminationPassTest, MultipleNewStyleProcs) { + auto p = CreatePackage(); + XLS_ASSERT_OK(CreateNewStyleAccumProc("my_proc1", p.get()).status()); + XLS_ASSERT_OK(CreateNewStyleAccumProc("my_proc2", p.get()).status()); + XLS_ASSERT_OK(CreateNewStyleAccumProc("my_proc3", p.get()).status()); + XLS_ASSERT_OK(p->SetTopByName("my_proc2")); + + EXPECT_THAT(p->GetFunctionBases(), + UnorderedElementsAre(m::Proc("my_proc1"), m::Proc("my_proc2"), + m::Proc("my_proc3"))); + + EXPECT_THAT(Run(p.get()), IsOkAndHolds(true)); + + EXPECT_THAT(p->GetFunctionBases(), UnorderedElementsAre(m::Proc("my_proc2"))); +} + +TEST_F(DeadFunctionEliminationPassTest, NewStyleProcWithInstantiations) { + auto p = CreatePackage(); + XLS_ASSERT_OK_AND_ASSIGN(Proc * my_proc1, + CreateNewStyleAccumProc("my_proc1", p.get())); + XLS_ASSERT_OK_AND_ASSIGN(Proc * my_proc2, + CreateNewStyleAccumProc("my_proc2", p.get())); + XLS_ASSERT_OK(CreateNewStyleAccumProc("my_proc3", p.get()).status()); + + TokenlessProcBuilder pb(NewStyleProc(), "top_proc", "tkn", p.get()); + XLS_ASSERT_OK_AND_ASSIGN(ChannelReferences the_channel, + pb.AddChannel("the_channel", p->GetBitsType(32))); + XLS_ASSERT_OK( + pb.InstantiateProc("inst0", my_proc1, + std::vector{the_channel.receive_ref, + the_channel.send_ref})); + XLS_ASSERT_OK( + pb.InstantiateProc("inst1", my_proc2, + std::vector{the_channel.receive_ref, + the_channel.send_ref})); + XLS_ASSERT_OK(pb.Build({}).status()); + + XLS_ASSERT_OK(p->SetTopByName("top_proc")); + + EXPECT_THAT(p->GetFunctionBases(), + UnorderedElementsAre(m::Proc("top_proc"), m::Proc("my_proc1"), + m::Proc("my_proc2"), m::Proc("my_proc3"))); + + EXPECT_THAT(Run(p.get()), IsOkAndHolds(true)); + + EXPECT_THAT(p->GetFunctionBases(), + UnorderedElementsAre(m::Proc("top_proc"), m::Proc("my_proc1"), + m::Proc("my_proc2"))); +} + } // namespace } // namespace xls