Skip to content

Commit 245ca1f

Browse files
James Molloycopybara-github
authored andcommitted
[xls][mlir] Add extern_sproc and extern_eproc ops
These ops allow to reference procedures external to the current module. Support has been added in -elaborate-procs to elaborate from extern_sproc to extern_eproc. No support has been added for linking modules and "resolving" the externs. PiperOrigin-RevId: 703196503
1 parent fcff834 commit 245ca1f

File tree

7 files changed

+215
-10
lines changed

7 files changed

+215
-10
lines changed

xls/contrib/mlir/IR/assembly_format.cc

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "xls/contrib/mlir/IR/assembly_format.h"
1616

1717
#include <cassert>
18+
#include <string>
1819

1920
#include "llvm/include/llvm/ADT/STLExtras.h"
2021
#include "llvm/include/llvm/Support/LogicalResult.h"
@@ -190,7 +191,7 @@ ParseResult parseZippedSymbols(mlir::AsmParser& parser, ArrayAttr& globalRefs,
190191
}
191192
if (failed(parser.parseOptionalRParen())) {
192193
if (failed(parser.parseCommaSeparatedList([&]() {
193-
FlatSymbolRefAttr global, local;
194+
Attribute global, local;
194195
if (parser.parseAttribute(local) || parser.parseKeyword("as") ||
195196
parser.parseAttribute(global)) {
196197
return failure();
@@ -210,4 +211,47 @@ ParseResult parseZippedSymbols(mlir::AsmParser& parser, ArrayAttr& globalRefs,
210211
return success();
211212
}
212213

214+
void printChannelNamesAndTypes(mlir::AsmPrinter& p, Operation*,
215+
ArrayAttr channelNames, ArrayAttr channelTypes) {
216+
p << "(";
217+
llvm::interleaveComma(llvm::zip(channelNames, channelTypes), p.getStream(),
218+
[&](auto nameType) {
219+
auto name = cast<StringAttr>(std::get<0>(nameType));
220+
p << name.getValue() << ": ";
221+
p.printAttribute(std::get<1>(nameType));
222+
});
223+
p << ")";
224+
}
225+
ParseResult parseChannelNamesAndTypes(mlir::AsmParser& parser,
226+
ArrayAttr& channelNames,
227+
ArrayAttr& channelTypes) {
228+
SmallVector<Attribute> names;
229+
SmallVector<Attribute> types;
230+
231+
if (parser.parseLParen()) {
232+
return failure();
233+
}
234+
if (failed(parser.parseOptionalRParen())) {
235+
if (failed(parser.parseCommaSeparatedList([&]() {
236+
std::string name;
237+
TypeAttr type;
238+
if (parser.parseKeywordOrString(&name) || parser.parseColon() ||
239+
parser.parseAttribute(type)) {
240+
return failure();
241+
}
242+
names.push_back(StringAttr::get(parser.getContext(), name));
243+
types.push_back(type);
244+
return success();
245+
}))) {
246+
return failure();
247+
}
248+
if (failed(parser.parseRParen())) {
249+
return failure();
250+
}
251+
}
252+
channelNames = ArrayAttr::get(parser.getContext(), names);
253+
channelTypes = ArrayAttr::get(parser.getContext(), types);
254+
return success();
255+
}
256+
213257
} // namespace mlir::xls

xls/contrib/mlir/IR/assembly_format.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,11 @@ void printZippedSymbols(mlir::AsmPrinter& p, Operation* op,
9898
ArrayAttr globalRefs, ArrayAttr localRefs);
9999
ParseResult parseZippedSymbols(mlir::AsmParser& parser, ArrayAttr& globalRefs,
100100
ArrayAttr& localRefs);
101+
void printChannelNamesAndTypes(mlir::AsmPrinter& p, Operation* op,
102+
ArrayAttr channelNames, ArrayAttr channelTypes);
103+
ParseResult parseChannelNamesAndTypes(mlir::AsmParser& parser,
104+
ArrayAttr& channelNames,
105+
ArrayAttr& channelTypes);
101106
} // namespace mlir::xls
102107

103108
#endif // GDM_HW_MLIR_XLS_IR_ASSEMBLY_FORMAT_H_

xls/contrib/mlir/IR/xls_ops.cc

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -582,18 +582,41 @@ SprocOp SpawnOp::resolveCallee(SymbolTableCollection* symbolTable) {
582582
getCallee());
583583
}
584584

585+
ExternSprocOp SpawnOp::resolveExternCallee(SymbolTableCollection* symbolTable) {
586+
if (symbolTable) {
587+
return symbolTable->lookupNearestSymbolFrom<ExternSprocOp>(getOperation(),
588+
getCallee());
589+
}
590+
return SymbolTable::lookupNearestSymbolFrom<ExternSprocOp>(getOperation(),
591+
getCallee());
592+
}
593+
594+
namespace {
595+
template <typename T>
596+
LogicalResult verifySpawnOpSymbolUses(SpawnOp op, T callee) {
597+
if (callee.getChannelArgumentTypes().size() != op.getChannels().size()) {
598+
return op.emitOpError()
599+
<< "callee expects " << callee.getChannelArgumentTypes().size()
600+
<< " channels but spawn has " << op.getChannels().size()
601+
<< " arguments";
602+
}
603+
return success();
604+
}
605+
} // namespace
606+
585607
LogicalResult SpawnOp::verifySymbolUses(SymbolTableCollection& symbolTable) {
586-
SprocOp callee = resolveCallee(&symbolTable);
608+
Operation* callee =
609+
symbolTable.lookupNearestSymbolFrom(getOperation(), getCallee());
587610
if (!callee) {
588611
return emitOpError() << "callee not found: " << getCallee();
589612
}
590-
if (callee.getChannelArguments().size() != getChannels().size()) {
591-
return emitOpError() << "callee expects "
592-
<< callee.getChannelArguments().size()
593-
<< " channels but spawn has " << getChannels().size()
594-
<< " arguments";
613+
if (auto sproc = dyn_cast<SprocOp>(callee)) {
614+
return verifySpawnOpSymbolUses(*this, sproc);
595615
}
596-
return success();
616+
if (auto extern_sproc = dyn_cast<ExternSprocOp>(callee)) {
617+
return verifySpawnOpSymbolUses(*this, extern_sproc);
618+
}
619+
return emitOpError() << "callee is not a SprocOp or ExternSprocOp";
597620
}
598621

599622
namespace {

xls/contrib/mlir/IR/xls_ops.td

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1408,6 +1408,30 @@ def Xls_InstantiateEprocOp : Xls_Op<"instantiate_eproc", [DeclareOpInterfaceMeth
14081408
}];
14091409
}
14101410

1411+
def Xls_InstantiateExternEprocOp : Xls_Op<"instantiate_extern_eproc", []> {
1412+
let summary = "Binds an externally defined eproc";
1413+
let description = [{
1414+
Binds channels to an eproc that is defined externally to this module.
1415+
1416+
This functions similarly to the `xls.instantiate_eproc` op, except that:
1417+
1) The eproc definition is not available in this module and so it is
1418+
referred to by string, not a symbol.
1419+
2) Instead of binding global channels to local channels, it binds global
1420+
channels to the boundary channel names of the target eproc.
1421+
1422+
The target eproc is referred to by an opaque string. The interpretation of
1423+
this string is left to the user.
1424+
}];
1425+
let arguments = (ins
1426+
StrAttr:$eproc_name,
1427+
FlatSymbolRefArrayAttr:$global_channels,
1428+
StrArrayAttr:$boundary_channel_names
1429+
);
1430+
let assemblyFormat = [{
1431+
$eproc_name custom<ZippedSymbols>($global_channels, $boundary_channel_names) attr-dict
1432+
}];
1433+
}
1434+
14111435
//===----------------------------------------------------------------------===//
14121436
// Structured procs
14131437
//===----------------------------------------------------------------------===//
@@ -1584,6 +1608,10 @@ def Xls_SprocOp : Xls_Op<"sproc", [
15841608
}
15851609
return index;
15861610
}
1611+
1612+
::mlir::TypeRange getChannelArgumentTypes() {
1613+
return getChannelArguments().getTypes();
1614+
}
15871615
}];
15881616
}
15891617

