Skip to content

Commit 27c3660

Browse files
meheffernancopybara-github
authored andcommitted
Add new-style proc support to the interpreter and jit.
A previous change switched the interpreter and jit over to using an elaboration (instances) instead of the IR directly. However, only old-style procs were supported. This change adds supports for new-style procs. Requires passing instance-specific information to the JITted function. This replaces the unused "user data" argument. This also required a couple fixes to elaboration which wasn't properly distinguishing between channel names and channel reference names. PiperOrigin-RevId: 597678974
1 parent 95f2d70 commit 27c3660

30 files changed

+900
-388
lines changed

xls/interpreter/BUILD

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,6 @@ cc_library(
386386
":proc_evaluator",
387387
":proc_runtime",
388388
"@com_google_absl//absl/container:flat_hash_map",
389-
"@com_google_absl//absl/memory",
390389
"@com_google_absl//absl/status:statusor",
391390
"@com_google_absl//absl/strings:str_format",
392391
"//xls/common/logging",
@@ -401,13 +400,20 @@ cc_test(
401400
name = "serial_proc_runtime_test",
402401
srcs = ["serial_proc_runtime_test.cc"],
403402
deps = [
403+
":channel_queue",
404404
":interpreter_proc_runtime",
405+
":proc_evaluator",
405406
":proc_interpreter",
407+
":proc_runtime",
406408
":proc_runtime_test_base",
407409
":serial_proc_runtime",
410+
"@com_google_absl//absl/status:statusor",
408411
"//xls/common:xls_gunit",
409412
"//xls/common:xls_gunit_main",
413+
"//xls/common/status:status_macros",
410414
"//xls/ir",
415+
"//xls/ir:elaboration",
416+
"//xls/ir:value",
411417
"//xls/jit:jit_channel_queue",
412418
"//xls/jit:jit_proc_runtime",
413419
"//xls/jit:proc_jit",
@@ -422,14 +428,16 @@ cc_library(
422428
deps = [
423429
":channel_queue",
424430
":proc_runtime",
431+
"@com_google_absl//absl/container:flat_hash_map",
425432
"@com_google_absl//absl/status",
426433
"@com_google_absl//absl/status:statusor",
427-
"@com_google_absl//absl/strings",
428434
"//xls/common/status:matchers",
435+
"//xls/common/status:status_macros",
429436
"//xls/ir",
430437
"//xls/ir:bits",
431438
"//xls/ir:channel",
432439
"//xls/ir:channel_ops",
440+
"//xls/ir:elaboration",
433441
"//xls/ir:function_builder",
434442
"//xls/ir:ir_parser",
435443
"//xls/ir:ir_test_base",

xls/interpreter/interpreter_proc_runtime.cc

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -29,34 +29,33 @@
2929
#include "xls/ir/value.h"
3030

3131
namespace xls {
32+
namespace {
3233

33-
absl::StatusOr<std::unique_ptr<SerialProcRuntime>>
34-
CreateInterpreterSerialProcRuntime(Package* package) {
35-
// TODO(https://github.com/google/xls/issues/869): Support new-style procs.
36-
XLS_ASSIGN_OR_RETURN(Elaboration elaboration,
37-
Elaboration::ElaborateOldStylePackage(package));
38-
34+
absl::StatusOr<std::unique_ptr<SerialProcRuntime>> CreateRuntime(
35+
Elaboration elaboration) {
3936
// Create a queue manager for the queues. This factory verifies that there an
4037
// receive only queue for every receive only channel.
4138
XLS_ASSIGN_OR_RETURN(std::unique_ptr<ChannelQueueManager> queue_manager,
4239
ChannelQueueManager::Create(std::move(elaboration)));
4340

4441
// Create a ProcInterpreter for each Proc.
4542
std::vector<std::unique_ptr<ProcEvaluator>> proc_interpreters;
46-
for (auto& proc : package->procs()) {
43+
for (Proc* proc : queue_manager->elaboration().procs()) {
4744
proc_interpreters.push_back(
48-
std::make_unique<ProcInterpreter>(proc.get(), queue_manager.get()));
45+
std::make_unique<ProcInterpreter>(proc, queue_manager.get()));
4946
}
5047

5148
// Create a runtime.
52-
XLS_ASSIGN_OR_RETURN(
53-
std::unique_ptr<SerialProcRuntime> proc_runtime,
54-
SerialProcRuntime::Create(package, std::move(proc_interpreters),
55-
std::move(queue_manager)));
49+
XLS_ASSIGN_OR_RETURN(std::unique_ptr<SerialProcRuntime> proc_runtime,
50+
SerialProcRuntime::Create(std::move(proc_interpreters),
51+
std::move(queue_manager)));
5652

57-
// Inject initial values into channels.
58-
for (Channel* channel : package->channels()) {
59-
ChannelQueue& queue = proc_runtime->queue_manager().GetQueue(channel);
53+
// Inject initial values into channel queues.
54+
for (ChannelInstance* channel_instance :
55+
proc_runtime->elaboration().channel_instances()) {
56+
Channel* channel = channel_instance->channel;
57+
ChannelQueue& queue =
58+
proc_runtime->queue_manager().GetQueue(channel_instance);
6059
for (const Value& value : channel->initial_values()) {
6160
XLS_RETURN_IF_ERROR(queue.Write(value));
6261
}
@@ -65,4 +64,19 @@ CreateInterpreterSerialProcRuntime(Package* package) {
6564
return std::move(proc_runtime);
6665
}
6766

67+
} // namespace
68+
69+
absl::StatusOr<std::unique_ptr<SerialProcRuntime>>
70+
CreateInterpreterSerialProcRuntime(Package* package) {
71+
XLS_ASSIGN_OR_RETURN(Elaboration elaboration,
72+
Elaboration::ElaborateOldStylePackage(package));
73+
return CreateRuntime(std::move(elaboration));
74+
}
75+
76+
absl::StatusOr<std::unique_ptr<SerialProcRuntime>>
77+
CreateInterpreterSerialProcRuntime(Proc* top) {
78+
XLS_ASSIGN_OR_RETURN(Elaboration elaboration, Elaboration::Elaborate(top));
79+
return CreateRuntime(std::move(elaboration));
80+
}
81+
6882
} // namespace xls

xls/interpreter/interpreter_proc_runtime.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,16 @@
2323

2424
namespace xls {
2525

26-
// Create a SerialProcRuntime composed of ProcInterpreters.
26+
// Create a SerialProcRuntime composed of ProcInterpreters. Supports old-style
27+
// procs.
2728
absl::StatusOr<std::unique_ptr<SerialProcRuntime>>
2829
CreateInterpreterSerialProcRuntime(Package* package);
2930

31+
// Create a SerialProcRuntime composed of ProcInterpreters. Constructed from the
32+
// elaboration of the given proc. Supports new-style procs.
33+
absl::StatusOr<std::unique_ptr<SerialProcRuntime>>
34+
CreateInterpreterSerialProcRuntime(Proc* top);
35+
3036
} // namespace xls
3137

3238
#endif // XLS_INTERPRETER_INTERPRETER_PROC_RUNTIME_H_

xls/interpreter/proc_interpreter.cc

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <cstdint>
1818
#include <memory>
1919
#include <optional>
20+
#include <string_view>
2021
#include <utility>
2122
#include <vector>
2223

@@ -47,23 +48,25 @@ namespace {
4748
class ProcIrInterpreter : public IrInterpreter {
4849
public:
4950
// Constructor args:
51+
// proc_instance: the instance of the proc which is being interpreted.
5052
// state: is the value to use for the proc state in the tick being
5153
// interpreted.
5254
// node_values: map from Node to Value for already computed values in this
5355
// tick of the proc. Used for continuations.
5456
// events: events object to record events in (e.g, traces).
5557
// queue_manager: manager for channel queues.
56-
ProcIrInterpreter(absl::Span<const Value> state,
58+
ProcIrInterpreter(ProcInstance* proc_instance, absl::Span<const Value> state,
5759
absl::flat_hash_map<Node*, Value>* node_values,
5860
InterpreterEvents* events,
5961
ChannelQueueManager* queue_manager)
6062
: IrInterpreter(node_values, events),
63+
proc_instance_(proc_instance),
6164
state_(state.begin(), state.end()),
6265
queue_manager_(queue_manager) {}
6366

6467
absl::Status HandleReceive(Receive* receive) override {
65-
XLS_ASSIGN_OR_RETURN(ChannelQueue * queue, queue_manager_->GetQueueByName(
66-
receive->channel_name()));
68+
XLS_ASSIGN_OR_RETURN(ChannelQueue * queue,
69+
GetChannelQueue(receive->channel_name()));
6770

6871
if (receive->predicate().has_value()) {
6972
const Bits& pred = ResolveAsBits(receive->predicate().value());
@@ -97,7 +100,7 @@ class ProcIrInterpreter : public IrInterpreter {
97100

98101
absl::Status HandleSend(Send* send) override {
99102
XLS_ASSIGN_OR_RETURN(ChannelQueue * queue,
100-
queue_manager_->GetQueueByName(send->channel_name()));
103+
GetChannelQueue(send->channel_name()));
101104
if (send->predicate().has_value()) {
102105
const Bits& pred = ResolveAsBits(send->predicate().value());
103106
if (pred.IsZero()) {
@@ -139,6 +142,21 @@ class ProcIrInterpreter : public IrInterpreter {
139142
}
140143

141144
private:
145+
// Get the channel queue for the channel or channel reference of the given
146+
// name.
147+
absl::StatusOr<ChannelQueue*> GetChannelQueue(std::string_view name) {
148+
if (proc_instance_->path().has_value()) {
149+
// New-style proc-scoped channel.
150+
XLS_ASSIGN_OR_RETURN(ChannelInstance * channel_instance,
151+
queue_manager_->elaboration().GetChannelInstance(
152+
name, *proc_instance_->path()));
153+
return &queue_manager_->GetQueue(channel_instance);
154+
}
155+
// Old-style global channel.
156+
return queue_manager_->GetQueueByName(name);
157+
}
158+
159+
ProcInstance* proc_instance_;
142160
std::vector<Value> state_;
143161
ChannelQueueManager* queue_manager_;
144162

@@ -167,8 +185,9 @@ absl::StatusOr<TickResult> ProcInterpreter::Tick(
167185
XLS_RET_CHECK_NE(cont, nullptr) << "ProcInterpreter requires a continuation "
168186
"of type ProcInterpreterContinuation";
169187

170-
ProcIrInterpreter ir_interpreter(cont->GetState(), &cont->GetNodeValues(),
171-
&cont->GetEvents(), queue_manager_);
188+
ProcIrInterpreter ir_interpreter(cont->proc_instance(), cont->GetState(),
189+
&cont->GetNodeValues(), &cont->GetEvents(),
190+
queue_manager_);
172191

173192
// Resume execution at the node indicated in the continuation
174193
// (NodeExecutionIndex).

xls/interpreter/proc_runtime.cc

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
#include "xls/interpreter/proc_runtime.h"
1616

17-
#include <algorithm>
1817
#include <cstdint>
1918
#include <memory>
2019
#include <optional>
@@ -42,11 +41,9 @@
4241
namespace xls {
4342

4443
ProcRuntime::ProcRuntime(
45-
Package* package,
4644
absl::flat_hash_map<Proc*, std::unique_ptr<ProcEvaluator>>&& evaluators,
4745
std::unique_ptr<ChannelQueueManager>&& queue_manager)
48-
: package_(package),
49-
queue_manager_(std::move(queue_manager)),
46+
: queue_manager_(std::move(queue_manager)),
5047
evaluators_(std::move(evaluators)) {
5148
for (ProcInstance* instance : elaboration().proc_instances()) {
5249
std::unique_ptr<ProcContinuation> continuation =
@@ -72,43 +69,57 @@ absl::Status ProcRuntime::Tick() {
7269
}
7370

7471
absl::StatusOr<int64_t> ProcRuntime::TickUntilOutput(
75-
absl::flat_hash_map<Channel*, int64_t> output_counts,
72+
const absl::flat_hash_map<Channel*, int64_t>& output_counts,
73+
std::optional<int64_t> max_ticks) {
74+
absl::flat_hash_map<ChannelInstance*, int64_t> instance_output_counts;
75+
for (const auto& [channel, count] : output_counts) {
76+
XLS_ASSIGN_OR_RETURN(ChannelInstance * channel_instance,
77+
elaboration().GetUniqueInstance(channel));
78+
instance_output_counts[channel_instance] = count;
79+
}
80+
return TickUntilOutput(instance_output_counts, max_ticks);
81+
}
82+
83+
absl::StatusOr<int64_t> ProcRuntime::TickUntilOutput(
84+
const absl::flat_hash_map<ChannelInstance*, int64_t>& output_counts,
7685
std::optional<int64_t> max_ticks) {
7786
XLS_VLOG(3) << absl::StreamFormat("TickUntilOutput on package %s",
78-
package_->name());
87+
package()->name());
7988
// Create a deterministically sorted vector of the output channels for
8089
// deterministic behavior and error messages.
81-
std::vector<Channel*> output_channels;
82-
for (auto [channel, _] : output_counts) {
83-
output_channels.push_back(channel);
90+
std::vector<ChannelInstance*> output_channels;
91+
for (ChannelInstance* channel_instance : elaboration().channel_instances()) {
92+
if (output_counts.contains(channel_instance)) {
93+
output_channels.push_back(channel_instance);
94+
}
8495
}
85-
std::sort(output_channels.begin(), output_channels.end(),
86-
[](Channel* a, Channel* b) { return a->name() < b->name(); });
8796

8897
if (XLS_VLOG_IS_ON(3)) {
89-
XLS_VLOG(3) << "Expected outputs produced for each channel:";
90-
for (Channel* channel : output_channels) {
91-
XLS_VLOG(3) << absl::StreamFormat(" %s : %d", channel->name(),
92-
output_counts.at(channel));
98+
XLS_VLOG(3) << "Expected outputs produced for each channel instance:";
99+
for (ChannelInstance* channel_instance : output_channels) {
100+
XLS_VLOG(3) << absl::StreamFormat(" %s : %d",
101+
channel_instance->ToString(),
102+
output_counts.at(channel_instance));
93103
}
94104
}
95105

96-
for (Channel* channel : output_channels) {
106+
for (ChannelInstance* channel_instance : output_channels) {
107+
Channel* channel = channel_instance->channel;
97108
if (channel->supported_ops() != ChannelOps::kSendOnly) {
98109
return absl::InvalidArgumentError(absl::StrFormat(
99110
"Channel `%s` is not a send-only channel", channel->name()));
100111
}
101112
if (channel->kind() == ChannelKind::kSingleValue &&
102-
output_counts.at(channel) > 1) {
113+
output_counts.at(channel_instance) > 1) {
103114
return absl::InvalidArgumentError(
104115
absl::StrFormat("Channel `%s` is single-value, expected number of "
105116
"elements must be one or less, is %d",
106-
channel->name(), output_counts.at(channel)));
117+
channel->name(), output_counts.at(channel_instance)));
107118
}
108119
}
109120
int64_t ticks = 0;
110121
auto needs_more_output = [&]() {
111-
for (Channel* ch : output_channels) {
122+
for (ChannelInstance* ch : output_channels) {
112123
if (queue_manager().GetQueue(ch).GetSize() < output_counts.at(ch)) {
113124
return true;
114125
}
@@ -131,7 +142,7 @@ absl::StatusOr<int64_t> ProcRuntime::TickUntilOutput(
131142
absl::StatusOr<int64_t> ProcRuntime::TickUntilBlocked(
132143
std::optional<int64_t> max_ticks) {
133144
XLS_VLOG(3) << absl::StreamFormat("TickUntilBlocked on package %s",
134-
package_->name());
145+
package()->name());
135146
int64_t ticks = 0;
136147
while (!max_ticks.has_value() || ticks < max_ticks.value()) {
137148
XLS_ASSIGN_OR_RETURN(NetworkTickResult result, TickInternal());

xls/interpreter/proc_runtime.h

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ namespace xls {
3737
class ProcRuntime {
3838
public:
3939
ProcRuntime(
40-
Package* package,
4140
absl::flat_hash_map<Proc*, std::unique_ptr<ProcEvaluator>>&& evaluators,
4241
std::unique_ptr<ChannelQueueManager>&& queue_manager);
4342

@@ -53,14 +52,17 @@ class ProcRuntime {
5352
// error if no progress can be made due to a deadlock.
5453
absl::Status Tick();
5554

56-
// Tick the proc network until some output channels have produced at least a
57-
// specified number of outputs as indicated by `output_counts`.
58-
// `output_counts` must only contain output channels and need not contain all
59-
// output channels. Returns the number of ticks executed before the conditions
60-
// were met. `max_ticks` is the maximum number of ticks of the proc network
61-
// before returning an error.
55+
// Tick the proc network until some output channels (channel instances) have
56+
// produced at least a specified number of outputs as indicated by
57+
// `output_counts`. `output_counts` must only contain output channels and need
58+
// not contain all output channels. Returns the number of ticks executed
59+
// before the conditions were met. `max_ticks` is the maximum number of ticks
60+
// of the proc network before returning an error.
61+
absl::StatusOr<int64_t> TickUntilOutput(
62+
const absl::flat_hash_map<Channel*, int64_t>& output_counts,
63+
std::optional<int64_t> max_ticks = std::nullopt);
6264
absl::StatusOr<int64_t> TickUntilOutput(
63-
absl::flat_hash_map<Channel*, int64_t> output_counts,
65+
const absl::flat_hash_map<ChannelInstance*, int64_t>& output_counts,
6466
std::optional<int64_t> max_ticks = std::nullopt);
6567

6668
// Tick until all procs with IO (send or receive nodes) are blocked on receive
@@ -111,6 +113,8 @@ class ProcRuntime {
111113
return queue_manager_->elaboration();
112114
}
113115

116+
Package* package() const { return elaboration().package(); }
117+
114118
protected:
115119
// Execute (up to) a single iteration of every proc in the package.
116120
struct NetworkTickResult {
@@ -124,7 +128,6 @@ class ProcRuntime {
124128
};
125129
virtual absl::StatusOr<NetworkTickResult> TickInternal() = 0;
126130

127-
Package* package_;
128131
std::unique_ptr<ChannelQueueManager> queue_manager_;
129132
absl::flat_hash_map<Proc*, std::unique_ptr<ProcEvaluator>> evaluators_;
130133
absl::flat_hash_map<ProcInstance*, std::unique_ptr<ProcContinuation>>

0 commit comments

Comments
 (0)