Skip to content

Commit

Permalink
Support instantiations in block interpreter.
Browse files Browse the repository at this point in the history
Now that we have block elaboration and visitors on those elaborations, we implement support for instantiations in the block interpreter. The idea is that there's a top-level elaborated block visitor that iterates through the elaborated nodes, implements inter-block ops, and then delegates the other ops to a per-block IR interpreter.

These changes also extend to the block evaluator, which needs to elaborate in some places, update how registers are initialized, etc.

PiperOrigin-RevId: 625088110
  • Loading branch information
grebe authored and copybara-github committed Apr 15, 2024
1 parent 5e6bad1 commit 93244c0
Show file tree
Hide file tree
Showing 16 changed files with 806 additions and 92 deletions.
10 changes: 10 additions & 0 deletions xls/interpreter/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ cc_library(
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
"//xls/codegen:module_signature_cc_proto",
Expand All @@ -47,6 +48,7 @@ cc_library(
"//xls/ir",
"//xls/ir:bits",
"//xls/ir:bits_ops",
"//xls/ir:block_elaboration",
"//xls/ir:events",
"//xls/ir:keyword_args",
"//xls/ir:register",
Expand All @@ -67,13 +69,16 @@ cc_library(
"@com_google_absl//absl/random:distributions",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
"//xls/codegen:module_signature_cc_proto",
"//xls/common/status:ret_check",
"//xls/common/status:status_macros",
"//xls/ir",
"//xls/ir:bits",
"//xls/ir:block_elaboration",
"//xls/ir:elaboration",
"//xls/ir:events",
"//xls/ir:register",
"//xls/ir:value",
Expand Down Expand Up @@ -169,12 +174,17 @@ cc_library(
":block_evaluator",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
"//xls/codegen:module_signature_cc_proto",
"//xls/common:xls_gunit_main",
"//xls/common/status:matchers",
"//xls/common/status:status_macros",
"//xls/ir",
"//xls/ir:bits",
"//xls/ir:elaboration",
"//xls/ir:format_preference",
"//xls/ir:function_builder",
"//xls/ir:ir_test_base",
Expand Down
71 changes: 56 additions & 15 deletions xls/interpreter/block_evaluator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,15 @@
#include "absl/random/distributions.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/types/span.h"
#include "xls/common/status/ret_check.h"
#include "xls/common/status/status_macros.h"
#include "xls/ir/bits.h"
#include "xls/ir/block.h"
#include "xls/ir/block_elaboration.h"
#include "xls/ir/elaboration.h"
#include "xls/ir/events.h"
#include "xls/ir/node.h"
#include "xls/ir/nodes.h"
Expand Down Expand Up @@ -179,16 +182,24 @@ absl::StatusOr<std::vector<absl::flat_hash_map<std::string, Value>>>
BlockEvaluator::EvaluateSequentialBlock(
Block* block,
absl::Span<const absl::flat_hash_map<std::string, Value>> inputs) const {
XLS_ASSIGN_OR_RETURN(BlockElaboration elaboration,
BlockElaboration::Elaborate(block));
// Initial register state is zero for all registers.
absl::flat_hash_map<std::string, Value> reg_state;
for (Register* reg : block->GetRegisters()) {
reg_state[reg->name()] = ZeroOfType(reg->type());
for (BlockInstance* inst : elaboration.instances()) {
if (!inst->block().has_value()) {
continue;
}
for (Register* reg : inst->block().value()->GetRegisters()) {
reg_state[absl::StrCat(inst->RegisterPrefix(), reg->name())] =
ZeroOfType(reg->type());
}
}

std::vector<absl::flat_hash_map<std::string, Value>> outputs;
for (const absl::flat_hash_map<std::string, Value>& input_set : inputs) {
XLS_ASSIGN_OR_RETURN(BlockRunResult result,
EvaluateBlock(input_set, reg_state, block));
EvaluateBlock(input_set, reg_state, elaboration));
outputs.push_back(std::move(result.outputs));
reg_state = std::move(result.reg_state);
}
Expand Down Expand Up @@ -394,10 +405,21 @@ BlockEvaluator::EvaluateChannelizedSequentialBlock(
std::minstd_rand random_engine;
random_engine.seed(seed);

XLS_ASSIGN_OR_RETURN(BlockElaboration elaboration,
BlockElaboration::Elaborate(block));

// Initial register state is zero for all registers.
absl::flat_hash_map<std::string, Value> reg_state;
for (Register* reg : block->GetRegisters()) {
reg_state[reg->name()] = ZeroOfType(reg->type());
for (BlockInstance* inst : elaboration.instances()) {
// Instance isn't a BlockInstantiation, must be e.g. a FIFO or FFI
// instantiation. No registers to initialize.
if (!inst->block().has_value()) {
continue;
}
for (Register* reg : inst->block().value()->GetRegisters()) {
reg_state[absl::StrCat(inst->RegisterPrefix(), reg->name())] =
ZeroOfType(reg->type());
}
}

int64_t max_cycle_count = inputs.size();
Expand Down Expand Up @@ -427,7 +449,7 @@ BlockEvaluator::EvaluateChannelizedSequentialBlock(

// Block results
XLS_ASSIGN_OR_RETURN(BlockRunResult result,
EvaluateBlock(input_set, reg_state, block));
EvaluateBlock(input_set, reg_state, elaboration));

// Sources get ready
for (ChannelSource& src : channel_sources) {
Expand Down Expand Up @@ -503,9 +525,10 @@ BlockEvaluator::EvaluateChannelizedSequentialBlockWithUint64(
namespace {
class BaseBlockContinuation final : public BlockContinuation {
public:
BaseBlockContinuation(Block* block, BlockRunResult&& initial_result,
BaseBlockContinuation(BlockElaboration&& block,
BlockRunResult&& initial_result,
const BlockEvaluator& evaluator)
: block_(block),
: elaboration_(std::move(block)),
last_result_(std::move(initial_result)),
evaluator_(evaluator) {}

Expand All @@ -525,7 +548,7 @@ class BaseBlockContinuation final : public BlockContinuation {
const absl::flat_hash_map<std::string, Value>& inputs) final {
XLS_ASSIGN_OR_RETURN(
last_result_,
evaluator_.EvaluateBlock(inputs, last_result_.reg_state, block_));
evaluator_.EvaluateBlock(inputs, last_result_.reg_state, elaboration_));
return absl::OkStatus();
}

Expand All @@ -540,7 +563,7 @@ class BaseBlockContinuation final : public BlockContinuation {
}

private:
Block* block_;
BlockElaboration elaboration_;
BlockRunResult last_result_;
const BlockEvaluator& evaluator_;
};
Expand All @@ -550,18 +573,36 @@ absl::StatusOr<std::unique_ptr<BlockContinuation>>
BlockEvaluator::NewContinuation(
Block* block,
const absl::flat_hash_map<std::string, Value>& initial_registers) const {
return std::make_unique<BaseBlockContinuation>(
block, BlockRunResult{.reg_state = initial_registers}, *this);
XLS_ASSIGN_OR_RETURN(BlockElaboration elaboration,
BlockElaboration::Elaborate(block));
return NewContinuation(std::move(elaboration), initial_registers);
}

absl::StatusOr<std::unique_ptr<BlockContinuation>>
BlockEvaluator::NewContinuation(Block* block) const {
XLS_ASSIGN_OR_RETURN(BlockElaboration elaboration,
BlockElaboration::Elaborate(block));
absl::flat_hash_map<std::string, Value> regs;
regs.reserve(block->GetRegisters().size());
for (const auto reg : block->GetRegisters()) {
regs[reg->name()] = ZeroOfType(reg->type());
for (BlockInstance* inst : elaboration.instances()) {
if (!inst->block().has_value()) {
continue;
}
for (const auto reg : inst->block().value()->GetRegisters()) {
regs[absl::StrCat(inst->RegisterPrefix(), reg->name())] =
ZeroOfType(reg->type());
}
}
return NewContinuation(block, regs);
return NewContinuation(std::move(elaboration), regs);
}

absl::StatusOr<std::unique_ptr<BlockContinuation>>
BlockEvaluator::NewContinuation(
BlockElaboration&& elaboration,
const absl::flat_hash_map<std::string, Value>& initial_registers) const {
return std::make_unique<BaseBlockContinuation>(
std::move(elaboration), BlockRunResult{.reg_state = initial_registers},
*this);
}

} // namespace xls
12 changes: 8 additions & 4 deletions xls/interpreter/block_evaluator.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "absl/types/span.h"
#include "xls/codegen/module_signature.pb.h"
#include "xls/ir/block.h"
#include "xls/ir/block_elaboration.h"
#include "xls/ir/events.h"
#include "xls/ir/value.h"

Expand Down Expand Up @@ -211,10 +212,9 @@ class BlockEvaluator {
// Create a new block continuation with all registers initialized to the given
// values. This continuation can be used to feed input values in
// cycle-by-cycle.
virtual absl::StatusOr<std::unique_ptr<BlockContinuation>> NewContinuation(
absl::StatusOr<std::unique_ptr<BlockContinuation>> NewContinuation(
Block* block,
const absl::flat_hash_map<std::string, Value>& initial_registers)
const;
const absl::flat_hash_map<std::string, Value>& initial_registers) const;

// Create a new block continuation with all registers initialized to zero
// values. This continuation can be used to feed input values in
Expand All @@ -228,7 +228,7 @@ class BlockEvaluator {
virtual absl::StatusOr<BlockRunResult> EvaluateBlock(
const absl::flat_hash_map<std::string, Value>& inputs,
const absl::flat_hash_map<std::string, Value>& registers,
Block* block) const = 0;
const BlockElaboration& elaboration) const = 0;

// The name of this evaluator for debug purposes.
std::string_view name() const { return name_; }
Expand Down Expand Up @@ -362,6 +362,10 @@ class BlockEvaluator {
}

protected:
virtual absl::StatusOr<std::unique_ptr<BlockContinuation>> NewContinuation(
BlockElaboration&& elaboration,
const absl::flat_hash_map<std::string, Value>& initial_registers) const;

std::string_view name_;
};

Expand Down
Loading

0 comments on commit 93244c0

Please sign in to comment.