@@ -1602,6 +1630,45 @@ def Xls_SpawnOp : Xls_Op<"spawn", [
16021630
}];
16031631
let extraClassDeclaration = [{
16041632
SprocOp resolveCallee(::mlir::SymbolTableCollection* symbolTable = nullptr);
1633+
ExternSprocOp resolveExternCallee(::mlir::SymbolTableCollection* symbolTable = nullptr);
1634+
}];
1635+
}
1636+
1637+
def Xls_ExternSprocOp : Xls_Op<"extern_sproc", [
1638+
Symbol,
1639+
CallableOpInterface
1640+
]> {
1641+
let summary = "extern sproc";
1642+
let description = [{
1643+
Declares an sproc that is external to the current module. The sproc is
1644+
spawnable by sprocs in the current module.
1645+
1646+
The `boundary_channel_names` attribute is used to name each argument or
1647+
result channel. These correspond to the `boundary_channel_names` on the
1648+
target sproc (wherever it is defined). The `channel_types` attribute is used
1649+
to specify the types of the channels.
1650+
1651+
An `spawn` of a `extern_sproc` is lowered to an `instantiate_extern_eproc`
1652+
op.
1653+
}];
1654+
let arguments = (ins
1655+
SymbolNameAttr:$sym_name,
1656+
StrArrayAttr:$boundary_channel_names,
1657+
TypeArrayAttr:$channel_argument_types
1658+
);
1659+
let assemblyFormat = [{
1660+
$sym_name custom<ChannelNamesAndTypes>($boundary_channel_names, $channel_argument_types) attr-dict
1661+
}];
1662+
let extraClassDeclaration = [{
1663+
::mlir::Region* getCallableRegion() {
1664+
return nullptr;
1665+
}
1666+
::llvm::ArrayRef<::mlir::Type> getArgumentTypes() {
1667+
return {};
1668+
}
1669+
::llvm::ArrayRef<::mlir::Type> getResultTypes() {
1670+
return {};
1671+
}
16051672
}];
16061673
}
16071674

