Skip to content

Commit

Permalink
[xls][mlir] Add extern_sproc and extern_eproc ops
Browse files Browse the repository at this point in the history
These ops allow to reference procedures external to the current module. Support has
been added in -elaborate-procs to elaborate from extern_sproc to extern_eproc.

No support has been added for linking modules and "resolving" the externs.

PiperOrigin-RevId: 703196503
  • Loading branch information
James Molloy authored and copybara-github committed Dec 5, 2024
1 parent fcff834 commit 245ca1f
Show file tree
Hide file tree
Showing 7 changed files with 215 additions and 10 deletions.
46 changes: 45 additions & 1 deletion xls/contrib/mlir/IR/assembly_format.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "xls/contrib/mlir/IR/assembly_format.h"

#include <cassert>
#include <string>

#include "llvm/include/llvm/ADT/STLExtras.h"
#include "llvm/include/llvm/Support/LogicalResult.h"
Expand Down Expand Up @@ -190,7 +191,7 @@ ParseResult parseZippedSymbols(mlir::AsmParser& parser, ArrayAttr& globalRefs,
}
if (failed(parser.parseOptionalRParen())) {
if (failed(parser.parseCommaSeparatedList([&]() {
FlatSymbolRefAttr global, local;
Attribute global, local;
if (parser.parseAttribute(local) || parser.parseKeyword("as") ||
parser.parseAttribute(global)) {
return failure();
Expand All @@ -210,4 +211,47 @@ ParseResult parseZippedSymbols(mlir::AsmParser& parser, ArrayAttr& globalRefs,
return success();
}

void printChannelNamesAndTypes(mlir::AsmPrinter& p, Operation*,
ArrayAttr channelNames, ArrayAttr channelTypes) {
p << "(";
llvm::interleaveComma(llvm::zip(channelNames, channelTypes), p.getStream(),
[&](auto nameType) {
auto name = cast<StringAttr>(std::get<0>(nameType));
p << name.getValue() << ": ";
p.printAttribute(std::get<1>(nameType));
});
p << ")";
}
ParseResult parseChannelNamesAndTypes(mlir::AsmParser& parser,
ArrayAttr& channelNames,
ArrayAttr& channelTypes) {
SmallVector<Attribute> names;
SmallVector<Attribute> types;

if (parser.parseLParen()) {
return failure();
}
if (failed(parser.parseOptionalRParen())) {
if (failed(parser.parseCommaSeparatedList([&]() {
std::string name;
TypeAttr type;
if (parser.parseKeywordOrString(&name) || parser.parseColon() ||
parser.parseAttribute(type)) {
return failure();
}
names.push_back(StringAttr::get(parser.getContext(), name));
types.push_back(type);
return success();
}))) {
return failure();
}
if (failed(parser.parseRParen())) {
return failure();
}
}
channelNames = ArrayAttr::get(parser.getContext(), names);
channelTypes = ArrayAttr::get(parser.getContext(), types);
return success();
}

} // namespace mlir::xls
5 changes: 5 additions & 0 deletions xls/contrib/mlir/IR/assembly_format.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,11 @@ void printZippedSymbols(mlir::AsmPrinter& p, Operation* op,
ArrayAttr globalRefs, ArrayAttr localRefs);
ParseResult parseZippedSymbols(mlir::AsmParser& parser, ArrayAttr& globalRefs,
ArrayAttr& localRefs);
void printChannelNamesAndTypes(mlir::AsmPrinter& p, Operation* op,
ArrayAttr channelNames, ArrayAttr channelTypes);
ParseResult parseChannelNamesAndTypes(mlir::AsmParser& parser,
ArrayAttr& channelNames,
ArrayAttr& channelTypes);
} // namespace mlir::xls

#endif // GDM_HW_MLIR_XLS_IR_ASSEMBLY_FORMAT_H_
37 changes: 30 additions & 7 deletions xls/contrib/mlir/IR/xls_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -582,18 +582,41 @@ SprocOp SpawnOp::resolveCallee(SymbolTableCollection* symbolTable) {
getCallee());
}

ExternSprocOp SpawnOp::resolveExternCallee(SymbolTableCollection* symbolTable) {
if (symbolTable) {
return symbolTable->lookupNearestSymbolFrom<ExternSprocOp>(getOperation(),
getCallee());
}
return SymbolTable::lookupNearestSymbolFrom<ExternSprocOp>(getOperation(),
getCallee());
}

namespace {
template <typename T>
LogicalResult verifySpawnOpSymbolUses(SpawnOp op, T callee) {
if (callee.getChannelArgumentTypes().size() != op.getChannels().size()) {
return op.emitOpError()
<< "callee expects " << callee.getChannelArgumentTypes().size()
<< " channels but spawn has " << op.getChannels().size()
<< " arguments";
}
return success();
}
} // namespace

LogicalResult SpawnOp::verifySymbolUses(SymbolTableCollection& symbolTable) {
SprocOp callee = resolveCallee(&symbolTable);
Operation* callee =
symbolTable.lookupNearestSymbolFrom(getOperation(), getCallee());
if (!callee) {
return emitOpError() << "callee not found: " << getCallee();
}
if (callee.getChannelArguments().size() != getChannels().size()) {
return emitOpError() << "callee expects "
<< callee.getChannelArguments().size()
<< " channels but spawn has " << getChannels().size()
<< " arguments";
if (auto sproc = dyn_cast<SprocOp>(callee)) {
return verifySpawnOpSymbolUses(*this, sproc);
}
return success();
if (auto extern_sproc = dyn_cast<ExternSprocOp>(callee)) {
return verifySpawnOpSymbolUses(*this, extern_sproc);
}
return emitOpError() << "callee is not a SprocOp or ExternSprocOp";
}

namespace {
Expand Down
67 changes: 67 additions & 0 deletions xls/contrib/mlir/IR/xls_ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -1408,6 +1408,30 @@ def Xls_InstantiateEprocOp : Xls_Op<"instantiate_eproc", [DeclareOpInterfaceMeth
}];
}

def Xls_InstantiateExternEprocOp : Xls_Op<"instantiate_extern_eproc", []> {
let summary = "Binds an externally defined eproc";
let description = [{
Binds channels to an eproc that is defined externally to this module.

This functions similarly to the `xls.instantiate_eproc` op, except that:
1) The eproc definition is not available in this module and so it is
referred to by string, not a symbol.
2) Instead of binding global channels to local channels, it binds global
channels to the boundary channel names of the target eproc.

The target eproc is referred to by an opaque string. The interpretation of
this string is left to the user.
}];
let arguments = (ins
StrAttr:$eproc_name,
FlatSymbolRefArrayAttr:$global_channels,
StrArrayAttr:$boundary_channel_names
);
let assemblyFormat = [{
$eproc_name custom<ZippedSymbols>($global_channels, $boundary_channel_names) attr-dict
}];
}

//===----------------------------------------------------------------------===//
// Structured procs
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1584,6 +1608,10 @@ def Xls_SprocOp : Xls_Op<"sproc", [
}
return index;
}