xls/contrib/mlir/testdata/ops.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,14 @@ func.func @trace_cond(%arg0: i32, %tkn: !xls.token, %cond: i1) -> !xls.token {
499499
return %0 : !xls.token
500500
}
501501

502+
// CHECK-LABEL: xls.instantiate_extern_eproc "external" ("arg0" as @c1, "result0" as @c2)
503+
xls.chan @c1 : i32
504+
xls.chan @c2 : i32
505+
xls.instantiate_extern_eproc "external" ("arg0" as @c1, "result0" as @c2)
506+
507+
// CHECK-LABEL: xls.extern_sproc @external_sproc (arg0: !xls.schan<i32, in>, result0: !xls.schan<i32, out>)
508+
xls.extern_sproc @external_sproc (arg0: !xls.schan<i32, in>, result0: !xls.schan<i32, out>)
509+
502510
// -----
503511

504512
// expected-error@+1 {{yielded state type does not match carried state type ('tuple<i7>' vs 'tuple<i32>'}}

xls/contrib/mlir/testdata/proc_elaboration.mlir

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: xls/contrib/mlir/xls_opt -elaborate-procs %s 2>&1 | FileCheck %s
1+
// RUN: xls/contrib/mlir/xls_opt -elaborate-procs -split-input-file %s 2>&1 | FileCheck %s
22
// CHECK: xls.chan @req : i32
33
// CHECK-NEXT: xls.chan @resp : i32
44
// CHECK-NEXT: xls.chan @rom1_req : i32
@@ -89,3 +89,27 @@ xls.sproc @rom(%req: !xls.schan<i32, in>, %resp: !xls.schan<i32, out>) top attri
8989
xls.yield %state : i32
9090
}
9191
}
92+
93+
// -----
94+
95+
// CHECK: xls.chan @req : i32
96+
// CHECK-NEXT: xls.chan @resp : i32
97+
// CHECK-NEXT: xls.instantiate_extern_eproc "external" ("req" as @req, "resp" as @resp)
98+
xls.extern_sproc @external(req: !xls.schan<i32, in>, resp: !xls.schan<i32, out>)
99+
100+
// CHECK: xls.chan @main_arg0 : i32
101+
// CHECK-NEXT: xls.chan @main_arg1 : i32
102+
// CHECK: xls.instantiate_eproc @main_0 (@main_arg0 as @req, @main_arg1 as @resp)
103+
xls.sproc @main() top {
104+
spawns {
105+
%req_out, %req_in = xls.schan<i32>("req")
106+
%resp_out, %resp_in = xls.schan<i32>("resp")
107+
xls.spawn @external(%req_in, %resp_out) : !xls.schan<i32, in>, !xls.schan<i32, out>
108+
xls.yield %req_out, %resp_in : !xls.schan<i32, out>, !xls.schan<i32, in>
109+
}
110+
next (%req: !xls.schan<i32, out>, %resp: !xls.schan<i32, in>, %state: i32) zeroinitializer {
111+
%tok = xls.after_all : !xls.token
112+
%tok2 = xls.ssend %tok, %state, %req : (!xls.token, i32, !xls.schan<i32, out>) -> !xls.token
113+
xls.yield %state : i32
114+
}
115+
}