::mlir::TypeRange getChannelArgumentTypes() {
return getChannelArguments().getTypes();
}
}];
}

Expand All @@ -1602,6 +1630,45 @@ def Xls_SpawnOp : Xls_Op<"spawn", [
}];
let extraClassDeclaration = [{
SprocOp resolveCallee(::mlir::SymbolTableCollection* symbolTable = nullptr);
ExternSprocOp resolveExternCallee(::mlir::SymbolTableCollection* symbolTable = nullptr);
}];
}

def Xls_ExternSprocOp : Xls_Op<"extern_sproc", [
Symbol,
CallableOpInterface
]> {
let summary = "extern sproc";
let description = [{
Declares an sproc that is external to the current module. The sproc is
spawnable by sprocs in the current module.

The `boundary_channel_names` attribute is used to name each argument or
result channel. These correspond to the `boundary_channel_names` on the
target sproc (wherever it is defined). The `channel_types` attribute is used
to specify the types of the channels.

An `spawn` of a `extern_sproc` is lowered to an `instantiate_extern_eproc`
op.
}];
let arguments = (ins
SymbolNameAttr:$sym_name,
StrArrayAttr:$boundary_channel_names,
TypeArrayAttr:$channel_argument_types
);
let assemblyFormat = [{
$sym_name custom<ChannelNamesAndTypes>($boundary_channel_names, $channel_argument_types) attr-dict
}];
let extraClassDeclaration = [{
::mlir::Region* getCallableRegion() {
return nullptr;
}
::llvm::ArrayRef<::mlir::Type> getArgumentTypes() {
return {};
}
::llvm::ArrayRef<::mlir::Type> getResultTypes() {
return {};
}
}];
}

Expand Down
8 changes: 8 additions & 0 deletions xls/contrib/mlir/testdata/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,14 @@ func.func @trace_cond(%arg0: i32, %tkn: !xls.token, %cond: i1) -> !xls.token {
return %0 : !xls.token
}

// CHECK-LABEL: xls.instantiate_extern_eproc "external" ("arg0" as @c1, "result0" as @c2)
xls.chan @c1 : i32
xls.chan @c2 : i32
xls.instantiate_extern_eproc "external" ("arg0" as @c1, "result0" as @c2)

// CHECK-LABEL: xls.extern_sproc @external_sproc (arg0: !xls.schan<i32, in>, result0: !xls.schan<i32, out>)
xls.extern_sproc @external_sproc (arg0: !xls.schan<i32, in>, result0: !xls.schan<i32, out>)

// -----