xls/contrib/mlir/transforms/proc_elaboration.cc

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,19 @@ class ElaborationContext
150150
builder.getArrayAttr(localSymbols));
151151
}
152152

153+
void instantiateExternEproc(ExternSprocOp externSproc,
154+
ArrayRef<value_type> globalChannels) {
155+
auto flatchan = [](value_type chan) -> Attribute {
156+
return FlatSymbolRefAttr::get(chan.getSymNameAttr());
157+
};
158+
SmallVector<Attribute> globalSymbols =
159+
llvm::map_to_vector(globalChannels, flatchan);
160+
StringAttr symName = externSproc.getSymNameAttr();
161+
builder.create<InstantiateExternEprocOp>(
162+
externSproc.getLoc(), symName, builder.getArrayAttr(globalSymbols),
163+
externSproc.getBoundaryChannelNames());
164+
}
165+
153166
SymbolTable& getSymbolTable() { return symbolTable; }
154167

155168
private:
@@ -189,6 +202,10 @@ class ElaborationInterpreter
189202
}
190203

191204
absl::Status Interpret(SpawnOp op, ElaborationContext& ctx) {
205+
ExternSprocOp externSproc = op.resolveExternCallee();
206+
if (externSproc) {
207+
return InterpretSpawnOfExtern(op, externSproc, ctx);
208+
}
192209
SprocOp sproc = op.resolveCallee();
193210
if (!sproc) {
194211
return absl::InvalidArgumentError("failed to resolve callee");
@@ -211,6 +228,19 @@ class ElaborationInterpreter
211228
return absl::OkStatus();
212229
}
213230

231+
absl::Status InterpretSpawnOfExtern(SpawnOp op, ExternSprocOp externSproc,
232+
ElaborationContext& ctx) {
233+
XLS_ASSIGN_OR_RETURN(auto arguments, ctx.Get(op.getChannels()));
234+
if (arguments.size() != externSproc.getChannelArgumentTypes().size()) {
235+
return absl::InternalError(absl::StrFormat(
236+
"Call to %s requires %d arguments but got %d",
237+
op.getCallee().getLeafReference().str(),
238+
externSproc.getChannelArgumentTypes().size(), arguments.size()));
239+
}
240+
ctx.instantiateExternEproc(externSproc, arguments);
241+
return absl::OkStatus();
242+
}
243+
214244
absl::Status Interpret(SprocOp op, ElaborationContext& ctx,
215245
ArrayRef<ChanOp> boundaryChannels = {}) {
216246
XLS_ASSIGN_OR_RETURN(auto results,
@@ -259,7 +289,11 @@ void ProcElaborationPass::runOnOperation() {
259289
sproc.emitError() << "failed to elaborate: " << result.message();
260290
}
261291
}
262-
module.walk([&](SprocOp sproc) { sproc.erase(); });
292+
module.walk([&](Operation* op) {
293+
if (isa<SprocOp, ExternSprocOp>(op)) {
294+
op->erase();
295+
}
296+
});
263297
}
264298

265299
} // namespace mlir::xls

0 commit comments

Comments
 (0)