// expected-error@+1 {{yielded state type does not match carried state type ('tuple<i7>' vs 'tuple<i32>'}}
Expand Down
26 changes: 25 additions & 1 deletion xls/contrib/mlir/testdata/proc_elaboration.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: xls/contrib/mlir/xls_opt -elaborate-procs %s 2>&1 | FileCheck %s
// RUN: xls/contrib/mlir/xls_opt -elaborate-procs -split-input-file %s 2>&1 | FileCheck %s
// CHECK: xls.chan @req : i32
// CHECK-NEXT: xls.chan @resp : i32
// CHECK-NEXT: xls.chan @rom1_req : i32
Expand Down Expand Up @@ -89,3 +89,27 @@ xls.sproc @rom(%req: !xls.schan<i32, in>, %resp: !xls.schan<i32, out>) top attri
xls.yield %state : i32
}
}

// -----

// CHECK: xls.chan @req : i32
// CHECK-NEXT: xls.chan @resp : i32
// CHECK-NEXT: xls.instantiate_extern_eproc "external" ("req" as @req, "resp" as @resp)
xls.extern_sproc @external(req: !xls.schan<i32, in>, resp: !xls.schan<i32, out>)

// CHECK: xls.chan @main_arg0 : i32
// CHECK-NEXT: xls.chan @main_arg1 : i32
// CHECK: xls.instantiate_eproc @main_0 (@main_arg0 as @req, @main_arg1 as @resp)
xls.sproc @main() top {
spawns {
%req_out, %req_in = xls.schan<i32>("req")
%resp_out, %resp_in = xls.schan<i32>("resp")
xls.spawn @external(%req_in, %resp_out) : !xls.schan<i32, in>, !xls.schan<i32, out>
xls.yield %req_out, %resp_in : !xls.schan<i32, out>, !xls.schan<i32, in>
}
next (%req: !xls.schan<i32, out>, %resp: !xls.schan<i32, in>, %state: i32) zeroinitializer {
%tok = xls.after_all : !xls.token
%tok2 = xls.ssend %tok, %state, %req : (!xls.token, i32, !xls.schan<i32, out>) -> !xls.token
xls.yield %state : i32
}
}
36 changes: 35 additions & 1 deletion xls/contrib/mlir/transforms/proc_elaboration.cc
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,19 @@ class ElaborationContext
builder.getArrayAttr(localSymbols));
}

void instantiateExternEproc(ExternSprocOp externSproc,
ArrayRef<value_type> globalChannels) {
auto flatchan = [](value_type chan) -> Attribute {
return FlatSymbolRefAttr::get(chan.getSymNameAttr());
};
SmallVector<Attribute> globalSymbols =
llvm::map_to_vector(globalChannels, flatchan);
StringAttr symName = externSproc.getSymNameAttr();
builder.create<InstantiateExternEprocOp>(
externSproc.getLoc(), symName, builder.getArrayAttr(globalSymbols),
externSproc.getBoundaryChannelNames());
}

SymbolTable& getSymbolTable() { return symbolTable; }

private:
Expand Down Expand Up @@ -189,6 +202,10 @@ class ElaborationInterpreter
}

absl::Status Interpret(SpawnOp op, ElaborationContext& ctx) {
ExternSprocOp externSproc = op.resolveExternCallee();
if (externSproc) {
return InterpretSpawnOfExtern(op, externSproc, ctx);
}
SprocOp sproc = op.resolveCallee();
if (!sproc) {
return absl::InvalidArgumentError("failed to resolve callee");
Expand All @@ -211,6 +228,19 @@ class ElaborationInterpreter
return absl::OkStatus();
}

absl::Status InterpretSpawnOfExtern(SpawnOp op, ExternSprocOp externSproc,
ElaborationContext& ctx) {
XLS_ASSIGN_OR_RETURN(auto arguments, ctx.Get(op.getChannels()));
if (arguments.size() != externSproc.getChannelArgumentTypes().size()) {
return absl::InternalError(absl::StrFormat(
"Call to %s requires %d arguments but got %d",
op.getCallee().getLeafReference().str(),
externSproc.getChannelArgumentTypes().size(), arguments.size()));
}
ctx.instantiateExternEproc(externSproc, arguments);
return absl::OkStatus();
}

absl::Status Interpret(SprocOp op, ElaborationContext& ctx,
ArrayRef<ChanOp> boundaryChannels = {}) {
XLS_ASSIGN_OR_RETURN(auto results,
Expand Down Expand Up @@ -259,7 +289,11 @@ void ProcElaborationPass::runOnOperation() {
sproc.emitError() << "failed to elaborate: " << result.message();
}
}
module.walk([&](SprocOp sproc) { sproc.erase(); });
module.walk([&](Operation* op) {
if (isa<SprocOp, ExternSprocOp>(op)) {
op->erase();
}
});
}

} // namespace mlir::xls

0 comments on commit 245ca1f

Please sign in to comment.